In [1]:
# -*- coding: utf-8 -*-

import sys,os
import numpy as np
import pandas as pd
import scipy.stats as stats
from scipy import linalg
from numpy import dot
import geomloss as gs

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as D
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.autograd import grad
import torch.utils.data
import torch.backends.cudnn as cudnn
from torch.nn.modules import Linear
from torch.autograd.functional import jacobian,hessian,vjp,vhp,hvp

import random
import math

In [8]:
## functions preparation

def setup_seed(seed):
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.benchmark=False
     torch.backends.cudnn.deterministic = True


class MLP(nn.Module):

    def __init__(self, dim_in, dim_out, dim_hidden=64, num_hidden=0, activation=nn.Tanh()):
        super(MLP, self).__init__()

        if num_hidden == 0:
            self.linears = nn.ModuleList([nn.Linear(dim_in, dim_out)])
        elif num_hidden >= 1:
            self.linears = nn.ModuleList() 
            self.linears.append(nn.Linear(dim_in, dim_hidden))
            self.linears.extend([nn.Linear(dim_hidden, dim_hidden) for _ in range(num_hidden-1)])
            self.linears.append(nn.Linear(dim_hidden, dim_out))
        else:
            raise Exception('number of hidden layers must be positive')

        for m in self.linears:
            nn.init.xavier_normal_(m.weight)
            nn.init.uniform_(m.bias,a=-0.1,b=0.1)
 
        self.activation = activation

    def forward(self, x):
        for m in self.linears[:-1]:
            x = self.activation(m(x))

        return self.linears[-1](x)
    
