In [1]:
# -*- coding: utf-8 -*-

import sys,os
import numpy as np
import pandas as pd
import scipy.stats as stats
from scipy import linalg
from numpy import dot
import geomloss as gs

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as D
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.autograd import grad
import torch.utils.data
import torch.backends.cudnn as cudnn
from torch.nn.modules import Linear
from torch.autograd.functional import jacobian,hessian,vjp,vhp,hvp

import random
import math

In [2]:
## functions preparation

def setup_seed(seed):
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.benchmark=False
     torch.backends.cudnn.deterministic = True


class MLP(nn.Module):

    def __init__(self, dim_in, dim_out, dim_hidden=64, num_hidden=0, activation=nn.Tanh()):
        super(MLP, self).__init__()

        if num_hidden == 0:
            self.linears = nn.ModuleList([nn.Linear(dim_in, dim_out)])
        elif num_hidden >= 1:
            self.linears = nn.ModuleList() 
            self.linears.append(nn.Linear(dim_in, dim_hidden))
            self.linears.extend([nn.Linear(dim_hidden, dim_hidden) for _ in range(num_hidden-1)])
            self.linears.append(nn.Linear(dim_hidden, dim_out))
        else:
            raise Exception('number of hidden layers must be positive')

        for m in self.linears:
            nn.init.xavier_normal_(m.weight)
            nn.init.uniform_(m.bias,a=-0.1,b=0.1)
 
        self.activation = activation

    def forward(self, x):
        for m in self.linears[:-1]:
            x = self.activation(m(x))

        return self.linears[-1](x)
    
