In [141]:
import os
import torch.utils.data as data
import torch
import torch.nn as nn
from os.path import join
import numpy as np
import numpy.matlib
import sys
from collections import defaultdict
from scipy.io import savemat
from scipy.io import loadmat


In [8]:
#  mlp model
class MLP(nn.Module):
    def __init__(self, num_layer, num_nodes, relu_final=False):
        super(MLP, self).__init__()
        main = nn.Sequential()
        for l in np.arange(num_layer - 1):
            main.add_module('linear{0}'.format(l), nn.Linear(num_nodes[l], num_nodes[l + 1]))
            if relu_final:
                main.add_module('relu{0}'.format(l), nn.ReLU())
            else:
                if num_layer > 2 and l < num_layer - 2: # 2 layers = linear network, >2 layers, relu net
                    main.add_module('relu{0}'.format(l), nn.ReLU())
        self.main = main

    def forward(self, input):
        output = self.main(input)
        return output

In [9]:
# The graph nodes.
class Data(object):
    def __init__(self, name):
        self.__name = name
        self.__links = set()

    @property
    def name(self):
        return self.__name

    @property
    def links(self):
        return set(self.__links)

    def add_link(self, other):
        self.__links.add(other)
        other.__links.add(self)


# Class to represent a graph, for topological sort of DAG
class Graph:
    def __init__(self, vertices):
        self.graph = defaultdict(list)  # dictionary containing adjacency List
        self.V = vertices  # No. of vertices

    # function to add an edge to graph
    def addEdge(self, u, v):
        self.graph[u].append(v)

    # A recursive function used by topologicalSort
    def topologicalSortUtil(self, v, visited, stack):

        # Mark the current node as visited.
        visited[v] = True

        # Recur for all the vertices adjacent to this vertex
        for i in self.graph[v]:
            if visited[i] is False:
                self.topologicalSortUtil(i, visited, stack)

        # Push current vertex to stack which stores result
        stack.insert(0, v)

    # The function to do Topological Sort. It uses recursive
    # topologicalSortUtil()
    def topologicalSort(self):
        # Mark all the vertices as not visited
        visited = [False] * self.V
        stack = []

        # Call the recursive helper function to store Topological
        # Sort starting from all vertices one by one
        for i in range(self.V):
            if visited[i] is False:
                self.topologicalSortUtil(i, visited, stack)

        # Return contents of stack
        return stack

