In [1]:
# -*- coding: utf-8 -*-

import sys,os
import numpy as np
import pandas as pd
import scipy.stats as stats
from scipy import linalg
from numpy import dot
import geomloss as gs

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as D
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.autograd import grad
import torch.utils.data
import torch.backends.cudnn as cudnn
from torch.nn.modules import Linear
from torch.autograd.functional import jacobian,hessian,vjp,vhp,hvp

import random
import math

FilePath = '../../'

file_list = ['GSM1599494_ES_d0_main.csv', 'GSM1599497_ES_d2_LIFminus.csv', 'GSM1599498_ES_d4_LIFminus.csv', 'GSM1599499_ES_d7_LIFminus.csv']

table_list = []
for filein in file_list:
    table_list.append(pd.read_csv(FilePath+filein, header=None))

matrix_list = []
gene_names = table_list[0].values[:,0]
for table in table_list:
    matrix_list.append(table.values[:,1:].astype('float32'))

cell_counts = [matrix.shape[1] for matrix in matrix_list]

def normalize_run(mat):
    rpm = np.sum(mat,0)/1e6
    detect_pr = np.sum(mat==0,0)/float(mat.shape[0])
    return np.log(mat*(np.median(detect_pr)/detect_pr)*1.0/rpm + 1.0)

norm_mat = [normalize_run(matrix) for matrix in matrix_list]

qt_mat = [np.percentile(norm_in,q=np.linspace(0,100,50),axis=1) for norm_in in norm_mat] 
wdiv=np.sum((qt_mat[0]-qt_mat[3])**2,0)
w_order = np.argsort(-wdiv)

wsub = w_order[0:100]


def nmf(X, latent_features, max_iter=100, error_limit=1e-6, fit_error_limit=1e-6, print_iter=200):
    """
    Decompose X to A*Y
    """
    eps = 1e-5
    print('Starting NMF decomposition with {} latent features and {} iterations.'.format(latent_features, max_iter))
    #X = X.toarray()   I am passing in a scipy sparse matrix

    # mask
    mask = np.sign(X)

    # initial matrices. A is random [0,1] and Y is A\X.
    rows, columns = X.shape
    A = np.random.rand(rows, latent_features) 
    A = np.maximum(A, eps)

    Y = linalg.lstsq(A, X)[0]
    Y = np.maximum(Y, eps)

    masked_X = mask * X
    X_est_prev = dot(A, Y)
    for i in range(1, max_iter + 1):
        # ===== updates =====
        # Matlab: A=A.*(((W.*X)*Y')./((W.*(A*Y))*Y'));
        top = dot(masked_X, Y.T)
        bottom = (dot((mask * dot(A, Y)), Y.T)) + eps
        A *= top / bottom

        A = np.maximum(A, eps)
        # print 'A',  np.round(A, 2)

        # Matlab: Y=Y.*((A'*(W.*X))./(A'*(W.*(A*Y))));
        top = dot(A.T, masked_X)
        bottom = dot(A.T, mask * dot(A, Y)) + eps
        Y *= top / bottom
        Y = np.maximum(Y, eps)
        # print 'Y', np.round(Y, 2)


        # ==== evaluation ====
        if i % print_iter == 0 or i == 1 or i == max_iter:
            print('Iteration {}:'.format(i),)
            X_est = dot(A, Y)
            err = mask * (X_est_prev - X_est)
            fit_residual = np.sqrt(np.sum(err ** 2))
            X_est_prev = X_est

            curRes = linalg.norm(mask * (X - X_est), ord='fro')
            print('fit residual', np.round(fit_residual, 4),)
            print('total residual', np.round(curRes, 4))
            if curRes < error_limit or fit_residual < fit_error_limit:
                break
    return A, Y, dot(A,Y)

np.random.seed(0)
norm_imputed = [nmf(normin[wsub,:], latent_features = len(wsub)*4, max_iter=500)[2] for normin in norm_mat]

norm_adj = np.mean(norm_imputed[3],1)[:,np.newaxis]
subvec = np.array([0,1,2,3,4,5,6,7,8,9])

gnvec = gene_names[w_order[subvec]]

cov_mat = np.cov(norm_imputed[3][subvec,:])
whiten = np.diag(np.diag(cov_mat)**(-0.5))
unwhiten = np.diag(np.diag(cov_mat)**(0.5))

norm_imputed2 = [np.dot(whiten,(normin - norm_adj)[subvec,:]) for normin in norm_imputed]


class MLP(nn.Module):

    def __init__(self, dim_in, dim_out, dim_hidden=64, num_hidden=0, activation=nn.LeakyReLU()):
        super(MLP, self).__init__()

        if num_hidden == 0:
            self.linears = nn.ModuleList([nn.Linear(dim_in, dim_out)])
        elif num_hidden >= 1:
            self.linears = nn.ModuleList()
            self.linears.append(nn.Linear(dim_in, dim_hidden))
            self.linears.extend([nn.Linear(dim_hidden, dim_hidden) for _ in range(num_hidden-1)])
            self.linears.append(nn.Linear(dim_hidden, dim_out))
        else:
            raise Exception('number of hidden layers must be positive')

        for m in self.linears:
            #nn.init.xavier_uniform_(m.weight)
            nn.init.xavier_normal_(m.weight)
            nn.init.uniform_(m.bias,a=-0.1,b=0.1)
            #nn.init.constant_(m.bias,0) 
 
        self.activation = activation 

    def forward(self, x):
        for m in self.linears[:-1]:
            x = self.activation(m(x))
            #x = F.dropout(x,p=0.5)

        return self.linears[-1](x)