## WGAN-div term
def compute_gradient_penalty(D, real_sample, fake_sample,k=2,p=6):
    real_samples = real_sample.requires_grad_(True)
    fake_samples = fake_sample.requires_grad_(True)

    real_validity = D(real_samples)
    fake_validity = D(fake_samples)

    real_grad_out = torch.ones((real_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    real_grad = grad(
        real_validity, real_samples, real_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    real_grad_norm = real_grad.view(real_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    fake_grad_out = torch.ones((fake_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    fake_grad = grad(
        fake_validity, fake_samples, fake_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    fake_grad_norm = fake_grad.view(fake_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    return (torch.sum(real_grad_norm) + torch.sum(fake_grad_norm)) * k / (real_sample.shape[0]+fake_sample.shape[0])

# generator, EM scheme
class JumpEulerForwardCuda(nn.Module):
    def __init__(self,in_features,num_hidden,dim_hidden,step_size,intensity,bd):
        super(JumpEulerForwardCuda,self).__init__()

        self.drift = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.intensity = torch.tensor(intensity,device="cuda")
        self.mean = nn.Parameter(0.01*torch.ones(in_features))
        self.covHalf = nn.Parameter(0.08*torch.eye(in_features))
        self.diffusion = nn.Parameter(torch.ones(bd,in_features))
        self.in_features = in_features
        self.jump = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.step_size = step_size
        self.bd = bd

    def forward(self,z0,Nsim,steps):

        PopulationPath = torch.empty(size = (Nsim,steps+1,self.in_features),device="cuda")
        PopulationPath[:,0,:] = z0
        state = z0

        for i in range(1,steps+1):
            DP = D.poisson.Poisson(self.intensity*self.step_size) 
            pois = DP.sample((Nsim,1)).cuda()
            state = state + self.drift(state)*self.step_size + math.sqrt(self.step_size)*torch.normal(0,1,size=(Nsim,self.bd),device="cuda")@self.diffusion+\
                (pois*self.mean + pois**(0.5)*torch.normal(0,1,size=(Nsim,self.in_features),device="cuda")@self.covHalf)*self.jump(state)
            PopulationPath[:,i,:] = state
        return PopulationPath

## Sinkhorn distance
a = gs.SamplesLoss(loss='sinkhorn',p=2,blur=0.01)

In [18]:
def Train(AggregateData,n_steps,bd,intensity,num_hidden,dim_hidden,step_size=0.05,n_epochs=30000,n_critic=4,lr=0.0001,Seed=80):
    ##########################################################################################################################
    ####### input
    # AggregateData: a list length l, each element is a numpy recording the observation at a time point. row:variables, col:sample size
    # nsteps: a list with length l-1
    # bd: the dimension of Brownian motion m
    # intensity: lambda, the jump intensity
    # num_hidden, dim_hidden: parameters for the neural networks
    # step_size: a numerical value,delta
    # n_epochs: training iterations
    # n_critic: the number of times that the critic is updated until the generator is updated
    # lr: learning rate
    # Seed: random seed
    
    ###### output
    # netG: the jump diffusion model
    # AggregateData: the training data in cuda
    ##########################################################################################################################
    
    # transform to torch.tensor
    AggregateData = [torch.tensor(ele,dtype=torch.float32,requires_grad = True,device="cuda").t() for ele in AggregateData]
    
    D = len(AggregateData) ## D is the number of observed time points
    print("{0} time points have been observed".format(D))
    
    n_sims = AggregateData[0].shape[0]
    in_features = AggregateData[0].shape[1]
    
    #print(n_sims,in_features,AggregateData[0][0])
    
    # random seed
    setup_seed(Seed)

    # create generator and critic
    netG = JumpEulerForwardCuda(in_features,num_hidden,dim_hidden,step_size,intensity,bd).cuda()
    netD = [MLP(in_features,1,dim_hidden,num_hidden).cuda() for _ in range(D-1)]

    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizerSD = [optim.Adam(netD_i.parameters(), lr=lr, betas=(0.5, 0.999)) for netD_i in netD]


    ### training process
    for epoch in range(n_epochs):


        # -------------------
        # train the critic
        # -------------------


        for _ in range(n_critic): 
            fake_data = netG(AggregateData[0],n_sims,n_steps[-1])
            fake = [fake_data[:,ele,:] for ele in n_steps]
            
            for SDi in range(D-1):
                optimizerSD[SDi].zero_grad()
                div_gp = compute_gradient_penalty(netD[SDi],AggregateData[SDi+1],fake[SDi])
                d_loss = -torch.mean(netD[SDi](AggregateData[SDi+1]))+torch.mean(netD[SDi](fake[SDi]))+div_gp
                d_loss.backward(retain_graph=True) # retain_graph=True

                optimizerSD[SDi].step()


        # ----------------
        # train the generator
        # ----------------

        for _ in range(1):
            optimizerG.zero_grad()

            fake_data = netG(AggregateData[0],n_sims,n_steps[-1])
            fake = [fake_data[:,ele,:] for ele in n_steps]
            g_loss = [-torch.mean(netD[el](fake[el])) for el in range(D-1)]
            g_loss = sum(g_loss)
            g_loss.backward() 

            optimizerG.step()

        if epoch %10==0:
            error = [a(fake[ii],AggregateData[ii+1]).item() for ii in range(D-1)]
            #print("epoch:",epoch,";", "d1_loss:",(-d1_loss+div_gp1).item(),";","d2_loss:",(-d2_loss+div_gp2).item(),";","g_loss:",g_loss.item())
            print("epoch: ",epoch,"training: ",error)
            
    return netG,AggregateData

# An example: the stem cell differentiation dataset with two genes Tagln and Gsn

## data preprocessing

In [4]:
## The data processing steps for the steam cell dataset is the same as that in https://github.com/thashim/population-diffusions. 
FilePath = './'

file_list = ['GSM1599494_ES_d0_main.csv', 'GSM1599497_ES_d2_LIFminus.csv', 'GSM1599498_ES_d4_LIFminus.csv', 'GSM1599499_ES_d7_LIFminus.csv']

table_list = []
for filein in file_list:
    table_list.append(pd.read_csv(FilePath+filein, header=None))

matrix_list = []
gene_names = table_list[0].values[:,0]
for table in table_list:
    matrix_list.append(table.values[:,1:].astype('float32'))

cell_counts = [matrix.shape[1] for matrix in matrix_list]

# Normalization
def normalize_run(mat):
    rpm = np.sum(mat,0)/1e6
    detect_pr = np.sum(mat==0,0)/float(mat.shape[0])
    return np.log(mat*(np.median(detect_pr)/detect_pr)*1.0/rpm + 1.0)

norm_mat = [normalize_run(matrix) for matrix in matrix_list]

# sort genes by Wasserstein distance between D0 and D7
qt_mat = [np.percentile(norm_in,q=np.linspace(0,100,50),axis=1) for norm_in in norm_mat] 
wdiv=np.sum((qt_mat[0]-qt_mat[3])**2,0)
w_order = np.argsort(-wdiv)

wsub = w_order[0:100]

def nmf(X, latent_features, max_iter=100, error_limit=1e-6, fit_error_limit=1e-6, print_iter=200):
    """
    Decompose X to A*Y
    """
    eps = 1e-5
    print('Starting NMF decomposition with {} latent features and {} iterations.'.format(latent_features, max_iter))
    #X = X.toarray()   I am passing in a scipy sparse matrix

    # mask
    mask = np.sign(X)

    # initial matrices. A is random [0,1] and Y is A\X.
    rows, columns = X.shape
    A = np.random.rand(rows, latent_features) 
    A = np.maximum(A, eps)

    Y = linalg.lstsq(A, X)[0]
    Y = np.maximum(Y, eps)

    masked_X = mask * X
    X_est_prev = dot(A, Y)
    for i in range(1, max_iter + 1):
        # ===== updates =====
        # Matlab: A=A.*(((W.*X)*Y')./((W.*(A*Y))*Y'));
        top = dot(masked_X, Y.T)
        bottom = (dot((mask * dot(A, Y)), Y.T)) + eps
        A *= top / bottom

        A = np.maximum(A, eps)
        # print 'A',  np.round(A, 2)

        # Matlab: Y=Y.*((A'*(W.*X))./(A'*(W.*(A*Y))));
        top = dot(A.T, masked_X)
        bottom = dot(A.T, mask * dot(A, Y)) + eps
        Y *= top / bottom
        Y = np.maximum(Y, eps)
        # print 'Y', np.round(Y, 2)


        # ==== evaluation ====
        if i % print_iter == 0 or i == 1 or i == max_iter:
            print('Iteration {}:'.format(i),)
            X_est = dot(A, Y)
            err = mask * (X_est_prev - X_est)
            fit_residual = np.sqrt(np.sum(err ** 2))
            X_est_prev = X_est

            curRes = linalg.norm(mask * (X - X_est), ord='fro')
            print('fit residual', np.round(fit_residual, 4),)
            print('total residual', np.round(curRes, 4))
            if curRes < error_limit or fit_residual < fit_error_limit:
                break
    return A, Y, dot(A,Y)

np.random.seed(0)
norm_imputed = [nmf(normin[wsub,:], latent_features = len(wsub)*4, max_iter=500)[2] for normin in norm_mat]

## choose the 3rd and 7th genes as the training varibales
norm_adj = np.mean(norm_imputed[3],1)[:,np.newaxis]
subvec = np.array([3,7])

gnvec = gene_names[w_order[subvec]]

cov_mat = np.cov(norm_imputed[3][subvec,:])
whiten = np.diag(np.diag(cov_mat)**(-0.5))
unwhiten = np.diag(np.diag(cov_mat)**(0.5))

norm_imputed2 = [np.dot(whiten,(normin - norm_adj)[subvec,:]) for normin in norm_imputed]

train_data = norm_imputed2

Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 4000.1233
total residual 138.7959
Iteration 200:
fit residual 123.6643
total residual 31.8464
Iteration 400:
fit residual 21.7894
total residual 12.6115
Iteration 500:
fit residual 4.3442
total residual 8.6686
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2257.8358
total residual 58.2265
Iteration 200:
fit residual 56.8198
total residual 4.0447
Iteration 400:
fit residual 3.4777
total residual 0.6841
Iteration 500:
fit residual 0.3744
total residual 0.3185
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2925.3664
total residual 80.2657
Iteration 200:
fit residual 78.1148
total residual 6.1559
Iteration 400:
fit residual 5.1661
total residual 1.1917
Iteration 500:
fit residual 0.6183
total residual 0.5917
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
f

## Training

In [19]:
netG,AggregateData = Train(train_data,[10,20,35],2,10,4,256,step_size=0.03,n_epochs=20000,n_critic=3,lr=0.0003,Seed=200)

4 time points have been observed
933 2 tensor([-2.3410, -3.7850], device='cuda:0', grad_fn=<SelectBackward>)
epoch:  0 training:  [1.32451593875885, 4.614095687866211, 7.783506870269775]
epoch:  10 training:  [0.23683612048625946, 0.33031779527664185, 0.04719405621290207]
epoch:  20 training:  [0.16710378229618073, 0.24282582104206085, 0.07422839105129242]
epoch:  30 training:  [0.2877744436264038, 0.1751551330089569, 0.048849303275346756]
epoch:  40 training:  [0.29298001527786255, 0.08472757786512375, 0.2747907042503357]
epoch:  50 training:  [0.15940210223197937, 0.3763771653175354, 0.09240724891424179]
epoch:  60 training:  [0.18050657212734222, 0.15011844038963318, 0.055368080735206604]
epoch:  70 training:  [0.21605640649795532, 0.10083488374948502, 0.10211674869060516]
epoch:  80 training:  [0.17942486703395844, 0.12052836269140244, 0.00978100299835205]
epoch:  90 training:  [0.13572803139686584, 0.3425350487232208, 0.11001603305339813]
epoch:  100 training:  [0.2304874211549759

epoch:  930 training:  [0.032793182879686356, 0.014565370976924896, 0.05449560284614563]
epoch:  940 training:  [0.02832268737256527, 0.009817056357860565, 0.020097212865948677]
epoch:  950 training:  [0.03642052412033081, 0.011674363166093826, 0.046858105808496475]
epoch:  960 training:  [0.04481282830238342, -0.015736058354377747, 0.049625590443611145]
epoch:  970 training:  [0.05344946309924126, 0.07154548913240433, 0.08976475149393082]
epoch:  980 training:  [0.02761252596974373, 0.040257152169942856, 0.06497224420309067]
epoch:  990 training:  [0.05715734511613846, 0.010157670825719833, 0.04515911266207695]
epoch:  1000 training:  [0.030996818095445633, 0.0065995994955301285, 0.06136413663625717]
epoch:  1010 training:  [0.03241056948900223, 0.012193076312541962, 0.048627886921167374]
epoch:  1020 training:  [0.060212258249521255, 0.0667809545993805, 0.08462502807378769]
epoch:  1030 training:  [0.04181305319070816, 0.010978959500789642, 0.06522159278392792]
epoch:  1040 training:

epoch:  1850 training:  [0.006973080337047577, 0.017855454236268997, 0.0559210367500782]
epoch:  1860 training:  [0.0065390244126319885, 0.06795811653137207, 0.12372131645679474]
epoch:  1870 training:  [0.016286468133330345, 0.026805385947227478, 0.044765230268239975]
epoch:  1880 training:  [0.010330302640795708, 0.008847996592521667, 0.03628695011138916]
epoch:  1890 training:  [0.01432364247739315, 0.054821498692035675, 0.06492931395769119]
epoch:  1900 training:  [0.02367084100842476, 0.10729160904884338, 0.0568460188806057]
epoch:  1910 training:  [0.01862134411931038, 0.020668577402830124, 0.015585611574351788]
epoch:  1920 training:  [0.011328469961881638, 0.0061258673667907715, 0.034870557487010956]
epoch:  1930 training:  [0.008364027366042137, 0.01092321053147316, 0.07346336543560028]
epoch:  1940 training:  [0.014184268191456795, 0.005559124052524567, 0.018604308366775513]
epoch:  1950 training:  [0.006963448598980904, 0.011600784957408905, 0.028196323662996292]
epoch:  196

epoch:  2760 training:  [0.012330565601587296, 0.015370678156614304, 0.03830689936876297]
epoch:  2770 training:  [0.012301363050937653, 0.0076060667634010315, 0.03996308147907257]
epoch:  2780 training:  [0.021229347214102745, 0.025635823607444763, 0.022101737558841705]
epoch:  2790 training:  [0.013555431738495827, -0.0011057443916797638, 0.05951905995607376]
epoch:  2800 training:  [0.013531045988202095, -0.007452338933944702, 0.019014030694961548]
epoch:  2810 training:  [0.010328955948352814, 0.005980063229799271, 0.01847117394208908]
epoch:  2820 training:  [0.008055157959461212, 0.0016368776559829712, 0.016589198261499405]
epoch:  2830 training:  [0.014667998999357224, 0.026677973568439484, 0.04487504065036774]
epoch:  2840 training:  [0.01106935366988182, 0.016098812222480774, 0.07239707559347153]
epoch:  2850 training:  [0.006292965263128281, 0.04233580455183983, 0.02791139855980873]
epoch:  2860 training:  [0.011967789381742477, 0.01890932023525238, 0.02968434803187847]
epoch

epoch:  3670 training:  [0.005728365853428841, 0.0004261434078216553, 0.009812891483306885]
epoch:  3680 training:  [0.0023781675845384598, -0.003974303603172302, 0.027809370309114456]
epoch:  3690 training:  [0.005373004823923111, 0.0075819529592990875, 0.027623549103736877]
epoch:  3700 training:  [0.005721655674278736, 0.007194120436906815, 0.04508807510137558]
epoch:  3710 training:  [0.004670598544180393, 0.009105071425437927, 0.04840580374002457]
epoch:  3720 training:  [0.0034453272819519043, 0.0009968318045139313, 0.012265969067811966]
epoch:  3730 training:  [0.013204891234636307, 0.0005998127162456512, 0.0016507506370544434]
epoch:  3740 training:  [0.0033029038459062576, 0.004992540925741196, 0.0486859455704689]
epoch:  3750 training:  [-0.0018662279471755028, 0.0003766566514968872, 0.023477332666516304]
epoch:  3760 training:  [0.012055226601660252, 0.031221045181155205, 0.04067142307758331]
epoch:  3770 training:  [0.02309131622314453, 0.029926005750894547, 0.0594025328755

epoch:  4580 training:  [0.01884627714753151, 0.037878476083278656, 0.08367260545492172]
epoch:  4590 training:  [0.009058183059096336, 0.023271046578884125, 0.06722207367420197]
epoch:  4600 training:  [0.009389450773596764, 0.0031278282403945923, 0.012172643095254898]
epoch:  4610 training:  [0.020591584965586662, 0.018954038619995117, 0.06039438396692276]
epoch:  4620 training:  [0.00963791273534298, 0.004891064018011093, 0.029116716235876083]
epoch:  4630 training:  [0.004677146673202515, 0.00936858355998993, 0.037444960325956345]
epoch:  4640 training:  [0.007834233343601227, 0.016708742827177048, 0.017455147579312325]
epoch:  4650 training:  [0.011141810566186905, 0.027693919837474823, 0.06588390469551086]
epoch:  4660 training:  [0.008896180428564548, 0.012890055775642395, 0.021397965028882027]
epoch:  4670 training:  [0.003028072416782379, -0.004061639308929443, 0.004449591040611267]
epoch:  4680 training:  [0.009375471621751785, 0.004770427942276001, 0.02318500727415085]
epoch

epoch:  5490 training:  [0.004976317286491394, 0.0034876689314842224, 0.009755147621035576]
epoch:  5500 training:  [0.005745317786931992, 0.0037778951227664948, 0.015789654105901718]
epoch:  5510 training:  [0.010787718929350376, 0.029218921437859535, 0.04383781924843788]
epoch:  5520 training:  [0.008355516009032726, 0.007677685469388962, 0.011583656072616577]
epoch:  5530 training:  [0.004819036461412907, -0.005839407444000244, 0.01180979609489441]
epoch:  5540 training:  [0.006873925216495991, 0.011250555515289307, 0.015317827463150024]
epoch:  5550 training:  [0.0071716015227139, -0.0051339007914066315, 0.00647938996553421]
epoch:  5560 training:  [0.005893852561712265, 0.008803647011518478, 0.05554334819316864]
epoch:  5570 training:  [0.016029052436351776, 0.0194280706346035, 0.08222019672393799]
epoch:  5580 training:  [0.012844890356063843, 0.004819206893444061, 0.017067084088921547]
epoch:  5590 training:  [0.009352274239063263, 0.016955099999904633, 0.03917904943227768]
epoc

epoch:  6400 training:  [0.005945183336734772, -0.00186823308467865, 0.008921835571527481]
epoch:  6410 training:  [0.007030405104160309, -0.004828270524740219, 0.02091626077890396]
epoch:  6420 training:  [0.0022013895213603973, -0.0020813755691051483, 0.028050530701875687]
epoch:  6430 training:  [0.009535808116197586, 0.0009705610573291779, 0.005396362394094467]
epoch:  6440 training:  [0.004389575682580471, -0.001167062669992447, 0.006407037377357483]
epoch:  6450 training:  [-7.750256918370724e-05, 0.011410720646381378, 0.017075348645448685]
epoch:  6460 training:  [0.004770344123244286, -0.00036294758319854736, 0.014822715893387794]
epoch:  6470 training:  [0.00862930715084076, -0.008254598826169968, 0.011670105159282684]
epoch:  6480 training:  [0.006848193239420652, 0.0012676939368247986, 0.009342636913061142]
epoch:  6490 training:  [0.0063864691182971, -0.0077655985951423645, 0.012020917609333992]
epoch:  6500 training:  [0.0057397219352424145, -0.008248087018728256, 0.013081

epoch:  7310 training:  [0.005890618544071913, 0.010082900524139404, 0.020968126133084297]
epoch:  7320 training:  [0.0063730692490935326, -0.001373153179883957, 0.013765616342425346]
epoch:  7330 training:  [0.007942816242575645, 0.005243424326181412, 0.02190428227186203]
epoch:  7340 training:  [0.007212920114398003, -0.012301869690418243, 0.009088100865483284]
epoch:  7350 training:  [0.007604809477925301, 0.007089890539646149, 0.02595124952495098]
epoch:  7360 training:  [0.003246057778596878, -0.0015097446739673615, 0.04823035001754761]
epoch:  7370 training:  [0.0023728152737021446, -4.4930726289749146e-05, 0.00698448158800602]
epoch:  7380 training:  [0.010975482873618603, 0.02673671394586563, 0.050693221390247345]
epoch:  7390 training:  [0.008527224883437157, 0.004163149744272232, 0.01139771193265915]
epoch:  7400 training:  [0.030942775309085846, 0.10154944658279419, 0.05478862673044205]
epoch:  7410 training:  [0.01072391215711832, 0.011706776916980743, 0.028104711323976517]

epoch:  8210 training:  [-5.7213008403778076e-05, 0.0095624178647995, 0.014674107544124126]
epoch:  8220 training:  [0.0036231400445103645, 0.0068156179040670395, 0.015498293563723564]
epoch:  8230 training:  [0.0014820594806224108, -0.002166435122489929, 0.014201968908309937]
epoch:  8240 training:  [0.007108610589057207, -0.0008702129125595093, -0.002384886145591736]
epoch:  8250 training:  [0.00949124339967966, -0.0008346512913703918, 0.015209324657917023]
epoch:  8260 training:  [0.0035068076103925705, -0.004930291324853897, 0.016914110630750656]
epoch:  8270 training:  [0.005655445158481598, 0.00029375962913036346, 0.01863902062177658]
epoch:  8280 training:  [0.0031988490372896194, 0.0008484832942485809, 0.01472623273730278]
epoch:  8290 training:  [0.010725751519203186, 0.013959303498268127, 0.01077151857316494]
epoch:  8300 training:  [0.006763029843568802, 0.011217277497053146, 0.023688968271017075]
epoch:  8310 training:  [-0.0012415405362844467, 0.0070552825927734375, 0.0383

epoch:  9110 training:  [0.004104383289813995, -0.003806479275226593, 0.02808769792318344]
epoch:  9120 training:  [0.007639472372829914, 0.0014281868934631348, 0.06251034885644913]
epoch:  9130 training:  [-0.0023503918200731277, 0.0012405291199684143, 0.030529240146279335]
epoch:  9140 training:  [0.0009766267612576485, 0.0029128268361091614, 0.07644138485193253]
epoch:  9150 training:  [-0.0013024769723415375, -0.004170779138803482, -0.019804365932941437]
epoch:  9160 training:  [0.003377344459295273, 0.003838673233985901, 0.018161065876483917]
epoch:  9170 training:  [0.0025435006245970726, 0.007150474935770035, 0.03852391242980957]
epoch:  9180 training:  [0.005924563389271498, 0.012373041361570358, 0.01705927401781082]
epoch:  9190 training:  [0.0012844651937484741, 0.00935000367462635, 0.03619740158319473]
epoch:  9200 training:  [0.005126042291522026, -0.00740446150302887, 0.023850662633776665]
epoch:  9210 training:  [0.0046427687630057335, 0.0010163169354200363, 0.01326697692

epoch:  10010 training:  [0.004943943582475185, -0.01453610509634018, 0.008921608328819275]
epoch:  10020 training:  [0.003387789474800229, 0.006080232560634613, 0.032891325652599335]
epoch:  10030 training:  [0.010408295318484306, 0.02489619143307209, 0.024177556857466698]
epoch:  10040 training:  [0.006859591230750084, 0.00473962165415287, 0.013626154512166977]
epoch:  10050 training:  [0.006114325486123562, 0.0022168196737766266, 0.025240249931812286]
epoch:  10060 training:  [0.004251587204635143, 0.005317842587828636, 0.009555086493492126]
epoch:  10070 training:  [-0.0023012589663267136, -0.0021792612969875336, -0.006239354610443115]
epoch:  10080 training:  [0.0035457946360111237, 0.0051682740449905396, 0.04014141485095024]
epoch:  10090 training:  [0.005386459641158581, 0.014940984547138214, 0.05282285809516907]
epoch:  10100 training:  [0.002648906782269478, 0.0033187195658683777, 0.0245568435639143]
epoch:  10110 training:  [0.012286238372325897, 0.00993461161851883, 0.023166

epoch:  10910 training:  [0.008913140743970871, 0.005869373679161072, 0.03266327828168869]
epoch:  10920 training:  [0.00884589459747076, 0.0009146071970462799, 0.014866746962070465]
epoch:  10930 training:  [0.007093079388141632, 0.00646982342004776, 0.025669436901807785]
epoch:  10940 training:  [0.007550298236310482, 0.005299631506204605, 0.008012592792510986]
epoch:  10950 training:  [0.00020786747336387634, 0.008902417495846748, 0.022154677659273148]
epoch:  10960 training:  [0.011489310301840305, 0.03379450738430023, 0.02605423331260681]
epoch:  10970 training:  [0.010475628077983856, -0.0011925771832466125, 0.017074119299650192]
epoch:  10980 training:  [0.024229807779192924, 0.01754888892173767, 0.019201116636395454]
epoch:  10990 training:  [-0.0021824948489665985, -0.0012876428663730621, 0.020407911390066147]
epoch:  11000 training:  [0.004882393404841423, 0.002987578511238098, 0.02474183589220047]
epoch:  11010 training:  [-0.007008334621787071, -0.013164274394512177, 0.0133

epoch:  11800 training:  [0.0113845095038414, 0.01539980061352253, 0.054301917552948]
epoch:  11810 training:  [0.00890866108238697, 0.006008755415678024, -0.001812577247619629]
epoch:  11820 training:  [0.0042508989572525024, 0.015345979481935501, 0.002497635781764984]
epoch:  11830 training:  [0.0024270564317703247, 0.007851876318454742, 0.019731596112251282]
epoch:  11840 training:  [0.008614814840257168, 0.002501577138900757, 0.03717944025993347]
epoch:  11850 training:  [0.0031334715895354748, 0.016663599759340286, 0.022629983723163605]
epoch:  11860 training:  [0.005677295848727226, -0.006702356040477753, 0.008465401828289032]
epoch:  11870 training:  [0.015444694086909294, 0.010643616318702698, 0.04152859002351761]
epoch:  11880 training:  [0.002886701375246048, 0.0179428793489933, 0.04789550602436066]
epoch:  11890 training:  [0.010538242757320404, 0.01727796159684658, 0.024820469319820404]
epoch:  11900 training:  [0.0023515173234045506, 0.0038644541054964066, -0.0017904266715

epoch:  12700 training:  [0.005290934816002846, 0.005692293867468834, 0.03458026051521301]
epoch:  12710 training:  [0.007454621605575085, -0.007509306073188782, 0.010178787633776665]
epoch:  12720 training:  [0.0063688792288303375, 0.0007297024130821228, 0.01584881916642189]
epoch:  12730 training:  [0.009522100910544395, 0.00939573347568512, 0.017286548390984535]
epoch:  12740 training:  [0.011039325967431068, 0.005103558301925659, 0.01820765808224678]
epoch:  12750 training:  [0.004019483923912048, 0.0011397898197174072, 0.007327606435865164]
epoch:  12760 training:  [0.006758246570825577, -0.042547039687633514, 0.00356413796544075]
epoch:  12770 training:  [0.014507599174976349, 0.004688907414674759, 0.02227954752743244]
epoch:  12780 training:  [0.012561287730932236, 0.01409432664513588, 0.010174544528126717]
epoch:  12790 training:  [0.006397917866706848, 0.00504109263420105, 0.00722329318523407]
epoch:  12800 training:  [0.006390244700014591, 0.005491364747285843, 0.012098489329

epoch:  13600 training:  [0.014230367727577686, 0.006239522248506546, 0.0394924134016037]
epoch:  13610 training:  [0.01075243204832077, -0.004063326865434647, 0.021727412939071655]
epoch:  13620 training:  [0.12177062779664993, 0.27099186182022095, 0.4705240726470947]
epoch:  13630 training:  [0.3510120213031769, 0.48148345947265625, 0.6868048906326294]
epoch:  13640 training:  [0.5357881784439087, 0.385509192943573, 0.28769880533218384]
epoch:  13650 training:  [0.3544826805591583, 0.18116174638271332, 0.33813387155532837]
epoch:  13660 training:  [0.5671097040176392, 0.3608956038951874, 0.2500133514404297]
epoch:  13670 training:  [0.5707197785377502, 0.3535971939563751, 0.40471988916397095]
epoch:  13680 training:  [0.48613566160202026, 0.3008081912994385, 0.31364864110946655]
epoch:  13690 training:  [0.41467320919036865, 0.2716352343559265, 0.30276739597320557]
epoch:  13700 training:  [0.2922760248184204, 0.24385195970535278, 0.20511290431022644]
epoch:  13710 training:  [0.5066

epoch:  14540 training:  [0.1071319729089737, 0.08432646840810776, 0.24572056531906128]
epoch:  14550 training:  [0.16429737210273743, 0.03433641791343689, 0.24839527904987335]
epoch:  14560 training:  [0.12415195256471634, 0.039758265018463135, 0.23262682557106018]
epoch:  14570 training:  [0.1285201907157898, 0.03608466684818268, 0.22275635600090027]
epoch:  14580 training:  [0.14039620757102966, 0.123603455722332, 0.21664366126060486]
epoch:  14590 training:  [0.14630207419395447, 0.03625396639108658, 0.2198704183101654]
epoch:  14600 training:  [0.11555606126785278, 0.06893341988325119, 0.213828444480896]
epoch:  14610 training:  [0.13364142179489136, 0.07481120526790619, 0.22696927189826965]
epoch:  14620 training:  [0.12627995014190674, 0.05328689143061638, 0.250166654586792]
epoch:  14630 training:  [0.10851223766803741, 0.047565121203660965, 0.23509031534194946]
epoch:  14640 training:  [0.10299244523048401, 0.06696799397468567, 0.2298501431941986]
epoch:  14650 training:  [0.1

epoch:  15470 training:  [0.3747072219848633, 0.27142179012298584, 0.6294237971305847]
epoch:  15480 training:  [0.07851645350456238, 0.2981310188770294, 0.12660107016563416]
epoch:  15490 training:  [0.19913369417190552, 0.7211155891418457, 0.48209524154663086]
epoch:  15500 training:  [0.125849187374115, 0.40391814708709717, 0.24296411871910095]
epoch:  15510 training:  [0.15380559861660004, 0.17943650484085083, 0.046559520065784454]
epoch:  15520 training:  [0.13657821714878082, 0.1330501288175583, 0.0023761987686157227]
epoch:  15530 training:  [0.11180196702480316, 0.12635289132595062, 0.024290598928928375]
epoch:  15540 training:  [0.13134092092514038, 0.18467310070991516, 0.05760291963815689]
epoch:  15550 training:  [0.13113103806972504, 0.1243838518857956, 0.0317947119474411]
epoch:  15560 training:  [0.141945943236351, 0.10727758705615997, 0.04354492947459221]
epoch:  15570 training:  [0.10520128905773163, 0.1268424242734909, 0.020882144570350647]
epoch:  15580 training:  [0.

epoch:  16400 training:  [0.05227595567703247, 0.12371111661195755, 0.02784419059753418]
epoch:  16410 training:  [0.0556604340672493, 0.11212997883558273, 0.018577940762043]
epoch:  16420 training:  [0.06203350052237511, 0.12892432510852814, 0.02492784708738327]
epoch:  16430 training:  [0.04991636425256729, 0.12841089069843292, 0.017668467015028]
epoch:  16440 training:  [0.04786491394042969, 0.11713629961013794, 0.01340075209736824]
epoch:  16450 training:  [0.05634023994207382, 0.11666470021009445, 0.014608338475227356]
epoch:  16460 training:  [0.0489199161529541, 0.11915792524814606, 0.01196671649813652]
epoch:  16470 training:  [0.0411372184753418, 0.12501884996891022, 0.02397039532661438]
epoch:  16480 training:  [0.05294078588485718, 0.11968865990638733, 0.02130740135908127]
epoch:  16490 training:  [0.04950297623872757, 0.13584929704666138, 0.019918549805879593]
epoch:  16500 training:  [0.047268666326999664, 0.1427176594734192, 0.01764826476573944]
epoch:  16510 training:  [

epoch:  17320 training:  [0.031131960451602936, 0.05191957950592041, 0.025376828387379646]
epoch:  17330 training:  [0.043612681329250336, 0.05892264097929001, 0.02028524875640869]
epoch:  17340 training:  [0.04671599715948105, 0.05712028592824936, 0.023989010602235794]
epoch:  17350 training:  [0.04487836733460426, 0.06060806289315224, 0.03769192099571228]
epoch:  17360 training:  [0.04848135635256767, 0.055820293724536896, 0.028824618086218834]
epoch:  17370 training:  [0.053331080824136734, 0.05496525019407272, 0.02215355634689331]
epoch:  17380 training:  [0.05557442456483841, 0.04874057695269585, 0.013872720301151276]
epoch:  17390 training:  [0.0492052361369133, 0.045500706881284714, 0.030770648270845413]
epoch:  17400 training:  [0.04558064788579941, 0.045476481318473816, 0.014075696468353271]
epoch:  17410 training:  [0.040606673806905746, 0.0607142448425293, 0.018927760422229767]
epoch:  17420 training:  [0.054520804435014725, 0.05661356821656227, 0.021604709327220917]
epoch: 

epoch:  18230 training:  [0.05307544767856598, 0.02662789449095726, 0.018292594701051712]
epoch:  18240 training:  [0.05781986936926842, 0.03712949901819229, 0.024060387164354324]
epoch:  18250 training:  [0.07370837032794952, 0.03656449541449547, 0.016907939687371254]
epoch:  18260 training:  [0.06343390792608261, 0.031839288771152496, 0.020905226469039917]
epoch:  18270 training:  [0.053290169686079025, 0.03820333629846573, 0.01974428817629814]
epoch:  18280 training:  [0.05604754388332367, 0.03882746025919914, 0.019379211589694023]
epoch:  18290 training:  [0.06388583779335022, 0.025721659883856773, 0.017061306163668633]
epoch:  18300 training:  [0.05861509591341019, 0.03855788707733154, 0.02465720847249031]
epoch:  18310 training:  [0.06045969948172569, 0.017531704157590866, 0.025806263089179993]
epoch:  18320 training:  [0.05900038778781891, 0.03574215620756149, 0.012416115030646324]
epoch:  18330 training:  [0.05889410525560379, 0.0242107305675745, 0.041191790252923965]
epoch:  1

epoch:  19140 training:  [0.055627837777137756, 0.05345918983221054, 0.025074534118175507]
epoch:  19150 training:  [0.07422847300767899, 0.032869089394807816, 0.035968273878097534]
epoch:  19160 training:  [0.0671294778585434, 0.04063253104686737, 0.027992650866508484]
epoch:  19170 training:  [0.05633245408535004, 0.061025671660900116, 0.05207977071404457]
epoch:  19180 training:  [0.043544963002204895, 0.04482750967144966, 0.016855081543326378]
epoch:  19190 training:  [0.04638803377747536, 0.04945966601371765, 0.016691338270902634]
epoch:  19200 training:  [0.0679575502872467, 0.027004607021808624, 0.03843115642666817]
epoch:  19210 training:  [0.0686725378036499, 0.02867666259407997, 0.021459383890032768]
epoch:  19220 training:  [0.05946354940533638, 0.04282984137535095, 0.01717453822493553]
epoch:  19230 training:  [0.055176664143800735, 0.034974370151758194, 0.024474019184708595]
epoch:  19240 training:  [0.08459188044071198, 0.025400858372449875, 0.06530033051967621]
epoch:  1

## Visualization

In [21]:
import scipy
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 18})

setup_seed(435)

AggregateData = [torch.tensor(ele,dtype=torch.float32,requires_grad = True,device="cuda").t() for ele in train_data]

path = netG(AggregateData[0],AggregateData[0].shape[0],35)

G2 = path[:,10,:]
G4 = path[:,20,:]
G7 = path[:,35,:]


####### The Sinkhorn distance between predicted and observed distribution at D2,D4 and D7
print(a(G2,AggregateData[1]),a(G4,AggregateData[2]),a(G7,AggregateData[3]))

T0 = AggregateData[0].detach().cpu().numpy()
T2 = AggregateData[1].detach().cpu().numpy()
T4 = AggregateData[2].detach().cpu().numpy()
T7 = AggregateData[3].detach().cpu().numpy()

G2 = G2.detach().cpu().numpy()
G4 = G4.detach().cpu().numpy()
G7 = G7.detach().cpu().numpy()

tensor(0.0570, device='cuda:0', grad_fn=<SelectBackward>) tensor(0.0148, device='cuda:0', grad_fn=<SelectBackward>) tensor(0.0217, device='cuda:0', grad_fn=<SelectBackward>)