In [10]:
class DAG_Generator(nn.Module):
    def __init__(self, i_dim, cl_num, do_num, cl_dim, do_dim, z_dim, num_layer=1, num_nodes=64, is_reg=False, dagMat=None, prob=True):
        super(DAG_Generator, self).__init__()
        # create a dag
        dag = Graph(i_dim)

        for i in range(i_dim):
            for j in range(i_dim):
                if dagMat[j, i]:
                    dag.addEdge(i, j)

        # extract y and d signs
        self.yd_sign = dagMat[:, -2:]
        dagMat = dagMat[:, :-2]

        # topological sort
        nodeSort = dag.topologicalSort()
        numInput = dagMat.sum(1)

        # define class and domain conditional networks
        self.prob = prob
        if prob:
            # VAE posterior parameters, Gaussian
            self.mu = nn.Parameter(torch.zeros(do_num, do_dim * i_dim))
            self.sigma = nn.Parameter(torch.ones(do_num, do_dim * i_dim))
        else:
            self.dnet = nn.Linear(do_num, do_dim * i_dim, bias=False)
        if not is_reg:
            self.cnet = nn.Linear(cl_num, cl_dim * i_dim, bias=False)

        # construct generative network according to the dag
        nets = nn.ModuleList()
        for i in range(i_dim):
            num_nodesIn = int(numInput[i]) + cl_dim + do_dim + z_dim
            num_nodes_i = [num_nodesIn] + [num_nodes]*num_layer + [1]
            netMB = MLP(num_layer + 2, num_nodes_i)
            nets.append(netMB)

        # prediction network
        self.nets = nets
        self.nodeSort = nodeSort
        self.nodesA = np.array(range(i_dim)).reshape(i_dim, 1).tolist()
        self.i_dim = i_dim
        self.i_dimNew = i_dim
        self.do_num = do_num
        self.cl_num = cl_num
        self.cl_dim = cl_dim
        self.do_dim = do_dim
        self.z_dim = z_dim
        self.dagMat = dagMat
        self.numInput = numInput
        self.is_reg = is_reg
        self.ischain = False

    # inputs: class indicator, domain indicator, noise
    # forward for all factors in a graph
    def forward(self, noise, input_c, input_d, device='cpu', noise_d=None):
        # class parameter network
        batch_size = input_c.size(0)
        if self.is_reg:
            inputs_c = input_c.view(batch_size, 1)
        else:
            inputs_c = self.cnet(input_c)
        if self.prob:
            theta = self.mu + torch.mul(self.sigma, noise_d)
            inputs_d = torch.matmul(input_d, theta)
        else:
            inputs_d = self.dnet(input_d)

        inputs_n = noise

        output = torch.zeros((batch_size, len(self.nodeSort)))
        output = output.to(device)

        # create a network for each module
        for i in self.nodeSort:
            inputs_pDim = self.numInput[i]
            if inputs_pDim > 0:
                index = np.argwhere(self.dagMat[i, :])
                index = index.flatten()
                index = [int(j) for j in index]
                inputs_p = output[:, index]

            if not self.is_reg:
                inputs_ci = inputs_c[:, i * self.cl_dim:(i + 1) * self.cl_dim]
            else:
                inputs_ci = inputs_c
            inputs_di = inputs_d[:, i * self.do_dim:(i + 1) * self.do_dim]
            inputs_ni = inputs_n[:, i * self.z_dim:(i + 1) * self.z_dim]
            if inputs_pDim > 0:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni, inputs_p), 1)
            else:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni), 1)

            output[:, i] = self.nets[i](inputs_i).squeeze()

        return output