## WGAN-div term
def compute_gradient_penalty(D, real_sample, fake_sample,k=2,p=6):
    real_samples = real_sample.requires_grad_(True)
    fake_samples = fake_sample.requires_grad_(True)

    real_validity = D(real_samples)
    fake_validity = D(fake_samples)

    real_grad_out = torch.ones((real_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    real_grad = grad(
        real_validity, real_samples, real_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    real_grad_norm = real_grad.view(real_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    fake_grad_out = torch.ones((fake_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    fake_grad = grad(
        fake_validity, fake_samples, fake_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    fake_grad_norm = fake_grad.view(fake_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    return (torch.sum(real_grad_norm) + torch.sum(fake_grad_norm)) * k / (real_sample.shape[0]+fake_sample.shape[0])

# generator, EM scheme
class JumpEulerForwardCuda(nn.Module):
    def __init__(self,in_features,num_hidden,dim_hidden,step_size,intensity,bd):
        super(JumpEulerForwardCuda,self).__init__()

        self.drift = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.intensity = torch.tensor(intensity,device="cuda")
        self.mean = nn.Parameter(0.01*torch.ones(in_features))
        self.covHalf = nn.Parameter(0.08*torch.eye(in_features))
        self.diffusion = nn.Parameter(torch.ones(bd,in_features))
        self.in_features = in_features
        self.jump = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.step_size = step_size
        self.bd = bd

    def forward(self,z0,Nsim,steps):

        PopulationPath = torch.empty(size = (Nsim,steps+1,self.in_features),device="cuda")
        PopulationPath[:,0,:] = z0
        state = z0

        for i in range(1,steps+1):
            DP = D.poisson.Poisson(self.intensity*self.step_size) 
            pois = DP.sample((Nsim,1)).cuda()
            state = state + self.drift(state)*self.step_size + math.sqrt(self.step_size)*torch.normal(0,1,size=(Nsim,self.bd),device="cuda")@self.diffusion+\
                (pois*self.mean + pois**(0.5)*torch.normal(0,1,size=(Nsim,self.in_features),device="cuda")@self.covHalf)*self.jump(state)
            PopulationPath[:,i,:] = state
        return PopulationPath

## Sinkhorn distance
a = gs.SamplesLoss(loss='sinkhorn',p=2,blur=0.01)

In [12]:
def Train(AggregateData,n_steps,bd,intensity,num_hidden,dim_hidden,step_size=0.05,n_epochs=30000,n_critic=4,lr=0.0001,Seed=80):
    ##########################################################################################################################
    # AggregateData: a list length l, each element is a numpy recording the observation at a time point. row:variables, col:sample size
    # nsteps: a list with length l-1
    # step_size: a numerical value
    ##########################################################################################################################

    # random seed
    setup_seed(Seed)
    
    # transform to torch.tensor
    AggregateData = [torch.tensor(ele,dtype=torch.float32,requires_grad = True,device="cuda").t() for ele in AggregateData]
    
    D = len(AggregateData) ## D is the number of observed time points
    print("{0} time points have been observed".format(D))
    
    n_sims = AggregateData[0].shape[0]
    in_features = AggregateData[0].shape[1]

    # create generator and critic
    netG = JumpEulerForwardCuda(in_features,num_hidden,dim_hidden,step_size,intensity,bd).cuda()
    netD = [MLP(in_features,1,dim_hidden,num_hidden).cuda() for _ in range(D-1)]

    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizerSD = [optim.Adam(netD_i.parameters(), lr=lr, betas=(0.5, 0.999)) for netD_i in netD]


    ### training process
    for epoch in range(n_epochs):


        # -------------------
        # train the critic
        # -------------------


        for _ in range(n_critic): 
            fake_data = netG(AggregateData[0],n_sims,n_steps[-1])
            fake = [fake_data[ele] for ele in n_steps]
            
            for SDi in range(D-1):
                optimizerSD[SDi].zero_grad()
                div_gp = compute_gradient_penalty(netD[SDi],AggregateData[SDi+1],fake[SDi])
                d_loss = -torch.mean(netD[SDi](AggregateData[SDi+1]))+torch.mean(netD[SDi](fake[SDi]))+div_gp
                d_loss.backward(retain_graph=True) # retain_graph=True

                optimizerSD[SDi].step()


        # ----------------
        # train the generator
        # ----------------

        for _ in range(1):
            optimizerG.zero_grad()

            fake_data = netG(AggregateData[0],n_sims,n_steps[-1])
            fake = [fake_data[ele] for ele in n_steps]
            g_loss = [-torch.mean(netD[el](fake[el])) for el in range(D-1)]
            g_loss = sum(g_loss)
            g_loss.backward() 

            optimizerG.step()

        if epoch %10==0:
            error = [a(fake[ii],AggregateData[ii+1]).item() for ii in range(D-1)]
            #print("epoch:",epoch,";", "d1_loss:",(-d1_loss+div_gp1).item(),";","d2_loss:",(-d2_loss+div_gp2).item(),";","g_loss:",g_loss.item())
            print("epoch: ",epoch,"training: ",error)
            
    return netG

# An example: the stem cell differentiation dataset with two genes Tagln and Gsn

## data preprocessing

In [4]:
## The data processing steps for the steam cell dataset is the same as that in https://github.com/thashim/population-diffusions. 
FilePath = './'

file_list = ['GSM1599494_ES_d0_main.csv', 'GSM1599497_ES_d2_LIFminus.csv', 'GSM1599498_ES_d4_LIFminus.csv', 'GSM1599499_ES_d7_LIFminus.csv']

table_list = []
for filein in file_list:
    table_list.append(pd.read_csv(FilePath+filein, header=None))

matrix_list = []
gene_names = table_list[0].values[:,0]
for table in table_list:
    matrix_list.append(table.values[:,1:].astype('float32'))

cell_counts = [matrix.shape[1] for matrix in matrix_list]

# Normalization
def normalize_run(mat):
    rpm = np.sum(mat,0)/1e6
    detect_pr = np.sum(mat==0,0)/float(mat.shape[0])
    return np.log(mat*(np.median(detect_pr)/detect_pr)*1.0/rpm + 1.0)

norm_mat = [normalize_run(matrix) for matrix in matrix_list]

# sort genes by Wasserstein distance between D0 and D7
qt_mat = [np.percentile(norm_in,q=np.linspace(0,100,50),axis=1) for norm_in in norm_mat] 
wdiv=np.sum((qt_mat[0]-qt_mat[3])**2,0)
w_order = np.argsort(-wdiv)

wsub = w_order[0:100]

def nmf(X, latent_features, max_iter=100, error_limit=1e-6, fit_error_limit=1e-6, print_iter=200):
    """
    Decompose X to A*Y
    """
    eps = 1e-5
    print('Starting NMF decomposition with {} latent features and {} iterations.'.format(latent_features, max_iter))
    #X = X.toarray()   I am passing in a scipy sparse matrix

    # mask
    mask = np.sign(X)

    # initial matrices. A is random [0,1] and Y is A\X.
    rows, columns = X.shape
    A = np.random.rand(rows, latent_features) 
    A = np.maximum(A, eps)

    Y = linalg.lstsq(A, X)[0]
    Y = np.maximum(Y, eps)

    masked_X = mask * X
    X_est_prev = dot(A, Y)
    for i in range(1, max_iter + 1):
        # ===== updates =====
        # Matlab: A=A.*(((W.*X)*Y')./((W.*(A*Y))*Y'));
        top = dot(masked_X, Y.T)
        bottom = (dot((mask * dot(A, Y)), Y.T)) + eps
        A *= top / bottom

        A = np.maximum(A, eps)
        # print 'A',  np.round(A, 2)

        # Matlab: Y=Y.*((A'*(W.*X))./(A'*(W.*(A*Y))));
        top = dot(A.T, masked_X)
        bottom = dot(A.T, mask * dot(A, Y)) + eps
        Y *= top / bottom
        Y = np.maximum(Y, eps)
        # print 'Y', np.round(Y, 2)


        # ==== evaluation ====
        if i % print_iter == 0 or i == 1 or i == max_iter:
            print('Iteration {}:'.format(i),)
            X_est = dot(A, Y)
            err = mask * (X_est_prev - X_est)
            fit_residual = np.sqrt(np.sum(err ** 2))
            X_est_prev = X_est

            curRes = linalg.norm(mask * (X - X_est), ord='fro')
            print('fit residual', np.round(fit_residual, 4),)
            print('total residual', np.round(curRes, 4))
            if curRes < error_limit or fit_residual < fit_error_limit:
                break
    return A, Y, dot(A,Y)

np.random.seed(0)
norm_imputed = [nmf(normin[wsub,:], latent_features = len(wsub)*4, max_iter=500)[2] for normin in norm_mat]

## choose the 3rd and 7th genes as the training varibales
norm_adj = np.mean(norm_imputed[3],1)[:,np.newaxis]
subvec = np.array([3,7])

gnvec = gene_names[w_order[subvec]]

cov_mat = np.cov(norm_imputed[3][subvec,:])
whiten = np.diag(np.diag(cov_mat)**(-0.5))
unwhiten = np.diag(np.diag(cov_mat)**(0.5))

norm_imputed2 = [np.dot(whiten,(normin - norm_adj)[subvec,:]) for normin in norm_imputed]

train_data = norm_imputed2

Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 4000.1233
total residual 138.7959
Iteration 200:
fit residual 123.6643
total residual 31.8464
Iteration 400:
fit residual 21.7894
total residual 12.6115
Iteration 500:
fit residual 4.3442
total residual 8.6686
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2257.8358
total residual 58.2265
Iteration 200:
fit residual 56.8198
total residual 4.0447
Iteration 400:
fit residual 3.4777
total residual 0.6841
Iteration 500:
fit residual 0.3744
total residual 0.3185
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2925.3664
total residual 80.2657
Iteration 200:
fit residual 78.1148
total residual 6.1559
Iteration 400:
fit residual 5.1661
total residual 1.1917
Iteration 500:
fit residual 0.6183
total residual 0.5917
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
f

In [13]:
netG = Train(train_data,[10,20,35],2,10,4,256,step_size=0.03,n_epochs=20000,n_critic=3,lr=0.0003,Seed=200)

epoch:  0 training:  [6.6159257888793945, 1.76807701587677, 10.826925277709961]
epoch:  10 training:  [0.3668695092201233, 0.9004901647567749, 1.8247014284133911]
epoch:  20 training:  [2.109130859375, 0.9301775097846985, 1.922403335571289]
epoch:  30 training:  [0.5416212677955627, 0.4822661280632019, 2.4327378273010254]
epoch:  40 training:  [1.9299306869506836, 0.38908737897872925, 1.0038197040557861]
epoch:  50 training:  [0.24575933814048767, 0.3655726909637451, 1.3868253231048584]
epoch:  60 training:  [1.231175184249878, 0.5147644877433777, 1.2681608200073242]
epoch:  70 training:  [1.330683708190918, 0.22446539998054504, 1.4914500713348389]
epoch:  80 training:  [0.42016342282295227, 0.2722369134426117, 1.394891381263733]
epoch:  90 training:  [1.5101299285888672, 0.24985086917877197, 1.2965095043182373]
epoch:  100 training:  [1.619311809539795, 0.3462846875190735, 1.119505524635315]
epoch:  110 training:  [0.4004099369049072, 0.27861979603767395, 1.2140696048736572]
epoch:  1

epoch:  980 training:  [1.1382217407226562, 0.14799968898296356, 1.0838236808776855]
epoch:  990 training:  [0.748303234577179, 0.2805342972278595, 0.9876933097839355]
epoch:  1000 training:  [0.49005451798439026, 0.5145635008811951, 1.3675649166107178]
epoch:  1010 training:  [0.6511986255645752, 0.16599291563034058, 1.1448554992675781]
epoch:  1020 training:  [1.7876673936843872, 0.9409806132316589, 1.2503740787506104]
epoch:  1030 training:  [0.4003203511238098, 0.2076723873615265, 1.1529828310012817]
epoch:  1040 training:  [1.1868730783462524, 0.33310845494270325, 1.3162789344787598]
epoch:  1050 training:  [0.764006495475769, 0.1587028205394745, 1.0145801305770874]
epoch:  1060 training:  [1.3889528512954712, 0.33649206161499023, 0.9869415760040283]
epoch:  1070 training:  [1.3358640670776367, 0.3512333035469055, 0.970906138420105]
epoch:  1080 training:  [0.3769077956676483, 0.17537908256053925, 1.353621244430542]
epoch:  1090 training:  [1.830887794494629, 0.9897688031196594, 1

epoch:  1960 training:  [1.3497488498687744, 3.504181146621704, 11.8164644241333]
epoch:  1970 training:  [4.384837627410889, 4.838074684143066, 0.9690749645233154]
epoch:  1980 training:  [7.251648902893066, 0.4140828847885132, 1.1221404075622559]
epoch:  1990 training:  [6.414519309997559, 26.088565826416016, 1.3211029767990112]
epoch:  2000 training:  [5.702828884124756, 3.2861318588256836, 4.36351203918457]
epoch:  2010 training:  [3.266969680786133, 3.237349510192871, 3.2525863647460938]
epoch:  2020 training:  [3.719040632247925, 2.936323404312134, 6.272946357727051]
epoch:  2030 training:  [4.215874671936035, 1.70421302318573, 20.106225967407227]
epoch:  2040 training:  [2.931222915649414, 1.9860665798187256, 1.1045866012573242]
epoch:  2050 training:  [2.061779022216797, 5.727506160736084, 2.8787758350372314]
epoch:  2060 training:  [1.4405227899551392, 1.5510694980621338, 1.3242607116699219]
epoch:  2070 training:  [3.0627951622009277, 10.48626708984375, 0.9099634289741516]
ep

epoch:  2950 training:  [1.4429216384887695, 5.902612686157227, 2.6287059783935547]
epoch:  2960 training:  [4.178805351257324, 2.962470531463623, 5.123430252075195]
epoch:  2970 training:  [2.4271695613861084, 3.479020357131958, 29.430763244628906]
epoch:  2980 training:  [2.862335681915283, 16.215808868408203, 10.567312240600586]
epoch:  2990 training:  [1.4650304317474365, 2.371962070465088, 96.69694519042969]
epoch:  3000 training:  [1.184456706047058, 1.7917509078979492, 0.9302650094032288]
epoch:  3010 training:  [2.164989709854126, 45.649993896484375, 1.2353218793869019]
epoch:  3020 training:  [1.7240164279937744, 1.0997545719146729, 8.937580108642578]
epoch:  3030 training:  [4.94251823425293, 2.6957192420959473, 9.526959419250488]
epoch:  3040 training:  [1.8747748136520386, 4.5136260986328125, 3.838405132293701]
epoch:  3050 training:  [4.673011302947998, 1.4073067903518677, 4.943495750427246]
epoch:  3060 training:  [9.436933517456055, 8.239765167236328, 9.873579025268555]


epoch:  3940 training:  [1.9947282075881958, 0.7127466797828674, 10.365360260009766]
epoch:  3950 training:  [2.0165762901306152, 0.5164665579795837, 11.943855285644531]
epoch:  3960 training:  [1.4752106666564941, 2.156484603881836, 0.9406416416168213]
epoch:  3970 training:  [1.6280277967453003, 2.537691593170166, 20.213016510009766]
epoch:  3980 training:  [2.6168317794799805, 20.809711456298828, 49.64250183105469]
epoch:  3990 training:  [2.5683164596557617, 0.964505672454834, 1.919560432434082]
epoch:  4000 training:  [0.8163362145423889, 4.081417560577393, 8.110889434814453]
epoch:  4010 training:  [11.321426391601562, 6.366119384765625, 1.1328730583190918]
epoch:  4020 training:  [1.4855331182479858, 2.5595171451568604, 0.5075370073318481]
epoch:  4030 training:  [1.8408584594726562, 22.389375686645508, 11.370306015014648]
epoch:  4040 training:  [2.2198643684387207, 1.9752213954925537, 7.173151969909668]
epoch:  4050 training:  [5.286280632019043, 9.61367130279541, 4.3697767257

epoch:  4930 training:  [1.6825740337371826, 1.2765929698944092, 1.7863823175430298]
epoch:  4940 training:  [0.8230628967285156, 1.5227086544036865, 3.5003206729888916]
epoch:  4950 training:  [11.430389404296875, 1.0817660093307495, 2.7067065238952637]
epoch:  4960 training:  [0.8938184976577759, 3.2701849937438965, 1.0261836051940918]
epoch:  4970 training:  [2.052858829498291, 2.3851542472839355, 1.6555650234222412]
epoch:  4980 training:  [1.0371417999267578, 7.30891227722168, 5.3911027908325195]
epoch:  4990 training:  [4.804998397827148, 10.645040512084961, 2.1043272018432617]
epoch:  5000 training:  [4.929975986480713, 9.519840240478516, 54.52717208862305]
epoch:  5010 training:  [4.243919849395752, 17.185558319091797, 5.473175048828125]
epoch:  5020 training:  [4.796537399291992, 1.050585150718689, 3.29896879196167]
epoch:  5030 training:  [1.3844619989395142, 5.746860504150391, 50.889923095703125]
epoch:  5040 training:  [1.9868884086608887, 1.520075798034668, 2.1903743743896

epoch:  5920 training:  [3.7106058597564697, 11.932275772094727, 5.858913421630859]
epoch:  5930 training:  [0.95655357837677, 1.5113506317138672, 16.760229110717773]
epoch:  5940 training:  [1.4791303873062134, 0.6983115673065186, 2.8260231018066406]
epoch:  5950 training:  [1.8762809038162231, 5.584278106689453, 1.3879072666168213]
epoch:  5960 training:  [12.380517959594727, 0.45505285263061523, 1.0171951055526733]
epoch:  5970 training:  [4.182701110839844, 0.8891176581382751, 7.103206634521484]
epoch:  5980 training:  [1.9227275848388672, 3.233983278274536, 3.5787148475646973]
epoch:  5990 training:  [1.9919458627700806, 0.528354823589325, 4.76118278503418]
epoch:  6000 training:  [1.0959689617156982, 1.1959973573684692, 2.5466184616088867]
epoch:  6010 training:  [1.598425269126892, 0.5782017707824707, 3.0074660778045654]
epoch:  6020 training:  [21.620948791503906, 1.412339687347412, 0.5532914400100708]
epoch:  6030 training:  [1.9050240516662598, 0.4492960274219513, 0.845499455

epoch:  6900 training:  [1.7462869882583618, 0.676484227180481, 0.89695805311203]
epoch:  6910 training:  [3.0565643310546875, 0.460277795791626, 4.706910133361816]
epoch:  6920 training:  [2.0823278427124023, 2.9860827922821045, 7.4249372482299805]
epoch:  6930 training:  [10.34062385559082, 4.631650447845459, 2.0627007484436035]
epoch:  6940 training:  [3.051637887954712, 4.367724895477295, 7.503314971923828]
epoch:  6950 training:  [12.099054336547852, 3.7190444469451904, 12.393540382385254]
epoch:  6960 training:  [2.011002540588379, 3.280388116836548, 2.2497222423553467]
epoch:  6970 training:  [2.5639519691467285, 0.8321906328201294, 15.04002571105957]
epoch:  6980 training:  [2.2372055053710938, 1.1908855438232422, 34.953548431396484]
epoch:  6990 training:  [1.123337984085083, 5.2214155197143555, 5.287186622619629]
epoch:  7000 training:  [54.08306121826172, 44.15300369262695, 7.763449668884277]
epoch:  7010 training:  [6.277371883392334, 1.621683120727539, 2.7892942428588867]


epoch:  7890 training:  [3.3354339599609375, 2.064922332763672, 2.9141364097595215]
epoch:  7900 training:  [3.0488016605377197, 14.53255844116211, 14.66991901397705]
epoch:  7910 training:  [1.4485325813293457, 6.681770324707031, 9.723665237426758]
epoch:  7920 training:  [2.544355630874634, 1.2717398405075073, 3.8894009590148926]
epoch:  7930 training:  [0.7651568651199341, 16.8439884185791, 3.1523332595825195]
epoch:  7940 training:  [1.0153573751449585, 2.735612392425537, 4.67436408996582]
epoch:  7950 training:  [6.974143028259277, 2.1136791706085205, 1.4567599296569824]
epoch:  7960 training:  [5.135624885559082, 12.608113288879395, 2.9571456909179688]
epoch:  7970 training:  [3.246976852416992, 5.12313175201416, 13.687966346740723]
epoch:  7980 training:  [12.372011184692383, 4.045035362243652, 5.262625217437744]
epoch:  7990 training:  [13.311116218566895, 5.9542694091796875, 14.126333236694336]
epoch:  8000 training:  [1.5701302289962769, 2.61326003074646, 1.1004735231399536]


epoch:  8880 training:  [1.8627032041549683, 2.471179962158203, 4.618161201477051]
epoch:  8890 training:  [0.8921136856079102, 5.3868303298950195, 9.900153160095215]
epoch:  8900 training:  [1.301367998123169, 3.3331408500671387, 7.689057350158691]
epoch:  8910 training:  [7.690708637237549, 2.94278621673584, 9.737329483032227]
epoch:  8920 training:  [0.953376293182373, 2.414201021194458, 2.7557849884033203]
epoch:  8930 training:  [3.0301973819732666, 2.7451260089874268, 8.827923774719238]
epoch:  8940 training:  [2.183831214904785, 1.743888258934021, 3.663250684738159]
epoch:  8950 training:  [1.6569836139678955, 0.8554233312606812, 1.2243684530258179]
epoch:  8960 training:  [3.101107597351074, 1.1208508014678955, 18.684967041015625]
epoch:  8970 training:  [1.5479050874710083, 14.146393775939941, 1.8933497667312622]
epoch:  8980 training:  [4.962852478027344, 4.515117168426514, 3.516981601715088]
epoch:  8990 training:  [1.4893114566802979, 0.44432199001312256, 2.389683246612549]

epoch:  9860 training:  [19.88072395324707, 1.1719176769256592, 4.0871477127075195]
epoch:  9870 training:  [2.2346978187561035, 4.774773597717285, 1.2826999425888062]
epoch:  9880 training:  [3.268890857696533, 0.6908183097839355, 3.500967502593994]
epoch:  9890 training:  [3.4761104583740234, 2.8144283294677734, 3.08170485496521]
epoch:  9900 training:  [0.9677826166152954, 7.157406330108643, 4.252724647521973]
epoch:  9910 training:  [1.1643086671829224, 1.999945878982544, 23.180028915405273]
epoch:  9920 training:  [1.8145408630371094, 4.394622802734375, 21.415218353271484]
epoch:  9930 training:  [3.995708465576172, 3.9656448364257812, 3.6548895835876465]
epoch:  9940 training:  [1.6496939659118652, 1.3778795003890991, 6.709253311157227]
epoch:  9950 training:  [17.277278900146484, 0.6734658479690552, 5.153963088989258]
epoch:  9960 training:  [9.184036254882812, 35.90946960449219, 4.0085530281066895]
epoch:  9970 training:  [3.82633638381958, 1.3424828052520752, 5.007400512695312

epoch:  10840 training:  [1.2944989204406738, 6.678375244140625, 43.42818832397461]
epoch:  10850 training:  [1.2450326681137085, 3.240907669067383, 4.5373735427856445]
epoch:  10860 training:  [2.2952558994293213, 4.210056304931641, 2.4957940578460693]
epoch:  10870 training:  [3.1299664974212646, 0.5765343308448792, 2.9922401905059814]
epoch:  10880 training:  [1.1826103925704956, 7.570329666137695, 2.0736782550811768]
epoch:  10890 training:  [2.026790142059326, 0.6230930089950562, 2.8664913177490234]
epoch:  10900 training:  [2.273378372192383, 8.235441207885742, 15.238470077514648]
epoch:  10910 training:  [7.874642372131348, 3.9385247230529785, 57.288475036621094]
epoch:  10920 training:  [0.7438088655471802, 0.8471258282661438, 16.707998275756836]
epoch:  10930 training:  [5.651668071746826, 8.458433151245117, 2.591657876968384]
epoch:  10940 training:  [6.482451915740967, 8.515657424926758, 42.112098693847656]
epoch:  10950 training:  [38.47602844238281, 5.315077781677246, 2.86

epoch:  11820 training:  [2.844926357269287, 22.26758575439453, 11.35707950592041]
epoch:  11830 training:  [3.8233513832092285, 0.3601227402687073, 1.4865446090698242]
epoch:  11840 training:  [2.3146543502807617, 20.917522430419922, 1.8484883308410645]
epoch:  11850 training:  [27.190034866333008, 5.494978904724121, 3.66548490524292]
epoch:  11860 training:  [0.9558472633361816, 2.8286795616149902, 0.85092693567276]
epoch:  11870 training:  [2.018256187438965, 2.583174228668213, 2.190946102142334]
epoch:  11880 training:  [2.256272315979004, 1.1440768241882324, 5.325845718383789]
epoch:  11890 training:  [1.1801636219024658, 23.546764373779297, 8.699136734008789]
epoch:  11900 training:  [1.540901780128479, 4.386199951171875, 44.12799072265625]
epoch:  11910 training:  [2.173576593399048, 6.555751323699951, 13.746191024780273]
epoch:  11920 training:  [2.101752758026123, 2.7281198501586914, 5.2444562911987305]
epoch:  11930 training:  [1.2195019721984863, 7.170230865478516, 4.5972390

epoch:  12800 training:  [0.6134900450706482, 6.098824501037598, 1.7519009113311768]
epoch:  12810 training:  [1.7719659805297852, 32.796634674072266, 27.447765350341797]
epoch:  12820 training:  [1.784053087234497, 5.803743362426758, 40.34601593017578]
epoch:  12830 training:  [4.371877193450928, 5.98539924621582, 3.9009170532226562]
epoch:  12840 training:  [2.9169654846191406, 19.46066665649414, 17.34964370727539]
epoch:  12850 training:  [0.9738113880157471, 3.653334856033325, 7.901237964630127]
epoch:  12860 training:  [2.6200428009033203, 4.336814880371094, 3.722947120666504]
epoch:  12870 training:  [1.447892665863037, 5.0865912437438965, 2.227065324783325]
epoch:  12880 training:  [1.46040678024292, 4.232125282287598, 15.671239852905273]
epoch:  12890 training:  [14.720991134643555, 5.119372844696045, 9.232087135314941]
epoch:  12900 training:  [1.353642463684082, 2.4466896057128906, 9.060722351074219]
epoch:  12910 training:  [2.4566221237182617, 5.427817344665527, 2.350618362

epoch:  13780 training:  [5.543304920196533, 1.408308744430542, 4.2503862380981445]
epoch:  13790 training:  [3.802361011505127, 1.635861873626709, 2.703625202178955]
epoch:  13800 training:  [2.0216081142425537, 2.0752158164978027, 5.0547404289245605]
epoch:  13810 training:  [2.5055935382843018, 1.2299606800079346, 18.638593673706055]
epoch:  13820 training:  [54.63885498046875, 1.6051039695739746, 44.60923767089844]
epoch:  13830 training:  [5.416240692138672, 2.635955333709717, 46.794803619384766]
epoch:  13840 training:  [3.451725482940674, 1.9256592988967896, 2.4745101928710938]
epoch:  13850 training:  [3.2381234169006348, 10.269673347473145, 1.5960144996643066]
epoch:  13860 training:  [58.036277770996094, 36.37858200073242, 5.703909873962402]
epoch:  13870 training:  [1.552304983139038, 45.65479278564453, 1.931274175643921]
epoch:  13880 training:  [1.8486775159835815, 2.2960333824157715, 8.980094909667969]
epoch:  13890 training:  [24.815210342407227, 6.06796407699585, 3.2277

epoch:  14760 training:  [0.6813153624534607, 6.466082572937012, 29.02437973022461]
epoch:  14770 training:  [11.060998916625977, 5.826216220855713, 1.9231748580932617]
epoch:  14780 training:  [33.185176849365234, 1.5560598373413086, 4.458796501159668]
epoch:  14790 training:  [1.8447439670562744, 0.995299756526947, 3.71284556388855]
epoch:  14800 training:  [1.598402738571167, 2.427377223968506, 2.6181774139404297]
epoch:  14810 training:  [11.292984962463379, 0.8814115524291992, 4.119428634643555]
epoch:  14820 training:  [2.976278305053711, 2.8210906982421875, 3.878199338912964]
epoch:  14830 training:  [2.5550546646118164, 43.28765106201172, 4.466610908508301]
epoch:  14840 training:  [1.3552491664886475, 6.4205522537231445, 3.2705085277557373]
epoch:  14850 training:  [30.763404846191406, 3.143348217010498, 70.41323852539062]
epoch:  14860 training:  [0.8519220352172852, 10.440827369689941, 18.219215393066406]
epoch:  14870 training:  [1.1118950843811035, 54.076133728027344, 12.3

epoch:  15740 training:  [11.996467590332031, 4.156796455383301, 6.009592056274414]
epoch:  15750 training:  [1.6797300577163696, 2.6753201484680176, 3.9029862880706787]
epoch:  15760 training:  [2.3516430854797363, 3.968461513519287, 21.869569778442383]
epoch:  15770 training:  [11.668220520019531, 76.48101806640625, 10.246578216552734]
epoch:  15780 training:  [0.8361143469810486, 2.3777475357055664, 10.585258483886719]
epoch:  15790 training:  [24.24291229248047, 22.525146484375, 6.7866973876953125]
epoch:  15800 training:  [1.717206358909607, 16.36799430847168, 78.51799774169922]
epoch:  15810 training:  [3.8843297958374023, 0.8671650290489197, 3.8803932666778564]
epoch:  15820 training:  [2.2365403175354004, 41.940128326416016, 4.979948997497559]
epoch:  15830 training:  [4.1335649490356445, 14.405323028564453, 4.906198501586914]
epoch:  15840 training:  [2.4077138900756836, 1.3951834440231323, 10.482152938842773]
epoch:  15850 training:  [22.50442886352539, 22.48270034790039, 1.8

KeyboardInterrupt: 