In [1]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn import metrics
import torch.optim as optim

In [2]:
sub_features = ['42 tGravityAcc-mean()-Y',
 '43 tGravityAcc-mean()-Z',
 '51 tGravityAcc-max()-Y',
 '52 tGravityAcc-max()-Z',
 '54 tGravityAcc-min()-Y',
 '55 tGravityAcc-min()-Z',
 '56 tGravityAcc-sma()',
 '58 tGravityAcc-energy()-Y',
 '59 tGravityAcc-energy()-Z',
 '128 tBodyGyro-mad()-Y',
 '141 tBodyGyro-iqr()-Y',
 '428 fBodyGyro-std()-Y',
 '434 fBodyGyro-max()-Y',
 '475 fBodyGyro-bandsEnergy()-1,8',
 '483 fBodyGyro-bandsEnergy()-1,16',
 '487 fBodyGyro-bandsEnergy()-1,24',
 '559 angle(X,gravityMean)',
 '560 angle(Y,gravityMean)',
 '561 angle(Z,gravityMean)']

act_features = ['4 tBodyAcc-std()-X',
 '7 tBodyAcc-mad()-X',
 '10 tBodyAcc-max()-X',
 '17 tBodyAcc-energy()-X',
 '202 tBodyAccMag-std()',
 '203 tBodyAccMag-mad()',
 '215 tGravityAccMag-std()',
 '216 tGravityAccMag-mad()',
 '266 fBodyAcc-mean()-X',
 '269 fBodyAcc-std()-X',
 '282 fBodyAcc-energy()-X',
 '303 fBodyAcc-bandsEnergy()-1,8',
 '311 fBodyAcc-bandsEnergy()-1,16',
 '315 fBodyAcc-bandsEnergy()-1,24',
 '382 fBodyAccJerk-bandsEnergy()-1,8',
 '504 fBodyAccMag-std()',
 '505 fBodyAccMag-mad()',
 '509 fBodyAccMag-energy()']

input_shape = len(sub_features) + len(act_features)

In [3]:
def classifier_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.LeakyReLU(0.05)
    )

class Classifier(nn.Module):
    def __init__(self, feature_dim = input_shape):
        super(Classifier, self).__init__()
        self.network = nn.Sequential(
            classifier_block(feature_dim, 30),
            classifier_block(30, 25),
            classifier_block(25, 20),
            classifier_block(20, 15),
            nn.Linear(15, 12)
        )
    def forward(self, x):
        return self.network(x)

In [4]:
#defines each generator layer
#input and output dimensions needed
def generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace = True)
    )

#returns n_samples of z_dim (number of dimensions of latent space) noise
def get_noise(n_samples, z_dim):
    return torch.randn(n_samples, z_dim)

#defines generator class
class Generator(nn.Module):
    def __init__(self, z_dim = 10, feature_dim = input_shape, hidden_dim = 128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            generator_block(z_dim, 80),
            generator_block(80, 60),
            generator_block(60, 50),
            nn.Linear(50, feature_dim),
            nn.Tanh()
        )
    def forward(self, noise):
        return self.gen(noise)

def get_act_matrix(batch_size, a_dim):
    indexes = np.random.randint(a_dim, size = batch_size)
    
    one_hot = np.zeros((len(indexes), indexes.max()+1))
    one_hot[np.arange(len(indexes)),indexes] = 1
    return torch.Tensor(indexes).long(), torch.Tensor(one_hot)
    
def get_usr_matrix(batch_size, u_dim):
    indexes = np.random.randint(u_dim, size = batch_size)
    
    one_hot = np.zeros((indexes.size, indexes.max()+1))
    one_hot[np.arange(indexes.size),indexes] = 1
    return torch.Tensor(indexes).long(), torch.Tensor(one_hot)

def load_model(model, model_name):
    model.load_state_dict(torch.load(f'../../../saved_models/{model_name}'))

# Train on Real Test on Real