In [19]:
def gaussian_weights_init_simul(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 and classname.find('Conv') == 0:
        # print m.__class__.__name__
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        if m.weight is not None:
            m.weight.data.normal_(1.0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.5)
        if m.bias is not None:
            m.bias.data.normal_(0.0, 0.5)

In [40]:
root = '../data/DatasetSimuDAG9'
num_class = 2
num_domain = 10
sample_size = 500
dim = 4
dagMatFile = '../data/DatasetSimuDAG9/A1_dag.npy'
dimClass = 1
dimDomain = 1
dimHidden = 1
numLayer = 1
numNodes = 30

data = []
label = []
for id in range(num_domain):
        
    # load graph matrix
    dagMat = np.load(dagMatFile)

    # generate  labels
    label_domain = np.ones((sample_size, 1)) * id
    label_domain = label_domain.squeeze()
    label_y = np.repeat(np.arange(num_class), sample_size/num_class).T
    noise = torch.randn(sample_size, dim)

    label_domain_tensor = torch.LongTensor(label_domain)
    label_y_tensor = torch.LongTensor(label_y)

    # one-hot
    label_y_onehot = torch.FloatTensor(sample_size, num_class).zero_()
    label_y_onehot.scatter_(1, label_y_tensor.view(sample_size, 1), 1)
    label_domain_onehot = torch.FloatTensor(sample_size, num_domain).zero_()
    label_domain_onehot.scatter_(1, label_domain_tensor.view(sample_size, 1), 1)
    
    net = DAG_Generator(dim, num_class, num_domain, dimClass, dimDomain,
                            dimHidden, numLayer, numNodes, dagMat=dagMat, prob=False)
    net.apply(gaussian_weights_init_simul)
    xg = net(noise, label_y_onehot, label_domain_onehot)
    data.append(xg.detach().numpy())
    label.append(label_y)

In [41]:
print(data)
print(label)

[array([[ 0.09864143,  2.7052681 ,  1.3831779 ,  1.1127517 ],
       [ 0.21430472,  2.310706  ,  1.5097187 ,  0.5914149 ],
       [ 0.9677466 ,  2.5623677 ,  0.06359386,  1.0196785 ],
       ...,
       [ 0.7578374 ,  2.2240367 ,  2.3796709 , -0.07106432],
       [ 4.7157545 ,  3.5843477 , -2.8514109 ,  0.4660935 ],
       [ 3.4017243 ,  3.2779372 , -0.29269144,  0.26055992]],
      dtype=float32), array([[-0.8527918 ,  2.7224221 ,  0.24441934,  3.3479452 ],
       [-0.22541618,  2.5866091 ,  1.0641435 ,  2.7574959 ],
       [-2.1872716 ,  1.5071819 ,  0.25514698,  0.44290864],
       ...,
       [-0.03508091,  3.2412353 , -0.11489442,  2.3484447 ],
       [-0.9342663 ,  2.773861  , -0.19178483,  3.413465  ],
       [ 0.5457282 ,  2.0244958 ,  1.8420997 ,  2.3609698 ]],
      dtype=float32), array([[-4.1721787 ,  5.5908613 , -0.6142471 ,  0.13159834],
       [-3.246148  ,  5.00941   , -1.481427  ,  0.5329975 ],
       [-1.6193832 ,  2.098598  , -0.5559429 , -1.024453  ],
       ...,
  

9 source domains

In [107]:
num_domain = 10

for i in range(num_domain):
    x = np.zeros((sample_size*num_domain, dim))
    y = np.zeros((sample_size*num_domain, 2))
    full_path = join(root,  'to'+str(i)+'_numData'+str(sample_size)+'.npz')
    cnt = 0
    for j in range(num_domain):
        if j != i:
            x[cnt*sample_size:(cnt+1)*sample_size] = data[j]
            y[cnt*sample_size:(cnt+1)*sample_size, 0] = label[j]
            y[cnt*sample_size:(cnt+1)*sample_size, 1] = np.ones(sample_size) * cnt
            cnt = cnt + 1
    x[(num_domain-1)*sample_size:num_domain*sample_size] = data[i]
    y[cnt*sample_size:(cnt+1)*sample_size, 1] = np.ones(sample_size) * (num_domain-1)
    np.savez(full_path, x=x, y=y)


In [116]:
npzfile = np.load(join(root,  'to'+str(7)+'_numData'+str(sample_size)+'.npz'))
x = npzfile['x']
y = npzfile['y']

In [117]:
print(data[7])
print('\n')
print(x[9*sample_size:10*sample_size])

[[-1.1994536   0.1678969  -4.0217404   1.2788496 ]
 [-0.94169134 -2.6127505  -3.5029953   1.7652314 ]
 [-0.7918181  -0.7177076  -2.541016    1.6904633 ]
 ...
 [ 0.594404   -1.6134539  -1.5509926   1.3119018 ]
 [ 0.32000542 -2.9700537  -2.0198355   1.4633005 ]
 [-0.95640475 -1.1285782  -3.0212383   1.4444156 ]]


[[-1.19945359  0.1678969  -4.02174044  1.2788496 ]
 [-0.94169134 -2.61275053 -3.50299525  1.76523137]
 [-0.79181808 -0.71770757 -2.5410161   1.6904633 ]
 ...
 [ 0.59440398 -1.61345387 -1.55099261  1.31190181]
 [ 0.32000542 -2.97005367 -2.01983547  1.46330047]
 [-0.95640475 -1.12857819 -3.02123833  1.44441557]]


5 source domains

In [130]:
num_domain = 6
root = '../data/DatasetSimuDAG5'
for i in range(10):
    x = np.zeros((sample_size*num_domain, dim))
    y = np.zeros((sample_size*num_domain, 2))
    full_path = join(root,  'to'+str(i)+'_numData'+str(sample_size)+'.npz')
    cnt = 0
    for j in range(10):
        if j != i and cnt < num_domain-1:
            x[cnt*sample_size:(cnt+1)*sample_size] = data[j]
            y[cnt*sample_size:(cnt+1)*sample_size, 0] = label[j]
            y[cnt*sample_size:(cnt+1)*sample_size, 1] = np.ones(sample_size) * cnt
            cnt = cnt + 1
    x[(num_domain-1)*sample_size:num_domain*sample_size] = data[i]
    y[cnt*sample_size:(cnt+1)*sample_size, 1] = np.ones(sample_size) * (num_domain-1)
    np.savez(full_path, x=x, y=y)

2 source domains

In [131]:
num_domain = 3
root = '../data/DatasetSimuDAG2'
for i in range(10):
    x = np.zeros((sample_size*num_domain, dim))
    y = np.zeros((sample_size*num_domain, 2))
    full_path = join(root,  'to'+str(i)+'_numData'+str(sample_size)+'.npz')
    cnt = 0
    for j in range(10):
        if j != i and cnt < num_domain-1:
            x[cnt*sample_size:(cnt+1)*sample_size] = data[j]
            y[cnt*sample_size:(cnt+1)*sample_size, 0] = label[j]
            y[cnt*sample_size:(cnt+1)*sample_size, 1] = np.ones(sample_size) * cnt
            cnt = cnt + 1
    x[(num_domain-1)*sample_size:num_domain*sample_size] = data[i]
    y[cnt*sample_size:(cnt+1)*sample_size, 1] = np.ones(sample_size) * (num_domain-1)
    np.savez(full_path, x=x, y=y)

In [132]:
npzfile = np.load(join(root,  'to'+str(0)+'_numData'+str(sample_size)+'.npz'))
x = npzfile['x']
y = npzfile['y']

In [139]:
print(data[1])
print('\n')
print(x)
print(y)

[[-0.8527918   2.7224221   0.24441934  3.3479452 ]
 [-0.22541618  2.5866091   1.0641435   2.7574959 ]
 [-2.1872716   1.5071819   0.25514698  0.44290864]
 ...
 [-0.03508091  3.2412353  -0.11489442  2.3484447 ]
 [-0.9342663   2.773861   -0.19178483  3.413465  ]
 [ 0.5457282   2.0244958   1.8420997   2.3609698 ]]


[[-0.85279179  2.72242212  0.24441934  3.34794521]
 [-0.22541618  2.58660913  1.06414354  2.75749588]
 [-2.1872716   1.50718188  0.25514698  0.44290864]
 ...
 [ 0.75783741  2.22403669  2.37967086 -0.07106432]
 [ 4.71575451  3.58434772 -2.85141087  0.46609351]
 [ 3.40172434  3.27793717 -0.29269144  0.26055992]]
[[0. 0.]
 [0. 0.]
 [0. 0.]
 ...
 [0. 2.]
 [0. 2.]
 [0. 2.]]


In [157]:
root = '../data/DatasetSimuDAG2'
for i in range(10):
    file_path = join(root,  'to'+str(i)+'_numData'+str(sample_size)+'.npz')
    file_path_mat = join(root,  'to'+str(i)+'_numData'+str(sample_size)+'.mat')
    npzfile = np.load(file_path)
    list(npzfile.keys())
    savemat(file_path_mat, mdict={'x':npzfile['x'],'y':npzfile['y']})

loadmat(file_path_mat)['x'].shape

(1500, 4)

In [163]:
dagMatFile = '../data/DatasetSimuDAG9/A1_dag.npy'
dagMat = np.load(dagMatFile)
print(dagMat)
full_mat_path = '../data/DatasetSimuDAG9/A1_dag.npz'
np.savez(full_mat_path, mat=dagMat)

[[0 0 1 0 1 1]
 [1 0 0 0 1 1]
 [0 0 0 0 1 1]
 [0 0 0 0 1 0]]