def compute_gradient_penalty(D, real_sample, fake_sample,k,p):
    real_samples = real_sample.requires_grad_(True)
    fake_samples = fake_sample.requires_grad_(True)

    real_validity = D(real_samples)
    fake_validity = D(fake_samples)

    real_grad_out = torch.ones((real_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    real_grad = grad(
        real_validity, real_samples, real_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    real_grad_norm = real_grad.view(real_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    fake_grad_out = torch.ones((fake_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    fake_grad = grad(
        fake_validity, fake_samples, fake_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    fake_grad_norm = fake_grad.view(fake_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    return (torch.sum(real_grad_norm) + torch.sum(fake_grad_norm)) * k / (real_sample.shape[0]+fake_sample.shape[0])

class JumpEulerForwardCuda(nn.Module):
    def __init__(self,in_features,num_hidden,dim_hidden,step_size):
        super(JumpEulerForwardCuda,self).__init__()

        self.drift = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.intensity = torch.tensor(intensity,device="cuda")
        self.mean = nn.Parameter(0.01*torch.ones(in_features))
        self.covHalf = nn.Parameter(0.08*torch.eye(in_features))
        self.diffusion = nn.Parameter(torch.ones(bd,10))
        self.in_features = in_features
        self.jump = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.step_size = step_size

    def forward(self,z0,Nsim,steps):

        PopulationPath = torch.empty(size = (Nsim,steps+1,self.in_features),device="cuda")
        PopulationPath[:,0,:] = z0
        state = z0

        for i in range(1,steps+1):
            DP = D.poisson.Poisson(self.intensity*self.step_size) ## 第一次这地方忘记乘以step_size了
            pois = DP.sample((Nsim,1)).cuda()
            state = state + self.drift(state)*self.step_size + math.sqrt(self.step_size)*torch.normal(0,1,size=(Nsim,bd),device="cuda")@self.diffusion+\
                (pois*self.mean + pois**(0.5)*torch.normal(0,1,size=(Nsim,self.in_features),device="cuda")@self.covHalf)*self.jump(state)
            PopulationPath[:,i,:] = state
        return PopulationPath


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
sed = 200
setup_seed(sed)

a=gs.SamplesLoss(loss='sinkhorn',p=2,blur=0.01)


train_data = norm_imputed2

train0 = torch.tensor(train_data[0],dtype=torch.float32,requires_grad = True,device="cuda").t()
train2 = torch.tensor(train_data[1],dtype=torch.float32,requires_grad = True,device="cuda").t()
train4 = torch.tensor(train_data[2],dtype=torch.float32,requires_grad = True,device="cuda").t()
train7 = torch.tensor(train_data[3],dtype=torch.float32,requires_grad = True,device="cuda").t()

train0 = train0+0.001*torch.normal(0,1,size=train0.shape,device="cuda")
train2 = train2+0.001*torch.normal(0,1,size=train2.shape,device="cuda")
train4 = train4+0.001*torch.normal(0,1,size=train4.shape,device="cuda")
train7 = train7+0.001*torch.normal(0,1,size=train7.shape,device="cuda")


intensity = 10
lr = 0.0003
step_size = 0.03
kuan = 256
ceng = 4
bd = 2
n_critic = 3
k = 2
p = 6

n_sims = train0.shape[0]
in_features = train0.shape[1]
n_steps = [10,20,35]


netG = JumpEulerForwardCuda(10,ceng,kuan,step_size).cuda()
netD1 = MLP(10,1,dim_hidden=kuan,num_hidden=ceng).cuda()
netD2 = MLP(10,1,dim_hidden=kuan,num_hidden=ceng).cuda()
netD3 = MLP(10,1,dim_hidden=kuan,num_hidden=ceng).cuda()


optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerSD1 = optim.Adam(netD1.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerSD2 = optim.Adam(netD2.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerSD3 = optim.Adam(netD3.parameters(), lr=lr, betas=(0.5, 0.999))

n_epochs =  20000

wd = []
for epoch in range(n_epochs):
  
    

    for _ in range(n_critic):
        fake_data = netG(train0,n_sims,n_steps[2])
        fake1 = fake_data[:,n_steps[0],:]
        fake2 = fake_data[:,n_steps[1],:]
        fake3 = fake_data[:,n_steps[2],:]

        optimizerSD1.zero_grad()

        div_gp1 = compute_gradient_penalty(netD1,train2,fake1,k,p)
        d1_loss = -torch.mean(netD1(train2))+torch.mean(netD1(fake1))+div_gp1
        d1_loss.backward(retain_graph=True) # retain_graph=True

        optimizerSD1.step()


        optimizerSD2.zero_grad()
        
        div_gp2 = compute_gradient_penalty(netD2,train4,fake2,k,p)
        d2_loss = -torch.mean(netD2(train4))+torch.mean(netD2(fake2))+div_gp2
        d2_loss.backward(retain_graph=True)

        optimizerSD2.step()
        
        
        optimizerSD3.zero_grad()
        
        div_gp3 = compute_gradient_penalty(netD3,train7,fake3,k,p)
        d3_loss = -torch.mean(netD3(train7))+torch.mean(netD3(fake3))+div_gp3
        d3_loss.backward(retain_graph=True)

        optimizerSD3.step()
        



    
    for _ in range(1):
        optimizerG.zero_grad()
        fake_data = netG(train0,n_sims,n_steps[2])
        fake1 = fake_data[:,n_steps[0],:]
        fake2 = fake_data[:,n_steps[1],:]
        fake3 = fake_data[:,n_steps[2],:]
        g_loss = -torch.mean(netD1(fake1))-torch.mean(netD2(fake2))-torch.mean(netD3(fake3))
        g_loss.backward() 

        optimizerG.step()

    if epoch %10==0:
        x1 = a(fake_data[:,n_steps[0],:],train2).item()
        x2 = a(fake_data[:,n_steps[1],:],train4).item()
        x3 = a(fake_data[:,n_steps[2],:],train7).item()
        
        wd.append(x1+x2+x3)
        
        print("training error: ",x1," and ",x2," and ",x3)
        print("total error: ",wd[-1])

Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 4000.1233
total residual 138.7959
Iteration 200:
fit residual 123.6643
total residual 31.8464
Iteration 400:
fit residual 21.7894
total residual 12.6115
Iteration 500:
fit residual 4.3442
total residual 8.6686
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2257.8358
total residual 58.2265
Iteration 200:
fit residual 56.8198
total residual 4.0447
Iteration 400:
fit residual 3.4777
total residual 0.6841
Iteration 500:
fit residual 0.3744
total residual 0.3185
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2925.3664
total residual 80.2657
Iteration 200:
fit residual 78.1148
total residual 6.1559
Iteration 400:
fit residual 5.1661
total residual 1.1917
Iteration 500:
fit residual 0.6183
total residual 0.5917
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
f

training error:  1.6631428003311157  and  2.257638454437256  and  1.9392973184585571
total error:  5.860078573226929
training error:  1.6949307918548584  and  2.3249282836914062  and  1.478002905845642
total error:  5.497861981391907
training error:  1.6639182567596436  and  2.250065565109253  and  1.6895270347595215
total error:  5.603510856628418
training error:  1.6989284753799438  and  2.583865165710449  and  1.9762165546417236
total error:  6.259010195732117
training error:  1.633334755897522  and  2.3654940128326416  and  1.926204800605774
total error:  5.9250335693359375
training error:  1.6879818439483643  and  2.5570151805877686  and  1.682377815246582
total error:  5.927374839782715
training error:  1.6364991664886475  and  2.665964365005493  and  2.0688893795013428
total error:  6.371352910995483
training error:  1.8367919921875  and  3.4953718185424805  and  1.7718031406402588
total error:  7.103966951370239
training error:  1.6168009042739868  and  3.5174005031585693  and 

training error:  1.3214328289031982  and  2.006110429763794  and  1.4769978523254395
total error:  4.804541110992432
training error:  1.2692029476165771  and  2.000307083129883  and  1.5152500867843628
total error:  4.784760117530823
training error:  1.3622065782546997  and  2.644624948501587  and  1.4285255670547485
total error:  5.435357093811035
training error:  1.2582054138183594  and  1.9149502515792847  and  1.434206247329712
total error:  4.607361912727356
training error:  1.3171777725219727  and  2.3729639053344727  and  1.6659202575683594
total error:  5.356061935424805
training error:  1.508386254310608  and  4.744759559631348  and  1.836484670639038
total error:  8.089630484580994
training error:  1.3420811891555786  and  1.9973245859146118  and  1.783073902130127
total error:  5.122479677200317
training error:  1.2929071187973022  and  3.9035634994506836  and  2.3464736938476562
total error:  7.542944312095642
training error:  1.291176438331604  and  1.9643149375915527  and

training error:  1.1857914924621582  and  1.9521067142486572  and  1.3319575786590576
total error:  4.469855785369873
training error:  1.1738567352294922  and  1.784622311592102  and  1.3472869396209717
total error:  4.305765986442566
training error:  1.1662921905517578  and  1.9954631328582764  and  1.3303676843643188
total error:  4.492123007774353
training error:  1.1809489727020264  and  1.805793046951294  and  1.321408987045288
total error:  4.308151006698608
training error:  1.186049222946167  and  1.7189922332763672  and  1.327791452407837
total error:  4.232832908630371
training error:  1.1680128574371338  and  1.6822724342346191  and  1.288053274154663
total error:  4.138338565826416
training error:  1.1668903827667236  and  1.847100853919983  and  1.422107219696045
total error:  4.4360984563827515
training error:  1.1537805795669556  and  1.8025716543197632  and  1.3597456216812134
total error:  4.316097855567932
training error:  1.1695475578308105  and  1.767630934715271  an

training error:  1.082651972770691  and  1.5694411993026733  and  1.2398974895477295
total error:  3.8919906616210938
training error:  1.1065235137939453  and  1.6438207626342773  and  1.2793476581573486
total error:  4.029691934585571
training error:  1.0807836055755615  and  1.5614514350891113  and  1.256742000579834
total error:  3.898977041244507
training error:  1.100866675376892  and  1.6280901432037354  and  1.2389788627624512
total error:  3.9679356813430786
training error:  1.118648886680603  and  1.7439281940460205  and  1.266235113143921
total error:  4.128812193870544
training error:  1.1081326007843018  and  1.6104834079742432  and  1.2418253421783447
total error:  3.9604413509368896
training error:  1.0709325075149536  and  1.591498851776123  and  1.266540288925171
total error:  3.9289716482162476
training error:  1.1084856986999512  and  1.6166077852249146  and  1.2851698398590088
total error:  4.0102633237838745
training error:  1.0877759456634521  and  1.58624243736267

training error:  1.029987096786499  and  1.51487398147583  and  1.2212293148040771
total error:  3.7660903930664062
training error:  1.067935585975647  and  1.5122263431549072  and  1.2857176065444946
total error:  3.865879535675049
training error:  1.068639874458313  and  1.5118070840835571  and  1.229353427886963
total error:  3.809800386428833
training error:  1.065548062324524  and  1.5322118997573853  and  1.2237532138824463
total error:  3.8215131759643555
training error:  1.0294502973556519  and  1.5439766645431519  and  1.2547876834869385
total error:  3.828214645385742
training error:  1.042989730834961  and  1.5254201889038086  and  1.2457191944122314
total error:  3.814129114151001
training error:  1.0366590023040771  and  1.5158822536468506  and  1.2146503925323486
total error:  3.7671916484832764
training error:  1.0300216674804688  and  1.5625526905059814  and  1.178154468536377
total error:  3.770728826522827
training error:  1.029659628868103  and  1.5105122327804565  a

training error:  1.0164449214935303  and  1.5324208736419678  and  1.1856513023376465
total error:  3.7345170974731445
training error:  0.9769495129585266  and  1.553206205368042  and  1.2043198347091675
total error:  3.734475553035736
training error:  1.009903073310852  and  1.7243211269378662  and  1.2075533866882324
total error:  3.9417775869369507
training error:  0.9907833337783813  and  1.4781527519226074  and  1.1599719524383545
total error:  3.6289080381393433
training error:  1.0300356149673462  and  1.5102877616882324  and  1.1479480266571045
total error:  3.688271403312683
training error:  1.0111464262008667  and  1.4378809928894043  and  1.1392751932144165
total error:  3.5883026123046875
training error:  1.0282728672027588  and  1.4978820085525513  and  1.175278902053833
total error:  3.701433777809143
training error:  1.0427746772766113  and  1.4995089769363403  and  1.2064015865325928
total error:  3.7486852407455444
training error:  1.0033353567123413  and  1.4717488288

training error:  0.9823002219200134  and  1.4484996795654297  and  1.1520602703094482
total error:  3.5828601717948914
training error:  0.9843050241470337  and  1.5161516666412354  and  1.1573905944824219
total error:  3.657847285270691
training error:  0.9649534821510315  and  1.4747616052627563  and  1.100407600402832
total error:  3.54012268781662
training error:  0.9935953617095947  and  1.4499573707580566  and  1.2484896183013916
total error:  3.692042350769043
training error:  0.978767454624176  and  1.4307787418365479  and  1.1567296981811523
total error:  3.566275894641876
training error:  0.959815263748169  and  1.5248788595199585  and  1.1058937311172485
total error:  3.590587854385376
training error:  0.9860512614250183  and  1.412726879119873  and  1.1626418828964233
total error:  3.5614200234413147
training error:  0.9812066555023193  and  1.5396990776062012  and  1.174234390258789
total error:  3.6951401233673096
training error:  0.9610597491264343  and  1.418902635574340

training error:  0.9297829866409302  and  1.5389668941497803  and  1.1560089588165283
total error:  3.6247588396072388
training error:  0.9142119884490967  and  1.4597601890563965  and  1.1387615203857422
total error:  3.5127336978912354
training error:  0.9421353340148926  and  1.4087321758270264  and  1.1902048587799072
total error:  3.541072368621826
training error:  0.9462657570838928  and  1.4452800750732422  and  1.111417531967163
total error:  3.502963364124298
training error:  0.9698188900947571  and  1.4350814819335938  and  1.168384075164795
total error:  3.5732844471931458
training error:  0.9796668291091919  and  1.3813393115997314  and  1.1421434879302979
total error:  3.503149628639221
training error:  0.936327338218689  and  1.425174355506897  and  1.1411128044128418
total error:  3.5026144981384277
training error:  0.9503235816955566  and  1.3688173294067383  and  1.1193532943725586
total error:  3.4384942054748535
training error:  0.9593188762664795  and  1.45150732994

training error:  0.9306355118751526  and  1.5375537872314453  and  1.099687099456787
total error:  3.567876398563385
training error:  0.9145578145980835  and  1.4067373275756836  and  1.065438985824585
total error:  3.386734127998352
training error:  0.9039539098739624  and  1.3761367797851562  and  1.0959484577178955
total error:  3.376039147377014
training error:  0.9086968898773193  and  1.3790833950042725  and  1.1145988702774048
total error:  3.4023791551589966
training error:  0.9331225752830505  and  1.3588283061981201  and  1.1043044328689575
total error:  3.396255314350128
training error:  0.9008907675743103  and  1.394452452659607  and  1.126993179321289
total error:  3.4223363995552063
training error:  0.9196387529373169  and  1.358971118927002  and  1.1265347003936768
total error:  3.4051445722579956
training error:  0.911771297454834  and  1.3966856002807617  and  1.321079969406128
total error:  3.6295368671417236
training error:  0.9196392297744751  and  1.414931297302246

training error:  0.8967896103858948  and  1.4265923500061035  and  1.1170666217803955
total error:  3.440448582172394
training error:  0.9036704301834106  and  1.3771719932556152  and  1.2388168573379517
total error:  3.5196592807769775
training error:  0.8816125988960266  and  1.452840805053711  and  1.0936747789382935
total error:  3.428128182888031
training error:  0.886893630027771  and  1.4672276973724365  and  1.1117265224456787
total error:  3.4658478498458862
training error:  0.8762737512588501  and  1.3164844512939453  and  1.0498331785202026
total error:  3.242591381072998
training error:  0.8764644861221313  and  1.3305926322937012  and  1.1010825634002686
total error:  3.308139681816101
training error:  0.8605129718780518  and  1.3441441059112549  and  1.112669825553894
total error:  3.3173269033432007
training error:  0.8593778610229492  and  1.3319101333618164  and  1.0618994235992432
total error:  3.253187417984009
training error:  0.8773601055145264  and  1.356617212295

training error:  0.8499705791473389  and  1.3869389295578003  and  1.0714194774627686
total error:  3.3083289861679077
training error:  0.862732470035553  and  1.3158268928527832  and  1.1245216131210327
total error:  3.303080976009369
training error:  0.8458577394485474  and  1.3548166751861572  and  1.0766499042510986
total error:  3.2773243188858032
training error:  0.8642997741699219  and  1.3251093626022339  and  1.0586634874343872
total error:  3.248072624206543
training error:  0.8629542589187622  and  1.3148136138916016  and  1.0730834007263184
total error:  3.250851273536682
training error:  0.8557963967323303  and  1.2727001905441284  and  1.0567522048950195
total error:  3.1852487921714783
training error:  0.8375076055526733  and  1.3437548875808716  and  1.094325065612793
total error:  3.275587558746338
training error:  0.8670086860656738  and  1.3068227767944336  and  1.0481183528900146
total error:  3.221949815750122
training error:  0.8544324636459351  and  1.29605698585

training error:  0.8271832466125488  and  1.258244276046753  and  1.0351170301437378
total error:  3.1205445528030396
training error:  0.8408120274543762  and  1.317200779914856  and  1.0814545154571533
total error:  3.2394673228263855
training error:  0.8096740245819092  and  1.246523380279541  and  1.047299861907959
total error:  3.103497266769409
training error:  0.8160289525985718  and  1.337223768234253  and  1.0319616794586182
total error:  3.185214400291443
training error:  0.8170487284660339  and  1.2751742601394653  and  1.0018768310546875
total error:  3.0940998196601868
training error:  0.8392819166183472  and  1.3202449083328247  and  1.0650169849395752
total error:  3.224543809890747
training error:  0.8224081993103027  and  1.2995898723602295  and  1.0759243965148926
total error:  3.197922468185425
training error:  0.8261072039604187  and  1.3642514944076538  and  1.1204538345336914
total error:  3.310812532901764
training error:  0.8351098299026489  and  1.29554259777069

training error:  0.8311899900436401  and  1.311295509338379  and  1.0885093212127686
total error:  3.2309948205947876
training error:  0.847199559211731  and  1.2836873531341553  and  1.0551663637161255
total error:  3.1860532760620117
training error:  0.8004895448684692  and  1.3188987970352173  and  1.1073946952819824
total error:  3.226783037185669
training error:  0.794191837310791  and  1.2966194152832031  and  1.0332164764404297
total error:  3.124027729034424
training error:  0.7830031514167786  and  1.2648364305496216  and  1.035666584968567
total error:  3.083506166934967
training error:  0.7714688777923584  and  1.263152837753296  and  1.1245713233947754
total error:  3.1591930389404297
training error:  0.7987768054008484  and  1.2806310653686523  and  1.0668412446975708
total error:  3.1462491154670715
training error:  0.790244460105896  and  1.406087875366211  and  1.0677905082702637
total error:  3.2641228437423706
training error:  0.7742235660552979  and  1.21015548706054

training error:  0.7938895225524902  and  1.2651375532150269  and  1.0066745281219482
total error:  3.0657016038894653
training error:  0.7784528732299805  and  1.304269552230835  and  1.0272576808929443
total error:  3.1099801063537598
training error:  0.7715979814529419  and  1.2569568157196045  and  1.025928020477295
total error:  3.0544828176498413
training error:  0.7880673408508301  and  1.2332899570465088  and  1.022209882736206
total error:  3.043567180633545
training error:  0.7726867198944092  and  1.2417716979980469  and  0.9929215908050537
total error:  3.0073800086975098
training error:  0.7796928882598877  and  1.2410051822662354  and  1.0918078422546387
total error:  3.1125059127807617
training error:  0.7716183066368103  and  1.2621980905532837  and  1.0302025079727173
total error:  3.0640189051628113
training error:  0.7746267318725586  and  1.3538544178009033  and  1.0236198902130127
total error:  3.1521010398864746
training error:  0.7620975375175476  and  1.29745078

training error:  0.7348832488059998  and  1.2141796350479126  and  0.9704897403717041
total error:  2.9195526242256165
training error:  0.7746350169181824  and  1.2336339950561523  and  0.9452230930328369
total error:  2.9534921050071716
training error:  0.7507161498069763  and  1.290914535522461  and  1.2111635208129883
total error:  3.2527942061424255
training error:  0.7731152772903442  and  1.3159255981445312  and  1.231755256652832
total error:  3.3207961320877075
training error:  0.7744698524475098  and  1.2486809492111206  and  1.057438611984253
total error:  3.0805894136428833
training error:  0.7568436861038208  and  1.1956672668457031  and  0.9849082827568054
total error:  2.9374192357063293
training error:  0.7379083633422852  and  1.2324882745742798  and  1.0257530212402344
total error:  2.9961496591567993
training error:  0.7217824459075928  and  1.2152082920074463  and  0.9692016839981079
total error:  2.906192421913147
training error:  0.7246823906898499  and  1.33047294

training error:  0.7532034516334534  and  1.2718703746795654  and  1.072311520576477
total error:  3.097385346889496
training error:  0.7021301984786987  and  1.2576329708099365  and  0.9816451668739319
total error:  2.941408336162567
training error:  0.7351552248001099  and  1.3407070636749268  and  1.0910520553588867
total error:  3.1669143438339233
training error:  0.7605706453323364  and  1.2524352073669434  and  1.049459457397461
total error:  3.0624653100967407
training error:  0.7294942736625671  and  1.188984751701355  and  0.9930278658866882
total error:  2.9115068912506104
training error:  0.6911059021949768  and  1.2197957038879395  and  0.9581719040870667
total error:  2.869073510169983
training error:  0.7320231199264526  and  1.2710025310516357  and  0.987063467502594
total error:  2.9900891184806824
training error:  0.7240985631942749  and  1.2315380573272705  and  0.9676340818405151
total error:  2.9232707023620605
training error:  0.7352137565612793  and  1.24631798267

training error:  0.6848692893981934  and  1.1448274850845337  and  0.8983487486839294
total error:  2.7280455231666565
training error:  0.6783990859985352  and  1.1704447269439697  and  0.9696337580680847
total error:  2.8184775710105896
training error:  0.6894669532775879  and  1.1667948961257935  and  0.9951034784317017
total error:  2.851365327835083
training error:  0.6918032169342041  and  1.1534779071807861  and  1.0051100254058838
total error:  2.850391149520874
training error:  0.6805319786071777  and  1.1639091968536377  and  0.9785377383232117
total error:  2.822978913784027
training error:  0.6958842277526855  and  1.1850404739379883  and  1.0405935049057007
total error:  2.9215182065963745
training error:  0.7036113142967224  and  1.221641182899475  and  1.0151909589767456
total error:  2.940443456172943
training error:  0.6865917444229126  and  1.1603293418884277  and  1.0168005228042603
total error:  2.8637216091156006
training error:  0.6767810583114624  and  1.271755218

training error:  0.6521326899528503  and  1.1071079969406128  and  0.9186083078384399
total error:  2.677848994731903
training error:  0.7071826457977295  and  1.1767096519470215  and  0.976219654083252
total error:  2.860111951828003
training error:  0.6671779155731201  and  1.162737250328064  and  0.9391658902168274
total error:  2.7690810561180115
training error:  0.6791041493415833  and  1.1727848052978516  and  1.0083763599395752
total error:  2.86026531457901
training error:  0.6570221185684204  and  1.1750023365020752  and  0.9192061424255371
total error:  2.7512305974960327
training error:  0.6795530319213867  and  1.136615514755249  and  0.9213747978210449
total error:  2.7375433444976807
training error:  0.655441164970398  and  1.178556203842163  and  1.0867505073547363
total error:  2.9207478761672974
training error:  0.6702800989151001  and  1.1331431865692139  and  0.9075943231582642
total error:  2.711017608642578
training error:  0.6648313999176025  and  1.16274333000183

training error:  0.6729460954666138  and  1.1487395763397217  and  0.9916409254074097
total error:  2.813326597213745
training error:  0.6402976512908936  and  1.1000275611877441  and  0.921676516532898
total error:  2.6620017290115356
training error:  0.6797587871551514  and  1.1989924907684326  and  0.9753572940826416
total error:  2.8541085720062256
training error:  0.6522995233535767  and  1.2262341976165771  and  0.9944720268249512
total error:  2.873005747795105
training error:  0.6835119128227234  and  1.1463778018951416  and  0.8701382875442505
total error:  2.7000280022621155
training error:  0.6455821394920349  and  1.162536859512329  and  0.9037045836448669
total error:  2.711823582649231
training error:  0.6550247073173523  and  1.1480371952056885  and  0.9207537770271301
total error:  2.723815679550171
training error:  0.6664907336235046  and  1.1756995916366577  and  0.8918231725692749
total error:  2.7340134978294373
training error:  0.6569106578826904  and  1.3751219511

training error:  0.6641684174537659  and  1.1807677745819092  and  0.9144865274429321
total error:  2.759422719478607
training error:  0.6336236000061035  and  1.0969998836517334  and  0.9388994574546814
total error:  2.6695229411125183
training error:  0.6154282689094543  and  1.2413487434387207  and  0.9403674602508545
total error:  2.7971444725990295
training error:  0.635056734085083  and  1.2275357246398926  and  0.8842095136642456
total error:  2.746801972389221
training error:  0.6360870599746704  and  1.093353509902954  and  0.8903293013572693
total error:  2.619769871234894
training error:  0.6434231400489807  and  1.1235954761505127  and  0.8927925825119019
total error:  2.6598111987113953
training error:  0.6581672430038452  and  1.1026722192764282  and  0.9265464544296265
total error:  2.6873859167099
training error:  0.6409928202629089  and  1.1267788410186768  and  1.0431808233261108
total error:  2.8109524846076965
training error:  0.6273969411849976  and  1.093453049659

training error:  0.6208950281143188  and  1.1797542572021484  and  0.917375385761261
total error:  2.7180246710777283
training error:  0.6229380965232849  and  1.1061586141586304  and  0.9468220472335815
total error:  2.675918757915497
training error:  0.6538971662521362  and  1.1650885343551636  and  0.9713515639305115
total error:  2.7903372645378113
training error:  0.6224360466003418  and  1.0789885520935059  and  0.8889142274856567
total error:  2.5903388261795044
training error:  0.6198209524154663  and  1.0914132595062256  and  0.8878932595252991
total error:  2.599127471446991
training error:  0.6234270334243774  and  1.252968668937683  and  1.0569431781768799
total error:  2.9333388805389404
training error:  0.6267377734184265  and  1.1154086589813232  and  0.8883962631225586
total error:  2.6305426955223083
training error:  0.6304716467857361  and  1.1464505195617676  and  0.9575619697570801
total error:  2.7344841361045837
training error:  0.6206470131874084  and  1.08651614

training error:  0.616778552532196  and  1.074183464050293  and  0.8863476514816284
total error:  2.5773096680641174
training error:  0.6010257005691528  and  1.1126673221588135  and  0.8841454982757568
total error:  2.597838521003723
training error:  0.5903974771499634  and  1.0728521347045898  and  1.0115435123443604
total error:  2.6747931241989136
training error:  0.5904617309570312  and  1.0690412521362305  and  0.8753769993782043
total error:  2.534879982471466
training error:  0.6096457242965698  and  1.0966362953186035  and  0.8716248273849487
total error:  2.577906847000122
training error:  0.6123290061950684  and  1.143498420715332  and  1.0267784595489502
total error:  2.7826058864593506
training error:  0.6363744735717773  and  1.0727508068084717  and  0.9674950838088989
total error:  2.676620364189148
training error:  0.683756947517395  and  1.4054710865020752  and  1.3389204740524292
total error:  3.4281485080718994
training error:  0.6238699555397034  and  1.430182695388

training error:  0.5868736505508423  and  1.0327004194259644  and  0.856196939945221
total error:  2.4757710099220276
training error:  0.6117908954620361  and  1.1156730651855469  and  0.967422366142273
total error:  2.694886326789856
training error:  0.5870959162712097  and  1.0547388792037964  and  0.8965674638748169
total error:  2.538402259349823
training error:  0.6007768511772156  and  1.1104997396469116  and  0.8600520491600037
total error:  2.571328639984131
training error:  0.5926999449729919  and  1.0281925201416016  and  0.8437520265579224
total error:  2.464644491672516
training error:  0.579916775226593  and  1.0609657764434814  and  0.9119912385940552
total error:  2.5528737902641296
training error:  0.5922114849090576  and  1.088904857635498  and  0.9559053182601929
total error:  2.6370216608047485
training error:  0.6039060354232788  and  1.0776445865631104  and  0.8492053747177124
total error:  2.5307559967041016
training error:  0.6288250684738159  and  1.128335714340

training error:  0.5700758695602417  and  1.0548774003982544  and  0.8404555320739746
total error:  2.4654088020324707
training error:  0.5735599994659424  and  1.0478284358978271  and  0.8382349610328674
total error:  2.459623396396637
training error:  0.6026009917259216  and  1.0693466663360596  and  0.886950671672821
total error:  2.5588983297348022
training error:  0.5879980325698853  and  1.0598963499069214  and  0.8522490859031677
total error:  2.5001434683799744
training error:  0.5943492650985718  and  1.0891478061676025  and  0.9259555339813232
total error:  2.6094526052474976
training error:  0.6123688220977783  and  1.0796704292297363  and  0.835111141204834
total error:  2.5271503925323486
training error:  0.6143937110900879  and  1.1297452449798584  and  0.8930840492248535
total error:  2.6372230052948
training error:  0.5644432306289673  and  1.1080245971679688  and  0.8968919515609741
total error:  2.56935977935791
training error:  0.5733357667922974  and  1.068327784538

training error:  0.5690383911132812  and  1.0531363487243652  and  0.8423528075218201
total error:  2.4645275473594666
training error:  0.6423468589782715  and  1.4036061763763428  and  1.004421353340149
total error:  3.050374388694763
training error:  0.5784029960632324  and  1.3391778469085693  and  0.8245582580566406
total error:  2.7421391010284424
training error:  0.5877491235733032  and  1.068476915359497  and  0.9486024975776672
total error:  2.6048285365104675
training error:  0.5786504149436951  and  1.0563125610351562  and  0.900822639465332
total error:  2.5357856154441833
training error:  0.5847354531288147  and  1.2581926584243774  and  0.8896453976631165
total error:  2.7325735092163086
training error:  0.6644133925437927  and  1.1910604238510132  and  0.9751296043395996
total error:  2.8306034207344055
training error:  0.5974355936050415  and  1.0602850914001465  and  0.9427566528320312
total error:  2.6004773378372192
training error:  0.5824794173240662  and  1.06006240

training error:  0.5712482929229736  and  1.1011531352996826  and  0.9648962616920471
total error:  2.6372976899147034
training error:  0.5531517267227173  and  1.0358562469482422  and  0.8858247995376587
total error:  2.474832773208618
training error:  0.5504785776138306  and  1.0594062805175781  and  0.8176162242889404
total error:  2.427501082420349
training error:  0.6481895446777344  and  1.3124055862426758  and  1.0967170000076294
total error:  3.0573121309280396
training error:  0.5789716243743896  and  1.0552022457122803  and  0.8863322138786316
total error:  2.5205060839653015
training error:  0.5574877262115479  and  1.2035374641418457  and  0.914854884147644
total error:  2.6758800745010376
training error:  0.5669169425964355  and  1.2011942863464355  and  0.8040738105773926
total error:  2.5721850395202637
training error:  0.547856330871582  and  1.0747255086898804  and  0.8497315049171448
total error:  2.472313344478607
training error:  0.5463268756866455  and  1.033362388

training error:  0.608500599861145  and  1.1010791063308716  and  0.9055929183959961
total error:  2.6151726245880127
training error:  0.5400739908218384  and  1.0361213684082031  and  0.8413223624229431
total error:  2.4175177216529846
training error:  0.5593010783195496  and  1.0068706274032593  and  0.8395212888717651
total error:  2.405692994594574
training error:  0.5534769296646118  and  1.0620336532592773  and  0.8314172029495239
total error:  2.446927785873413
training error:  0.5737706422805786  and  1.0399667024612427  and  0.8638936877250671
total error:  2.4776310324668884
training error:  0.5536398887634277  and  1.068602204322815  and  0.8353724479675293
total error:  2.457614541053772
training error:  0.5647276043891907  and  1.0584698915481567  and  1.0580612421035767
total error:  2.681258738040924
training error:  0.5578919649124146  and  1.0926432609558105  and  0.934403657913208
total error:  2.584938883781433
training error:  0.5829120874404907  and  1.186242341995

training error:  0.5471341609954834  and  1.0136408805847168  and  0.7887009382247925
total error:  2.3494759798049927
training error:  0.5365278720855713  and  1.100887417793274  and  1.1029349565505981
total error:  2.7403502464294434
training error:  0.558200478553772  and  1.0334367752075195  and  0.8564598560333252
total error:  2.4480971097946167
training error:  0.5337881445884705  and  1.0536854267120361  and  0.8831799626350403
total error:  2.470653533935547
training error:  0.5430501699447632  and  1.0098017454147339  and  0.7806083559989929
total error:  2.33346027135849
training error:  0.5154911875724792  and  0.9818286895751953  and  0.8191725015640259
total error:  2.3164923787117004
training error:  0.5346687436103821  and  1.0221436023712158  and  0.8173761963844299
total error:  2.374188542366028
training error:  0.5401254296302795  and  1.025716781616211  and  0.9112526774406433
total error:  2.477094888687134
training error:  0.5300000905990601  and  0.970676183700

In [2]:
torch.save(netG.state_dict(),"./epsilon0.001.pt")