In [5]:
#label is a list of integers specifying which labels to filter by
#users is a list of integers specifying which users to filter by
#y_label is a string, either "Activity" or "Subject" depending on what y output needs to be returned
def start_data(label, users, y_label, sub_features, act_features):
    #get the dataframe column names
    name_dataframe = pd.read_csv('../../../data/features.txt', delimiter = '\n', header = None)
    names = name_dataframe.values.tolist()
    names = [k for row in names for k in row] #List of column names

    data = pd.read_csv('../../../data/X_train.txt', delim_whitespace = True, header = None) #Read in dataframe
    data.columns = names #Setting column names
    
    X_train_1 = data[sub_features]
    X_train_2 = data[act_features]
    X_train = pd.concat([X_train_1, X_train_2], axis = 1)
    
    y_train_activity = pd.read_csv('../../../data/y_train.txt', header = None)
    y_train_activity.columns = ['Activity']
    
    y_train_subject = pd.read_csv('../../../data/subject_train.txt', header = None)
    y_train_subject.columns = ['Subject']
    
    GAN_data = pd.concat([X_train, y_train_activity, y_train_subject], axis = 1)
    GAN_data = GAN_data[GAN_data['Activity'].isin(label)]
    GAN_data = GAN_data[GAN_data['Subject'].isin(users)]
    
    X_1 = GAN_data[(GAN_data['Subject'] == 1) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_2 = GAN_data[(GAN_data['Subject'] == 1) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_3 = GAN_data[(GAN_data['Subject'] == 1) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_4 = GAN_data[(GAN_data['Subject'] == 3) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_5 = GAN_data[(GAN_data['Subject'] == 3) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_6 = GAN_data[(GAN_data['Subject'] == 3) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_7 = GAN_data[(GAN_data['Subject'] == 5) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_8 = GAN_data[(GAN_data['Subject'] == 5) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_9 = GAN_data[(GAN_data['Subject'] == 5) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_10 = GAN_data[(GAN_data['Subject'] == 7) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_11 = GAN_data[(GAN_data['Subject'] == 7) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_12 = GAN_data[(GAN_data['Subject'] == 7) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    
    X_train = np.concatenate((X_1, X_2, X_3, X_4, X_5, X_6, X_7, X_8, X_9, X_10, X_11, X_12))
    y_train = [0] * len(X_1) + [1] * len(X_2) + [2] * len(X_3) + [3] * len(X_4) + [4] * len(X_5) + [5] * len(X_6) + [6] * len(X_7) + [7] * len(X_8) + [8] * len(X_9) + [9] * len(X_10) + [10] * len(X_11) + [11] * len(X_12)
    
    return X_train, np.asarray(y_train)

In [6]:
activities = [1, 3, 4]
users = [1, 3, 5, 7]

X, y = start_data(activities, users, "Activity", sub_features, act_features)

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, shuffle = True)

model = Classifier()
lr = 0.001
n_epochs = 5000
batch_size = 250

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

train_features = torch.tensor(X_train)
train_labels = torch.tensor(y_train)
test_features = torch.tensor(X_test)
test_labels = torch.tensor(y_test)

train_data = torch.utils.data.TensorDataset(train_features, train_labels)
test_data = torch.utils.data.TensorDataset(test_features, test_labels)

train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = len(test_labels), shuffle = True)

In [8]:
for epoch in range(n_epochs):
    total_loss = 0
    for batch in train_loader:
        features, labels = batch
        
        optimizer.zero_grad()
        preds = model(features.float())
        
        loss = criterion(preds, labels.long()) 
        loss.backward()
        
        optimizer.step()
        total_loss += loss.item()
        
    print(f'Epoch {epoch + 1}, Loss: {total_loss}, Final Batch Loss: {loss.item()}')

Epoch 1, Loss: 7.4565019607543945, Final Batch Loss: 2.4861161708831787
Epoch 2, Loss: 7.389077186584473, Final Batch Loss: 2.4213759899139404
Epoch 3, Loss: 7.438206195831299, Final Batch Loss: 2.476147413253784
Epoch 4, Loss: 7.425340890884399, Final Batch Loss: 2.4706366062164307
Epoch 5, Loss: 7.407440423965454, Final Batch Loss: 2.455002784729004
Epoch 6, Loss: 7.455641031265259, Final Batch Loss: 2.507023572921753
Epoch 7, Loss: 7.404532432556152, Final Batch Loss: 2.455768346786499
Epoch 8, Loss: 7.409439325332642, Final Batch Loss: 2.464600086212158
Epoch 9, Loss: 7.442714214324951, Final Batch Loss: 2.5060012340545654
Epoch 10, Loss: 7.377971410751343, Final Batch Loss: 2.4446983337402344
Epoch 11, Loss: 7.463738679885864, Final Batch Loss: 2.538994312286377
Epoch 12, Loss: 7.381815671920776, Final Batch Loss: 2.4655027389526367
Epoch 13, Loss: 7.37760591506958, Final Batch Loss: 2.46934175491333
Epoch 14, Loss: 7.332941293716431, Final Batch Loss: 2.4351930618286133
Epoch 15,

Epoch 121, Loss: 3.4227999448776245, Final Batch Loss: 1.2566736936569214
Epoch 122, Loss: 3.445485472679138, Final Batch Loss: 1.2329602241516113
Epoch 123, Loss: 3.1802684664726257, Final Batch Loss: 0.9533560872077942
Epoch 124, Loss: 3.2421780824661255, Final Batch Loss: 0.9654021263122559
Epoch 125, Loss: 3.391984462738037, Final Batch Loss: 1.1992532014846802
Epoch 126, Loss: 3.08952397108078, Final Batch Loss: 0.9206394553184509
Epoch 127, Loss: 3.0704381465911865, Final Batch Loss: 0.9501303434371948
Epoch 128, Loss: 3.337642788887024, Final Batch Loss: 1.1260581016540527
Epoch 129, Loss: 3.213936448097229, Final Batch Loss: 1.022255301475525
Epoch 130, Loss: 3.31439745426178, Final Batch Loss: 1.1407952308654785
Epoch 131, Loss: 3.3022472858428955, Final Batch Loss: 1.132822036743164
Epoch 132, Loss: 3.1774948835372925, Final Batch Loss: 0.9923214912414551
Epoch 133, Loss: 3.4500350952148438, Final Batch Loss: 1.2308024168014526
Epoch 134, Loss: 3.1075358390808105, Final Batch

Epoch 232, Loss: 2.794485032558441, Final Batch Loss: 1.0647879838943481
Epoch 233, Loss: 2.5321959853172302, Final Batch Loss: 0.786387026309967
Epoch 234, Loss: 1.9754333198070526, Final Batch Loss: 0.3347020447254181
Epoch 235, Loss: 2.54498028755188, Final Batch Loss: 0.8374310731887817
Epoch 236, Loss: 2.537276029586792, Final Batch Loss: 0.9168505668640137
Epoch 237, Loss: 2.389365315437317, Final Batch Loss: 0.6080032587051392
Epoch 238, Loss: 2.187043458223343, Final Batch Loss: 0.4386216104030609
Epoch 239, Loss: 2.3405430912971497, Final Batch Loss: 0.7358778119087219
Epoch 240, Loss: 2.3713969588279724, Final Batch Loss: 0.7175435423851013
Epoch 241, Loss: 2.6247350573539734, Final Batch Loss: 0.8955686688423157
Epoch 242, Loss: 2.455954134464264, Final Batch Loss: 0.7595963478088379
Epoch 243, Loss: 2.3565760254859924, Final Batch Loss: 0.739057719707489
Epoch 244, Loss: 2.495826303958893, Final Batch Loss: 0.7519140839576721
Epoch 245, Loss: 2.4173234701156616, Final Batch

Epoch 351, Loss: 2.013450562953949, Final Batch Loss: 0.6400045156478882
Epoch 352, Loss: 2.1342252492904663, Final Batch Loss: 0.6904661655426025
Epoch 353, Loss: 2.0439324378967285, Final Batch Loss: 0.5866657495498657
Epoch 354, Loss: 1.9390186667442322, Final Batch Loss: 0.5529820322990417
Epoch 355, Loss: 1.9993765950202942, Final Batch Loss: 0.5824297666549683
Epoch 356, Loss: 1.7692166864871979, Final Batch Loss: 0.2674235999584198
Epoch 357, Loss: 2.4637480974197388, Final Batch Loss: 0.9410055875778198
Epoch 358, Loss: 2.076961040496826, Final Batch Loss: 0.7111289501190186
Epoch 359, Loss: 2.3985010385513306, Final Batch Loss: 0.9990618824958801
Epoch 360, Loss: 1.955083817243576, Final Batch Loss: 0.4821673333644867
Epoch 361, Loss: 1.8687787353992462, Final Batch Loss: 0.4787910282611847
Epoch 362, Loss: 2.514367163181305, Final Batch Loss: 1.039547324180603
Epoch 363, Loss: 1.9633955359458923, Final Batch Loss: 0.6130227446556091
Epoch 364, Loss: 2.2619903683662415, Final 

Epoch 470, Loss: 1.9586249589920044, Final Batch Loss: 0.6220993399620056
Epoch 471, Loss: 1.8814269304275513, Final Batch Loss: 0.5689801573753357
Epoch 472, Loss: 1.8783403038978577, Final Batch Loss: 0.5932577848434448
Epoch 473, Loss: 2.3281983733177185, Final Batch Loss: 1.061290979385376
Epoch 474, Loss: 2.292081296443939, Final Batch Loss: 0.9729374051094055
Epoch 475, Loss: 2.084551751613617, Final Batch Loss: 0.7967334985733032
Epoch 476, Loss: 1.7897415161132812, Final Batch Loss: 0.5070258975028992
Epoch 477, Loss: 1.723101019859314, Final Batch Loss: 0.4327154755592346
Epoch 478, Loss: 1.5935910940170288, Final Batch Loss: 0.3171360492706299
Epoch 479, Loss: 2.260531723499298, Final Batch Loss: 0.991112470626831
Epoch 480, Loss: 1.7151660025119781, Final Batch Loss: 0.4271680414676666
Epoch 481, Loss: 1.959616482257843, Final Batch Loss: 0.6650979518890381
Epoch 482, Loss: 1.8413841128349304, Final Batch Loss: 0.5836924910545349
Epoch 483, Loss: 2.0854429602622986, Final Ba

Epoch 592, Loss: 1.69326913356781, Final Batch Loss: 0.4686644673347473
Epoch 593, Loss: 1.9077742099761963, Final Batch Loss: 0.7457810640335083
Epoch 594, Loss: 1.6730315685272217, Final Batch Loss: 0.4802761673927307
Epoch 595, Loss: 1.8142898082733154, Final Batch Loss: 0.6509081125259399
Epoch 596, Loss: 1.5487203299999237, Final Batch Loss: 0.3358323276042938
Epoch 597, Loss: 1.6283623278141022, Final Batch Loss: 0.4968093931674957
Epoch 598, Loss: 1.7959765791893005, Final Batch Loss: 0.6620500683784485
Epoch 599, Loss: 1.8955330848693848, Final Batch Loss: 0.7218726277351379
Epoch 600, Loss: 1.6426363587379456, Final Batch Loss: 0.45590901374816895
Epoch 601, Loss: 1.7832172513008118, Final Batch Loss: 0.6326601505279541
Epoch 602, Loss: 1.6320799887180328, Final Batch Loss: 0.4860135018825531
Epoch 603, Loss: 1.7057815790176392, Final Batch Loss: 0.6133989691734314
Epoch 604, Loss: 1.727394938468933, Final Batch Loss: 0.675423800945282
Epoch 605, Loss: 1.6163692772388458, Fina

Epoch 712, Loss: 1.3604874014854431, Final Batch Loss: 0.3702690601348877
Epoch 713, Loss: 1.4938486516475677, Final Batch Loss: 0.5340778231620789
Epoch 714, Loss: 1.3767221868038177, Final Batch Loss: 0.41569191217422485
Epoch 715, Loss: 1.4216744303703308, Final Batch Loss: 0.4121338427066803
Epoch 716, Loss: 1.6443625092506409, Final Batch Loss: 0.6093217730522156
Epoch 717, Loss: 1.5765307545661926, Final Batch Loss: 0.5296295285224915
Epoch 718, Loss: 1.48597052693367, Final Batch Loss: 0.4985570013523102
Epoch 719, Loss: 1.4395391643047333, Final Batch Loss: 0.4265119731426239
Epoch 720, Loss: 1.4471289813518524, Final Batch Loss: 0.44619980454444885
Epoch 721, Loss: 1.3910677134990692, Final Batch Loss: 0.32238802313804626
Epoch 722, Loss: 1.5264371037483215, Final Batch Loss: 0.44726496934890747
Epoch 723, Loss: 1.4387918412685394, Final Batch Loss: 0.44875070452690125
Epoch 724, Loss: 1.314142107963562, Final Batch Loss: 0.3284420073032379
Epoch 725, Loss: 1.567865937948227, 

Epoch 833, Loss: 1.470682293176651, Final Batch Loss: 0.5155634880065918
Epoch 834, Loss: 1.410218507051468, Final Batch Loss: 0.4960308074951172
Epoch 835, Loss: 1.4570623934268951, Final Batch Loss: 0.5344264507293701
Epoch 836, Loss: 1.3133151829242706, Final Batch Loss: 0.4102727770805359
Epoch 837, Loss: 1.259329617023468, Final Batch Loss: 0.345157265663147
Epoch 838, Loss: 1.172047957777977, Final Batch Loss: 0.19939281046390533
Epoch 839, Loss: 1.6252447664737701, Final Batch Loss: 0.7337371706962585
Epoch 840, Loss: 1.5271019637584686, Final Batch Loss: 0.6065645813941956
Epoch 841, Loss: 1.423428624868393, Final Batch Loss: 0.4991820454597473
Epoch 842, Loss: 1.611154466867447, Final Batch Loss: 0.7026229500770569
Epoch 843, Loss: 1.439673662185669, Final Batch Loss: 0.5549708008766174
Epoch 844, Loss: 1.3150800168514252, Final Batch Loss: 0.40172523260116577
Epoch 845, Loss: 1.3338307738304138, Final Batch Loss: 0.35686376690864563
Epoch 846, Loss: 1.4344549775123596, Final 

Epoch 955, Loss: 1.5196517407894135, Final Batch Loss: 0.6454699635505676
Epoch 956, Loss: 1.4424238502979279, Final Batch Loss: 0.48761168122291565
Epoch 957, Loss: 1.413595050573349, Final Batch Loss: 0.39649781584739685
Epoch 958, Loss: 1.337997853755951, Final Batch Loss: 0.41922083497047424
Epoch 959, Loss: 1.1498874127864838, Final Batch Loss: 0.29849082231521606
Epoch 960, Loss: 1.2952357530593872, Final Batch Loss: 0.4441763162612915
Epoch 961, Loss: 1.2279143035411835, Final Batch Loss: 0.3579261302947998
Epoch 962, Loss: 1.3247837126255035, Final Batch Loss: 0.43515315651893616
Epoch 963, Loss: 1.4294835329055786, Final Batch Loss: 0.5881605744361877
Epoch 964, Loss: 1.5590317845344543, Final Batch Loss: 0.7682259678840637
Epoch 965, Loss: 1.2587236762046814, Final Batch Loss: 0.3943637013435364
Epoch 966, Loss: 1.291303664445877, Final Batch Loss: 0.4496114253997803
Epoch 967, Loss: 1.4650084972381592, Final Batch Loss: 0.6133897304534912
Epoch 968, Loss: 1.3441638350486755,

Epoch 1077, Loss: 1.327939510345459, Final Batch Loss: 0.49086710810661316
Epoch 1078, Loss: 1.2435298562049866, Final Batch Loss: 0.3936282992362976
Epoch 1079, Loss: 1.0416203290224075, Final Batch Loss: 0.2397276908159256
Epoch 1080, Loss: 1.2104957103729248, Final Batch Loss: 0.35244083404541016
Epoch 1081, Loss: 1.2080605030059814, Final Batch Loss: 0.33318188786506653
Epoch 1082, Loss: 1.347434937953949, Final Batch Loss: 0.48266395926475525
Epoch 1083, Loss: 1.235037088394165, Final Batch Loss: 0.3323505222797394
Epoch 1084, Loss: 1.4008136689662933, Final Batch Loss: 0.6172818541526794
Epoch 1085, Loss: 1.1778971552848816, Final Batch Loss: 0.34654173254966736
Epoch 1086, Loss: 1.1345899999141693, Final Batch Loss: 0.35266634821891785
Epoch 1087, Loss: 1.4490914940834045, Final Batch Loss: 0.5513786673545837
Epoch 1088, Loss: 1.2285272777080536, Final Batch Loss: 0.4093921184539795
Epoch 1089, Loss: 1.2693987488746643, Final Batch Loss: 0.4619610011577606
Epoch 1090, Loss: 1.66

Epoch 1189, Loss: 1.0033676624298096, Final Batch Loss: 0.21760019659996033
Epoch 1190, Loss: 1.1708939373493195, Final Batch Loss: 0.34505507349967957
Epoch 1191, Loss: 1.4176709055900574, Final Batch Loss: 0.5587289929389954
Epoch 1192, Loss: 1.3553505837917328, Final Batch Loss: 0.5813755989074707
Epoch 1193, Loss: 1.2929305732250214, Final Batch Loss: 0.4259875416755676
Epoch 1194, Loss: 1.1062619090080261, Final Batch Loss: 0.3026600778102875
Epoch 1195, Loss: 1.473622888326645, Final Batch Loss: 0.6839519739151001
Epoch 1196, Loss: 0.9648405313491821, Final Batch Loss: 0.18951785564422607
Epoch 1197, Loss: 1.038310021162033, Final Batch Loss: 0.21991398930549622
Epoch 1198, Loss: 1.0288339406251907, Final Batch Loss: 0.22165478765964508
Epoch 1199, Loss: 1.2105428874492645, Final Batch Loss: 0.46924805641174316
Epoch 1200, Loss: 1.1443953812122345, Final Batch Loss: 0.3593771755695343
Epoch 1201, Loss: 1.2912573218345642, Final Batch Loss: 0.5466058850288391
Epoch 1202, Loss: 1.2

Epoch 1307, Loss: 1.3072474300861359, Final Batch Loss: 0.5326068997383118
Epoch 1308, Loss: 1.0564275681972504, Final Batch Loss: 0.3283202648162842
Epoch 1309, Loss: 1.2168425023555756, Final Batch Loss: 0.48669251799583435
Epoch 1310, Loss: 1.3159209489822388, Final Batch Loss: 0.44013726711273193
Epoch 1311, Loss: 1.2328728437423706, Final Batch Loss: 0.4863797724246979
Epoch 1312, Loss: 1.2568134367465973, Final Batch Loss: 0.4769265055656433
Epoch 1313, Loss: 1.5046598017215729, Final Batch Loss: 0.7686900496482849
Epoch 1314, Loss: 1.2429843544960022, Final Batch Loss: 0.48695921897888184
Epoch 1315, Loss: 1.0970387756824493, Final Batch Loss: 0.3271529972553253
Epoch 1316, Loss: 1.1675329208374023, Final Batch Loss: 0.3857846260070801
Epoch 1317, Loss: 1.2877053320407867, Final Batch Loss: 0.5722139477729797
Epoch 1318, Loss: 1.2907248735427856, Final Batch Loss: 0.5683863162994385
Epoch 1319, Loss: 1.4039615988731384, Final Batch Loss: 0.5751543045043945
Epoch 1320, Loss: 1.14

Epoch 1426, Loss: 1.0886410176753998, Final Batch Loss: 0.3556775748729706
Epoch 1427, Loss: 0.8918674439191818, Final Batch Loss: 0.14383848011493683
Epoch 1428, Loss: 1.1006512939929962, Final Batch Loss: 0.35579389333724976
Epoch 1429, Loss: 0.8611798286437988, Final Batch Loss: 0.1306520700454712
Epoch 1430, Loss: 1.450405329465866, Final Batch Loss: 0.7800040245056152
Epoch 1431, Loss: 1.0535339415073395, Final Batch Loss: 0.34312114119529724
Epoch 1432, Loss: 1.0713087916374207, Final Batch Loss: 0.39511123299598694
Epoch 1433, Loss: 1.072282463312149, Final Batch Loss: 0.3846088945865631
Epoch 1434, Loss: 0.9838964194059372, Final Batch Loss: 0.22699888050556183
Epoch 1435, Loss: 1.299311876296997, Final Batch Loss: 0.5551758408546448
Epoch 1436, Loss: 1.1112126111984253, Final Batch Loss: 0.43034055829048157
Epoch 1437, Loss: 1.1132056415081024, Final Batch Loss: 0.3448439836502075
Epoch 1438, Loss: 1.1157569289207458, Final Batch Loss: 0.3878211975097656
Epoch 1439, Loss: 1.21

Epoch 1543, Loss: 0.9781745821237564, Final Batch Loss: 0.24896584451198578
Epoch 1544, Loss: 1.123850405216217, Final Batch Loss: 0.48651623725891113
Epoch 1545, Loss: 0.9339016377925873, Final Batch Loss: 0.15915656089782715
Epoch 1546, Loss: 0.9819650053977966, Final Batch Loss: 0.3043406009674072
Epoch 1547, Loss: 1.1937699615955353, Final Batch Loss: 0.5062872171401978
Epoch 1548, Loss: 1.192251741886139, Final Batch Loss: 0.5070564150810242
Epoch 1549, Loss: 1.196245789527893, Final Batch Loss: 0.47511935234069824
Epoch 1550, Loss: 1.0190660655498505, Final Batch Loss: 0.26497703790664673
Epoch 1551, Loss: 1.0737581253051758, Final Batch Loss: 0.32042476534843445
Epoch 1552, Loss: 1.064243197441101, Final Batch Loss: 0.30275893211364746
Epoch 1553, Loss: 0.9886417239904404, Final Batch Loss: 0.22771702706813812
Epoch 1554, Loss: 0.9609031677246094, Final Batch Loss: 0.1938348114490509
Epoch 1555, Loss: 0.8493582606315613, Final Batch Loss: 0.17192232608795166
Epoch 1556, Loss: 1.

Epoch 1657, Loss: 1.108250081539154, Final Batch Loss: 0.45247143507003784
Epoch 1658, Loss: 1.037992775440216, Final Batch Loss: 0.3540152609348297
Epoch 1659, Loss: 0.849588006734848, Final Batch Loss: 0.20581045746803284
Epoch 1660, Loss: 0.9612293243408203, Final Batch Loss: 0.2810381352901459
Epoch 1661, Loss: 1.117911696434021, Final Batch Loss: 0.3671298921108246
Epoch 1662, Loss: 0.897187665104866, Final Batch Loss: 0.1610742062330246
Epoch 1663, Loss: 0.9651276171207428, Final Batch Loss: 0.2909499406814575
Epoch 1664, Loss: 0.8987091556191444, Final Batch Loss: 0.09982345253229141
Epoch 1665, Loss: 0.9692788273096085, Final Batch Loss: 0.23493270576000214
Epoch 1666, Loss: 0.9971007853746414, Final Batch Loss: 0.20761381089687347
Epoch 1667, Loss: 1.2252178192138672, Final Batch Loss: 0.47054392099380493
Epoch 1668, Loss: 0.8620901256799698, Final Batch Loss: 0.16827274858951569
Epoch 1669, Loss: 0.941439300775528, Final Batch Loss: 0.25566697120666504
Epoch 1670, Loss: 0.949

Epoch 1776, Loss: 0.999047726392746, Final Batch Loss: 0.3387008011341095
Epoch 1777, Loss: 1.06328547000885, Final Batch Loss: 0.3669305741786957
Epoch 1778, Loss: 1.1429382860660553, Final Batch Loss: 0.4587636888027191
Epoch 1779, Loss: 1.090624064207077, Final Batch Loss: 0.4002494215965271
Epoch 1780, Loss: 1.1227708756923676, Final Batch Loss: 0.40253809094429016
Epoch 1781, Loss: 1.1464858651161194, Final Batch Loss: 0.43445447087287903
Epoch 1782, Loss: 0.9722220599651337, Final Batch Loss: 0.2966083586215973
Epoch 1783, Loss: 1.048597365617752, Final Batch Loss: 0.3950671851634979
Epoch 1784, Loss: 1.1810934245586395, Final Batch Loss: 0.4999314248561859
Epoch 1785, Loss: 0.8284646570682526, Final Batch Loss: 0.19209325313568115
Epoch 1786, Loss: 1.060299426317215, Final Batch Loss: 0.3852297365665436
Epoch 1787, Loss: 1.2161397337913513, Final Batch Loss: 0.5260770916938782
Epoch 1788, Loss: 1.308099627494812, Final Batch Loss: 0.6306250691413879
Epoch 1789, Loss: 1.009910136

Epoch 1899, Loss: 1.098759114742279, Final Batch Loss: 0.46602606773376465
Epoch 1900, Loss: 0.8945351392030716, Final Batch Loss: 0.22643350064754486
Epoch 1901, Loss: 0.9954308569431305, Final Batch Loss: 0.34090903401374817
Epoch 1902, Loss: 0.7488192543387413, Final Batch Loss: 0.09345542639493942
Epoch 1903, Loss: 0.9275643527507782, Final Batch Loss: 0.27908995747566223
Epoch 1904, Loss: 0.9271690249443054, Final Batch Loss: 0.26929569244384766
Epoch 1905, Loss: 0.929902657866478, Final Batch Loss: 0.24700461328029633
Epoch 1906, Loss: 0.9039938598871231, Final Batch Loss: 0.23308812081813812
Epoch 1907, Loss: 1.0299493968486786, Final Batch Loss: 0.4231427311897278
Epoch 1908, Loss: 0.9743348360061646, Final Batch Loss: 0.358271062374115
Epoch 1909, Loss: 1.0555935204029083, Final Batch Loss: 0.46534863114356995
Epoch 1910, Loss: 1.1056521236896515, Final Batch Loss: 0.47161152958869934
Epoch 1911, Loss: 0.9457659721374512, Final Batch Loss: 0.29069676995277405
Epoch 1912, Loss:

Epoch 2010, Loss: 1.0023800432682037, Final Batch Loss: 0.3258413076400757
Epoch 2011, Loss: 0.9074434340000153, Final Batch Loss: 0.3210393190383911
Epoch 2012, Loss: 1.033296376466751, Final Batch Loss: 0.3540628254413605
Epoch 2013, Loss: 1.2299789190292358, Final Batch Loss: 0.5615839958190918
Epoch 2014, Loss: 0.8750750422477722, Final Batch Loss: 0.311292439699173
Epoch 2015, Loss: 0.9550526142120361, Final Batch Loss: 0.28797677159309387
Epoch 2016, Loss: 1.1707079112529755, Final Batch Loss: 0.4937925934791565
Epoch 2017, Loss: 1.1209132373332977, Final Batch Loss: 0.5005034804344177
Epoch 2018, Loss: 0.9892841279506683, Final Batch Loss: 0.32010963559150696
Epoch 2019, Loss: 1.0433287620544434, Final Batch Loss: 0.45932769775390625
Epoch 2020, Loss: 0.9082143306732178, Final Batch Loss: 0.23144114017486572
Epoch 2021, Loss: 1.1437968015670776, Final Batch Loss: 0.5468708276748657
Epoch 2022, Loss: 1.0117157101631165, Final Batch Loss: 0.40263205766677856
Epoch 2023, Loss: 0.91

Epoch 2126, Loss: 0.8856310546398163, Final Batch Loss: 0.2943977117538452
Epoch 2127, Loss: 0.946929931640625, Final Batch Loss: 0.30498093366622925
Epoch 2128, Loss: 0.8235775530338287, Final Batch Loss: 0.1806895136833191
Epoch 2129, Loss: 0.9599853456020355, Final Batch Loss: 0.35935330390930176
Epoch 2130, Loss: 0.8452766537666321, Final Batch Loss: 0.2881045639514923
Epoch 2131, Loss: 0.9231488257646561, Final Batch Loss: 0.2760295867919922
Epoch 2132, Loss: 1.254666954278946, Final Batch Loss: 0.5867373943328857
Epoch 2133, Loss: 0.9430772662162781, Final Batch Loss: 0.3261212110519409
Epoch 2134, Loss: 0.8134166151285172, Final Batch Loss: 0.13324479758739471
Epoch 2135, Loss: 0.9981550574302673, Final Batch Loss: 0.27747735381126404
Epoch 2136, Loss: 0.8837332874536514, Final Batch Loss: 0.23411087691783905
Epoch 2137, Loss: 0.7530854493379593, Final Batch Loss: 0.08916573226451874
Epoch 2138, Loss: 0.9725283086299896, Final Batch Loss: 0.34148702025413513
Epoch 2139, Loss: 0.

Epoch 2243, Loss: 0.8824934512376785, Final Batch Loss: 0.20333199203014374
Epoch 2244, Loss: 0.7185371741652489, Final Batch Loss: 0.07885619252920151
Epoch 2245, Loss: 1.0554279386997223, Final Batch Loss: 0.4776233732700348
Epoch 2246, Loss: 0.8044615685939789, Final Batch Loss: 0.1719551384449005
Epoch 2247, Loss: 0.8556760549545288, Final Batch Loss: 0.265796422958374
Epoch 2248, Loss: 1.0426585972309113, Final Batch Loss: 0.47507792711257935
Epoch 2249, Loss: 0.8782760500907898, Final Batch Loss: 0.2841080129146576
Epoch 2250, Loss: 1.2228443324565887, Final Batch Loss: 0.6355757117271423
Epoch 2251, Loss: 0.9021815061569214, Final Batch Loss: 0.2861860394477844
Epoch 2252, Loss: 0.7787950038909912, Final Batch Loss: 0.17020517587661743
Epoch 2253, Loss: 0.9115378856658936, Final Batch Loss: 0.2748991549015045
Epoch 2254, Loss: 0.9260956794023514, Final Batch Loss: 0.31032925844192505
Epoch 2255, Loss: 0.8460886478424072, Final Batch Loss: 0.17811930179595947
Epoch 2256, Loss: 0.

Epoch 2360, Loss: 1.1010383665561676, Final Batch Loss: 0.4809664785861969
Epoch 2361, Loss: 0.7933000177145004, Final Batch Loss: 0.14285831153392792
Epoch 2362, Loss: 0.8179675936698914, Final Batch Loss: 0.21190056204795837
Epoch 2363, Loss: 1.0377165973186493, Final Batch Loss: 0.3377109169960022
Epoch 2364, Loss: 0.8057158142328262, Final Batch Loss: 0.1809232085943222
Epoch 2365, Loss: 1.0201101303100586, Final Batch Loss: 0.3408176302909851
Epoch 2366, Loss: 0.9748915731906891, Final Batch Loss: 0.3424778878688812
Epoch 2367, Loss: 1.0796456038951874, Final Batch Loss: 0.5236622095108032
Epoch 2368, Loss: 0.7053753510117531, Final Batch Loss: 0.10026810318231583
Epoch 2369, Loss: 0.9230621457099915, Final Batch Loss: 0.27927398681640625
Epoch 2370, Loss: 1.013007491827011, Final Batch Loss: 0.4656338393688202
Epoch 2371, Loss: 0.8486962467432022, Final Batch Loss: 0.18248872458934784
Epoch 2372, Loss: 0.890102744102478, Final Batch Loss: 0.3268655836582184
Epoch 2373, Loss: 0.78

Epoch 2482, Loss: 0.7287871763110161, Final Batch Loss: 0.10394265502691269
Epoch 2483, Loss: 0.7548899576067924, Final Batch Loss: 0.12145914882421494
Epoch 2484, Loss: 0.6576360166072845, Final Batch Loss: 0.06529697775840759
Epoch 2485, Loss: 0.9332756102085114, Final Batch Loss: 0.2721741497516632
Epoch 2486, Loss: 1.065201848745346, Final Batch Loss: 0.4572789967060089
Epoch 2487, Loss: 1.0808401107788086, Final Batch Loss: 0.5223912596702576
Epoch 2488, Loss: 0.8447131365537643, Final Batch Loss: 0.19633762538433075
Epoch 2489, Loss: 1.289351299405098, Final Batch Loss: 0.6371640563011169
Epoch 2490, Loss: 0.8591054081916809, Final Batch Loss: 0.2510230243206024
Epoch 2491, Loss: 1.101508766412735, Final Batch Loss: 0.46394509077072144
Epoch 2492, Loss: 0.9817136526107788, Final Batch Loss: 0.3067033588886261
Epoch 2493, Loss: 0.9801949858665466, Final Batch Loss: 0.34230878949165344
Epoch 2494, Loss: 0.8991856575012207, Final Batch Loss: 0.28951096534729004
Epoch 2495, Loss: 0.8

Epoch 2600, Loss: 0.8517942726612091, Final Batch Loss: 0.1962651014328003
Epoch 2601, Loss: 0.8611461073160172, Final Batch Loss: 0.2293388694524765
Epoch 2602, Loss: 0.953361451625824, Final Batch Loss: 0.3202393651008606
Epoch 2603, Loss: 0.8008095175027847, Final Batch Loss: 0.22267703711986542
Epoch 2604, Loss: 0.8831939697265625, Final Batch Loss: 0.30656394362449646
Epoch 2605, Loss: 0.9122813045978546, Final Batch Loss: 0.32236000895500183
Epoch 2606, Loss: 0.8485393077135086, Final Batch Loss: 0.23573489487171173
Epoch 2607, Loss: 0.8496896475553513, Final Batch Loss: 0.239632710814476
Epoch 2608, Loss: 1.0434464663267136, Final Batch Loss: 0.4523519277572632
Epoch 2609, Loss: 0.9079917967319489, Final Batch Loss: 0.3423028886318207
Epoch 2610, Loss: 0.9311398863792419, Final Batch Loss: 0.3240046203136444
Epoch 2611, Loss: 0.840794712305069, Final Batch Loss: 0.27000927925109863
Epoch 2612, Loss: 0.9736138582229614, Final Batch Loss: 0.44666510820388794
Epoch 2613, Loss: 0.88

Epoch 2713, Loss: 0.7922983169555664, Final Batch Loss: 0.19024226069450378
Epoch 2714, Loss: 0.7512511014938354, Final Batch Loss: 0.20038560032844543
Epoch 2715, Loss: 0.712373822927475, Final Batch Loss: 0.18002277612686157
Epoch 2716, Loss: 0.7197156250476837, Final Batch Loss: 0.13272106647491455
Epoch 2717, Loss: 0.9699328541755676, Final Batch Loss: 0.4227825403213501
Epoch 2718, Loss: 0.8042860180139542, Final Batch Loss: 0.22984589636325836
Epoch 2719, Loss: 0.7919436395168304, Final Batch Loss: 0.2446940839290619
Epoch 2720, Loss: 0.9166964590549469, Final Batch Loss: 0.2729980945587158
Epoch 2721, Loss: 0.8505858331918716, Final Batch Loss: 0.23091883957386017
Epoch 2722, Loss: 0.7229273468255997, Final Batch Loss: 0.1724899560213089
Epoch 2723, Loss: 0.9324217736721039, Final Batch Loss: 0.307070791721344
Epoch 2724, Loss: 0.8014902025461197, Final Batch Loss: 0.2236594408750534
Epoch 2725, Loss: 0.6241913065314293, Final Batch Loss: 0.07046540826559067
Epoch 2726, Loss: 0.

Epoch 2829, Loss: 0.7986854314804077, Final Batch Loss: 0.2257419377565384
Epoch 2830, Loss: 0.8265009820461273, Final Batch Loss: 0.3238225281238556
Epoch 2831, Loss: 0.8302194476127625, Final Batch Loss: 0.2920842170715332
Epoch 2832, Loss: 0.9012007713317871, Final Batch Loss: 0.31864821910858154
Epoch 2833, Loss: 0.8674830794334412, Final Batch Loss: 0.3217345178127289
Epoch 2834, Loss: 0.7971169948577881, Final Batch Loss: 0.19839772582054138
Epoch 2835, Loss: 0.967263787984848, Final Batch Loss: 0.38336828351020813
Epoch 2836, Loss: 1.1331034302711487, Final Batch Loss: 0.5244491696357727
Epoch 2837, Loss: 0.9036588668823242, Final Batch Loss: 0.27979686856269836
Epoch 2838, Loss: 0.6576612070202827, Final Batch Loss: 0.11829794198274612
Epoch 2839, Loss: 0.7130696624517441, Final Batch Loss: 0.16657720506191254
Epoch 2840, Loss: 0.9332888424396515, Final Batch Loss: 0.3685648441314697
Epoch 2841, Loss: 0.9775343239307404, Final Batch Loss: 0.3868221938610077
Epoch 2842, Loss: 0.

Epoch 2946, Loss: 0.9766724109649658, Final Batch Loss: 0.4138809144496918
Epoch 2947, Loss: 0.8352666199207306, Final Batch Loss: 0.20803844928741455
Epoch 2948, Loss: 0.8263255655765533, Final Batch Loss: 0.2866247594356537
Epoch 2949, Loss: 0.8261346966028214, Final Batch Loss: 0.2255275696516037
Epoch 2950, Loss: 0.6416044682264328, Final Batch Loss: 0.11376157402992249
Epoch 2951, Loss: 0.8547063767910004, Final Batch Loss: 0.2491612434387207
Epoch 2952, Loss: 0.8996249437332153, Final Batch Loss: 0.3909396231174469
Epoch 2953, Loss: 0.923609122633934, Final Batch Loss: 0.4571853578090668
Epoch 2954, Loss: 0.7013723254203796, Final Batch Loss: 0.1525944024324417
Epoch 2955, Loss: 0.7492153346538544, Final Batch Loss: 0.19270360469818115
Epoch 2956, Loss: 0.6969158947467804, Final Batch Loss: 0.1281145215034485
Epoch 2957, Loss: 0.8129420578479767, Final Batch Loss: 0.2916302978992462
Epoch 2958, Loss: 0.6357437968254089, Final Batch Loss: 0.1270289272069931
Epoch 2959, Loss: 0.898

Epoch 3067, Loss: 0.9311368018388748, Final Batch Loss: 0.4305274486541748
Epoch 3068, Loss: 0.7672970294952393, Final Batch Loss: 0.23007729649543762
Epoch 3069, Loss: 0.6982380896806717, Final Batch Loss: 0.1378360539674759
Epoch 3070, Loss: 0.8561746180057526, Final Batch Loss: 0.29651060700416565
Epoch 3071, Loss: 0.7223459780216217, Final Batch Loss: 0.16953924298286438
Epoch 3072, Loss: 1.036016821861267, Final Batch Loss: 0.4312819540500641
Epoch 3073, Loss: 0.7841869741678238, Final Batch Loss: 0.2331835776567459
Epoch 3074, Loss: 0.7253727316856384, Final Batch Loss: 0.17035400867462158
Epoch 3075, Loss: 0.6849194467067719, Final Batch Loss: 0.18391551077365875
Epoch 3076, Loss: 0.9486301839351654, Final Batch Loss: 0.41053107380867004
Epoch 3077, Loss: 0.8439301252365112, Final Batch Loss: 0.26103028655052185
Epoch 3078, Loss: 0.7493157535791397, Final Batch Loss: 0.19571852684020996
Epoch 3079, Loss: 0.7116900533437729, Final Batch Loss: 0.18325042724609375
Epoch 3080, Loss:

Epoch 3177, Loss: 0.7319049835205078, Final Batch Loss: 0.23427267372608185
Epoch 3178, Loss: 0.6392709463834763, Final Batch Loss: 0.09469030797481537
Epoch 3179, Loss: 0.6834024041891098, Final Batch Loss: 0.1962609738111496
Epoch 3180, Loss: 0.9016111791133881, Final Batch Loss: 0.37527787685394287
Epoch 3181, Loss: 0.7301263958215714, Final Batch Loss: 0.19668127596378326
Epoch 3182, Loss: 0.6003816947340965, Final Batch Loss: 0.08025433868169785
Epoch 3183, Loss: 0.7605661004781723, Final Batch Loss: 0.2151780128479004
Epoch 3184, Loss: 1.1863479614257812, Final Batch Loss: 0.5228530764579773
Epoch 3185, Loss: 0.6960562467575073, Final Batch Loss: 0.18054234981536865
Epoch 3186, Loss: 0.7748894095420837, Final Batch Loss: 0.2209843099117279
Epoch 3187, Loss: 0.9999923408031464, Final Batch Loss: 0.48609742522239685
Epoch 3188, Loss: 0.7719481140375137, Final Batch Loss: 0.17113400995731354
Epoch 3189, Loss: 0.7109338045120239, Final Batch Loss: 0.20037555694580078
Epoch 3190, Loss

Epoch 3290, Loss: 0.8878923952579498, Final Batch Loss: 0.35783565044403076
Epoch 3291, Loss: 0.8664964437484741, Final Batch Loss: 0.2962164878845215
Epoch 3292, Loss: 0.9827928245067596, Final Batch Loss: 0.4372592270374298
Epoch 3293, Loss: 0.8848735392093658, Final Batch Loss: 0.3505192995071411
Epoch 3294, Loss: 0.7757997959852219, Final Batch Loss: 0.14820213615894318
Epoch 3295, Loss: 0.8817481994628906, Final Batch Loss: 0.28171420097351074
Epoch 3296, Loss: 0.7847360968589783, Final Batch Loss: 0.27537021040916443
Epoch 3297, Loss: 0.8463933020830154, Final Batch Loss: 0.3186052143573761
Epoch 3298, Loss: 0.7798597663640976, Final Batch Loss: 0.20850487053394318
Epoch 3299, Loss: 1.0512772798538208, Final Batch Loss: 0.4234544336795807
Epoch 3300, Loss: 1.0380082428455353, Final Batch Loss: 0.4794350564479828
Epoch 3301, Loss: 0.9044687151908875, Final Batch Loss: 0.3708811104297638
Epoch 3302, Loss: 0.7701655477285385, Final Batch Loss: 0.29093989729881287
Epoch 3303, Loss: 0

Epoch 3401, Loss: 0.7305516600608826, Final Batch Loss: 0.24094219505786896
Epoch 3402, Loss: 0.8716254234313965, Final Batch Loss: 0.38255947828292847
Epoch 3403, Loss: 0.8406840711832047, Final Batch Loss: 0.3336879312992096
Epoch 3404, Loss: 0.8857594430446625, Final Batch Loss: 0.3916034400463104
Epoch 3405, Loss: 0.8800168037414551, Final Batch Loss: 0.410142183303833
Epoch 3406, Loss: 0.7102246284484863, Final Batch Loss: 0.21058343350887299
Epoch 3407, Loss: 0.7103155255317688, Final Batch Loss: 0.20099954307079315
Epoch 3408, Loss: 0.6404886320233345, Final Batch Loss: 0.030745841562747955
Epoch 3409, Loss: 0.7727188467979431, Final Batch Loss: 0.18819014728069305
Epoch 3410, Loss: 0.8326431512832642, Final Batch Loss: 0.32724979519844055
Epoch 3411, Loss: 0.6620810627937317, Final Batch Loss: 0.1745223104953766
Epoch 3412, Loss: 0.6688921004533768, Final Batch Loss: 0.13614626228809357
Epoch 3413, Loss: 0.889753669500351, Final Batch Loss: 0.4018881916999817
Epoch 3414, Loss: 

Epoch 3516, Loss: 0.6431522220373154, Final Batch Loss: 0.13253211975097656
Epoch 3517, Loss: 0.6400966942310333, Final Batch Loss: 0.19423317909240723
Epoch 3518, Loss: 0.8467066287994385, Final Batch Loss: 0.3785759508609772
Epoch 3519, Loss: 0.7116797268390656, Final Batch Loss: 0.27240756154060364
Epoch 3520, Loss: 0.6737679243087769, Final Batch Loss: 0.23493964970111847
Epoch 3521, Loss: 0.6306958794593811, Final Batch Loss: 0.22320497035980225
Epoch 3522, Loss: 1.0845575630664825, Final Batch Loss: 0.5876431465148926
Epoch 3523, Loss: 0.6405893862247467, Final Batch Loss: 0.15823495388031006
Epoch 3524, Loss: 0.7646029591560364, Final Batch Loss: 0.3078666627407074
Epoch 3525, Loss: 1.1008691042661667, Final Batch Loss: 0.6338726282119751
Epoch 3526, Loss: 0.6536338925361633, Final Batch Loss: 0.1424720287322998
Epoch 3527, Loss: 0.7964730560779572, Final Batch Loss: 0.298841655254364
Epoch 3528, Loss: 0.7892763763666153, Final Batch Loss: 0.31586551666259766
Epoch 3529, Loss: 0

Epoch 3630, Loss: 0.6479271203279495, Final Batch Loss: 0.1936616152524948
Epoch 3631, Loss: 0.738362729549408, Final Batch Loss: 0.19200153648853302
Epoch 3632, Loss: 0.8849326521158218, Final Batch Loss: 0.3996552526950836
Epoch 3633, Loss: 0.8290478438138962, Final Batch Loss: 0.33053556084632874
Epoch 3634, Loss: 0.6312662214040756, Final Batch Loss: 0.1832883656024933
Epoch 3635, Loss: 0.8813779652118683, Final Batch Loss: 0.4020674526691437
Epoch 3636, Loss: 0.6315519362688065, Final Batch Loss: 0.15371590852737427
Epoch 3637, Loss: 0.5731405019760132, Final Batch Loss: 0.11220456659793854
Epoch 3638, Loss: 0.6604117602109909, Final Batch Loss: 0.1851205825805664
Epoch 3639, Loss: 0.7579231858253479, Final Batch Loss: 0.32229405641555786
Epoch 3640, Loss: 0.7862879782915115, Final Batch Loss: 0.33708661794662476
Epoch 3641, Loss: 0.6512503623962402, Final Batch Loss: 0.2217947244644165
Epoch 3642, Loss: 0.6377875953912735, Final Batch Loss: 0.25834909081459045
Epoch 3643, Loss: 0

Epoch 3742, Loss: 0.633259043097496, Final Batch Loss: 0.17907744646072388
Epoch 3743, Loss: 0.6526750326156616, Final Batch Loss: 0.2080174684524536
Epoch 3744, Loss: 0.511554341763258, Final Batch Loss: 0.04418199881911278
Epoch 3745, Loss: 0.7221812158823013, Final Batch Loss: 0.29003533720970154
Epoch 3746, Loss: 0.5195733085274696, Final Batch Loss: 0.11437172442674637
Epoch 3747, Loss: 0.7160371392965317, Final Batch Loss: 0.3032110929489136
Epoch 3748, Loss: 0.8443601429462433, Final Batch Loss: 0.3452191948890686
Epoch 3749, Loss: 0.7785419523715973, Final Batch Loss: 0.3137674331665039
Epoch 3750, Loss: 0.6182039529085159, Final Batch Loss: 0.15443511307239532
Epoch 3751, Loss: 0.5237366929650307, Final Batch Loss: 0.09726611524820328
Epoch 3752, Loss: 0.684269443154335, Final Batch Loss: 0.16557912528514862
Epoch 3753, Loss: 0.9467489570379257, Final Batch Loss: 0.49013394117355347
Epoch 3754, Loss: 0.6784857511520386, Final Batch Loss: 0.21284981071949005
Epoch 3755, Loss: 0

Epoch 3858, Loss: 0.6597162038087845, Final Batch Loss: 0.25431329011917114
Epoch 3859, Loss: 0.524184450507164, Final Batch Loss: 0.14318561553955078
Epoch 3860, Loss: 0.5418142080307007, Final Batch Loss: 0.12881122529506683
Epoch 3861, Loss: 0.5440384075045586, Final Batch Loss: 0.06838717311620712
Epoch 3862, Loss: 0.6635257303714752, Final Batch Loss: 0.2333843857049942
Epoch 3863, Loss: 0.8360437154769897, Final Batch Loss: 0.3638150095939636
Epoch 3864, Loss: 0.4700668603181839, Final Batch Loss: 0.08844655752182007
Epoch 3865, Loss: 0.5250285938382149, Final Batch Loss: 0.10915718227624893
Epoch 3866, Loss: 0.5490574166178703, Final Batch Loss: 0.12200135737657547
Epoch 3867, Loss: 0.5974601656198502, Final Batch Loss: 0.20010441541671753
Epoch 3868, Loss: 0.7112933099269867, Final Batch Loss: 0.27528947591781616
Epoch 3869, Loss: 0.852959156036377, Final Batch Loss: 0.4778936207294464
Epoch 3870, Loss: 0.4892452843487263, Final Batch Loss: 0.03653492406010628
Epoch 3871, Loss:

Epoch 3970, Loss: 0.5936843901872635, Final Batch Loss: 0.1425592303276062
Epoch 3971, Loss: 0.7270622402429581, Final Batch Loss: 0.3307855427265167
Epoch 3972, Loss: 0.5215021371841431, Final Batch Loss: 0.15313905477523804
Epoch 3973, Loss: 0.5197661370038986, Final Batch Loss: 0.11665581166744232
Epoch 3974, Loss: 0.5908165574073792, Final Batch Loss: 0.134630024433136
Epoch 3975, Loss: 0.7196303158998489, Final Batch Loss: 0.27026620507240295
Epoch 3976, Loss: 0.5470132529735565, Final Batch Loss: 0.1292308270931244
Epoch 3977, Loss: 0.5021747574210167, Final Batch Loss: 0.1161159947514534
Epoch 3978, Loss: 0.6060133129358292, Final Batch Loss: 0.19156566262245178
Epoch 3979, Loss: 0.6127110123634338, Final Batch Loss: 0.21392063796520233
Epoch 3980, Loss: 0.6344237625598907, Final Batch Loss: 0.19933557510375977
Epoch 3981, Loss: 0.6575249582529068, Final Batch Loss: 0.20155520737171173
Epoch 3982, Loss: 0.6747868061065674, Final Batch Loss: 0.2656401991844177
Epoch 3983, Loss: 0

Epoch 4084, Loss: 0.6174231469631195, Final Batch Loss: 0.20735017955303192
Epoch 4085, Loss: 0.5742910802364349, Final Batch Loss: 0.2105216085910797
Epoch 4086, Loss: 0.5239660292863846, Final Batch Loss: 0.10913443565368652
Epoch 4087, Loss: 0.5193252637982368, Final Batch Loss: 0.07174823433160782
Epoch 4088, Loss: 0.6738804280757904, Final Batch Loss: 0.29326650500297546
Epoch 4089, Loss: 0.5158053413033485, Final Batch Loss: 0.1107678934931755
Epoch 4090, Loss: 0.5972528755664825, Final Batch Loss: 0.20591427385807037
Epoch 4091, Loss: 1.2633021771907806, Final Batch Loss: 0.7702118754386902
Epoch 4092, Loss: 0.5933397561311722, Final Batch Loss: 0.1508931666612625
Epoch 4093, Loss: 0.5085155144333839, Final Batch Loss: 0.09638649970293045
Epoch 4094, Loss: 0.9105749875307083, Final Batch Loss: 0.5196416974067688
Epoch 4095, Loss: 0.5651775002479553, Final Batch Loss: 0.1636379510164261
Epoch 4096, Loss: 0.5489993989467621, Final Batch Loss: 0.12810300290584564
Epoch 4097, Loss: 

Epoch 4197, Loss: 0.5483764261007309, Final Batch Loss: 0.17152859270572662
Epoch 4198, Loss: 0.728663370013237, Final Batch Loss: 0.3838988244533539
Epoch 4199, Loss: 0.47582298517227173, Final Batch Loss: 0.08023695647716522
Epoch 4200, Loss: 0.5734476894140244, Final Batch Loss: 0.13507631421089172
Epoch 4201, Loss: 0.6143457740545273, Final Batch Loss: 0.2782410681247711
Epoch 4202, Loss: 0.6864828020334244, Final Batch Loss: 0.24334684014320374
Epoch 4203, Loss: 0.5776292532682419, Final Batch Loss: 0.1797875165939331
Epoch 4204, Loss: 0.5585960149765015, Final Batch Loss: 0.20221047103405
Epoch 4205, Loss: 0.507631927728653, Final Batch Loss: 0.13588744401931763
Epoch 4206, Loss: 0.4793340861797333, Final Batch Loss: 0.1280404031276703
Epoch 4207, Loss: 0.7463009506464005, Final Batch Loss: 0.30311068892478943
Epoch 4208, Loss: 0.4312754385173321, Final Batch Loss: 0.05276454612612724
Epoch 4209, Loss: 0.9732839316129684, Final Batch Loss: 0.6139168739318848
Epoch 4210, Loss: 0.6

Epoch 4307, Loss: 0.4894314408302307, Final Batch Loss: 0.11500094830989838
Epoch 4308, Loss: 0.4789011999964714, Final Batch Loss: 0.0657973513007164
Epoch 4309, Loss: 0.848340317606926, Final Batch Loss: 0.42620524764060974
Epoch 4310, Loss: 0.6246431618928909, Final Batch Loss: 0.21894267201423645
Epoch 4311, Loss: 0.8040285259485245, Final Batch Loss: 0.40319862961769104
Epoch 4312, Loss: 0.6013142168521881, Final Batch Loss: 0.16713784635066986
Epoch 4313, Loss: 0.6347603648900986, Final Batch Loss: 0.1744089126586914
Epoch 4314, Loss: 0.4634440205991268, Final Batch Loss: 0.05876928195357323
Epoch 4315, Loss: 0.49626919627189636, Final Batch Loss: 0.10305289924144745
Epoch 4316, Loss: 0.41270124539732933, Final Batch Loss: 0.05007728561758995
Epoch 4317, Loss: 0.6513612568378448, Final Batch Loss: 0.27391552925109863
Epoch 4318, Loss: 0.4934235215187073, Final Batch Loss: 0.09852755069732666
Epoch 4319, Loss: 0.5462040603160858, Final Batch Loss: 0.17124469578266144
Epoch 4320, L

Epoch 4425, Loss: 0.5287074595689774, Final Batch Loss: 0.13511493802070618
Epoch 4426, Loss: 0.44024666398763657, Final Batch Loss: 0.04455230385065079
Epoch 4427, Loss: 0.7444811910390854, Final Batch Loss: 0.37708723545074463
Epoch 4428, Loss: 0.4318287558853626, Final Batch Loss: 0.036633748561143875
Epoch 4429, Loss: 0.4125948026776314, Final Batch Loss: 0.09051666408777237
Epoch 4430, Loss: 0.568734660744667, Final Batch Loss: 0.17854177951812744
Epoch 4431, Loss: 0.7086790204048157, Final Batch Loss: 0.3134894371032715
Epoch 4432, Loss: 0.5167838633060455, Final Batch Loss: 0.06841737031936646
Epoch 4433, Loss: 0.5869929939508438, Final Batch Loss: 0.11341238021850586
Epoch 4434, Loss: 0.6243516206741333, Final Batch Loss: 0.22112002968788147
Epoch 4435, Loss: 0.5310191363096237, Final Batch Loss: 0.1451830118894577
Epoch 4436, Loss: 0.5999242812395096, Final Batch Loss: 0.17663414776325226
Epoch 4437, Loss: 0.9344068467617035, Final Batch Loss: 0.5608299374580383
Epoch 4438, Lo

Epoch 4539, Loss: 0.5666025131940842, Final Batch Loss: 0.22004473209381104
Epoch 4540, Loss: 0.5003958716988564, Final Batch Loss: 0.11072828620672226
Epoch 4541, Loss: 0.4873794764280319, Final Batch Loss: 0.12752194702625275
Epoch 4542, Loss: 0.5282069146633148, Final Batch Loss: 0.15377716720104218
Epoch 4543, Loss: 0.4560009315609932, Final Batch Loss: 0.0675114169716835
Epoch 4544, Loss: 0.4628240242600441, Final Batch Loss: 0.0927400216460228
Epoch 4545, Loss: 0.601320207118988, Final Batch Loss: 0.25336387753486633
Epoch 4546, Loss: 0.65164715051651, Final Batch Loss: 0.31512442231178284
Epoch 4547, Loss: 0.45374777913093567, Final Batch Loss: 0.10328014194965363
Epoch 4548, Loss: 0.47137850522994995, Final Batch Loss: 0.12404219806194305
Epoch 4549, Loss: 0.6847534030675888, Final Batch Loss: 0.32900160551071167
Epoch 4550, Loss: 0.627078041434288, Final Batch Loss: 0.25666651129722595
Epoch 4551, Loss: 0.5005881190299988, Final Batch Loss: 0.06389287114143372
Epoch 4552, Loss

Epoch 4649, Loss: 0.6126088798046112, Final Batch Loss: 0.2501172721385956
Epoch 4650, Loss: 0.46036243438720703, Final Batch Loss: 0.08489848673343658
Epoch 4651, Loss: 0.4388889893889427, Final Batch Loss: 0.06672964245080948
Epoch 4652, Loss: 0.5758350044488907, Final Batch Loss: 0.2345191091299057
Epoch 4653, Loss: 0.5879097431898117, Final Batch Loss: 0.24848179519176483
Epoch 4654, Loss: 0.5226708203554153, Final Batch Loss: 0.12284795939922333
Epoch 4655, Loss: 0.6108836829662323, Final Batch Loss: 0.16355466842651367
Epoch 4656, Loss: 0.7377432882785797, Final Batch Loss: 0.36987796425819397
Epoch 4657, Loss: 0.5062555149197578, Final Batch Loss: 0.10905729979276657
Epoch 4658, Loss: 0.5136840343475342, Final Batch Loss: 0.1542578637599945
Epoch 4659, Loss: 0.6591385900974274, Final Batch Loss: 0.30231955647468567
Epoch 4660, Loss: 0.5002055317163467, Final Batch Loss: 0.16202254593372345
Epoch 4661, Loss: 0.4195764120668173, Final Batch Loss: 0.027504777535796165
Epoch 4662, L

Epoch 4759, Loss: 0.7410129308700562, Final Batch Loss: 0.3482312560081482
Epoch 4760, Loss: 0.7254333198070526, Final Batch Loss: 0.3496089577674866
Epoch 4761, Loss: 0.43782223016023636, Final Batch Loss: 0.08625680953264236
Epoch 4762, Loss: 0.5555392354726791, Final Batch Loss: 0.19565466046333313
Epoch 4763, Loss: 0.5398659259080887, Final Batch Loss: 0.14686493575572968
Epoch 4764, Loss: 0.5527344048023224, Final Batch Loss: 0.1711265593767166
Epoch 4765, Loss: 0.46629560366272926, Final Batch Loss: 0.04795316234230995
Epoch 4766, Loss: 0.5304825454950333, Final Batch Loss: 0.18546436727046967
Epoch 4767, Loss: 0.6671982258558273, Final Batch Loss: 0.22338338196277618
Epoch 4768, Loss: 0.7168437242507935, Final Batch Loss: 0.3494672477245331
Epoch 4769, Loss: 0.6112695634365082, Final Batch Loss: 0.22594642639160156
Epoch 4770, Loss: 0.5022335574030876, Final Batch Loss: 0.08229503780603409
Epoch 4771, Loss: 0.5304052531719208, Final Batch Loss: 0.15894407033920288
Epoch 4772, Lo

Epoch 4867, Loss: 0.5132587775588036, Final Batch Loss: 0.10961037129163742
Epoch 4868, Loss: 0.3832574337720871, Final Batch Loss: 0.06140872836112976
Epoch 4869, Loss: 0.48988211154937744, Final Batch Loss: 0.11283731460571289
Epoch 4870, Loss: 0.4494621455669403, Final Batch Loss: 0.06993342936038971
Epoch 4871, Loss: 0.8435440957546234, Final Batch Loss: 0.47234559059143066
Epoch 4872, Loss: 0.6603234112262726, Final Batch Loss: 0.25995898246765137
Epoch 4873, Loss: 0.4013790972530842, Final Batch Loss: 0.04509866610169411
Epoch 4874, Loss: 0.4695611819624901, Final Batch Loss: 0.1002378985285759
Epoch 4875, Loss: 0.557216003537178, Final Batch Loss: 0.18812207877635956
Epoch 4876, Loss: 0.5817863792181015, Final Batch Loss: 0.2120371162891388
Epoch 4877, Loss: 0.6012236475944519, Final Batch Loss: 0.1722569763660431
Epoch 4878, Loss: 0.474753275513649, Final Batch Loss: 0.13140808045864105
Epoch 4879, Loss: 0.482304185628891, Final Batch Loss: 0.13370653986930847
Epoch 4880, Loss:

Epoch 4975, Loss: 0.540269136428833, Final Batch Loss: 0.1484324187040329
Epoch 4976, Loss: 0.48319507390260696, Final Batch Loss: 0.08595595508813858
Epoch 4977, Loss: 0.5825638920068741, Final Batch Loss: 0.16121725738048553
Epoch 4978, Loss: 0.6094497814774513, Final Batch Loss: 0.2919441759586334
Epoch 4979, Loss: 0.610106498003006, Final Batch Loss: 0.21196182072162628
Epoch 4980, Loss: 0.7131725549697876, Final Batch Loss: 0.2602177560329437
Epoch 4981, Loss: 0.5090555250644684, Final Batch Loss: 0.1526234894990921
Epoch 4982, Loss: 0.449469979852438, Final Batch Loss: 0.0513029508292675
Epoch 4983, Loss: 0.43913131207227707, Final Batch Loss: 0.11312951892614365
Epoch 4984, Loss: 0.584770455956459, Final Batch Loss: 0.287514328956604
Epoch 4985, Loss: 0.5496720671653748, Final Batch Loss: 0.1941850632429123
Epoch 4986, Loss: 0.40753748267889023, Final Batch Loss: 0.06299030035734177
Epoch 4987, Loss: 0.7426247596740723, Final Batch Loss: 0.4283389449119568
Epoch 4988, Loss: 0.41

In [66]:
softmax = nn.Softmax(dim = 1)
model.eval()
for batch in test_loader:
    features, labels = batch
    _, preds = torch.max(softmax(model(features.float())), dim = 1)
    print(metrics.confusion_matrix((labels).cpu(), preds.cpu()))
    print(metrics.classification_report((labels).cpu(), preds.cpu(), digits = 5))

[[17  0  0  0  0  0  0  0  0  0  0  0]
 [ 0 14  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  4  0  0  0  0  0  0  0  0  0]
 [ 0  0  0 14  1  0  0  0  0  0  0  0]
 [ 0  0  0  0  7  0  0  0  0  0  0  0]
 [ 0  0  4  0  0  2  0  0  2  0  0  1]
 [ 0  0  0  0  1  1 11  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  7  0  0  0  0]
 [ 0  0  0  0  0  7  0  0  8  0  0  0]
 [ 0  0  0  0  0  0  0  0  0 15  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  6  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  8]]
              precision    recall  f1-score   support

           0    1.00000   1.00000   1.00000        17
           1    1.00000   1.00000   1.00000        14
           2    0.50000   1.00000   0.66667         4
           3    1.00000   0.93333   0.96552        15
           4    0.77778   1.00000   0.87500         7
           5    0.20000   0.22222   0.21053         9
           6    1.00000   0.84615   0.91667        13
           7    1.00000   1.00000   1.00000         7
           8    0.80000   0.53333   0.64000 

# Train on Fake Test on Real

In [10]:
# 20,000 epochs DEFAULT PARAMETERS FOR THRESHOLDS

In [11]:
gen = Generator(z_dim = 107)
load_model(gen, "3 Label 4 Subject GAN Ablation_gen.param")
gen.eval()

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): Linear(in_features=107, out_features=80, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Linear(in_features=80, out_features=60, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): BatchNorm1d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Linear(in_features=60, out_features=50, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
    (3): Linear(in_features=50, out_features=37, bias=True)
    (4): Tanh()
  )
)

In [54]:
size = len(X_test)
latent_vectors = get_noise(size, 100)
act_vectors = get_act_matrix(size, 3)
usr_vectors = get_usr_matrix(size, 4)

fake_labels = []

for k in range(size):
    if act_vectors[0][k] == 0 and usr_vectors[0][k] == 0:
        fake_labels.append(0)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 0:
        fake_labels.append(1)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 0:
        fake_labels.append(2)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 1:
        fake_labels.append(3)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 1:
        fake_labels.append(4)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 1:
        fake_labels.append(5)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 2:
        fake_labels.append(6)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 2:
        fake_labels.append(7)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 2:
        fake_labels.append(8)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 3:
        fake_labels.append(9)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 3:
        fake_labels.append(10)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 3:
        fake_labels.append(11)
        
fake_labels = np.asarray(fake_labels)
to_gen = torch.cat((latent_vectors, act_vectors[1], usr_vectors[1]), 1)
fake_features = gen(to_gen).detach()
#fake_features = gen(to_gen).detach().numpy()

In [55]:
_, preds = torch.max(softmax(model(fake_features.float())), dim = 1)
print(metrics.confusion_matrix((fake_labels), preds.cpu()))
print(metrics.classification_report((fake_labels), preds.cpu(), digits = 5))

[[14  0  0  0  0  0  0  0  0  0  0  0]
 [ 0 12  0  0  0  0  0  0  0  0  0  0]
 [ 0  0 11  0  0  0  0  0  0  0  0  0]
 [ 0  0  0 15  0  0  0  0  0  0  0  0]
 [ 0  0  0  0 10  0  0  0  0  0  1  0]
 [ 0  0  0  0  0  0  0  0  2  0  0  3]
 [ 0  0  0  0  0  0 12  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  5  0  0  0  0]
 [ 0  0  0  0  0 10  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0 10  0  0]
 [ 0  0  0  0  0  0  0  2  0  0 13  0]
 [ 0  0  0  0  0  0  0  0  0  0  0 10]]
              precision    recall  f1-score   support

           0    1.00000   1.00000   1.00000        14
           1    1.00000   1.00000   1.00000        12
           2    1.00000   1.00000   1.00000        11
           3    1.00000   1.00000   1.00000        15
           4    1.00000   0.90909   0.95238        11
           5    0.00000   0.00000   0.00000         5
           6    1.00000   1.00000   1.00000        12
           7    0.71429   1.00000   0.83333         5
           8    0.00000   0.00000   0.00000 