In [1]:
# -*- coding: utf-8 -*-

import sys,os
import numpy as np
import pandas as pd
import scipy.stats as stats
from scipy import linalg
from numpy import dot
import geomloss as gs

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as D
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.autograd import grad
import torch.utils.data
import torch.backends.cudnn as cudnn
from torch.nn.modules import Linear
from torch.autograd.functional import jacobian,hessian,vjp,vhp,hvp

import random
import math

FilePath = '../../'

file_list = ['GSM1599494_ES_d0_main.csv', 'GSM1599497_ES_d2_LIFminus.csv', 'GSM1599498_ES_d4_LIFminus.csv', 'GSM1599499_ES_d7_LIFminus.csv']

table_list = []
for filein in file_list:
    table_list.append(pd.read_csv(FilePath+filein, header=None))

matrix_list = []
gene_names = table_list[0].values[:,0]
for table in table_list:
    matrix_list.append(table.values[:,1:].astype('float32'))

cell_counts = [matrix.shape[1] for matrix in matrix_list]

def normalize_run(mat):
    rpm = np.sum(mat,0)/1e6
    detect_pr = np.sum(mat==0,0)/float(mat.shape[0])
    return np.log(mat*(np.median(detect_pr)/detect_pr)*1.0/rpm + 1.0)

norm_mat = [normalize_run(matrix) for matrix in matrix_list]

qt_mat = [np.percentile(norm_in,q=np.linspace(0,100,50),axis=1) for norm_in in norm_mat] 
wdiv=np.sum((qt_mat[0]-qt_mat[3])**2,0)
w_order = np.argsort(-wdiv)

wsub = w_order[0:100]


def nmf(X, latent_features, max_iter=100, error_limit=1e-6, fit_error_limit=1e-6, print_iter=200):
    """
    Decompose X to A*Y
    """
    eps = 1e-5
    print('Starting NMF decomposition with {} latent features and {} iterations.'.format(latent_features, max_iter))
    #X = X.toarray()   I am passing in a scipy sparse matrix

    # mask
    mask = np.sign(X)

    # initial matrices. A is random [0,1] and Y is A\X.
    rows, columns = X.shape
    A = np.random.rand(rows, latent_features)
    A = np.maximum(A, eps)

    Y = linalg.lstsq(A, X)[0]
    Y = np.maximum(Y, eps)

    masked_X = mask * X
    X_est_prev = dot(A, Y)
    for i in range(1, max_iter + 1):
        # ===== updates =====
        # Matlab: A=A.*(((W.*X)*Y')./((W.*(A*Y))*Y'));
        top = dot(masked_X, Y.T)
        bottom = (dot((mask * dot(A, Y)), Y.T)) + eps
        A *= top / bottom

        A = np.maximum(A, eps)
        # print 'A',  np.round(A, 2)

        # Matlab: Y=Y.*((A'*(W.*X))./(A'*(W.*(A*Y))));
        top = dot(A.T, masked_X)
        bottom = dot(A.T, mask * dot(A, Y)) + eps
        Y *= top / bottom
        Y = np.maximum(Y, eps)
        # print 'Y', np.round(Y, 2)


        # ==== evaluation ====
        if i % print_iter == 0 or i == 1 or i == max_iter:
            print('Iteration {}:'.format(i),)
            X_est = dot(A, Y)
            err = mask * (X_est_prev - X_est)
            fit_residual = np.sqrt(np.sum(err ** 2))
            X_est_prev = X_est

            curRes = linalg.norm(mask * (X - X_est), ord='fro')
            print('fit residual', np.round(fit_residual, 4),)
            print('total residual', np.round(curRes, 4))
            if curRes < error_limit or fit_residual < fit_error_limit:
                break
    return A, Y, dot(A,Y)

np.random.seed(0)
norm_imputed = [nmf(normin[wsub,:], latent_features = len(wsub)*4, max_iter=500)[2] for normin in norm_mat]

norm_adj = np.mean(norm_imputed[3],1)[:,np.newaxis]
subvec = np.array([0,1,2,3,4,5,6,7,8,9])

gnvec = gene_names[w_order[subvec]]

cov_mat = np.cov(norm_imputed[3][subvec,:])
whiten = np.diag(np.diag(cov_mat)**(-0.5))
unwhiten = np.diag(np.diag(cov_mat)**(0.5))

norm_imputed2 = [np.dot(whiten,(normin - norm_adj)[subvec,:]) for normin in norm_imputed]


class MLP(nn.Module):

    def __init__(self, dim_in, dim_out, dim_hidden=64, num_hidden=0, activation=nn.LeakyReLU()):
        super(MLP, self).__init__()

        if num_hidden == 0:
            self.linears = nn.ModuleList([nn.Linear(dim_in, dim_out)])
        elif num_hidden >= 1:
            self.linears = nn.ModuleList() 
            self.linears.append(nn.Linear(dim_in, dim_hidden))
            self.linears.extend([nn.Linear(dim_hidden, dim_hidden) for _ in range(num_hidden-1)])
            self.linears.append(nn.Linear(dim_hidden, dim_out))
        else:
            raise Exception('number of hidden layers must be positive')

        for m in self.linears:
            #nn.init.xavier_uniform_(m.weight)
            nn.init.xavier_normal_(m.weight)
            nn.init.uniform_(m.bias,a=-0.1,b=0.1)
            #nn.init.constant_(m.bias,0)
 
        self.activation = activation # 

    def forward(self, x):
        for m in self.linears[:-1]:
            x = self.activation(m(x))
            #x = F.dropout(x,p=0.5)

        return self.linears[-1](x)


