In [1]:
# -*- coding: utf-8 -*-

import sys,os
import numpy as np
import pandas as pd
import scipy.stats as stats
from scipy import linalg
from numpy import dot
import geomloss as gs

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as D
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.autograd import grad
import torch.utils.data
import torch.backends.cudnn as cudnn
from torch.nn.modules import Linear
from torch.autograd.functional import jacobian,hessian,vjp,vhp,hvp

import random
import math

FilePath = '../../'

file_list = ['GSM1599494_ES_d0_main.csv', 'GSM1599497_ES_d2_LIFminus.csv', 'GSM1599498_ES_d4_LIFminus.csv', 'GSM1599499_ES_d7_LIFminus.csv']

table_list = []
for filein in file_list:
    table_list.append(pd.read_csv(FilePath+filein, header=None))

matrix_list = []
gene_names = table_list[0].values[:,0]
for table in table_list:
    matrix_list.append(table.values[:,1:].astype('float32'))

cell_counts = [matrix.shape[1] for matrix in matrix_list]

def normalize_run(mat):
    rpm = np.sum(mat,0)/1e6
    detect_pr = np.sum(mat==0,0)/float(mat.shape[0])
    return np.log(mat*(np.median(detect_pr)/detect_pr)*1.0/rpm + 1.0)

norm_mat = [normalize_run(matrix) for matrix in matrix_list]

qt_mat = [np.percentile(norm_in,q=np.linspace(0,100,50),axis=1) for norm_in in norm_mat] 
wdiv=np.sum((qt_mat[0]-qt_mat[3])**2,0)
w_order = np.argsort(-wdiv)

wsub = w_order[0:100]


def nmf(X, latent_features, max_iter=100, error_limit=1e-6, fit_error_limit=1e-6, print_iter=200):
    """
    Decompose X to A*Y
    """
    eps = 1e-5
    print('Starting NMF decomposition with {} latent features and {} iterations.'.format(latent_features, max_iter))
    #X = X.toarray()   I am passing in a scipy sparse matrix

    # mask
    mask = np.sign(X)

    # initial matrices. A is random [0,1] and Y is A\X.
    rows, columns = X.shape
    A = np.random.rand(rows, latent_features)
    A = np.maximum(A, eps)

    Y = linalg.lstsq(A, X)[0]
    Y = np.maximum(Y, eps)

    masked_X = mask * X
    X_est_prev = dot(A, Y)
    for i in range(1, max_iter + 1):
        # ===== updates =====
        # Matlab: A=A.*(((W.*X)*Y')./((W.*(A*Y))*Y'));
        top = dot(masked_X, Y.T)
        bottom = (dot((mask * dot(A, Y)), Y.T)) + eps
        A *= top / bottom

        A = np.maximum(A, eps)
        # print 'A',  np.round(A, 2)

        # Matlab: Y=Y.*((A'*(W.*X))./(A'*(W.*(A*Y))));
        top = dot(A.T, masked_X)
        bottom = dot(A.T, mask * dot(A, Y)) + eps
        Y *= top / bottom
        Y = np.maximum(Y, eps)
        # print 'Y', np.round(Y, 2)


        # ==== evaluation ====
        if i % print_iter == 0 or i == 1 or i == max_iter:
            print('Iteration {}:'.format(i),)
            X_est = dot(A, Y)
            err = mask * (X_est_prev - X_est)
            fit_residual = np.sqrt(np.sum(err ** 2))
            X_est_prev = X_est

            curRes = linalg.norm(mask * (X - X_est), ord='fro')
            print('fit residual', np.round(fit_residual, 4),)
            print('total residual', np.round(curRes, 4))
            if curRes < error_limit or fit_residual < fit_error_limit:
                break
    return A, Y, dot(A,Y)

np.random.seed(0)
norm_imputed = [nmf(normin[wsub,:], latent_features = len(wsub)*4, max_iter=500)[2] for normin in norm_mat]

norm_adj = np.mean(norm_imputed[3],1)[:,np.newaxis]
subvec = np.array([0,1,2,3,4,5,6,7,8,9])

gnvec = gene_names[w_order[subvec]]

cov_mat = np.cov(norm_imputed[3][subvec,:])
whiten = np.diag(np.diag(cov_mat)**(-0.5))
unwhiten = np.diag(np.diag(cov_mat)**(0.5))

norm_imputed2 = [np.dot(whiten,(normin - norm_adj)[subvec,:]) for normin in norm_imputed]


class MLP(nn.Module):

    def __init__(self, dim_in, dim_out, dim_hidden=64, num_hidden=0, activation=nn.LeakyReLU()):
        super(MLP, self).__init__()

        if num_hidden == 0:
            self.linears = nn.ModuleList([nn.Linear(dim_in, dim_out)])
        elif num_hidden >= 1:
            self.linears = nn.ModuleList() 
            self.linears.append(nn.Linear(dim_in, dim_hidden))
            self.linears.extend([nn.Linear(dim_hidden, dim_hidden) for _ in range(num_hidden-1)])
            self.linears.append(nn.Linear(dim_hidden, dim_out))
        else:
            raise Exception('number of hidden layers must be positive')

        for m in self.linears:
            #nn.init.xavier_uniform_(m.weight)
            nn.init.xavier_normal_(m.weight)
            nn.init.uniform_(m.bias,a=-0.1,b=0.1)
            #nn.init.constant_(m.bias,0) 
 
        self.activation = activation 

    def forward(self, x):

        for m in self.linears[:-1]:
            x = self.activation(m(x))
            #x = F.dropout(x,p=0.5)

        return self.linears[-1](x)


