In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn

In [4]:
dsprites = np.load("/home/<user>/projects/disentanglement-multi-task/data/dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz")
dsprites = {key: dsprites[key] for key in ['imgs', 'latents_classes', 'latents_values']}
dataset_x = dsprites['latents_values']
dataset_x = torch.tensor(dataset_x).cuda().float()

In [3]:
torch.manual_seed(658)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class Masked(nn.Module):
    def __init__(self, split, input_shape):
        super(Mask, self).__init__()
        self.split = split
        self.in = input_shape
        self.mask = torch.zeros(self.in, dtype=torch.float32)
        self.mask[self.split] = 1
    
    def forward(self, input):
        return self.mask*input
        
    


networks = []
multitask_targets = []
for network in range(50):
    network = nn.Sequential(
        nn.Linear(6, 300),
        nn.Tanh(),
        nn.Linear(300, 300),
        nn.Tanh(),
        nn.Linear(300, 300),
        nn.Tanh(),
        nn.Linear(300, 300),
        nn.Tanh(),
        nn.Linear(300, 1),
    )
    
    for module in network.children():
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight.data, 0, 1.)
            
    network = network.cuda()
    targets = network(dataset_x)
    targets = targets.detach().cpu().numpy()
    multitask_targets += [targets]

multitask_targets = np.concatenate(multitask_targets, 1)

dsprites['multitask_targets'] = multitask_targets
np.savez_compressed("dsprites_multitask.npz", **dsprites)

In [5]:
A = np.load('dsprites_multitask.npz')
print(A['multitask_targets'][0])

[ 17.085754     6.195522   -27.487833    15.12213     -0.60677946
  39.41825    -16.609615    -9.520086   -14.257808   -15.026125
 -16.258022   -22.908615    13.826408    12.454214    -1.5997301
 -11.580022   -22.412027    26.544628    19.908724    11.491505
  27.053112    26.539207    -6.8763776    4.025102     1.8178561
  -3.357193    -4.5485506    3.8842144    1.068662    -4.0464535
 -13.159656    -9.261389    -3.2104754    0.96721    -20.876127
   7.3661785    2.0018985  -19.01292     -5.56429     11.124806
 -10.706724   -34.558735    12.645012    -5.399464    21.121984
   4.085326   -17.37117     -0.3847853    8.716465    -9.86611   ]


In [3]:
A = np.load('/Users/stella/projects/disentanglement-multi-task/data/test_dsets/dsprites/dsprites_multitask.npz')

In [4]:
print(A['multitask_targets'].shape)

(737280, 50)


# Custom Splits

In [70]:
num_tasks=50
num_true = 6


#[1,1,0,0,0,0],
#[1,1,1,0,0,0],
#[1,1,1,1,0,0],etc
grow_splits = []
for i in range(num_tasks):
    temp = np.zeros(num_true)
    temp[:((i%(num_true-1))+2)] = 1.0
    grow_splits.append(temp)


#[1,1,0,0,0,0],
#[0,0,1,1,0,0],
#[0,0,0,0,1,1],etc.
independent_splits_v1 = []
div_size=2
for i in range(num_tasks):
    k = num_true//div_size
    temp = np.zeros(num_true)
    j = i%k
    temp[div_size*j:div_size*(j+1)] = 1.0
    independent_splits_v1.append(temp)
    
    
independent_splits_v2 = []
div_size=3
for i in range(num_tasks):
    k = num_true//div_size
    temp = np.zeros(num_true)
    j = i%k
    temp[div_size*j:div_size*(j+1)] = 1.0
    independent_splits_v2.append(temp)    

#random
#[1,0,1,1,0,0]
#[0,0,1,0,1,1], etc. 
rs = np.random.RandomState(11)
random_splits = []
for i in range(num_tasks):
    temp = np.zeros(num_true)
    temp[rs.rand(len(temp))>0.5]=1.0
    if(sum(temp)==0):
        i = rs.randint(num_true)
        temp[i]=1.0
    random_splits.append(temp)
    


torch.manual_seed(658)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class Masked(nn.Module):
    def __init__(self, split, input_shape):
        super(Masked, self).__init__()
        self.split = split
        self.input_shape = input_shape
        self.mask = torch.zeros(self.input_shape, dtype=torch.float32).cuda()
        self.mask[self.split] = 1
    
    def forward(self, input):
        return self.mask*input
        
    
SPLITS={
    "grow":grow_splits,
    "independent_v1":independent_splits_v1,
    "independent_v2":independent_splits_v2,
    "random":random_splits
}
    
    
    
def save_splitted_dataset(split_type):
    networks = []
    multitask_targets = []
    split_array = SPLITS[split_type]
    for network in range(50):
        network = nn.Sequential(
            Masked(split_array[network], 6),
            nn.Linear(6, 300),
            nn.Tanh(),
            nn.Linear(300, 300),
            nn.Tanh(),
            nn.Linear(300, 300),
            nn.Tanh(),
            nn.Linear(300, 300),
            nn.Tanh(),
            nn.Linear(300, 1),
        )

        for module in network.children():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight.data, 0, 1.)

        network = network.cuda()
    
        
        targets = network(dataset_x)
        targets = targets.detach().cpu().numpy()
        multitask_targets += [targets]

    multitask_targets = np.concatenate(multitask_targets, 1)

    dsprites['multitask_targets'] = multitask_targets
    np.savez_compressed("dsprites_multitask_{}_splits.npz".format(split_type), **dsprites)

In [71]:
print("grow")
save_splitted_dataset("grow")
print("independent_v1")
save_splitted_dataset("independent_v1")
print("independent_v2")
save_splitted_dataset("independent_v2")
print("random")
save_splitted_dataset("random")

grow
independent_v1
independent_v2
random


In [72]:
A = np.load('dsprites_multitask_random_splits.npz')
print(A['multitask_targets'][0])

[ 11.628903    -6.9355907  -18.089891   -21.871485   -51.237854
 -40.217728   -15.515591    -7.373314    13.756518    -7.5567546
  16.42033     29.8013      -4.8933244    3.2823796  -23.642212
  27.108505    -2.7973657  -14.442416    11.6041355   -6.3948245
   9.82257     18.88667     13.748902     9.389744    14.644038
  -1.185813   -22.49953    -18.795975   -17.900105   -33.955128
  -8.455559   -11.455919    -0.29316142 -10.250931     6.180991
  -7.1955323  -11.552702   -11.012647    20.333458    19.966867
   8.528882    -8.024888    -1.8292567  -17.510044    -6.521972
  -7.3071947   -6.7310696  -27.390554   -49.917305    -5.293173  ]
