In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn

import h5py

In [None]:
f = h5py.File("/shared/sets/datasets/disentanglement-multitask/3dshapes/3dshapes.h5", 'r')
print(f)


In [None]:
print(f.keys())

In [None]:
shapes3d = {key: f[key] for key in ['images', 'labels']}
dataset_x = shapes3d['labels']
dataset_x = torch.tensor(dataset_x).cuda().float()

In [11]:
torch.manual_seed(658)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

networks = []
multitask_targets = []
for network in range(50):
    network = nn.Sequential(
        nn.Linear(6, 300),
        nn.Tanh(),
        nn.Linear(300, 300),
        nn.Tanh(),
        nn.Linear(300, 300),
        nn.Tanh(),
        nn.Linear(300, 300),
        nn.Tanh(),
        nn.Linear(300, 1),
    )
    
    for module in network.children():
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight.data, 0, 1.)
            
    network = network.cuda()
    targets = network(dataset_x)
    targets = targets.detach().cpu().numpy()
    multitask_targets += [targets]

multitask_targets = np.concatenate(multitask_targets, 1)

In [12]:
shapes3d['multitask_targets'] = multitask_targets
np.savez_compressed("3dshapes_multitask.npz", **shapes3d)

In [13]:
A = np.load('3dshapes_multitask.npz')
print(A['multitask_targets'][0])

[ -0.31196606  -2.3516924   -1.4602647   -4.7415085  -19.127003
  -1.0143753  -12.447521    13.70324     -6.306644     3.1938343
  -1.9605145   16.70559      1.4269706  -31.091702    16.603064
  -0.6664022    9.763324    -4.485842    27.586367    19.714392
   5.034379    16.705286    12.028307    -8.746036    19.236835
   7.1748137  -12.025241   -22.951971   -15.968634    12.793978
  11.404209     4.051466    13.995813     5.4605536    3.017771
  -3.849686    12.531304    18.862926   -24.60506     -1.8412641
  -1.3432792   -0.98290575 -27.891468    28.465433    32.946815
  -1.4312706   -0.64788693 -17.087461     3.72648      3.6102152 ]


In [3]:
A = np.load('/Users/stella/projects/disentanglement-multi-task/data/test_dsets/dsprites/dsprites_multitask.npz')

In [4]:
print(A['multitask_targets'].shape)

(737280, 50)


# Custom splits

In [5]:
num_tasks=50
num_true = 6


#[1,1,0,0,0,0],
#[1,1,1,0,0,0],
#[1,1,1,1,0,0],etc
grow_splits = []
for i in range(num_tasks):
    temp = np.zeros(num_true)
    temp[:((i%(num_true-1))+2)] = 1.0
    grow_splits.append(temp)


#[1,1,0,0,0,0],
#[0,0,1,1,0,0],
#[0,0,0,0,1,1],etc.
independent_splits_v1 = []
div_size=2
for i in range(num_tasks):
    k = num_true//div_size
    temp = np.zeros(num_true)
    j = i%k
    temp[div_size*j:div_size*(j+1)] = 1.0
    independent_splits_v1.append(temp)
    
    
independent_splits_v2 = []
div_size=3
for i in range(num_tasks):
    k = num_true//div_size
    temp = np.zeros(num_true)
    j = i%k
    temp[div_size*j:div_size*(j+1)] = 1.0
    independent_splits_v2.append(temp)    

#random
#[1,0,1,1,0,0]
#[0,0,1,0,1,1], etc. 
rs = np.random.RandomState(11)
random_splits = []
for i in range(num_tasks):
    temp = np.zeros(num_true)
    temp[rs.rand(len(temp))>0.5]=1.0
    if(sum(temp)==0):
        i = rs.randint(num_true)
        temp[i]=1.0
    random_splits.append(temp)
    


torch.manual_seed(658)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class Masked(nn.Module):
    def __init__(self, split, input_shape):
        super(Masked, self).__init__()
        self.split = split
        self.input_shape = input_shape
        self.mask = torch.zeros(self.input_shape, dtype=torch.float32).cuda()
        self.mask[self.split] = 1
    
    def forward(self, input):
        return self.mask*input
        
    
SPLITS={
    "grow":grow_splits,
    "independent_v1":independent_splits_v1,
    "independent_v2":independent_splits_v2,
    "random":random_splits
}
    
    
    
def save_splitted_dataset(split_type):
    networks = []
    multitask_targets = []
    split_array = SPLITS[split_type]
    for network in range(50):
        network = nn.Sequential(
            Masked(split_array[network], 6),
            nn.Linear(6, 300),
            nn.Tanh(),
            nn.Linear(300, 300),
            nn.Tanh(),
            nn.Linear(300, 300),
            nn.Tanh(),
            nn.Linear(300, 300),
            nn.Tanh(),
            nn.Linear(300, 1),
        )

        for module in network.children():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight.data, 0, 1.)

        network = network.cuda()
    
        
        targets = network(dataset_x)
        targets = targets.detach().cpu().numpy()
        multitask_targets += [targets]

    multitask_targets = np.concatenate(multitask_targets, 1)
    shapes3d['multitask_targets'] = multitask_targets
    np.savez_compressed("3dshapes_multitask_{}_splits.npz".format(split_type), **shapes3d)    

In [6]:
print("grow")
save_splitted_dataset("grow")
print("independent_v1")
save_splitted_dataset("independent_v1")
print("independent_v2")
save_splitted_dataset("independent_v2")
print("random")
save_splitted_dataset("random")

grow
independent_v1
independent_v2
random


OSError: [Errno 5] Input/output error

In [None]:
save_splitted_dataset("random")