def compute_gradient_penalty(D, real_sample, fake_sample,k,p):
    real_samples = real_sample.requires_grad_(True)
    fake_samples = fake_sample.requires_grad_(True)

    real_validity = D(real_samples)
    fake_validity = D(fake_samples)

    real_grad_out = torch.ones((real_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    real_grad = grad(
        real_validity, real_samples, real_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    real_grad_norm = real_grad.view(real_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    fake_grad_out = torch.ones((fake_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    fake_grad = grad(
        fake_validity, fake_samples, fake_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    fake_grad_norm = fake_grad.view(fake_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    return (torch.sum(real_grad_norm) + torch.sum(fake_grad_norm)) * k / (real_sample.shape[0]+fake_sample.shape[0])

class JumpEulerForwardCuda(nn.Module):
    def __init__(self,in_features,num_hidden,dim_hidden,step_size):
        super(JumpEulerForwardCuda,self).__init__()

        self.drift = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.intensity = torch.tensor(intensity,device="cuda")
        self.mean = nn.Parameter(0.01*torch.ones(in_features))
        self.covHalf = nn.Parameter(0.08*torch.eye(in_features))
        self.diffusion = nn.Parameter(torch.ones(bd,10))
        self.in_features = in_features
        self.jump = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.step_size = step_size

    def forward(self,z0,Nsim,steps):

        PopulationPath = torch.empty(size = (Nsim,steps+1,self.in_features),device="cuda")
        PopulationPath[:,0,:] = z0
        state = z0

        for i in range(1,steps+1):
            DP = D.poisson.Poisson(self.intensity*self.step_size) 
            pois = DP.sample((Nsim,1)).cuda()
            state = state + self.drift(state)*self.step_size + math.sqrt(self.step_size)*torch.normal(0,1,size=(Nsim,bd),device="cuda")@self.diffusion+\
                (pois*self.mean + pois**(0.5)*torch.normal(0,1,size=(Nsim,self.in_features),device="cuda")@self.covHalf)*self.jump(state)
            PopulationPath[:,i,:] = state
        return PopulationPath


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
sed = 200
setup_seed(sed)

a=gs.SamplesLoss(loss='sinkhorn',p=2,blur=0.01)


train_data = norm_imputed2

train0 = torch.tensor(train_data[0],dtype=torch.float32,requires_grad = True,device="cuda").t()
train2 = torch.tensor(train_data[1],dtype=torch.float32,requires_grad = True,device="cuda").t()
train4 = torch.tensor(train_data[2],dtype=torch.float32,requires_grad = True,device="cuda").t()
train7 = torch.tensor(train_data[3],dtype=torch.float32,requires_grad = True,device="cuda").t()

train0 = train0+0.1*torch.normal(0,1,size=train0.shape,device="cuda")
train2 = train2+0.1*torch.normal(0,1,size=train2.shape,device="cuda")
train4 = train4+0.1*torch.normal(0,1,size=train4.shape,device="cuda")
train7 = train7+0.1*torch.normal(0,1,size=train7.shape,device="cuda")


intensity = 10
lr = 0.0003
step_size = 0.03
kuan = 256
ceng = 4
bd = 2
n_critic = 3
k = 2
p = 6

n_sims = train0.shape[0]
in_features = train0.shape[1]
n_steps = [10,20,35]


netG = JumpEulerForwardCuda(10,ceng,kuan,step_size).cuda()
netD1 = MLP(10,1,dim_hidden=kuan,num_hidden=ceng).cuda()
netD2 = MLP(10,1,dim_hidden=kuan,num_hidden=ceng).cuda()
netD3 = MLP(10,1,dim_hidden=kuan,num_hidden=ceng).cuda()


optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerSD1 = optim.Adam(netD1.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerSD2 = optim.Adam(netD2.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerSD3 = optim.Adam(netD3.parameters(), lr=lr, betas=(0.5, 0.999))

n_epochs =  20000

wd = []
for epoch in range(n_epochs):
  


    for _ in range(n_critic):
        fake_data = netG(train0,n_sims,n_steps[2])
        fake1 = fake_data[:,n_steps[0],:]
        fake2 = fake_data[:,n_steps[1],:]
        fake3 = fake_data[:,n_steps[2],:]

        optimizerSD1.zero_grad()

        div_gp1 = compute_gradient_penalty(netD1,train2,fake1,k,p)
        d1_loss = -torch.mean(netD1(train2))+torch.mean(netD1(fake1))+div_gp1
        d1_loss.backward(retain_graph=True) # retain_graph=True

        optimizerSD1.step()

        optimizerSD2.zero_grad()
        
        div_gp2 = compute_gradient_penalty(netD2,train4,fake2,k,p)
        d2_loss = -torch.mean(netD2(train4))+torch.mean(netD2(fake2))+div_gp2
        d2_loss.backward(retain_graph=True)

        optimizerSD2.step()
        
        
        optimizerSD3.zero_grad()
        
        div_gp3 = compute_gradient_penalty(netD3,train7,fake3,k,p)
        d3_loss = -torch.mean(netD3(train7))+torch.mean(netD3(fake3))+div_gp3
        d3_loss.backward(retain_graph=True)

        optimizerSD3.step()
        

    
    for _ in range(1):
        optimizerG.zero_grad()
        fake_data = netG(train0,n_sims,n_steps[2])
        fake1 = fake_data[:,n_steps[0],:]
        fake2 = fake_data[:,n_steps[1],:]
        fake3 = fake_data[:,n_steps[2],:]
        g_loss = -torch.mean(netD1(fake1))-torch.mean(netD2(fake2))-torch.mean(netD3(fake3))
        g_loss.backward() 

        optimizerG.step()

    if epoch %10==0:
        x1 = a(fake_data[:,n_steps[0],:],train2).item()
        x2 = a(fake_data[:,n_steps[1],:],train4).item()
        x3 = a(fake_data[:,n_steps[2],:],train7).item()
        
        wd.append(x1+x2+x3)
        
        print("training error: ",x1," and ",x2," and ",x3)
        print("total error: ",wd[-1])

Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 4000.1233
total residual 138.7959
Iteration 200:
fit residual 123.6643
total residual 31.8464
Iteration 400:
fit residual 21.7894
total residual 12.6115
Iteration 500:
fit residual 4.3442
total residual 8.6686
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2257.8358
total residual 58.2265
Iteration 200:
fit residual 56.8198
total residual 4.0447
Iteration 400:
fit residual 3.4777
total residual 0.6841
Iteration 500:
fit residual 0.3744
total residual 0.3185
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2925.3664
total residual 80.2657
Iteration 200:
fit residual 78.1148
total residual 6.1559
Iteration 400:
fit residual 5.1661
total residual 1.1917
Iteration 500:
fit residual 0.6183
total residual 0.5917
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
f

training error:  1.7011101245880127  and  2.2476353645324707  and  1.7428810596466064
total error:  5.69162654876709
training error:  1.721401333808899  and  2.3279826641082764  and  1.611617088317871
total error:  5.661001086235046
training error:  1.703049659729004  and  2.387953758239746  and  1.9327341318130493
total error:  6.023737549781799
training error:  1.6429194211959839  and  2.144483804702759  and  1.5250988006591797
total error:  5.312502026557922
training error:  1.6946673393249512  and  2.2536120414733887  and  1.5457133054733276
total error:  5.4939926862716675
training error:  1.6849839687347412  and  2.7425713539123535  and  2.1006228923797607
total error:  6.5281782150268555
training error:  1.845886468887329  and  3.0585598945617676  and  1.5953292846679688
total error:  6.499775648117065
training error:  1.7131321430206299  and  3.988844633102417  and  2.4839773178100586
total error:  8.185954093933105
training error:  1.8092446327209473  and  3.6468639373779297  

training error:  1.388885736465454  and  2.369910717010498  and  1.4723564386367798
total error:  5.231152892112732
training error:  1.323105812072754  and  1.9557515382766724  and  1.5739493370056152
total error:  4.8528066873550415
training error:  1.3234117031097412  and  2.0262324810028076  and  1.479233741760254
total error:  4.828877925872803
training error:  1.3494226932525635  and  1.9852122068405151  and  1.4553626775741577
total error:  4.789997577667236
training error:  1.3052940368652344  and  2.2015583515167236  and  1.590705156326294
total error:  5.097557544708252
training error:  1.3390984535217285  and  2.522245407104492  and  1.5170924663543701
total error:  5.378436326980591
training error:  1.3283052444458008  and  2.1996772289276123  and  1.6121844053268433
total error:  5.140166878700256
training error:  1.283340334892273  and  2.2637763023376465  and  1.5765883922576904
total error:  5.12370502948761
training error:  1.3050131797790527  and  2.1493589878082275  a

training error:  1.2382625341415405  and  1.842114806175232  and  1.4792684316635132
total error:  4.559645771980286
training error:  1.2185481786727905  and  1.8434884548187256  and  1.4447506666183472
total error:  4.506787300109863
training error:  1.2255820035934448  and  1.8285335302352905  and  1.4103959798812866
total error:  4.464511513710022
training error:  1.230910062789917  and  1.772304892539978  and  1.4131287336349487
total error:  4.416343688964844
training error:  1.1909915208816528  and  1.8601007461547852  and  1.2593872547149658
total error:  4.310479521751404
training error:  1.204993724822998  and  1.881202220916748  and  1.4790095090866089
total error:  4.565205454826355
training error:  1.2116484642028809  and  1.8271564245224  and  1.379004955291748
total error:  4.417809844017029
training error:  1.1799834966659546  and  1.8964927196502686  and  1.3821077346801758
total error:  4.458583950996399
training error:  1.1827647686004639  and  1.7544901371002197  and

training error:  1.1419776678085327  and  1.8048418760299683  and  1.3770833015441895
total error:  4.32390284538269
training error:  1.1107831001281738  and  1.8029401302337646  and  1.2556047439575195
total error:  4.169327974319458
training error:  1.1197806596755981  and  1.8206379413604736  and  1.3343799114227295
total error:  4.274798512458801
training error:  1.125493049621582  and  1.6406711339950562  and  1.253355860710144
total error:  4.019520044326782
training error:  1.1225330829620361  and  1.6401667594909668  and  1.2924052476882935
total error:  4.055105090141296
training error:  1.0880311727523804  and  1.7003262042999268  and  1.275097131729126
total error:  4.063454508781433
training error:  1.1091110706329346  and  1.6247878074645996  and  1.2767285108566284
total error:  4.010627388954163
training error:  1.1131985187530518  and  1.6765235662460327  and  1.2772393226623535
total error:  4.066961407661438
training error:  1.1287081241607666  and  1.8007919788360596

training error:  1.089461326599121  and  1.6288198232650757  and  1.2641561031341553
total error:  3.982437252998352
training error:  1.1010892391204834  and  1.5396218299865723  and  1.247085690498352
total error:  3.8877967596054077
training error:  1.095976710319519  and  1.5253914594650269  and  1.2092905044555664
total error:  3.8306586742401123
training error:  1.0600839853286743  and  1.8766369819641113  and  1.2565028667449951
total error:  4.193223834037781
training error:  1.0790317058563232  and  1.5622378587722778  and  1.2273986339569092
total error:  3.8686681985855103
training error:  1.0479933023452759  and  1.5503958463668823  and  1.2713379859924316
total error:  3.86972713470459
training error:  1.0633140802383423  and  1.4948043823242188  and  1.2623090744018555
total error:  3.8204275369644165
training error:  1.0474002361297607  and  1.5202728509902954  and  1.1977450847625732
total error:  3.7654181718826294
training error:  1.0461562871932983  and  1.50060343742

training error:  1.012222170829773  and  1.4677387475967407  and  1.2191723585128784
total error:  3.699133276939392
training error:  1.0270413160324097  and  1.5093976259231567  and  1.2333152294158936
total error:  3.76975417137146
training error:  1.0307295322418213  and  1.6513617038726807  and  1.1922566890716553
total error:  3.8743479251861572
training error:  1.0569535493850708  and  1.4838511943817139  and  1.1748242378234863
total error:  3.715628981590271
training error:  1.0282586812973022  and  1.5228501558303833  and  1.1961021423339844
total error:  3.74721097946167
training error:  1.0325318574905396  and  1.5252397060394287  and  1.21171236038208
total error:  3.7694839239120483
training error:  1.0529303550720215  and  1.5137537717819214  and  1.1812458038330078
total error:  3.7479299306869507
training error:  1.032060980796814  and  1.5032000541687012  and  1.2183351516723633
total error:  3.7535961866378784
training error:  1.0354633331298828  and  1.52186036109924

training error:  1.0148245096206665  and  1.4470826387405396  and  1.201918363571167
total error:  3.663825511932373
training error:  1.0097581148147583  and  1.505441427230835  and  1.148956298828125
total error:  3.6641558408737183
training error:  1.0109916925430298  and  1.4744303226470947  and  1.1986525058746338
total error:  3.6840745210647583
training error:  1.0165632963180542  and  1.4660077095031738  and  1.2052339315414429
total error:  3.687804937362671
training error:  0.9944055676460266  and  1.5638887882232666  and  1.1978814601898193
total error:  3.7561758160591125
training error:  1.019361972808838  and  1.4812407493591309  and  1.191866159439087
total error:  3.6924688816070557
training error:  1.016599178314209  and  1.5387153625488281  and  1.2539374828338623
total error:  3.8092520236968994
training error:  0.9859237670898438  and  1.4339845180511475  and  1.1941529512405396
total error:  3.6140612363815308
training error:  1.0287951231002808  and  1.430552959442

training error:  0.9589096903800964  and  1.4166338443756104  and  1.1772722005844116
total error:  3.5528157353401184
training error:  0.9674456119537354  and  1.4316189289093018  and  1.1514384746551514
total error:  3.5505030155181885
training error:  0.9869232177734375  and  1.4312430620193481  and  1.1386833190917969
total error:  3.5568495988845825
training error:  0.9816244840621948  and  1.3975629806518555  and  1.2302182912826538
total error:  3.609405755996704
training error:  0.9758148193359375  and  1.4352954626083374  and  1.1779088973999023
total error:  3.5890191793441772
training error:  0.9913684129714966  and  1.4439455270767212  and  1.2276835441589355
total error:  3.6629974842071533
training error:  0.9576026201248169  and  1.5395710468292236  and  1.1876964569091797
total error:  3.68487012386322
training error:  0.963894248008728  and  1.5276610851287842  and  1.1850167512893677
total error:  3.67657208442688
training error:  1.0142972469329834  and  1.5763292312

training error:  0.9668834209442139  and  1.4016822576522827  and  1.1753065586090088
total error:  3.5438722372055054
training error:  0.9846470355987549  and  1.5263291597366333  and  1.1425807476043701
total error:  3.6535569429397583
training error:  0.9453944563865662  and  1.4156566858291626  and  1.1049774885177612
total error:  3.46602863073349
training error:  0.9497134685516357  and  1.3718472719192505  and  1.1910600662231445
total error:  3.5126208066940308
training error:  0.9278929233551025  and  1.3996753692626953  and  1.1936371326446533
total error:  3.521205425262451
training error:  0.9538551568984985  and  1.4407165050506592  and  1.117540717124939
total error:  3.5121123790740967
training error:  0.9637132287025452  and  1.645376205444336  and  1.1971129179000854
total error:  3.8062023520469666
training error:  0.9519397020339966  and  1.4571940898895264  and  1.1587458848953247
total error:  3.5678796768188477
training error:  0.9455035328865051  and  1.489649176

training error:  0.9029238224029541  and  1.5822498798370361  and  1.167367696762085
total error:  3.652541399002075
training error:  0.9144394397735596  and  1.4793682098388672  and  1.18202543258667
total error:  3.5758330821990967
training error:  0.9162555932998657  and  1.3381714820861816  and  1.1455118656158447
total error:  3.399938941001892
training error:  0.9210065603256226  and  1.4602292776107788  and  1.1437582969665527
total error:  3.524994134902954
training error:  0.9142799377441406  and  1.3954466581344604  and  1.2142126560211182
total error:  3.5239392518997192
training error:  0.9186516404151917  and  1.5287641286849976  and  1.4116857051849365
total error:  3.8591014742851257
training error:  0.9084105491638184  and  1.5992493629455566  and  1.1824543476104736
total error:  3.6901142597198486
training error:  0.9036098718643188  and  1.5086698532104492  and  1.1216332912445068
total error:  3.533913016319275
training error:  0.9177913069725037  and  1.52447068691

training error:  0.909878134727478  and  1.521561861038208  and  1.1022142171859741
total error:  3.53365421295166
training error:  0.899621844291687  and  1.5044662952423096  and  1.134099006652832
total error:  3.5381871461868286
training error:  0.8764908313751221  and  1.2975294589996338  and  1.1571788787841797
total error:  3.3311991691589355
training error:  0.8895023465156555  and  1.3837974071502686  and  1.1069821119308472
total error:  3.3802818655967712
training error:  0.8854333162307739  and  1.3541781902313232  and  1.0785034894943237
total error:  3.318114995956421
training error:  0.879891037940979  and  1.2989654541015625  and  1.1173169612884521
total error:  3.2961734533309937
training error:  0.8993346691131592  and  1.4476220607757568  and  1.1306941509246826
total error:  3.4776508808135986
training error:  0.8838123679161072  and  1.3831844329833984  and  1.0870845317840576
total error:  3.3540813326835632
training error:  0.8771039247512817  and  1.371590495109

training error:  0.855925977230072  and  1.2947547435760498  and  1.099731206893921
total error:  3.2504119277000427
training error:  0.8585669994354248  and  1.3287866115570068  and  1.0876233577728271
total error:  3.274976968765259
training error:  0.8573178052902222  and  1.3453702926635742  and  1.1730711460113525
total error:  3.375759243965149
training error:  0.8658091425895691  and  1.6196496486663818  and  1.1077666282653809
total error:  3.593225419521332
training error:  0.8720710873603821  and  1.3185548782348633  and  1.122924566268921
total error:  3.3135505318641663
training error:  0.8911619782447815  and  1.3267661333084106  and  1.0904691219329834
total error:  3.3083972334861755
training error:  0.8947200179100037  and  1.431880235671997  and  1.2162885665893555
total error:  3.542888820171356
training error:  0.8526381254196167  and  1.2992026805877686  and  1.0858688354492188
total error:  3.237709641456604
training error:  0.8456621766090393  and  1.3250992298126

training error:  0.8747125864028931  and  1.5494056940078735  and  1.0913243293762207
total error:  3.5154426097869873
training error:  0.8833714723587036  and  1.3803167343139648  and  1.094362497329712
total error:  3.3580507040023804
training error:  0.8332467675209045  and  1.3432462215423584  and  1.0363333225250244
total error:  3.2128263115882874
training error:  0.8203883171081543  and  1.3921444416046143  and  1.0895135402679443
total error:  3.302046298980713
training error:  0.869623064994812  and  1.3294974565505981  and  1.0972356796264648
total error:  3.296356201171875
training error:  0.8322729468345642  and  1.30181884765625  and  1.111310362815857
total error:  3.245402157306671
training error:  0.8331079483032227  and  1.3097400665283203  and  1.0663138628005981
total error:  3.209161877632141
training error:  0.8163937330245972  and  1.3706691265106201  and  1.1020909547805786
total error:  3.289153814315796
training error:  0.8212243318557739  and  1.31246960163116

training error:  0.8131216764450073  and  1.3053971529006958  and  1.0718374252319336
total error:  3.1903562545776367
training error:  0.8043358325958252  and  1.2524833679199219  and  1.0625298023223877
total error:  3.1193490028381348
training error:  0.8337961435317993  and  1.336566686630249  and  1.1344645023345947
total error:  3.304827332496643
training error:  0.8177706003189087  and  1.2614351511001587  and  1.0733141899108887
total error:  3.152519941329956
training error:  0.8168355226516724  and  1.260274887084961  and  1.0960193872451782
total error:  3.1731297969818115
training error:  0.8044244050979614  and  1.2340567111968994  and  1.0984965562820435
total error:  3.1369776725769043
training error:  0.8088728785514832  and  1.2652037143707275  and  1.0357060432434082
total error:  3.109782636165619
training error:  0.8046869039535522  and  1.251209020614624  and  1.0425567626953125
total error:  3.0984526872634888
training error:  0.8147050142288208  and  1.3091504573

training error:  0.7766038179397583  and  1.213533878326416  and  1.0024137496948242
total error:  2.9925514459609985
training error:  0.7953242659568787  and  1.233389973640442  and  1.096744179725647
total error:  3.1254584193229675
training error:  0.7921582460403442  and  1.2207938432693481  and  1.0142107009887695
total error:  3.027162790298462
training error:  0.8104466199874878  and  1.2331931591033936  and  1.0302066802978516
total error:  3.073846459388733
training error:  0.8004802465438843  and  1.2384365797042847  and  1.0633448362350464
total error:  3.1022616624832153
training error:  0.8052924871444702  and  1.2809654474258423  and  1.0219964981079102
total error:  3.1082544326782227
training error:  0.8086881041526794  and  1.2566397190093994  and  1.0370454788208008
total error:  3.1023733019828796
training error:  0.8252558708190918  and  1.368607759475708  and  1.0196770429611206
total error:  3.2135406732559204
training error:  0.7632501721382141  and  1.2936196327

training error:  0.7741859555244446  and  1.3445786237716675  and  1.0155394077301025
total error:  3.1343039870262146
training error:  0.7610698938369751  and  1.2547032833099365  and  1.0742056369781494
total error:  3.089978814125061
training error:  0.7809051871299744  and  1.3873664140701294  and  1.0944468975067139
total error:  3.2627184987068176
training error:  0.8062238693237305  and  1.2964341640472412  and  0.9967019557952881
total error:  3.0993599891662598
training error:  0.738538384437561  and  1.2583330869674683  and  1.047328233718872
total error:  3.0441997051239014
training error:  0.7765416502952576  and  1.2318896055221558  and  1.0537350177764893
total error:  3.0621662735939026
training error:  0.7734037637710571  and  1.2290090322494507  and  1.0855075120925903
total error:  3.087920308113098
training error:  0.7423722743988037  and  1.247420072555542  and  1.0297085046768188
total error:  3.0195008516311646
training error:  0.789944052696228  and  1.2723222970

training error:  0.7572708129882812  and  1.1842132806777954  and  1.0089011192321777
total error:  2.9503852128982544
training error:  0.8024380803108215  and  1.4513037204742432  and  0.9971833229064941
total error:  3.250925123691559
training error:  0.7563798427581787  and  1.223292350769043  and  1.0359197854995728
total error:  3.0155919790267944
training error:  0.7622705698013306  and  1.2323957681655884  and  1.0083191394805908
total error:  3.0029854774475098
training error:  0.7806466817855835  and  1.2200120687484741  and  1.0368375778198242
total error:  3.037496328353882
training error:  0.7468177676200867  and  1.2267014980316162  and  1.1091997623443604
total error:  3.0827190279960632
training error:  0.7421633005142212  and  1.2419328689575195  and  1.0515265464782715
total error:  3.035622715950012
training error:  0.7526463270187378  and  1.3444262742996216  and  1.0743911266326904
total error:  3.17146372795105
training error:  0.7715246677398682  and  1.3215953111

training error:  0.7438594102859497  and  1.2121875286102295  and  1.026821494102478
total error:  2.9828684329986572
training error:  0.7304329872131348  and  1.1586649417877197  and  0.9730952382087708
total error:  2.8621931672096252
training error:  0.7272278070449829  and  1.2213084697723389  and  1.0414700508117676
total error:  2.9900063276290894
training error:  0.7256774306297302  and  1.1587879657745361  and  1.0246621370315552
total error:  2.9091275334358215
training error:  0.7344782948493958  and  1.2076518535614014  and  1.0328872203826904
total error:  2.9750173687934875
training error:  0.7326930165290833  and  1.3814903497695923  and  1.1198265552520752
total error:  3.2340099215507507
training error:  0.7736482620239258  and  1.2200469970703125  and  0.9886006116867065
total error:  2.982295870780945
training error:  0.7292425632476807  and  1.2961684465408325  and  0.9719806909561157
total error:  2.997391700744629
training error:  0.7135334610939026  and  1.1651252

training error:  0.6946917772293091  and  1.2223851680755615  and  1.0250508785247803
total error:  2.942127823829651
training error:  0.7258210778236389  and  1.2276211977005005  and  0.9990689158439636
total error:  2.952511191368103
training error:  0.7118434906005859  and  1.209136962890625  and  1.0192358493804932
total error:  2.940216302871704
training error:  0.7030588388442993  and  1.1840662956237793  and  1.0550652742385864
total error:  2.942190408706665
training error:  0.7270798683166504  and  1.3827831745147705  and  1.0190868377685547
total error:  3.1289498805999756
training error:  0.7056884765625  and  1.1433446407318115  and  1.008189082145691
total error:  2.8572221994400024
training error:  0.7047346830368042  and  1.1921842098236084  and  1.1119589805603027
total error:  3.0088778734207153
training error:  0.7222132086753845  and  1.1842172145843506  and  1.0006983280181885
total error:  2.9071287512779236
training error:  0.6972702145576477  and  1.1579103469848

training error:  0.688023567199707  and  1.1463347673416138  and  0.9778316020965576
total error:  2.8121899366378784
training error:  0.7300137281417847  and  1.4611005783081055  and  1.0242273807525635
total error:  3.2153416872024536
training error:  0.6996015310287476  and  1.3157672882080078  and  0.988591730594635
total error:  3.0039605498313904
training error:  0.6768760085105896  and  1.1426864862442017  and  0.9800878167152405
total error:  2.7996503114700317
training error:  0.7215583324432373  and  1.2063705921173096  and  0.9987204670906067
total error:  2.9266493916511536
training error:  0.6894087791442871  and  1.1336733102798462  and  1.015217661857605
total error:  2.8382997512817383
training error:  0.6783287525177002  and  1.1728240251541138  and  1.02507483959198
total error:  2.876227617263794
training error:  0.716657280921936  and  1.195887565612793  and  1.0222958326339722
total error:  2.934840679168701
training error:  0.7094314098358154  and  1.2019894123077

training error:  0.6770259141921997  and  1.1733289957046509  and  0.9903952479362488
total error:  2.8407501578330994
training error:  0.6920128464698792  and  1.1642067432403564  and  1.0188324451446533
total error:  2.875052034854889
training error:  0.688077986240387  and  1.159777045249939  and  0.9933687448501587
total error:  2.8412237763404846
training error:  0.6845200061798096  and  1.1717907190322876  and  0.9424605369567871
total error:  2.7987712621688843
training error:  0.671747088432312  and  1.17996084690094  and  0.9688952565193176
total error:  2.8206031918525696
training error:  0.6733415722846985  and  1.1936962604522705  and  0.9949229955673218
total error:  2.8619608283042908
training error:  0.6694693565368652  and  1.12046480178833  and  0.998383641242981
total error:  2.7883177995681763
training error:  0.6945583820343018  and  1.0972956418991089  and  0.9239950180053711
total error:  2.7158490419387817
training error:  0.6748124361038208  and  1.2062102556228

training error:  0.6513776183128357  and  1.128810167312622  and  0.9866204261779785
total error:  2.7668082118034363
training error:  0.664846658706665  and  1.1464349031448364  and  1.006248950958252
total error:  2.8175305128097534
training error:  0.6776829361915588  and  1.1242761611938477  and  0.9663394689559937
total error:  2.7682985663414
training error:  0.6752476096153259  and  1.0993489027023315  and  0.9691086411476135
total error:  2.743705153465271
training error:  0.6510549783706665  and  1.1217964887619019  and  0.9356532096862793
total error:  2.7085046768188477
training error:  0.6794750690460205  and  1.227914571762085  and  0.9610147476196289
total error:  2.8684043884277344
training error:  0.6897066831588745  and  1.1801632642745972  and  0.9516423344612122
total error:  2.821512281894684
training error:  0.7245989441871643  and  1.1578965187072754  and  1.0359957218170166
total error:  2.9184911847114563
training error:  0.6518348455429077  and  1.1722500324249

training error:  0.694537878036499  and  1.2089664936065674  and  0.9544491171836853
total error:  2.8579534888267517
training error:  0.6677933931350708  and  1.1386878490447998  and  0.9580751657485962
total error:  2.764556407928467
training error:  0.6512227654457092  and  1.1295266151428223  and  0.9552132487297058
total error:  2.7359626293182373
training error:  0.7122535109519958  and  1.6065592765808105  and  1.083182454109192
total error:  3.4019952416419983
training error:  0.6575085520744324  and  1.1011443138122559  and  0.923171877861023
total error:  2.681824743747711
training error:  0.6476731300354004  and  1.145766258239746  and  0.9302589893341064
total error:  2.723698377609253
training error:  0.6500458717346191  and  1.1861763000488281  and  0.9381353855133057
total error:  2.774357557296753
training error:  0.6605310440063477  and  1.1571216583251953  and  0.9549130797386169
total error:  2.77256578207016
training error:  0.6626479625701904  and  1.13396012783050

training error:  0.6454970836639404  and  1.1271156072616577  and  0.8973233103752136
total error:  2.6699360013008118
training error:  0.6250676512718201  and  1.1160988807678223  and  0.9026027917861938
total error:  2.643769323825836
training error:  0.6595950722694397  and  1.1290640830993652  and  0.9199907779693604
total error:  2.7086499333381653
training error:  0.638208270072937  and  1.0928668975830078  and  0.8591485023498535
total error:  2.5902236700057983
training error:  0.6459855437278748  and  1.1854891777038574  and  0.9941840767860413
total error:  2.8256587982177734
training error:  0.680299699306488  and  1.1986448764801025  and  0.9656565189361572
total error:  2.844601094722748
training error:  0.6327516436576843  and  1.1318528652191162  and  0.9232099056243896
total error:  2.68781441450119
training error:  0.6429406404495239  and  1.097536563873291  and  1.0304994583129883
total error:  2.7709766626358032
training error:  0.6334123611450195  and  1.19781851768

training error:  0.6673921346664429  and  1.1232575178146362  and  0.914074718952179
total error:  2.704724371433258
training error:  0.6227135062217712  and  1.1478989124298096  and  0.9519436359405518
total error:  2.7225560545921326
training error:  0.6590138673782349  and  1.1166348457336426  and  0.9409980177879333
total error:  2.716646730899811
training error:  0.6389837265014648  and  1.3119120597839355  and  1.2159051895141602
total error:  3.1668009757995605
training error:  0.6195387840270996  and  1.124267339706421  and  0.8793807029724121
total error:  2.6231868267059326
training error:  0.6435539722442627  and  1.1050307750701904  and  0.9356966614723206
total error:  2.6842814087867737
training error:  0.6181704998016357  and  1.1561086177825928  and  0.9581528306007385
total error:  2.732431948184967
training error:  0.6179817914962769  and  1.149308443069458  and  0.9177509546279907
total error:  2.6850411891937256
training error:  0.6433119773864746  and  1.1120676994

training error:  0.6174153685569763  and  1.1660767793655396  and  0.9799121618270874
total error:  2.7634043097496033
training error:  0.6102837324142456  and  1.1702117919921875  and  0.9301859140396118
total error:  2.710681438446045
training error:  0.6079088449478149  and  1.1288959980010986  and  0.905596911907196
total error:  2.6424017548561096
training error:  0.6355558037757874  and  1.15547513961792  and  1.1132344007492065
total error:  2.904265344142914
training error:  0.6277397871017456  and  1.1290069818496704  and  1.0089654922485352
total error:  2.765712261199951
training error:  0.6334760189056396  and  1.1263059377670288  and  0.9289933443069458
total error:  2.6887753009796143
training error:  0.6268795728683472  and  1.1003894805908203  and  1.1216068267822266
total error:  2.848875880241394
training error:  0.6709077954292297  and  1.1544378995895386  and  0.8812581300735474
total error:  2.7066038250923157
training error:  0.6322187781333923  and  1.06590938568

training error:  0.6060059666633606  and  1.0758438110351562  and  0.8649581074714661
total error:  2.546807885169983
training error:  0.6510473489761353  and  1.2356553077697754  and  0.8993418216705322
total error:  2.786044478416443
training error:  0.6787575483322144  and  1.3075759410858154  and  1.467695713043213
total error:  3.4540292024612427
training error:  0.6203290224075317  and  1.2382779121398926  and  0.9964420795440674
total error:  2.8550490140914917
training error:  0.6161158680915833  and  1.2696583271026611  and  0.8964725732803345
total error:  2.782246768474579
training error:  0.5876244902610779  and  1.054898738861084  and  0.8773679733276367
total error:  2.5198912024497986
training error:  0.5993084907531738  and  1.0615323781967163  and  0.8956220149993896
total error:  2.55646288394928
training error:  0.6003743410110474  and  1.0684823989868164  and  0.8658995032310486
total error:  2.5347562432289124
training error:  0.61202073097229  and  1.0654165744781

training error:  0.6298619508743286  and  1.3310742378234863  and  0.9133249521255493
total error:  2.8742611408233643
training error:  0.6826364398002625  and  1.1444909572601318  and  0.9376972913742065
total error:  2.764824688434601
training error:  0.6214983463287354  and  1.1100032329559326  and  0.9705526828765869
total error:  2.702054262161255
training error:  0.640586793422699  and  1.1488685607910156  and  1.0672321319580078
total error:  2.8566874861717224
training error:  0.610547661781311  and  1.2012039422988892  and  1.088070034980774
total error:  2.899821639060974
training error:  0.6679860353469849  and  1.1231969594955444  and  0.9144800901412964
total error:  2.7056630849838257
training error:  0.6284352540969849  and  1.136749267578125  and  0.953014612197876
total error:  2.718199133872986
training error:  0.7028155326843262  and  1.2173821926116943  and  0.9688256978988647
total error:  2.8890234231948853
training error:  0.630058765411377  and  1.12443947792053

In [2]:
torch.save(netG.state_dict(),"./epsilon0.1.pt")