def compute_gradient_penalty(D, real_sample, fake_sample,k,p):
    real_samples = real_sample.requires_grad_(True)
    fake_samples = fake_sample.requires_grad_(True)

    real_validity = D(real_samples)
    fake_validity = D(fake_samples)

    real_grad_out = torch.ones((real_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    real_grad = grad(
        real_validity, real_samples, real_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    real_grad_norm = real_grad.view(real_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    fake_grad_out = torch.ones((fake_samples.shape[0],1),dtype=torch.float32,requires_grad=False,device="cuda")
    fake_grad = grad(
        fake_validity, fake_samples, fake_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    fake_grad_norm = fake_grad.view(fake_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    return (torch.sum(real_grad_norm) + torch.sum(fake_grad_norm)) * k / (real_sample.shape[0]+fake_sample.shape[0])

class JumpEulerForwardCuda(nn.Module):
    def __init__(self,in_features,num_hidden,dim_hidden,step_size):
        super(JumpEulerForwardCuda,self).__init__()

        self.drift = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.intensity = torch.tensor(intensity,device="cuda")
        self.mean = nn.Parameter(0.01*torch.ones(in_features))
        self.covHalf = nn.Parameter(0.08*torch.eye(in_features))
        self.diffusion = nn.Parameter(torch.ones(bd,10))
        self.in_features = in_features
        self.jump = MLP(in_features,in_features,dim_hidden,num_hidden)
        self.step_size = step_size

    def forward(self,z0,Nsim,steps):

        PopulationPath = torch.empty(size = (Nsim,steps+1,self.in_features),device="cuda")
        PopulationPath[:,0,:] = z0
        state = z0

        for i in range(1,steps+1):
            DP = D.poisson.Poisson(self.intensity*self.step_size)
            pois = DP.sample((Nsim,1)).cuda()
            state = state + self.drift(state)*self.step_size + math.sqrt(self.step_size)*torch.normal(0,1,size=(Nsim,bd),device="cuda")@self.diffusion+\
                (pois*self.mean + pois**(0.5)*torch.normal(0,1,size=(Nsim,self.in_features),device="cuda")@self.covHalf)*self.jump(state)
            PopulationPath[:,i,:] = state
        return PopulationPath


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
sed = 200
setup_seed(sed)

a=gs.SamplesLoss(loss='sinkhorn',p=2,blur=0.01)


train_data = norm_imputed2

train0 = torch.tensor(train_data[0],dtype=torch.float32,requires_grad = True,device="cuda").t()
train2 = torch.tensor(train_data[1],dtype=torch.float32,requires_grad = True,device="cuda").t()
train4 = torch.tensor(train_data[2],dtype=torch.float32,requires_grad = True,device="cuda").t()
train7 = torch.tensor(train_data[3],dtype=torch.float32,requires_grad = True,device="cuda").t()

train0 = train0+0.01*torch.normal(0,1,size=train0.shape,device="cuda")
train2 = train2+0.01*torch.normal(0,1,size=train2.shape,device="cuda")
train4 = train4+0.01*torch.normal(0,1,size=train4.shape,device="cuda")
train7 = train7+0.01*torch.normal(0,1,size=train7.shape,device="cuda")


intensity = 10
lr = 0.0003
step_size = 0.03
kuan = 256
ceng = 4
bd = 2
n_critic = 3
k = 2
p = 6

n_sims = train0.shape[0]
in_features = train0.shape[1]
n_steps = [10,20,35]


netG = JumpEulerForwardCuda(10,ceng,kuan,step_size).cuda()
netD1 = MLP(10,1,dim_hidden=kuan,num_hidden=ceng).cuda()
netD2 = MLP(10,1,dim_hidden=kuan,num_hidden=ceng).cuda()
netD3 = MLP(10,1,dim_hidden=kuan,num_hidden=ceng).cuda()


optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerSD1 = optim.Adam(netD1.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerSD2 = optim.Adam(netD2.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerSD3 = optim.Adam(netD3.parameters(), lr=lr, betas=(0.5, 0.999))

n_epochs =  20000

wd = []
for epoch in range(n_epochs):
  

    for _ in range(n_critic):
        fake_data = netG(train0,n_sims,n_steps[2])
        fake1 = fake_data[:,n_steps[0],:]
        fake2 = fake_data[:,n_steps[1],:]
        fake3 = fake_data[:,n_steps[2],:]

        optimizerSD1.zero_grad()

        div_gp1 = compute_gradient_penalty(netD1,train2,fake1,k,p)
        d1_loss = -torch.mean(netD1(train2))+torch.mean(netD1(fake1))+div_gp1
        d1_loss.backward(retain_graph=True) # retain_graph=True

        optimizerSD1.step()


        optimizerSD2.zero_grad()
        
        div_gp2 = compute_gradient_penalty(netD2,train4,fake2,k,p)
        d2_loss = -torch.mean(netD2(train4))+torch.mean(netD2(fake2))+div_gp2
        d2_loss.backward(retain_graph=True)

        optimizerSD2.step()
        
        
        optimizerSD3.zero_grad()
        
        div_gp3 = compute_gradient_penalty(netD3,train7,fake3,k,p)
        d3_loss = -torch.mean(netD3(train7))+torch.mean(netD3(fake3))+div_gp3
        d3_loss.backward(retain_graph=True)

        optimizerSD3.step()
        
    
    for _ in range(1):
        optimizerG.zero_grad()
        fake_data = netG(train0,n_sims,n_steps[2])
        fake1 = fake_data[:,n_steps[0],:]
        fake2 = fake_data[:,n_steps[1],:]
        fake3 = fake_data[:,n_steps[2],:]
        g_loss = -torch.mean(netD1(fake1))-torch.mean(netD2(fake2))-torch.mean(netD3(fake3))
        g_loss.backward() 

        optimizerG.step()

    if epoch %10==0:
        x1 = a(fake_data[:,n_steps[0],:],train2).item()
        x2 = a(fake_data[:,n_steps[1],:],train4).item()
        x3 = a(fake_data[:,n_steps[2],:],train7).item()
        
        wd.append(x1+x2+x3)
        
        print("training error: ",x1," and ",x2," and ",x3)
        print("total error: ",wd[-1])

Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 4000.1233
total residual 138.7959
Iteration 200:
fit residual 123.6643
total residual 31.8464
Iteration 400:
fit residual 21.7894
total residual 12.6115
Iteration 500:
fit residual 4.3442
total residual 8.6686
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2257.8358
total residual 58.2265
Iteration 200:
fit residual 56.8198
total residual 4.0447
Iteration 400:
fit residual 3.4777
total residual 0.6841
Iteration 500:
fit residual 0.3744
total residual 0.3185
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
fit residual 2925.3664
total residual 80.2657
Iteration 200:
fit residual 78.1148
total residual 6.1559
Iteration 400:
fit residual 5.1661
total residual 1.1917
Iteration 500:
fit residual 0.6183
total residual 0.5917
Starting NMF decomposition with 400 latent features and 500 iterations.
Iteration 1:
f

training error:  1.6754486560821533  and  2.153059959411621  and  1.6871862411499023
total error:  5.515694856643677
training error:  1.6624603271484375  and  3.046994209289551  and  2.043606996536255
total error:  6.753061532974243
training error:  1.6892871856689453  and  3.1002635955810547  and  1.899970531463623
total error:  6.689521312713623
training error:  1.6894574165344238  and  2.7729454040527344  and  2.2279887199401855
total error:  6.690391540527344
training error:  1.7647335529327393  and  2.6678004264831543  and  1.770745873451233
total error:  6.2032798528671265
training error:  1.6796921491622925  and  2.9960856437683105  and  1.8817791938781738
total error:  6.557556986808777
training error:  1.714460849761963  and  3.060393810272217  and  1.6637136936187744
total error:  6.438568353652954
training error:  1.6232202053070068  and  2.6153311729431152  and  1.8693288564682007
total error:  6.107880234718323
training error:  1.662695050239563  and  2.477398157119751  an

training error:  1.2628190517425537  and  1.9020448923110962  and  1.5356340408325195
total error:  4.700497984886169
training error:  1.2569851875305176  and  1.867966890335083  and  1.5361485481262207
total error:  4.661100625991821
training error:  1.288804292678833  and  2.185131788253784  and  1.5550363063812256
total error:  5.028972387313843
training error:  1.2814550399780273  and  2.1346936225891113  and  1.611405372619629
total error:  5.027554035186768
training error:  1.2846368551254272  and  2.3056764602661133  and  1.430795431137085
total error:  5.0211087465286255
training error:  1.276003122329712  and  2.1782732009887695  and  1.523726463317871
total error:  4.9780027866363525
training error:  1.3605105876922607  and  2.7569808959960938  and  1.4932591915130615
total error:  5.610750675201416
training error:  1.290886640548706  and  3.019918918609619  and  1.7065062522888184
total error:  6.0173118114471436
training error:  1.3054978847503662  and  2.5352983474731445  

training error:  1.1695911884307861  and  1.6925134658813477  and  1.3058351278305054
total error:  4.167939782142639
training error:  1.1502621173858643  and  1.7651057243347168  and  1.3070567846298218
total error:  4.222424626350403
training error:  1.1801626682281494  and  1.8021478652954102  and  1.3474550247192383
total error:  4.329765558242798
training error:  1.1965818405151367  and  1.7112059593200684  and  1.3192580938339233
total error:  4.227045893669128
training error:  1.164451003074646  and  1.6757787466049194  and  1.325300931930542
total error:  4.165530681610107
training error:  1.1668429374694824  and  1.8111616373062134  and  1.4621843099594116
total error:  4.440188884735107
training error:  1.1670546531677246  and  1.730797529220581  and  1.315794825553894
total error:  4.2136470079422
training error:  1.1618108749389648  and  1.720443844795227  and  1.3132541179656982
total error:  4.19550883769989
training error:  1.1579790115356445  and  1.701810359954834  and

training error:  1.1098450422286987  and  1.6112385988235474  and  1.3919973373413086
total error:  4.113080978393555
training error:  1.117152214050293  and  1.5554172992706299  and  1.220350980758667
total error:  3.89292049407959
training error:  1.0991710424423218  and  1.6066737174987793  and  1.2927149534225464
total error:  3.9985597133636475
training error:  1.086097002029419  and  1.681613564491272  and  1.2633841037750244
total error:  4.031094670295715
training error:  1.1161448955535889  and  1.599325180053711  and  1.2905482053756714
total error:  4.006018280982971
training error:  1.097308874130249  and  1.5849332809448242  and  1.257698893547058
total error:  3.9399410486221313
training error:  1.0971806049346924  and  1.5627423524856567  and  1.2240736484527588
total error:  3.883996605873108
training error:  1.0886482000350952  and  1.5613977909088135  and  1.2240285873413086
total error:  3.8740745782852173
training error:  1.1262099742889404  and  1.5627999305725098 

training error:  1.1012204885482788  and  1.5812909603118896  and  1.2585015296936035
total error:  3.941012978553772
training error:  1.0666767358779907  and  1.498076319694519  and  1.1922807693481445
total error:  3.7570338249206543
training error:  1.0409541130065918  and  1.474427580833435  and  1.1929811239242554
total error:  3.7083628177642822
training error:  1.0317583084106445  and  1.529158353805542  and  1.2066552639007568
total error:  3.7675719261169434
training error:  1.0818530321121216  and  1.4647696018218994  and  1.2327828407287598
total error:  3.7794054746627808
training error:  1.0441814661026  and  1.492915391921997  and  1.2076282501220703
total error:  3.7447251081466675
training error:  1.0595569610595703  and  1.502698302268982  and  1.2690715789794922
total error:  3.8313268423080444
training error:  1.0525836944580078  and  1.4895503520965576  and  1.2186743021011353
total error:  3.7608083486557007
training error:  1.0428192615509033  and  1.5114426612854

training error:  0.9940990209579468  and  1.5220742225646973  and  1.2380967140197754
total error:  3.7542699575424194
training error:  1.0140527486801147  and  1.5845811367034912  and  1.1654689311981201
total error:  3.764102816581726
training error:  0.9973326921463013  and  1.5080043077468872  and  1.1570860147476196
total error:  3.662423014640808
training error:  1.0390820503234863  and  1.5296127796173096  and  1.1962090730667114
total error:  3.7649039030075073
training error:  1.0032600164413452  and  1.540157437324524  and  1.1593739986419678
total error:  3.702791452407837
training error:  0.9898293614387512  and  1.4783127307891846  and  1.158970594406128
total error:  3.6271126866340637
training error:  1.0196278095245361  and  1.4398040771484375  and  1.1783955097198486
total error:  3.6378273963928223
training error:  0.9965514540672302  and  1.4242584705352783  and  1.1475634574890137
total error:  3.568373382091522
training error:  1.012423038482666  and  1.54591143131

training error:  0.9889802932739258  and  1.5185110569000244  and  1.1432311534881592
total error:  3.6507225036621094
training error:  0.9621830582618713  and  1.5369091033935547  and  1.1978285312652588
total error:  3.696920692920685
training error:  0.9627043008804321  and  1.4554544687271118  and  1.2123239040374756
total error:  3.6304826736450195
training error:  1.0054206848144531  and  1.466135859489441  and  1.191497802734375
total error:  3.663054347038269
training error:  0.9487791061401367  and  1.5065281391143799  and  1.1137068271636963
total error:  3.569014072418213
training error:  1.0198001861572266  and  1.435736894607544  and  1.1593708992004395
total error:  3.61490797996521
training error:  0.9663941860198975  and  1.4452048540115356  and  1.1367502212524414
total error:  3.5483492612838745
training error:  0.9761618375778198  and  1.4381873607635498  and  1.2015520334243774
total error:  3.615901231765747
training error:  0.9867277145385742  and  1.4407832622528

training error:  0.9404668807983398  and  1.424098253250122  and  1.3067858219146729
total error:  3.6713509559631348
training error:  0.950316309928894  and  1.5022259950637817  and  1.1462339162826538
total error:  3.5987762212753296
training error:  0.9535213112831116  and  1.4670641422271729  and  1.1150217056274414
total error:  3.535607159137726
training error:  0.9312984943389893  and  1.3967548608779907  and  1.1282455921173096
total error:  3.4562989473342896
training error:  0.9446348547935486  and  1.427411675453186  and  1.206246256828308
total error:  3.5782927870750427
training error:  0.925500750541687  and  1.4642219543457031  and  1.116389274597168
total error:  3.506111979484558
training error:  0.9299601912498474  and  1.5186667442321777  and  1.1450270414352417
total error:  3.593653976917267
training error:  0.9543787240982056  and  1.3788928985595703  and  1.085832118988037
total error:  3.419103741645813
training error:  0.939847469329834  and  1.3711957931518555

training error:  0.9181543588638306  and  1.3664069175720215  and  1.1387379169464111
total error:  3.423299193382263
training error:  0.9145143628120422  and  1.4305052757263184  and  1.1377630233764648
total error:  3.4827826619148254
training error:  0.9053983092308044  and  1.4149689674377441  and  1.1045186519622803
total error:  3.424885928630829
training error:  0.9138343334197998  and  1.4044920206069946  and  1.1554741859436035
total error:  3.473800539970398
training error:  0.887093186378479  and  1.4112629890441895  and  1.1863155364990234
total error:  3.484671711921692
training error:  0.9048730134963989  and  1.3664066791534424  and  1.1271274089813232
total error:  3.3984071016311646
training error:  0.8935115337371826  and  1.4126609563827515  and  1.1307246685028076
total error:  3.4368971586227417
training error:  0.9275690317153931  and  1.361552119255066  and  1.1338067054748535
total error:  3.4229278564453125
training error:  0.8925497531890869  and  1.4372965097

training error:  0.8733423352241516  and  1.3460475206375122  and  1.047450065612793
total error:  3.266839921474457
training error:  0.8736662268638611  and  1.412517786026001  and  1.1072731018066406
total error:  3.3934571146965027
training error:  0.8704147934913635  and  1.528760552406311  and  1.0854637622833252
total error:  3.4846391081809998
training error:  0.8604089021682739  and  1.314079999923706  and  1.1149547100067139
total error:  3.289443612098694
training error:  0.8715293407440186  and  1.352165699005127  and  1.0673675537109375
total error:  3.291062593460083
training error:  0.854900062084198  and  1.3262965679168701  and  1.1466400623321533
total error:  3.3278366923332214
training error:  0.8601369857788086  and  1.3169162273406982  and  1.0601744651794434
total error:  3.23722767829895
training error:  0.8648403882980347  and  1.6251816749572754  and  1.1672385931015015
total error:  3.6572606563568115
training error:  0.8837770819664001  and  1.342712879180908

training error:  0.8887952566146851  and  1.4616400003433228  and  1.1018723249435425
total error:  3.4523075819015503
training error:  0.8542017340660095  and  1.3466122150421143  and  1.0942734479904175
total error:  3.2950873970985413
training error:  0.8272950649261475  and  1.3497035503387451  and  1.2198951244354248
total error:  3.3968937397003174
training error:  0.827559769153595  and  1.3239802122116089  and  1.1022276878356934
total error:  3.253767669200897
training error:  0.8226279020309448  and  1.4035546779632568  and  1.2616151571273804
total error:  3.487797737121582
training error:  0.8174234628677368  and  1.3652420043945312  and  1.0715175867080688
total error:  3.254183053970337
training error:  0.8606151342391968  and  1.3173766136169434  and  1.0706169605255127
total error:  3.248608708381653
training error:  0.8379865884780884  and  1.2641453742980957  and  1.1207870244979858
total error:  3.22291898727417
training error:  0.8389513492584229  and  1.40573954582

training error:  0.8320251107215881  and  1.2982938289642334  and  1.0503575801849365
total error:  3.180676519870758
training error:  0.7962765693664551  and  1.3224716186523438  and  1.046181559562683
total error:  3.164929747581482
training error:  0.830367922782898  and  1.2858538627624512  and  1.0505887269973755
total error:  3.1668105125427246
training error:  0.8045377731323242  and  1.3201203346252441  and  1.1112630367279053
total error:  3.2359211444854736
training error:  0.8032628893852234  and  1.282059669494629  and  1.0389354228973389
total error:  3.124257981777191
training error:  0.8475581407546997  and  1.342742681503296  and  1.20307195186615
total error:  3.3933727741241455
training error:  0.8247844576835632  and  1.4558155536651611  and  1.141068935394287
total error:  3.4216689467430115
training error:  0.8122880458831787  and  1.275554895401001  and  1.0571002960205078
total error:  3.1449432373046875
training error:  0.799368679523468  and  1.3322196006774902

training error:  0.8328872919082642  and  1.4739325046539307  and  1.2992701530456543
total error:  3.606089949607849
training error:  0.7959039211273193  and  1.271348476409912  and  1.0765241384506226
total error:  3.143776535987854
training error:  0.8091904520988464  and  1.345489263534546  and  1.1189628839492798
total error:  3.273642599582672
training error:  0.8437836766242981  and  1.3231785297393799  and  1.102539300918579
total error:  3.269501507282257
training error:  0.8123223185539246  and  1.3558140993118286  and  1.0569349527359009
total error:  3.225071370601654
training error:  0.8081828355789185  and  1.3036737442016602  and  1.0146287679672241
total error:  3.1264853477478027
training error:  0.7894676327705383  and  1.3399189710617065  and  1.1554875373840332
total error:  3.284874141216278
training error:  0.7794038653373718  and  1.3236442804336548  and  1.0555336475372314
total error:  3.158581793308258
training error:  0.7844398617744446  and  1.25359547138214

training error:  0.7435681223869324  and  1.245274305343628  and  1.0003864765167236
total error:  2.989228904247284
training error:  0.7498725652694702  and  1.2932714223861694  and  1.0567829608917236
total error:  3.0999269485473633
training error:  0.7924160957336426  and  1.2645986080169678  and  1.0653386116027832
total error:  3.1223533153533936
training error:  0.7511671781539917  and  1.326938271522522  and  0.9964550137519836
total error:  3.0745604634284973
training error:  0.7536033391952515  and  1.250058889389038  and  1.0413758754730225
total error:  3.045038104057312
training error:  0.7621943950653076  and  1.2310644388198853  and  1.055194616317749
total error:  3.048453450202942
training error:  0.7730115652084351  and  1.2368273735046387  and  1.015634298324585
total error:  3.0254732370376587
training error:  0.747424840927124  and  1.310813307762146  and  1.0249176025390625
total error:  3.0831557512283325
training error:  0.7375009059906006  and  1.24995994567871

training error:  0.7317082285881042  and  1.2773699760437012  and  0.9925781488418579
total error:  3.0016563534736633
training error:  0.7605822682380676  and  1.257725477218628  and  1.0208534002304077
total error:  3.0391611456871033
training error:  0.7321813106536865  and  1.188671588897705  and  0.9800540208816528
total error:  2.9009069204330444
training error:  0.7710613012313843  and  1.2818603515625  and  1.0698264837265015
total error:  3.1227481365203857
training error:  0.7514139413833618  and  1.3366175889968872  and  1.1016550064086914
total error:  3.1896865367889404
training error:  0.7483114004135132  and  1.250396966934204  and  1.0316638946533203
total error:  3.0303722620010376
training error:  0.7348713278770447  and  1.2530486583709717  and  1.1150333881378174
total error:  3.1029533743858337
training error:  0.7712981700897217  and  1.2892502546310425  and  1.0409979820251465
total error:  3.1015464067459106
training error:  0.7309648990631104  and  1.2391555309

training error:  0.7109464406967163  and  1.2099096775054932  and  0.9761776924133301
total error:  2.8970338106155396
training error:  0.7111570239067078  and  1.2485308647155762  and  1.0653986930847168
total error:  3.0250865817070007
training error:  0.71346116065979  and  1.2826321125030518  and  1.0549317598342896
total error:  3.0510250329971313
training error:  0.7441014647483826  and  1.2087255716323853  and  1.0247317552566528
total error:  2.9775587916374207
training error:  0.7323365211486816  and  1.3205502033233643  and  1.069115400314331
total error:  3.122002124786377
training error:  0.7186568379402161  and  1.1800098419189453  and  1.0096418857574463
total error:  2.9083085656166077
training error:  0.7331581711769104  and  1.244660496711731  and  1.0774431228637695
total error:  3.055261790752411
training error:  0.7311124801635742  and  1.2959415912628174  and  1.0871460437774658
total error:  3.1142001152038574
training error:  0.712791383266449  and  1.26328277587

training error:  0.7053011655807495  and  1.2533881664276123  and  1.1384609937667847
total error:  3.0971503257751465
training error:  0.7125871181488037  and  1.263728141784668  and  1.0011309385299683
total error:  2.97744619846344
training error:  0.7035617828369141  and  1.1913702487945557  and  1.0382367372512817
total error:  2.9331687688827515
training error:  0.7091390490531921  and  1.2188767194747925  and  1.0015970468521118
total error:  2.9296128153800964
training error:  0.6892918348312378  and  1.2348823547363281  and  0.9791200160980225
total error:  2.9032942056655884
training error:  0.7374359369277954  and  1.2651658058166504  and  0.996547281742096
total error:  2.9991490244865417
training error:  0.7016081213951111  and  1.307435393333435  and  0.951009213924408
total error:  2.960052728652954
training error:  0.7015217542648315  and  1.22553288936615  and  0.9857065677642822
total error:  2.9127612113952637
training error:  0.7226661443710327  and  1.2419526576995

training error:  0.6895118951797485  and  1.2155015468597412  and  0.9877712726593018
total error:  2.8927847146987915
training error:  0.6646513938903809  and  1.1906267404556274  and  1.0009398460388184
total error:  2.8562179803848267
training error:  0.7136592864990234  and  1.2661919593811035  and  0.9628415107727051
total error:  2.942692756652832
training error:  0.6680553555488586  and  1.149858832359314  and  0.9718100428581238
total error:  2.7897242307662964
training error:  0.6598175168037415  and  1.1658356189727783  and  1.0010600090026855
total error:  2.8267131447792053
training error:  0.6613394618034363  and  1.1353416442871094  and  0.9955861568450928
total error:  2.7922672629356384
training error:  0.6794639825820923  and  1.2055861949920654  and  0.9596872329711914
total error:  2.844737410545349
training error:  0.6843738555908203  and  1.2449755668640137  and  1.0141891241073608
total error:  2.943538546562195
training error:  0.6718795299530029  and  1.15001988

training error:  0.7097716927528381  and  1.2030763626098633  and  0.970677375793457
total error:  2.8835254311561584
training error:  0.6559072732925415  and  1.3528039455413818  and  1.006397008895874
total error:  3.0151082277297974
training error:  0.6749445199966431  and  1.2911169528961182  and  0.9928915500640869
total error:  2.958953022956848
training error:  0.665045976638794  and  1.1995317935943604  and  1.0260416269302368
total error:  2.890619397163391
training error:  0.6402813196182251  and  1.1333142518997192  and  0.9105044007301331
total error:  2.6840999722480774
training error:  0.6481939554214478  and  1.1992313861846924  and  0.9412566423416138
total error:  2.788681983947754
training error:  0.6971312165260315  and  1.2477171421051025  and  0.9856301546096802
total error:  2.930478513240814
training error:  0.6472312211990356  and  1.1873719692230225  and  0.9324456453323364
total error:  2.7670488357543945
training error:  0.6635128855705261  and  1.15807437896

training error:  0.6901647448539734  and  1.2427270412445068  and  0.9746440052986145
total error:  2.9075357913970947
training error:  0.646152138710022  and  1.3003041744232178  and  0.9761883616447449
total error:  2.9226446747779846
training error:  0.6844474077224731  and  1.256443738937378  and  1.0504143238067627
total error:  2.9913054704666138
training error:  0.6847940683364868  and  1.193699598312378  and  1.0006481409072876
total error:  2.8791418075561523
training error:  0.640740156173706  and  1.133882999420166  and  0.9328166246414185
total error:  2.7074397802352905
training error:  0.6335743069648743  and  1.129737377166748  and  0.913902759552002
total error:  2.6772144436836243
training error:  0.6251822113990784  and  1.1489068269729614  and  0.9998871088027954
total error:  2.773976147174835
training error:  0.6914986968040466  and  1.1929787397384644  and  0.94221031665802
total error:  2.826687753200531
training error:  0.6511985659599304  and  1.314778566360473

training error:  0.6528568267822266  and  1.2293591499328613  and  0.9455841779708862
total error:  2.827800154685974
training error:  0.6322684288024902  and  1.1228128671646118  and  0.953754186630249
total error:  2.708835482597351
training error:  0.6425156593322754  and  1.1538352966308594  and  0.9618743658065796
total error:  2.7582253217697144
training error:  0.6806371212005615  and  1.189725399017334  and  0.9361158609390259
total error:  2.8064783811569214
training error:  0.6652708649635315  and  1.2302072048187256  and  1.021905779838562
total error:  2.917383849620819
training error:  0.656363308429718  and  1.1718088388442993  and  0.9855109453201294
total error:  2.8136830925941467
training error:  0.6682276725769043  and  1.2502787113189697  and  1.152758002281189
total error:  3.071264386177063
training error:  0.6397720575332642  and  1.2224669456481934  and  0.9515810012817383
total error:  2.813820004463196
training error:  0.629645824432373  and  1.164507865905761

training error:  0.6296592950820923  and  1.1123582124710083  and  0.9012173414230347
total error:  2.6432348489761353
training error:  0.6207695007324219  and  1.1285018920898438  and  0.9541512131690979
total error:  2.7034226059913635
training error:  0.6220389604568481  and  1.1194097995758057  and  0.9633434414863586
total error:  2.7047922015190125
training error:  0.6231241226196289  and  1.1856775283813477  and  1.0114245414733887
total error:  2.8202261924743652
training error:  0.6180993318557739  and  1.3820956945419312  and  0.9581710696220398
total error:  2.958366096019745
training error:  0.6260132789611816  and  1.1586109399795532  and  0.988207221031189
total error:  2.772831439971924
training error:  0.6181572675704956  and  1.1070677042007446  and  0.9029533267021179
total error:  2.628178298473358
training error:  0.6378037929534912  and  1.1876260042190552  and  0.9548255205154419
total error:  2.7802553176879883
training error:  0.620019793510437  and  1.096014022

training error:  0.6189955472946167  and  1.0743584632873535  and  0.9079983830451965
total error:  2.6013523936271667
training error:  0.6312130689620972  and  1.1470919847488403  and  0.8951655626296997
total error:  2.673470616340637
training error:  0.6614522933959961  and  1.2211087942123413  and  0.9311130046844482
total error:  2.8136740922927856
training error:  0.6108990907669067  and  1.1266937255859375  and  0.9494856595993042
total error:  2.6870784759521484
training error:  0.5999575853347778  and  1.090692400932312  and  0.9276531338691711
total error:  2.618303120136261
training error:  0.6165540814399719  and  1.1291528940200806  and  0.9687210917472839
total error:  2.7144280672073364
training error:  0.6345139741897583  and  1.150397777557373  and  0.9373883008956909
total error:  2.7223000526428223
training error:  0.6250115036964417  and  1.1456294059753418  and  0.9294856786727905
total error:  2.700126588344574
training error:  0.604793131351471  and  1.2066844701

training error:  0.590773344039917  and  1.0679402351379395  and  0.9013017416000366
total error:  2.560015320777893
training error:  0.5844342112541199  and  1.2073030471801758  and  0.962181806564331
total error:  2.7539190649986267
training error:  0.5687960386276245  and  1.0959392786026  and  1.0307977199554443
total error:  2.695533037185669
training error:  0.5851587653160095  and  1.0840516090393066  and  0.8800610303878784
total error:  2.5492714047431946
training error:  0.6344858407974243  and  1.683504581451416  and  1.024016261100769
total error:  3.3420066833496094
training error:  0.6080040335655212  and  1.0964322090148926  and  0.9907436370849609
total error:  2.6951798796653748
training error:  0.6123466491699219  and  1.1005594730377197  and  0.9320604205131531
total error:  2.6449665427207947
training error:  0.5832064151763916  and  1.0985394716262817  and  0.9107155799865723
total error:  2.5924614667892456
training error:  0.6010260581970215  and  1.0786491632461

training error:  0.5843263268470764  and  1.070948600769043  and  0.8959866762161255
total error:  2.551261603832245
training error:  0.6053873300552368  and  1.0791934728622437  and  0.8962466716766357
total error:  2.580827474594116
training error:  0.6082632541656494  and  1.1611309051513672  and  0.931553065776825
total error:  2.7009472250938416
training error:  0.5819931030273438  and  1.0872492790222168  and  0.9279971122741699
total error:  2.5972394943237305
training error:  0.6512244939804077  and  1.3085980415344238  and  1.049906611442566
total error:  3.0097291469573975
training error:  0.5941983461380005  and  1.07937490940094  and  1.0505716800689697
total error:  2.72414493560791
training error:  0.6145918369293213  and  1.120802402496338  and  0.980650782585144
total error:  2.7160450220108032
training error:  0.6379356384277344  and  1.661158800125122  and  1.0834890604019165
total error:  3.382583498954773
training error:  0.6112608909606934  and  1.1401499509811401 

training error:  0.6149425506591797  and  1.222476840019226  and  0.8993515968322754
total error:  2.736770987510681
training error:  0.5943453907966614  and  1.1529688835144043  and  1.0544289350509644
total error:  2.80174320936203
training error:  0.605589747428894  and  1.0661239624023438  and  0.8701940774917603
total error:  2.541907787322998
training error:  0.578535795211792  and  1.0792478322982788  and  0.862877368927002
total error:  2.5206609964370728
training error:  0.5929731130599976  and  1.1002919673919678  and  0.9969443082809448
total error:  2.69020938873291
training error:  0.5867084860801697  and  1.1044914722442627  and  0.8831661939620972
total error:  2.5743661522865295
training error:  0.5832327604293823  and  1.089476466178894  and  0.9030553698539734
total error:  2.5757645964622498
training error:  0.5964058637619019  and  1.223304271697998  and  0.9159661531448364
total error:  2.7356762886047363
training error:  0.5807750225067139  and  1.099353313446045 

training error:  0.5691409707069397  and  1.0899503231048584  and  0.9325540065765381
total error:  2.591645300388336
training error:  0.5640130043029785  and  1.0682308673858643  and  0.9696282148361206
total error:  2.6018720865249634
training error:  0.5611695051193237  and  1.0765448808670044  and  0.9462345838546753
total error:  2.5839489698410034
training error:  0.5938529968261719  and  1.1145788431167603  and  0.8866311311721802
total error:  2.5950629711151123
training error:  0.5410531759262085  and  1.0432841777801514  and  0.8847142457962036
total error:  2.4690515995025635
training error:  0.5932449102401733  and  1.1334179639816284  and  0.815707802772522
total error:  2.5423706769943237
training error:  0.5658793449401855  and  1.1663259267807007  and  0.9260292649269104
total error:  2.6582345366477966
training error:  0.5632228255271912  and  1.096134901046753  and  0.8992069959640503
total error:  2.5585647225379944
training error:  0.5478334426879883  and  1.0408061

training error:  0.5747042894363403  and  1.08788001537323  and  0.8941582441329956
total error:  2.556742548942566
training error:  0.5724764466285706  and  1.0812737941741943  and  0.8877294659614563
total error:  2.541479706764221
training error:  0.5449273586273193  and  1.092639446258545  and  0.8873257040977478
total error:  2.524892508983612
training error:  0.5314903259277344  and  1.0682547092437744  and  0.8884127140045166
total error:  2.4881577491760254
training error:  0.5776450634002686  and  1.1643296480178833  and  0.8928090333938599
total error:  2.6347837448120117
training error:  0.5427354574203491  and  1.080774188041687  and  0.8571670055389404
total error:  2.4806766510009766
training error:  0.5484751462936401  and  1.104699730873108  and  0.8764665126800537
total error:  2.5296413898468018
training error:  0.598063588142395  and  1.1469672918319702  and  0.9141020178794861
total error:  2.6591328978538513
training error:  0.6297982335090637  and  1.0761170387268

In [2]:
torch.save(netG.state_dict(),"./epsilon0.01.pt")