In [1]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn import metrics
import torch.optim as optim

In [2]:
sub_features = ['42 tGravityAcc-mean()-Y',
 '43 tGravityAcc-mean()-Z',
 '51 tGravityAcc-max()-Y',
 '52 tGravityAcc-max()-Z',
 '54 tGravityAcc-min()-Y',
 '55 tGravityAcc-min()-Z',
 '56 tGravityAcc-sma()',
 '58 tGravityAcc-energy()-Y',
 '59 tGravityAcc-energy()-Z',
 '475 fBodyGyro-bandsEnergy()-1,8',
 '559 angle(X,gravityMean)',
 '560 angle(Y,gravityMean)',
 '561 angle(Z,gravityMean)']

act_features = ['4 tBodyAcc-std()-X',
 '10 tBodyAcc-max()-X',
 '17 tBodyAcc-energy()-X',
 '202 tBodyAccMag-std()',
 '203 tBodyAccMag-mad()',
 '215 tGravityAccMag-std()',
 '216 tGravityAccMag-mad()',
 '266 fBodyAcc-mean()-X',
 '269 fBodyAcc-std()-X',
 '272 fBodyAcc-mad()-X',
 '282 fBodyAcc-energy()-X',
 '303 fBodyAcc-bandsEnergy()-1,8',
 '311 fBodyAcc-bandsEnergy()-1,16',
 '315 fBodyAcc-bandsEnergy()-1,24',
 '382 fBodyAccJerk-bandsEnergy()-1,8',
 '390 fBodyAccJerk-bandsEnergy()-1,16',
 '504 fBodyAccMag-std()',
 '505 fBodyAccMag-mad()',
 '509 fBodyAccMag-energy()']

input_shape = len(sub_features) + len(act_features)

In [3]:
def classifier_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.LeakyReLU(0.05)
    )

class Classifier(nn.Module):
    def __init__(self, feature_dim = input_shape):
        super(Classifier, self).__init__()
        self.network = nn.Sequential(
            classifier_block(feature_dim, 40),
            classifier_block(40, 30),
            classifier_block(30, 25),
            classifier_block(25, 25),
            nn.Linear(25, 24)
        )
    def forward(self, x):
        return self.network(x)

In [4]:
#defines each generator layer
#input and output dimensions needed
def generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace = True)
    )

#returns n_samples of z_dim (number of dimensions of latent space) noise
def get_noise(n_samples, z_dim):
    return torch.randn(n_samples, z_dim)

#defines generator class
class Generator(nn.Module):
    def __init__(self, z_dim = 10, feature_dim = input_shape, hidden_dim = 128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            generator_block(z_dim, 80),
            generator_block(80, 60),
            generator_block(60, 50),
            nn.Linear(50, feature_dim),
            nn.Tanh()
        )
    def forward(self, noise):
        return self.gen(noise)

def get_act_matrix(batch_size, a_dim):
    indexes = np.random.randint(a_dim, size = batch_size)
    
    one_hot = np.zeros((len(indexes), indexes.max()+1))
    one_hot[np.arange(len(indexes)),indexes] = 1
    return torch.Tensor(indexes).long(), torch.Tensor(one_hot)
    
def get_usr_matrix(batch_size, u_dim):
    indexes = np.random.randint(u_dim, size = batch_size)
    
    one_hot = np.zeros((indexes.size, indexes.max()+1))
    one_hot[np.arange(indexes.size),indexes] = 1
    return torch.Tensor(indexes).long(), torch.Tensor(one_hot)

def load_model(model, model_name):
    model.load_state_dict(torch.load(f'../../../saved_models/{model_name}'))

# Train on Real Test on Real

In [5]:
#label is a list of integers specifying which labels to filter by
#users is a list of integers specifying which users to filter by
#y_label is a string, either "Activity" or "Subject" depending on what y output needs to be returned
def start_data(label, users, y_label, sub_features, act_features):
    #get the dataframe column names
    name_dataframe = pd.read_csv('../../../data/features.txt', delimiter = '\n', header = None)
    names = name_dataframe.values.tolist()
    names = [k for row in names for k in row] #List of column names

    data = pd.read_csv('../../../data/X_train.txt', delim_whitespace = True, header = None) #Read in dataframe
    data.columns = names #Setting column names
    
    X_train_1 = data[sub_features]
    X_train_2 = data[act_features]
    X_train = pd.concat([X_train_1, X_train_2], axis = 1)
    
    y_train_activity = pd.read_csv('../../../data/y_train.txt', header = None)
    y_train_activity.columns = ['Activity']
    
    y_train_subject = pd.read_csv('../../../data/subject_train.txt', header = None)
    y_train_subject.columns = ['Subject']
    
    GAN_data = pd.concat([X_train, y_train_activity, y_train_subject], axis = 1)
    GAN_data = GAN_data[GAN_data['Activity'].isin(label)]
    GAN_data = GAN_data[GAN_data['Subject'].isin(users)]
    
    X_1 = GAN_data[(GAN_data['Subject'] == 1) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_2 = GAN_data[(GAN_data['Subject'] == 1) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_3 = GAN_data[(GAN_data['Subject'] == 1) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_4 = GAN_data[(GAN_data['Subject'] == 3) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_5 = GAN_data[(GAN_data['Subject'] == 3) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_6 = GAN_data[(GAN_data['Subject'] == 3) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_7 = GAN_data[(GAN_data['Subject'] == 5) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_8 = GAN_data[(GAN_data['Subject'] == 5) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_9 = GAN_data[(GAN_data['Subject'] == 5) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_10 = GAN_data[(GAN_data['Subject'] == 7) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_11 = GAN_data[(GAN_data['Subject'] == 7) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_12 = GAN_data[(GAN_data['Subject'] == 7) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_13 = GAN_data[(GAN_data['Subject'] == 8) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_14 = GAN_data[(GAN_data['Subject'] == 8) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_15 = GAN_data[(GAN_data['Subject'] == 8) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_16 = GAN_data[(GAN_data['Subject'] == 11) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_17 = GAN_data[(GAN_data['Subject'] == 11) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_18 = GAN_data[(GAN_data['Subject'] == 11) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_19 = GAN_data[(GAN_data['Subject'] == 14) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_20 = GAN_data[(GAN_data['Subject'] == 14) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_21 = GAN_data[(GAN_data['Subject'] == 14) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_22 = GAN_data[(GAN_data['Subject'] == 17) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_23 = GAN_data[(GAN_data['Subject'] == 17) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_24 = GAN_data[(GAN_data['Subject'] == 17) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    
    X_train = np.concatenate((X_1, X_2, X_3, X_4, X_5, X_6, X_7, X_8, X_9, X_10, X_11, X_12, X_13, X_14, X_15, X_16, X_17, X_18, X_19, X_20, X_21, X_22, X_23, X_24))
    y_train = [0] * len(X_1) + [1] * len(X_2) + [2] * len(X_3) + [3] * len(X_4) + [4] * len(X_5) + [5] * len(X_6) + [6] * len(X_7) + [7] * len(X_8) + [8] * len(X_9) + [9] * len(X_10) + [10] * len(X_11) + [11] * len(X_12) + [12] * len(X_13) + [13] * len(X_14) + [14] * len(X_15) + [15] * len(X_16) + [16] * len(X_17) + [17] * len(X_18) + [18] * len(X_19) + [19] * len(X_20) + [20] * len(X_21) + [21] * len(X_22) + [22] * len(X_23) + [23] * len(X_24)
    
    return X_train, np.asarray(y_train)

In [6]:
activities = [1, 3, 4]
users = [1, 3, 5, 7, 8, 11, 14, 17]

X, y = start_data(activities, users, "Activity", sub_features, act_features)

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, shuffle = True)

model = Classifier()
lr = 0.001
n_epochs = 5000
batch_size = 250

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

train_features = torch.tensor(X_train)
train_labels = torch.tensor(y_train)
test_features = torch.tensor(X_test)
test_labels = torch.tensor(y_test)

train_data = torch.utils.data.TensorDataset(train_features, train_labels)
test_data = torch.utils.data.TensorDataset(test_features, test_labels)

train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = len(test_labels), shuffle = True)

In [8]:
for epoch in range(n_epochs):
    total_loss = 0
    for batch in train_loader:
        features, labels = batch
        
        optimizer.zero_grad()
        preds = model(features.float())
        
        loss = criterion(preds, labels.long()) 
        loss.backward()
        
        optimizer.step()
        total_loss += loss.item()
        
    print(f'Epoch {epoch + 1}, Loss: {total_loss}, Final Batch Loss: {loss.item()}')

Epoch 1, Loss: 15.937380075454712, Final Batch Loss: 3.235614776611328
Epoch 2, Loss: 15.838646650314331, Final Batch Loss: 3.1604206562042236
Epoch 3, Loss: 15.839359521865845, Final Batch Loss: 3.1766059398651123
Epoch 4, Loss: 15.766544103622437, Final Batch Loss: 3.1313016414642334
Epoch 5, Loss: 15.759380578994751, Final Batch Loss: 3.1677663326263428
Epoch 6, Loss: 15.655831336975098, Final Batch Loss: 3.109492778778076
Epoch 7, Loss: 15.57589340209961, Final Batch Loss: 3.1090586185455322
Epoch 8, Loss: 15.50944471359253, Final Batch Loss: 3.147876262664795
Epoch 9, Loss: 15.2925283908844, Final Batch Loss: 3.0778558254241943
Epoch 10, Loss: 15.008805274963379, Final Batch Loss: 2.959566831588745
Epoch 11, Loss: 14.978736877441406, Final Batch Loss: 3.174306869506836
Epoch 12, Loss: 14.5722975730896, Final Batch Loss: 2.9933362007141113
Epoch 13, Loss: 14.149859428405762, Final Batch Loss: 2.804372787475586
Epoch 14, Loss: 13.734985828399658, Final Batch Loss: 2.6676602363586426

Epoch 119, Loss: 7.556719541549683, Final Batch Loss: 1.3809492588043213
Epoch 120, Loss: 7.407825708389282, Final Batch Loss: 1.3216485977172852
Epoch 121, Loss: 7.701578497886658, Final Batch Loss: 1.5736013650894165
Epoch 122, Loss: 7.665084362030029, Final Batch Loss: 1.6203268766403198
Epoch 123, Loss: 7.688673377037048, Final Batch Loss: 1.5337859392166138
Epoch 124, Loss: 7.564011335372925, Final Batch Loss: 1.4941507577896118
Epoch 125, Loss: 7.771791100502014, Final Batch Loss: 1.5619646310806274
Epoch 126, Loss: 7.4590864181518555, Final Batch Loss: 1.4571233987808228
Epoch 127, Loss: 7.671635866165161, Final Batch Loss: 1.6422141790390015
Epoch 128, Loss: 7.920945405960083, Final Batch Loss: 1.9721649885177612
Epoch 129, Loss: 7.309614896774292, Final Batch Loss: 1.370534062385559
Epoch 130, Loss: 7.8140387535095215, Final Batch Loss: 1.799122929573059
Epoch 131, Loss: 7.541677474975586, Final Batch Loss: 1.6175528764724731
Epoch 132, Loss: 7.63666558265686, Final Batch Loss

Epoch 237, Loss: 6.265029072761536, Final Batch Loss: 1.2797572612762451
Epoch 238, Loss: 6.449880123138428, Final Batch Loss: 1.5597515106201172
Epoch 239, Loss: 6.1014145612716675, Final Batch Loss: 1.2609460353851318
Epoch 240, Loss: 5.682827174663544, Final Batch Loss: 0.8262578845024109
Epoch 241, Loss: 5.871882617473602, Final Batch Loss: 0.9991458058357239
Epoch 242, Loss: 6.264104843139648, Final Batch Loss: 1.395630955696106
Epoch 243, Loss: 5.889220356941223, Final Batch Loss: 1.0054367780685425
Epoch 244, Loss: 6.039827227592468, Final Batch Loss: 1.2137516736984253
Epoch 245, Loss: 5.9927942752838135, Final Batch Loss: 1.2187166213989258
Epoch 246, Loss: 6.385956048965454, Final Batch Loss: 1.4761635065078735
Epoch 247, Loss: 6.361606597900391, Final Batch Loss: 1.5796210765838623
Epoch 248, Loss: 5.888304591178894, Final Batch Loss: 1.0273126363754272
Epoch 249, Loss: 6.287814617156982, Final Batch Loss: 1.500476598739624
Epoch 250, Loss: 6.40536105632782, Final Batch Loss

Epoch 352, Loss: 5.0750232338905334, Final Batch Loss: 0.8376269936561584
Epoch 353, Loss: 5.36143171787262, Final Batch Loss: 1.2447483539581299
Epoch 354, Loss: 5.382011711597443, Final Batch Loss: 1.2122623920440674
Epoch 355, Loss: 5.125693380832672, Final Batch Loss: 1.0697342157363892
Epoch 356, Loss: 4.886724591255188, Final Batch Loss: 0.7198971509933472
Epoch 357, Loss: 5.417776048183441, Final Batch Loss: 1.1685751676559448
Epoch 358, Loss: 5.22874128818512, Final Batch Loss: 1.0814683437347412
Epoch 359, Loss: 5.234933018684387, Final Batch Loss: 1.074804425239563
Epoch 360, Loss: 5.542828559875488, Final Batch Loss: 1.373955488204956
Epoch 361, Loss: 5.212827444076538, Final Batch Loss: 1.1800851821899414
Epoch 362, Loss: 5.180410325527191, Final Batch Loss: 0.9596810936927795
Epoch 363, Loss: 5.2441911697387695, Final Batch Loss: 1.0934747457504272
Epoch 364, Loss: 5.070677876472473, Final Batch Loss: 1.0006989240646362
Epoch 365, Loss: 5.154768228530884, Final Batch Loss:

Epoch 470, Loss: 4.680335760116577, Final Batch Loss: 1.0522072315216064
Epoch 471, Loss: 4.3410614132881165, Final Batch Loss: 0.6707324981689453
Epoch 472, Loss: 4.678810894489288, Final Batch Loss: 1.031516432762146
Epoch 473, Loss: 4.9865840673446655, Final Batch Loss: 1.3513213396072388
Epoch 474, Loss: 4.5654279589653015, Final Batch Loss: 0.7094364166259766
Epoch 475, Loss: 4.711355745792389, Final Batch Loss: 1.0690332651138306
Epoch 476, Loss: 4.373522758483887, Final Batch Loss: 0.7072591185569763
Epoch 477, Loss: 5.0230395793914795, Final Batch Loss: 1.3021043539047241
Epoch 478, Loss: 4.403699815273285, Final Batch Loss: 0.7119150757789612
Epoch 479, Loss: 4.779449045658112, Final Batch Loss: 1.1492596864700317
Epoch 480, Loss: 4.615052402019501, Final Batch Loss: 0.9159550666809082
Epoch 481, Loss: 4.82598477602005, Final Batch Loss: 1.074839472770691
Epoch 482, Loss: 5.056930422782898, Final Batch Loss: 1.375945806503296
Epoch 483, Loss: 4.767443418502808, Final Batch Los

Epoch 587, Loss: 4.451578140258789, Final Batch Loss: 1.1782398223876953
Epoch 588, Loss: 4.197038114070892, Final Batch Loss: 0.9503714442253113
Epoch 589, Loss: 3.895269989967346, Final Batch Loss: 0.622497022151947
Epoch 590, Loss: 4.006572067737579, Final Batch Loss: 0.7450345158576965
Epoch 591, Loss: 4.150162756443024, Final Batch Loss: 0.7785460352897644
Epoch 592, Loss: 4.47522896528244, Final Batch Loss: 1.1474103927612305
Epoch 593, Loss: 4.283687353134155, Final Batch Loss: 0.8921141028404236
Epoch 594, Loss: 3.827569395303726, Final Batch Loss: 0.4256241023540497
Epoch 595, Loss: 3.97044438123703, Final Batch Loss: 0.6147072911262512
Epoch 596, Loss: 3.788129985332489, Final Batch Loss: 0.5150066018104553
Epoch 597, Loss: 4.040853977203369, Final Batch Loss: 0.8871555328369141
Epoch 598, Loss: 3.979031562805176, Final Batch Loss: 0.6913030743598938
Epoch 599, Loss: 4.259728848934174, Final Batch Loss: 0.9257954955101013
Epoch 600, Loss: 4.19517970085144, Final Batch Loss: 0

Epoch 703, Loss: 4.042892754077911, Final Batch Loss: 1.0500284433364868
Epoch 704, Loss: 4.17257297039032, Final Batch Loss: 1.0716512203216553
Epoch 705, Loss: 3.661650240421295, Final Batch Loss: 0.5297379493713379
Epoch 706, Loss: 3.8591538667678833, Final Batch Loss: 0.7884578704833984
Epoch 707, Loss: 4.137423396110535, Final Batch Loss: 1.141570806503296
Epoch 708, Loss: 3.993346333503723, Final Batch Loss: 0.8647961616516113
Epoch 709, Loss: 3.5855600237846375, Final Batch Loss: 0.6313950419425964
Epoch 710, Loss: 3.7141119241714478, Final Batch Loss: 0.6737118363380432
Epoch 711, Loss: 3.533948063850403, Final Batch Loss: 0.5212711691856384
Epoch 712, Loss: 3.5669317841529846, Final Batch Loss: 0.4812120795249939
Epoch 713, Loss: 3.854772686958313, Final Batch Loss: 0.8228899240493774
Epoch 714, Loss: 3.745350658893585, Final Batch Loss: 0.7727206349372864
Epoch 715, Loss: 4.224898993968964, Final Batch Loss: 1.1564759016036987
Epoch 716, Loss: 3.5949347019195557, Final Batch 

Epoch 817, Loss: 3.5654029846191406, Final Batch Loss: 0.6184566617012024
Epoch 818, Loss: 3.521871507167816, Final Batch Loss: 0.7351773381233215
Epoch 819, Loss: 3.3846647143363953, Final Batch Loss: 0.5141862630844116
Epoch 820, Loss: 3.410908877849579, Final Batch Loss: 0.5941787362098694
Epoch 821, Loss: 3.4910477995872498, Final Batch Loss: 0.6553317308425903
Epoch 822, Loss: 3.465358853340149, Final Batch Loss: 0.6788745522499084
Epoch 823, Loss: 3.7994651794433594, Final Batch Loss: 0.9958643317222595
Epoch 824, Loss: 3.797205328941345, Final Batch Loss: 1.0474690198898315
Epoch 825, Loss: 3.3429012298583984, Final Batch Loss: 0.5299071669578552
Epoch 826, Loss: 3.2725214064121246, Final Batch Loss: 0.48235949873924255
Epoch 827, Loss: 3.0918417274951935, Final Batch Loss: 0.33172520995140076
Epoch 828, Loss: 3.214739501476288, Final Batch Loss: 0.46065306663513184
Epoch 829, Loss: 3.5171753764152527, Final Batch Loss: 0.7245392799377441
Epoch 830, Loss: 3.5282431840896606, Fin

Epoch 931, Loss: 3.1690532565116882, Final Batch Loss: 0.541054368019104
Epoch 932, Loss: 3.1980621218681335, Final Batch Loss: 0.5041126608848572
Epoch 933, Loss: 3.189954459667206, Final Batch Loss: 0.6557515859603882
Epoch 934, Loss: 3.286547541618347, Final Batch Loss: 0.7407892942428589
Epoch 935, Loss: 3.5541642904281616, Final Batch Loss: 1.0014063119888306
Epoch 936, Loss: 3.1979544162750244, Final Batch Loss: 0.46842002868652344
Epoch 937, Loss: 3.161178410053253, Final Batch Loss: 0.5897001028060913
Epoch 938, Loss: 3.725052773952484, Final Batch Loss: 1.0954887866973877
Epoch 939, Loss: 3.486091911792755, Final Batch Loss: 0.8749174475669861
Epoch 940, Loss: 3.218143403530121, Final Batch Loss: 0.5356869101524353
Epoch 941, Loss: 3.3981735706329346, Final Batch Loss: 0.8129541277885437
Epoch 942, Loss: 3.484389007091522, Final Batch Loss: 1.0066343545913696
Epoch 943, Loss: 3.333767354488373, Final Batch Loss: 0.7268195152282715
Epoch 944, Loss: 2.7587678283452988, Final Bat

Epoch 1042, Loss: 2.6620081961154938, Final Batch Loss: 0.21297290921211243
Epoch 1043, Loss: 2.8256645500659943, Final Batch Loss: 0.3869730532169342
Epoch 1044, Loss: 3.2135862708091736, Final Batch Loss: 0.7692801356315613
Epoch 1045, Loss: 3.0323650240898132, Final Batch Loss: 0.5064999461174011
Epoch 1046, Loss: 3.0773606300354004, Final Batch Loss: 0.6174759268760681
Epoch 1047, Loss: 3.0842984318733215, Final Batch Loss: 0.5666951537132263
Epoch 1048, Loss: 3.532983660697937, Final Batch Loss: 0.9225796461105347
Epoch 1049, Loss: 3.0904229283332825, Final Batch Loss: 0.6249585747718811
Epoch 1050, Loss: 2.777259588241577, Final Batch Loss: 0.3786713480949402
Epoch 1051, Loss: 3.0583258867263794, Final Batch Loss: 0.6389416456222534
Epoch 1052, Loss: 3.0618852376937866, Final Batch Loss: 0.5906369090080261
Epoch 1053, Loss: 3.348967432975769, Final Batch Loss: 1.0671826601028442
Epoch 1054, Loss: 2.7820172905921936, Final Batch Loss: 0.40673238039016724
Epoch 1055, Loss: 3.331124

Epoch 1154, Loss: 2.9855674505233765, Final Batch Loss: 0.5608118176460266
Epoch 1155, Loss: 3.2300496697425842, Final Batch Loss: 0.9023270010948181
Epoch 1156, Loss: 2.8696924448013306, Final Batch Loss: 0.5814128518104553
Epoch 1157, Loss: 2.8375181555747986, Final Batch Loss: 0.6299611926078796
Epoch 1158, Loss: 2.7415904700756073, Final Batch Loss: 0.3753298223018646
Epoch 1159, Loss: 3.3190845251083374, Final Batch Loss: 0.9646862149238586
Epoch 1160, Loss: 3.3784740567207336, Final Batch Loss: 1.0570989847183228
Epoch 1161, Loss: 2.9208691120147705, Final Batch Loss: 0.37902265787124634
Epoch 1162, Loss: 2.5589442253112793, Final Batch Loss: 0.26530396938323975
Epoch 1163, Loss: 3.2297877073287964, Final Batch Loss: 0.8218773603439331
Epoch 1164, Loss: 2.9110403656959534, Final Batch Loss: 0.5889650583267212
Epoch 1165, Loss: 2.943133234977722, Final Batch Loss: 0.5712196826934814
Epoch 1166, Loss: 2.747560977935791, Final Batch Loss: 0.5155674815177917
Epoch 1167, Loss: 2.60167

Epoch 1264, Loss: 2.5040415227413177, Final Batch Loss: 0.37774038314819336
Epoch 1265, Loss: 2.5681983828544617, Final Batch Loss: 0.3662770390510559
Epoch 1266, Loss: 3.1570438146591187, Final Batch Loss: 1.0443581342697144
Epoch 1267, Loss: 2.776915490627289, Final Batch Loss: 0.44861191511154175
Epoch 1268, Loss: 2.7581529021263123, Final Batch Loss: 0.5088708996772766
Epoch 1269, Loss: 2.8369770646095276, Final Batch Loss: 0.5474634170532227
Epoch 1270, Loss: 2.4229461699724197, Final Batch Loss: 0.21419115364551544
Epoch 1271, Loss: 2.99090376496315, Final Batch Loss: 0.7843261361122131
Epoch 1272, Loss: 2.647695094347, Final Batch Loss: 0.43645504117012024
Epoch 1273, Loss: 2.85003525018692, Final Batch Loss: 0.4724969267845154
Epoch 1274, Loss: 2.6955322325229645, Final Batch Loss: 0.4320984184741974
Epoch 1275, Loss: 3.1258986592292786, Final Batch Loss: 0.9041799306869507
Epoch 1276, Loss: 2.681748330593109, Final Batch Loss: 0.469473659992218
Epoch 1277, Loss: 3.208678603172

Epoch 1374, Loss: 2.676491230726242, Final Batch Loss: 0.5950098037719727
Epoch 1375, Loss: 2.544954478740692, Final Batch Loss: 0.4606887400150299
Epoch 1376, Loss: 2.7102443277835846, Final Batch Loss: 0.6226475834846497
Epoch 1377, Loss: 3.050024092197418, Final Batch Loss: 0.9869090914726257
Epoch 1378, Loss: 2.4845909476280212, Final Batch Loss: 0.45452743768692017
Epoch 1379, Loss: 2.7564807534217834, Final Batch Loss: 0.5971169471740723
Epoch 1380, Loss: 2.8556411266326904, Final Batch Loss: 0.8487973809242249
Epoch 1381, Loss: 3.7045509815216064, Final Batch Loss: 1.3964669704437256
Epoch 1382, Loss: 2.345635309815407, Final Batch Loss: 0.15174327790737152
Epoch 1383, Loss: 2.810625195503235, Final Batch Loss: 0.6611502766609192
Epoch 1384, Loss: 2.458479329943657, Final Batch Loss: 0.1951904147863388
Epoch 1385, Loss: 2.602659434080124, Final Batch Loss: 0.4130711555480957
Epoch 1386, Loss: 2.41238334774971, Final Batch Loss: 0.30176353454589844
Epoch 1387, Loss: 2.51506263017

Epoch 1491, Loss: 2.6702761352062225, Final Batch Loss: 0.691491425037384
Epoch 1492, Loss: 2.5934222638607025, Final Batch Loss: 0.4877692759037018
Epoch 1493, Loss: 2.6362257599830627, Final Batch Loss: 0.6422598958015442
Epoch 1494, Loss: 2.7207189202308655, Final Batch Loss: 0.7145991921424866
Epoch 1495, Loss: 2.8882191479206085, Final Batch Loss: 0.9323150515556335
Epoch 1496, Loss: 2.792224556207657, Final Batch Loss: 0.8230805397033691
Epoch 1497, Loss: 2.8203614950180054, Final Batch Loss: 0.6584148406982422
Epoch 1498, Loss: 2.470856696367264, Final Batch Loss: 0.29127925634384155
Epoch 1499, Loss: 2.337303549051285, Final Batch Loss: 0.3231295943260193
Epoch 1500, Loss: 2.725342273712158, Final Batch Loss: 0.7686367034912109
Epoch 1501, Loss: 2.336157649755478, Final Batch Loss: 0.2799047529697418
Epoch 1502, Loss: 2.5958683490753174, Final Batch Loss: 0.5887660384178162
Epoch 1503, Loss: 2.65233650803566, Final Batch Loss: 0.611995279788971
Epoch 1504, Loss: 2.5206578075885

Epoch 1602, Loss: 2.4034528136253357, Final Batch Loss: 0.4584888815879822
Epoch 1603, Loss: 2.5050907731056213, Final Batch Loss: 0.6442188620567322
Epoch 1604, Loss: 2.8405532836914062, Final Batch Loss: 0.8595673441886902
Epoch 1605, Loss: 2.4156374335289, Final Batch Loss: 0.4327041506767273
Epoch 1606, Loss: 2.3903927505016327, Final Batch Loss: 0.5089300870895386
Epoch 1607, Loss: 2.31122624874115, Final Batch Loss: 0.38601064682006836
Epoch 1608, Loss: 2.273770183324814, Final Batch Loss: 0.29054728150367737
Epoch 1609, Loss: 2.012451335787773, Final Batch Loss: 0.07758991420269012
Epoch 1610, Loss: 2.3749831318855286, Final Batch Loss: 0.42287200689315796
Epoch 1611, Loss: 2.4566935896873474, Final Batch Loss: 0.5440369248390198
Epoch 1612, Loss: 2.6402996480464935, Final Batch Loss: 0.7004519104957581
Epoch 1613, Loss: 2.3898304104804993, Final Batch Loss: 0.40677446126937866
Epoch 1614, Loss: 2.6598819494247437, Final Batch Loss: 0.7050215601921082
Epoch 1615, Loss: 2.9060910

Epoch 1718, Loss: 2.2578163146972656, Final Batch Loss: 0.3435191512107849
Epoch 1719, Loss: 2.2233801782131195, Final Batch Loss: 0.3050099015235901
Epoch 1720, Loss: 2.115330845117569, Final Batch Loss: 0.3449226915836334
Epoch 1721, Loss: 2.293949633836746, Final Batch Loss: 0.44223424792289734
Epoch 1722, Loss: 1.9749472439289093, Final Batch Loss: 0.21052077412605286
Epoch 1723, Loss: 2.7035659551620483, Final Batch Loss: 0.7542632818222046
Epoch 1724, Loss: 2.55084890127182, Final Batch Loss: 0.7656283974647522
Epoch 1725, Loss: 2.484149694442749, Final Batch Loss: 0.6257988214492798
Epoch 1726, Loss: 2.5184419453144073, Final Batch Loss: 0.6211133003234863
Epoch 1727, Loss: 2.012462481856346, Final Batch Loss: 0.2115679532289505
Epoch 1728, Loss: 2.2963375449180603, Final Batch Loss: 0.40187567472457886
Epoch 1729, Loss: 2.114516243338585, Final Batch Loss: 0.22536064684391022
Epoch 1730, Loss: 2.9355777502059937, Final Batch Loss: 1.0744378566741943
Epoch 1731, Loss: 2.41898673

Epoch 1829, Loss: 2.139389470219612, Final Batch Loss: 0.23501335084438324
Epoch 1830, Loss: 2.2213033735752106, Final Batch Loss: 0.44849446415901184
Epoch 1831, Loss: 2.4282300770282745, Final Batch Loss: 0.5059633255004883
Epoch 1832, Loss: 2.782124698162079, Final Batch Loss: 0.9797729849815369
Epoch 1833, Loss: 2.3520622551441193, Final Batch Loss: 0.520107090473175
Epoch 1834, Loss: 2.3784931302070618, Final Batch Loss: 0.42696648836135864
Epoch 1835, Loss: 2.192035436630249, Final Batch Loss: 0.3844609260559082
Epoch 1836, Loss: 2.328195869922638, Final Batch Loss: 0.5455626845359802
Epoch 1837, Loss: 2.1353572010993958, Final Batch Loss: 0.33418846130371094
Epoch 1838, Loss: 2.3264014422893524, Final Batch Loss: 0.5531275868415833
Epoch 1839, Loss: 2.240060865879059, Final Batch Loss: 0.26446643471717834
Epoch 1840, Loss: 2.696379214525223, Final Batch Loss: 0.7086352109909058
Epoch 1841, Loss: 2.314227432012558, Final Batch Loss: 0.4784274399280548
Epoch 1842, Loss: 2.28541305

Epoch 1942, Loss: 2.2602561116218567, Final Batch Loss: 0.49592095613479614
Epoch 1943, Loss: 2.058957040309906, Final Batch Loss: 0.24314674735069275
Epoch 1944, Loss: 2.0616932213306427, Final Batch Loss: 0.2547931373119354
Epoch 1945, Loss: 2.0710346400737762, Final Batch Loss: 0.30421438813209534
Epoch 1946, Loss: 2.3655996918678284, Final Batch Loss: 0.6065478920936584
Epoch 1947, Loss: 2.108894407749176, Final Batch Loss: 0.34366700053215027
Epoch 1948, Loss: 2.28672131896019, Final Batch Loss: 0.4819726347923279
Epoch 1949, Loss: 2.273907631635666, Final Batch Loss: 0.4490041136741638
Epoch 1950, Loss: 1.986885815858841, Final Batch Loss: 0.1989668309688568
Epoch 1951, Loss: 2.1750669479370117, Final Batch Loss: 0.39500492811203003
Epoch 1952, Loss: 2.4734452664852142, Final Batch Loss: 0.6597999334335327
Epoch 1953, Loss: 2.391150325536728, Final Batch Loss: 0.578861653804779
Epoch 1954, Loss: 1.9506221115589142, Final Batch Loss: 0.24430349469184875
Epoch 1955, Loss: 2.1530481

Epoch 2057, Loss: 1.9086463749408722, Final Batch Loss: 0.2587532699108124
Epoch 2058, Loss: 2.2753641605377197, Final Batch Loss: 0.633583128452301
Epoch 2059, Loss: 2.015865668654442, Final Batch Loss: 0.24629037082195282
Epoch 2060, Loss: 2.1001496613025665, Final Batch Loss: 0.3272917568683624
Epoch 2061, Loss: 2.3899534046649933, Final Batch Loss: 0.5606130361557007
Epoch 2062, Loss: 2.237119495868683, Final Batch Loss: 0.4910913407802582
Epoch 2063, Loss: 2.121095597743988, Final Batch Loss: 0.4101352095603943
Epoch 2064, Loss: 2.0422704219818115, Final Batch Loss: 0.41347378492355347
Epoch 2065, Loss: 2.4722579419612885, Final Batch Loss: 0.7069603204727173
Epoch 2066, Loss: 2.1565282940864563, Final Batch Loss: 0.5446034669876099
Epoch 2067, Loss: 1.829342320561409, Final Batch Loss: 0.13303013145923615
Epoch 2068, Loss: 2.113856256008148, Final Batch Loss: 0.4681573212146759
Epoch 2069, Loss: 2.323078542947769, Final Batch Loss: 0.6471737027168274
Epoch 2070, Loss: 1.867087289

Epoch 2169, Loss: 2.395291656255722, Final Batch Loss: 0.5909348130226135
Epoch 2170, Loss: 2.0509800016880035, Final Batch Loss: 0.2534719407558441
Epoch 2171, Loss: 1.9892387092113495, Final Batch Loss: 0.26371827721595764
Epoch 2172, Loss: 2.2395528852939606, Final Batch Loss: 0.4894144833087921
Epoch 2173, Loss: 2.122625768184662, Final Batch Loss: 0.4672413766384125
Epoch 2174, Loss: 2.15603169798851, Final Batch Loss: 0.4100801944732666
Epoch 2175, Loss: 1.778222531080246, Final Batch Loss: 0.21772277355194092
Epoch 2176, Loss: 2.056923121213913, Final Batch Loss: 0.32231757044792175
Epoch 2177, Loss: 2.316058248281479, Final Batch Loss: 0.6375243067741394
Epoch 2178, Loss: 2.8586871922016144, Final Batch Loss: 1.1919564008712769
Epoch 2179, Loss: 2.2887759506702423, Final Batch Loss: 0.574081301689148
Epoch 2180, Loss: 2.7431493401527405, Final Batch Loss: 0.7452926635742188
Epoch 2181, Loss: 2.304284989833832, Final Batch Loss: 0.5166283845901489
Epoch 2182, Loss: 2.76485231518

Epoch 2281, Loss: 2.5729140043258667, Final Batch Loss: 0.9569322466850281
Epoch 2282, Loss: 2.1531553268432617, Final Batch Loss: 0.45836013555526733
Epoch 2283, Loss: 1.9672739207744598, Final Batch Loss: 0.30647802352905273
Epoch 2284, Loss: 2.257246971130371, Final Batch Loss: 0.6302801966667175
Epoch 2285, Loss: 2.2102458775043488, Final Batch Loss: 0.5136035680770874
Epoch 2286, Loss: 2.168722778558731, Final Batch Loss: 0.47392040491104126
Epoch 2287, Loss: 2.054363250732422, Final Batch Loss: 0.3896106779575348
Epoch 2288, Loss: 1.9951739013195038, Final Batch Loss: 0.28501778841018677
Epoch 2289, Loss: 2.2746669352054596, Final Batch Loss: 0.6002090573310852
Epoch 2290, Loss: 2.0624599754810333, Final Batch Loss: 0.37833255529403687
Epoch 2291, Loss: 1.9252579510211945, Final Batch Loss: 0.33916693925857544
Epoch 2292, Loss: 1.8381383121013641, Final Batch Loss: 0.2236439287662506
Epoch 2293, Loss: 2.424625873565674, Final Batch Loss: 0.7784889340400696
Epoch 2294, Loss: 1.959

Epoch 2394, Loss: 2.6150615513324738, Final Batch Loss: 0.9839878082275391
Epoch 2395, Loss: 2.097316265106201, Final Batch Loss: 0.36060211062431335
Epoch 2396, Loss: 2.221217304468155, Final Batch Loss: 0.5301672220230103
Epoch 2397, Loss: 2.0750701129436493, Final Batch Loss: 0.24041175842285156
Epoch 2398, Loss: 1.9610023647546768, Final Batch Loss: 0.23156292736530304
Epoch 2399, Loss: 1.9589523524045944, Final Batch Loss: 0.1803116351366043
Epoch 2400, Loss: 2.0465650856494904, Final Batch Loss: 0.370246559381485
Epoch 2401, Loss: 2.020585298538208, Final Batch Loss: 0.38758185505867004
Epoch 2402, Loss: 1.9427394568920135, Final Batch Loss: 0.3280128836631775
Epoch 2403, Loss: 1.8529632687568665, Final Batch Loss: 0.2821286618709564
Epoch 2404, Loss: 1.829059511423111, Final Batch Loss: 0.23603889346122742
Epoch 2405, Loss: 1.8787650614976883, Final Batch Loss: 0.22500617802143097
Epoch 2406, Loss: 2.084884852170944, Final Batch Loss: 0.3970238268375397
Epoch 2407, Loss: 2.07408

Epoch 2507, Loss: 1.9062752425670624, Final Batch Loss: 0.34923386573791504
Epoch 2508, Loss: 2.1335874497890472, Final Batch Loss: 0.44544264674186707
Epoch 2509, Loss: 1.9564216136932373, Final Batch Loss: 0.32691410183906555
Epoch 2510, Loss: 1.949002891778946, Final Batch Loss: 0.26919445395469666
Epoch 2511, Loss: 1.9898816347122192, Final Batch Loss: 0.41724762320518494
Epoch 2512, Loss: 1.6938487514853477, Final Batch Loss: 0.08276604861021042
Epoch 2513, Loss: 2.167422890663147, Final Batch Loss: 0.5223144888877869
Epoch 2514, Loss: 2.082979679107666, Final Batch Loss: 0.4340454638004303
Epoch 2515, Loss: 2.332226812839508, Final Batch Loss: 0.6638995409011841
Epoch 2516, Loss: 1.8497570306062698, Final Batch Loss: 0.19942839443683624
Epoch 2517, Loss: 2.0752953588962555, Final Batch Loss: 0.4433809816837311
Epoch 2518, Loss: 1.736672393977642, Final Batch Loss: 0.11307819932699203
Epoch 2519, Loss: 1.9312059581279755, Final Batch Loss: 0.2543948292732239
Epoch 2520, Loss: 1.85

Epoch 2619, Loss: 1.6940953060984612, Final Batch Loss: 0.09826614707708359
Epoch 2620, Loss: 1.9449054598808289, Final Batch Loss: 0.44621047377586365
Epoch 2621, Loss: 1.8850726187229156, Final Batch Loss: 0.45321378111839294
Epoch 2622, Loss: 1.905201107263565, Final Batch Loss: 0.2764430046081543
Epoch 2623, Loss: 1.991724044084549, Final Batch Loss: 0.46803969144821167
Epoch 2624, Loss: 2.044481724500656, Final Batch Loss: 0.4769309461116791
Epoch 2625, Loss: 1.7836255729198456, Final Batch Loss: 0.3265838921070099
Epoch 2626, Loss: 1.6813192516565323, Final Batch Loss: 0.17332400381565094
Epoch 2627, Loss: 1.9665555953979492, Final Batch Loss: 0.34994062781333923
Epoch 2628, Loss: 2.0952526330947876, Final Batch Loss: 0.5285928845405579
Epoch 2629, Loss: 2.1735468804836273, Final Batch Loss: 0.6174405217170715
Epoch 2630, Loss: 1.6623397469520569, Final Batch Loss: 0.1352774202823639
Epoch 2631, Loss: 1.779154870659113, Final Batch Loss: 0.05127403512597084
Epoch 2632, Loss: 2.05

Epoch 2734, Loss: 1.9676395952701569, Final Batch Loss: 0.3745596408843994
Epoch 2735, Loss: 2.0788058638572693, Final Batch Loss: 0.44391968846321106
Epoch 2736, Loss: 2.1881821751594543, Final Batch Loss: 0.6615462899208069
Epoch 2737, Loss: 1.897524744272232, Final Batch Loss: 0.3911585211753845
Epoch 2738, Loss: 1.8550027310848236, Final Batch Loss: 0.3414137363433838
Epoch 2739, Loss: 1.7555465027689934, Final Batch Loss: 0.10846801847219467
Epoch 2740, Loss: 1.89923757314682, Final Batch Loss: 0.39155882596969604
Epoch 2741, Loss: 1.871328592300415, Final Batch Loss: 0.34956058859825134
Epoch 2742, Loss: 1.6118585243821144, Final Batch Loss: 0.06543765217065811
Epoch 2743, Loss: 1.8559861183166504, Final Batch Loss: 0.3280427157878876
Epoch 2744, Loss: 1.7136452570557594, Final Batch Loss: 0.06140131503343582
Epoch 2745, Loss: 1.870427668094635, Final Batch Loss: 0.327888160943985
Epoch 2746, Loss: 2.029688835144043, Final Batch Loss: 0.4533940255641937
Epoch 2747, Loss: 1.975287

Epoch 2843, Loss: 1.5970125421881676, Final Batch Loss: 0.09884055703878403
Epoch 2844, Loss: 1.8460439443588257, Final Batch Loss: 0.360103040933609
Epoch 2845, Loss: 1.5911220610141754, Final Batch Loss: 0.2043595016002655
Epoch 2846, Loss: 1.7542199790477753, Final Batch Loss: 0.27177149057388306
Epoch 2847, Loss: 1.75023752450943, Final Batch Loss: 0.3143220841884613
Epoch 2848, Loss: 1.8285980522632599, Final Batch Loss: 0.3421168923377991
Epoch 2849, Loss: 1.9013644754886627, Final Batch Loss: 0.32705339789390564
Epoch 2850, Loss: 2.0987802743911743, Final Batch Loss: 0.5302807688713074
Epoch 2851, Loss: 1.7153591215610504, Final Batch Loss: 0.21597817540168762
Epoch 2852, Loss: 1.7678442299365997, Final Batch Loss: 0.24972951412200928
Epoch 2853, Loss: 2.1921893656253815, Final Batch Loss: 0.6607053875923157
Epoch 2854, Loss: 1.692915752530098, Final Batch Loss: 0.1694924384355545
Epoch 2855, Loss: 2.000639945268631, Final Batch Loss: 0.49123474955558777
Epoch 2856, Loss: 2.1576

Epoch 2956, Loss: 1.8276166021823883, Final Batch Loss: 0.27760374546051025
Epoch 2957, Loss: 2.0121428072452545, Final Batch Loss: 0.3800511360168457
Epoch 2958, Loss: 1.7571665942668915, Final Batch Loss: 0.2567586600780487
Epoch 2959, Loss: 1.6924197971820831, Final Batch Loss: 0.23538827896118164
Epoch 2960, Loss: 2.3071461617946625, Final Batch Loss: 0.7068365812301636
Epoch 2961, Loss: 1.6871298998594284, Final Batch Loss: 0.1936293989419937
Epoch 2962, Loss: 2.063665807247162, Final Batch Loss: 0.507523238658905
Epoch 2963, Loss: 1.7460215836763382, Final Batch Loss: 0.12541471421718597
Epoch 2964, Loss: 1.9672633707523346, Final Batch Loss: 0.2996692359447479
Epoch 2965, Loss: 2.3094770908355713, Final Batch Loss: 0.5622653961181641
Epoch 2966, Loss: 2.592919886112213, Final Batch Loss: 0.9798049926757812
Epoch 2967, Loss: 1.9269823729991913, Final Batch Loss: 0.22109660506248474
Epoch 2968, Loss: 2.1972790360450745, Final Batch Loss: 0.4788103699684143
Epoch 2969, Loss: 2.1141

Epoch 3066, Loss: 1.913848638534546, Final Batch Loss: 0.4008290469646454
Epoch 3067, Loss: 2.017664074897766, Final Batch Loss: 0.5104352235794067
Epoch 3068, Loss: 1.699044182896614, Final Batch Loss: 0.23165486752986908
Epoch 3069, Loss: 2.0398579835891724, Final Batch Loss: 0.5809992551803589
Epoch 3070, Loss: 1.955028623342514, Final Batch Loss: 0.3633013367652893
Epoch 3071, Loss: 1.8341506719589233, Final Batch Loss: 0.48550543189048767
Epoch 3072, Loss: 2.068960815668106, Final Batch Loss: 0.5087495446205139
Epoch 3073, Loss: 1.9698545038700104, Final Batch Loss: 0.37966641783714294
Epoch 3074, Loss: 1.6954728066921234, Final Batch Loss: 0.20183420181274414
Epoch 3075, Loss: 1.8723827600479126, Final Batch Loss: 0.3878819942474365
Epoch 3076, Loss: 2.3462518751621246, Final Batch Loss: 0.8097214102745056
Epoch 3077, Loss: 1.5974483639001846, Final Batch Loss: 0.0640900582075119
Epoch 3078, Loss: 1.8349459171295166, Final Batch Loss: 0.4032290279865265
Epoch 3079, Loss: 1.927259

Epoch 3178, Loss: 1.8043113052845001, Final Batch Loss: 0.3485831320285797
Epoch 3179, Loss: 1.6348433792591095, Final Batch Loss: 0.17881730198860168
Epoch 3180, Loss: 1.6593983173370361, Final Batch Loss: 0.1954542100429535
Epoch 3181, Loss: 2.232485979795456, Final Batch Loss: 0.7112273573875427
Epoch 3182, Loss: 2.2709136307239532, Final Batch Loss: 0.8472275733947754
Epoch 3183, Loss: 1.7816099971532822, Final Batch Loss: 0.20964930951595306
Epoch 3184, Loss: 1.9379447400569916, Final Batch Loss: 0.40355372428894043
Epoch 3185, Loss: 2.279056668281555, Final Batch Loss: 0.7352291345596313
Epoch 3186, Loss: 1.6075390949845314, Final Batch Loss: 0.09710530191659927
Epoch 3187, Loss: 1.7990627735853195, Final Batch Loss: 0.19314272701740265
Epoch 3188, Loss: 2.087369292974472, Final Batch Loss: 0.48683005571365356
Epoch 3189, Loss: 2.0341930389404297, Final Batch Loss: 0.48554590344429016
Epoch 3190, Loss: 2.0936070382595062, Final Batch Loss: 0.5798554420471191
Epoch 3191, Loss: 1.6

Epoch 3290, Loss: 1.5994803309440613, Final Batch Loss: 0.286302387714386
Epoch 3291, Loss: 1.809081882238388, Final Batch Loss: 0.37922120094299316
Epoch 3292, Loss: 1.8869928121566772, Final Batch Loss: 0.40496304631233215
Epoch 3293, Loss: 1.8202517926692963, Final Batch Loss: 0.46998950839042664
Epoch 3294, Loss: 1.5241800248622894, Final Batch Loss: 0.18669018149375916
Epoch 3295, Loss: 1.7997027039527893, Final Batch Loss: 0.36932533979415894
Epoch 3296, Loss: 1.5891956388950348, Final Batch Loss: 0.17227014899253845
Epoch 3297, Loss: 1.6852195709943771, Final Batch Loss: 0.23001264035701752
Epoch 3298, Loss: 1.7759045958518982, Final Batch Loss: 0.33616238832473755
Epoch 3299, Loss: 1.7597108483314514, Final Batch Loss: 0.36964747309684753
Epoch 3300, Loss: 2.215262711048126, Final Batch Loss: 0.7207756638526917
Epoch 3301, Loss: 1.7070141732692719, Final Batch Loss: 0.3108905255794525
Epoch 3302, Loss: 1.7091746628284454, Final Batch Loss: 0.3240102231502533
Epoch 3303, Loss: 1

Epoch 3406, Loss: 1.5529136210680008, Final Batch Loss: 0.11108706891536713
Epoch 3407, Loss: 1.6708002388477325, Final Batch Loss: 0.30452674627304077
Epoch 3408, Loss: 1.504167065024376, Final Batch Loss: 0.11384634673595428
Epoch 3409, Loss: 1.5232724100351334, Final Batch Loss: 0.19484789669513702
Epoch 3410, Loss: 1.468019299209118, Final Batch Loss: 0.11488258093595505
Epoch 3411, Loss: 1.8623719960451126, Final Batch Loss: 0.505659282207489
Epoch 3412, Loss: 2.047182321548462, Final Batch Loss: 0.7257800102233887
Epoch 3413, Loss: 2.139722168445587, Final Batch Loss: 0.5198814868927002
Epoch 3414, Loss: 2.412121683359146, Final Batch Loss: 0.7148996591567993
Epoch 3415, Loss: 1.950114667415619, Final Batch Loss: 0.3401780426502228
Epoch 3416, Loss: 1.7845974825322628, Final Batch Loss: 0.06162377819418907
Epoch 3417, Loss: 1.9794232547283173, Final Batch Loss: 0.29853034019470215
Epoch 3418, Loss: 1.8386962413787842, Final Batch Loss: 0.3725271224975586
Epoch 3419, Loss: 2.04106

Epoch 3517, Loss: 1.671934962272644, Final Batch Loss: 0.17351174354553223
Epoch 3518, Loss: 1.6569068431854248, Final Batch Loss: 0.3665630519390106
Epoch 3519, Loss: 1.725468248128891, Final Batch Loss: 0.28214675188064575
Epoch 3520, Loss: 1.7329452335834503, Final Batch Loss: 0.44520166516304016
Epoch 3521, Loss: 1.9331556558609009, Final Batch Loss: 0.5483071804046631
Epoch 3522, Loss: 1.689322680234909, Final Batch Loss: 0.3036908209323883
Epoch 3523, Loss: 1.6770178973674774, Final Batch Loss: 0.3162488639354706
Epoch 3524, Loss: 2.010677456855774, Final Batch Loss: 0.644112229347229
Epoch 3525, Loss: 1.6994785219430923, Final Batch Loss: 0.18409554660320282
Epoch 3526, Loss: 1.8495438992977142, Final Batch Loss: 0.37809106707572937
Epoch 3527, Loss: 2.155876874923706, Final Batch Loss: 0.7304315567016602
Epoch 3528, Loss: 1.988464117050171, Final Batch Loss: 0.5361183881759644
Epoch 3529, Loss: 1.5970681607723236, Final Batch Loss: 0.12461787462234497
Epoch 3530, Loss: 1.463821

Epoch 3633, Loss: 1.7527889013290405, Final Batch Loss: 0.4050881266593933
Epoch 3634, Loss: 1.9148178696632385, Final Batch Loss: 0.37129661440849304
Epoch 3635, Loss: 2.3201206028461456, Final Batch Loss: 0.948304295539856
Epoch 3636, Loss: 1.6397092044353485, Final Batch Loss: 0.3524138629436493
Epoch 3637, Loss: 1.6369935013353825, Final Batch Loss: 0.040164265781641006
Epoch 3638, Loss: 1.7414373010396957, Final Batch Loss: 0.1881394237279892
Epoch 3639, Loss: 2.19612455368042, Final Batch Loss: 0.6193372011184692
Epoch 3640, Loss: 1.5106603801250458, Final Batch Loss: 0.0860406756401062
Epoch 3641, Loss: 1.7739792466163635, Final Batch Loss: 0.2586461901664734
Epoch 3642, Loss: 1.7410172820091248, Final Batch Loss: 0.3410632312297821
Epoch 3643, Loss: 1.7845811545848846, Final Batch Loss: 0.28082576394081116
Epoch 3644, Loss: 1.4455076456069946, Final Batch Loss: 0.08509629964828491
Epoch 3645, Loss: 1.7213121950626373, Final Batch Loss: 0.2603656053543091
Epoch 3646, Loss: 1.483

Epoch 3744, Loss: 1.4363694936037064, Final Batch Loss: 0.21154595911502838
Epoch 3745, Loss: 1.717183381319046, Final Batch Loss: 0.38539496064186096
Epoch 3746, Loss: 1.6821835339069366, Final Batch Loss: 0.38004782795906067
Epoch 3747, Loss: 1.960587352514267, Final Batch Loss: 0.6236698031425476
Epoch 3748, Loss: 1.8179587721824646, Final Batch Loss: 0.38226228952407837
Epoch 3749, Loss: 1.528958447277546, Final Batch Loss: 0.12397242337465286
Epoch 3750, Loss: 1.874734103679657, Final Batch Loss: 0.4364781081676483
Epoch 3751, Loss: 1.666137009859085, Final Batch Loss: 0.2782686948776245
Epoch 3752, Loss: 1.7061849534511566, Final Batch Loss: 0.27047714591026306
Epoch 3753, Loss: 1.565783068537712, Final Batch Loss: 0.23176248371601105
Epoch 3754, Loss: 1.6606099605560303, Final Batch Loss: 0.4071520268917084
Epoch 3755, Loss: 1.5917856395244598, Final Batch Loss: 0.14830851554870605
Epoch 3756, Loss: 1.6463274359703064, Final Batch Loss: 0.3251475691795349
Epoch 3757, Loss: 1.484

Epoch 3858, Loss: 1.5015320628881454, Final Batch Loss: 0.22277326881885529
Epoch 3859, Loss: 1.4930284172296524, Final Batch Loss: 0.21899618208408356
Epoch 3860, Loss: 1.5767284333705902, Final Batch Loss: 0.20726454257965088
Epoch 3861, Loss: 1.8905069828033447, Final Batch Loss: 0.6252564787864685
Epoch 3862, Loss: 1.5859998762607574, Final Batch Loss: 0.3324475884437561
Epoch 3863, Loss: 1.5251133441925049, Final Batch Loss: 0.21928462386131287
Epoch 3864, Loss: 1.7591766715049744, Final Batch Loss: 0.42466649413108826
Epoch 3865, Loss: 1.5298977345228195, Final Batch Loss: 0.20096419751644135
Epoch 3866, Loss: 2.060969591140747, Final Batch Loss: 0.7155415415763855
Epoch 3867, Loss: 1.5220509767532349, Final Batch Loss: 0.1749386489391327
Epoch 3868, Loss: 1.8312151432037354, Final Batch Loss: 0.34873777627944946
Epoch 3869, Loss: 1.785616785287857, Final Batch Loss: 0.4034174084663391
Epoch 3870, Loss: 1.5612440556287766, Final Batch Loss: 0.22070665657520294
Epoch 3871, Loss: 1

Epoch 3973, Loss: 1.6574046313762665, Final Batch Loss: 0.33096012473106384
Epoch 3974, Loss: 1.5846691131591797, Final Batch Loss: 0.28346604108810425
Epoch 3975, Loss: 1.514073833823204, Final Batch Loss: 0.15289051830768585
Epoch 3976, Loss: 1.609542340040207, Final Batch Loss: 0.350526362657547
Epoch 3977, Loss: 1.953877717256546, Final Batch Loss: 0.6250749826431274
Epoch 3978, Loss: 2.225178509950638, Final Batch Loss: 0.8146029114723206
Epoch 3979, Loss: 1.636960282921791, Final Batch Loss: 0.23862190544605255
Epoch 3980, Loss: 1.7679249793291092, Final Batch Loss: 0.23112885653972626
Epoch 3981, Loss: 1.9601770043373108, Final Batch Loss: 0.48511403799057007
Epoch 3982, Loss: 1.7260498702526093, Final Batch Loss: 0.4421728253364563
Epoch 3983, Loss: 1.5501283258199692, Final Batch Loss: 0.16529031097888947
Epoch 3984, Loss: 1.6832843124866486, Final Batch Loss: 0.4126870930194855
Epoch 3985, Loss: 1.9652334451675415, Final Batch Loss: 0.5490394830703735
Epoch 3986, Loss: 1.6527

Epoch 4082, Loss: 1.6543003022670746, Final Batch Loss: 0.19683316349983215
Epoch 4083, Loss: 1.7222935557365417, Final Batch Loss: 0.41866418719291687
Epoch 4084, Loss: 1.740930199623108, Final Batch Loss: 0.13821235299110413
Epoch 4085, Loss: 2.084661513566971, Final Batch Loss: 0.6730309128761292
Epoch 4086, Loss: 1.6047578006982803, Final Batch Loss: 0.17505042254924774
Epoch 4087, Loss: 1.9504174292087555, Final Batch Loss: 0.5134692788124084
Epoch 4088, Loss: 2.006587117910385, Final Batch Loss: 0.5871421694755554
Epoch 4089, Loss: 2.272611141204834, Final Batch Loss: 0.7930352091789246
Epoch 4090, Loss: 1.5971624106168747, Final Batch Loss: 0.1839151233434677
Epoch 4091, Loss: 1.9365494847297668, Final Batch Loss: 0.3844606280326843
Epoch 4092, Loss: 1.653414636850357, Final Batch Loss: 0.3625578284263611
Epoch 4093, Loss: 1.7499569952487946, Final Batch Loss: 0.34891775250434875
Epoch 4094, Loss: 2.077423334121704, Final Batch Loss: 0.7355822920799255
Epoch 4095, Loss: 2.119643

Epoch 4195, Loss: 1.3979288786649704, Final Batch Loss: 0.1636226922273636
Epoch 4196, Loss: 1.5689540952444077, Final Batch Loss: 0.21888889372348785
Epoch 4197, Loss: 1.41623155772686, Final Batch Loss: 0.18055473268032074
Epoch 4198, Loss: 1.7864137589931488, Final Batch Loss: 0.4173758029937744
Epoch 4199, Loss: 1.686057686805725, Final Batch Loss: 0.34601807594299316
Epoch 4200, Loss: 1.4765211045742035, Final Batch Loss: 0.30324456095695496
Epoch 4201, Loss: 1.5502971857786179, Final Batch Loss: 0.3581823706626892
Epoch 4202, Loss: 1.708988606929779, Final Batch Loss: 0.45083457231521606
Epoch 4203, Loss: 1.4056779593229294, Final Batch Loss: 0.1764795035123825
Epoch 4204, Loss: 1.5087740421295166, Final Batch Loss: 0.16146329045295715
Epoch 4205, Loss: 1.6190583258867264, Final Batch Loss: 0.19398237764835358
Epoch 4206, Loss: 1.4471477419137955, Final Batch Loss: 0.1745065599679947
Epoch 4207, Loss: 1.43008291721344, Final Batch Loss: 0.18739959597587585
Epoch 4208, Loss: 1.639

Epoch 4311, Loss: 1.6572920978069305, Final Batch Loss: 0.32837316393852234
Epoch 4312, Loss: 1.6261481940746307, Final Batch Loss: 0.33751222491264343
Epoch 4313, Loss: 1.508090615272522, Final Batch Loss: 0.08714157342910767
Epoch 4314, Loss: 1.459261417388916, Final Batch Loss: 0.053846269845962524
Epoch 4315, Loss: 1.8997796773910522, Final Batch Loss: 0.4724533259868622
Epoch 4316, Loss: 1.6049737483263016, Final Batch Loss: 0.23729531466960907
Epoch 4317, Loss: 1.4118899255990982, Final Batch Loss: 0.06794138252735138
Epoch 4318, Loss: 1.660523772239685, Final Batch Loss: 0.31302008032798767
Epoch 4319, Loss: 1.80455881357193, Final Batch Loss: 0.46969208121299744
Epoch 4320, Loss: 1.493957556784153, Final Batch Loss: 0.10937399417161942
Epoch 4321, Loss: 1.7044222950935364, Final Batch Loss: 0.3442240357398987
Epoch 4322, Loss: 1.805236965417862, Final Batch Loss: 0.46362432837486267
Epoch 4323, Loss: 1.6059061586856842, Final Batch Loss: 0.33018049597740173
Epoch 4324, Loss: 1.

Epoch 4421, Loss: 1.4354041442275047, Final Batch Loss: 0.0996432825922966
Epoch 4422, Loss: 1.7181451916694641, Final Batch Loss: 0.39370378851890564
Epoch 4423, Loss: 1.6870754063129425, Final Batch Loss: 0.3942279517650604
Epoch 4424, Loss: 1.6047371625900269, Final Batch Loss: 0.4109848141670227
Epoch 4425, Loss: 1.485688254237175, Final Batch Loss: 0.1513158529996872
Epoch 4426, Loss: 1.357923462986946, Final Batch Loss: 0.08818657696247101
Epoch 4427, Loss: 1.4033411741256714, Final Batch Loss: 0.16484618186950684
Epoch 4428, Loss: 1.7092696577310562, Final Batch Loss: 0.4103213846683502
Epoch 4429, Loss: 1.6075778901576996, Final Batch Loss: 0.32805970311164856
Epoch 4430, Loss: 1.74995955824852, Final Batch Loss: 0.2725980877876282
Epoch 4431, Loss: 1.5225881040096283, Final Batch Loss: 0.18142366409301758
Epoch 4432, Loss: 2.2146187722682953, Final Batch Loss: 0.901379406452179
Epoch 4433, Loss: 1.4258845746517181, Final Batch Loss: 0.1891176402568817
Epoch 4434, Loss: 1.56158

Epoch 4537, Loss: 1.7959008812904358, Final Batch Loss: 0.6057360768318176
Epoch 4538, Loss: 1.9117180407047272, Final Batch Loss: 0.6859598755836487
Epoch 4539, Loss: 1.2167630270123482, Final Batch Loss: 0.06574303656816483
Epoch 4540, Loss: 1.4526012241840363, Final Batch Loss: 0.20804288983345032
Epoch 4541, Loss: 1.344199150800705, Final Batch Loss: 0.11685365438461304
Epoch 4542, Loss: 1.3385108970105648, Final Batch Loss: 0.051686253398656845
Epoch 4543, Loss: 1.6538376212120056, Final Batch Loss: 0.39232465624809265
Epoch 4544, Loss: 1.7836168706417084, Final Batch Loss: 0.4495803415775299
Epoch 4545, Loss: 1.3354259729385376, Final Batch Loss: 0.13686522841453552
Epoch 4546, Loss: 1.4401096850633621, Final Batch Loss: 0.23881883919239044
Epoch 4547, Loss: 1.5220265686511993, Final Batch Loss: 0.24050727486610413
Epoch 4548, Loss: 1.7252914309501648, Final Batch Loss: 0.5042794346809387
Epoch 4549, Loss: 1.797893911600113, Final Batch Loss: 0.5371542572975159
Epoch 4550, Loss: 

Epoch 4652, Loss: 1.4926857650279999, Final Batch Loss: 0.31707528233528137
Epoch 4653, Loss: 1.8978899419307709, Final Batch Loss: 0.6415185928344727
Epoch 4654, Loss: 1.6471637785434723, Final Batch Loss: 0.3913722038269043
Epoch 4655, Loss: 1.5062413439154625, Final Batch Loss: 0.09036274999380112
Epoch 4656, Loss: 1.3932485282421112, Final Batch Loss: 0.049750059843063354
Epoch 4657, Loss: 1.6235210001468658, Final Batch Loss: 0.39148879051208496
Epoch 4658, Loss: 1.4918900281190872, Final Batch Loss: 0.24011816084384918
Epoch 4659, Loss: 1.70378777384758, Final Batch Loss: 0.42004913091659546
Epoch 4660, Loss: 1.8739351034164429, Final Batch Loss: 0.4563877284526825
Epoch 4661, Loss: 1.4624264985322952, Final Batch Loss: 0.09192900359630585
Epoch 4662, Loss: 1.7628569304943085, Final Batch Loss: 0.38380399346351624
Epoch 4663, Loss: 1.360618695616722, Final Batch Loss: 0.09064634144306183
Epoch 4664, Loss: 1.5213534533977509, Final Batch Loss: 0.17719149589538574
Epoch 4665, Loss:

Epoch 4764, Loss: 1.3181604892015457, Final Batch Loss: 0.08824069797992706
Epoch 4765, Loss: 1.6328075230121613, Final Batch Loss: 0.3384116590023041
Epoch 4766, Loss: 1.4635802954435349, Final Batch Loss: 0.1685890406370163
Epoch 4767, Loss: 1.6410718858242035, Final Batch Loss: 0.376504123210907
Epoch 4768, Loss: 1.6540104150772095, Final Batch Loss: 0.3832504451274872
Epoch 4769, Loss: 1.3623510748147964, Final Batch Loss: 0.18857689201831818
Epoch 4770, Loss: 1.3561007231473923, Final Batch Loss: 0.11581166088581085
Epoch 4771, Loss: 1.560865968465805, Final Batch Loss: 0.31873688101768494
Epoch 4772, Loss: 1.549048274755478, Final Batch Loss: 0.2719821333885193
Epoch 4773, Loss: 1.8771938979625702, Final Batch Loss: 0.5844098329544067
Epoch 4774, Loss: 1.6636315286159515, Final Batch Loss: 0.3811592161655426
Epoch 4775, Loss: 1.757193773984909, Final Batch Loss: 0.5177679657936096
Epoch 4776, Loss: 1.5711308717727661, Final Batch Loss: 0.35538095235824585
Epoch 4777, Loss: 1.6290

Epoch 4875, Loss: 1.4723033010959625, Final Batch Loss: 0.13981953263282776
Epoch 4876, Loss: 1.5574347078800201, Final Batch Loss: 0.27486488223075867
Epoch 4877, Loss: 1.7846229374408722, Final Batch Loss: 0.48560264706611633
Epoch 4878, Loss: 1.5859536230564117, Final Batch Loss: 0.3776788115501404
Epoch 4879, Loss: 1.5126896500587463, Final Batch Loss: 0.28425973653793335
Epoch 4880, Loss: 1.4493619501590729, Final Batch Loss: 0.20003843307495117
Epoch 4881, Loss: 1.5925387144088745, Final Batch Loss: 0.27369818091392517
Epoch 4882, Loss: 1.329044371843338, Final Batch Loss: 0.17108884453773499
Epoch 4883, Loss: 1.5584344863891602, Final Batch Loss: 0.3970637917518616
Epoch 4884, Loss: 1.4926426708698273, Final Batch Loss: 0.3132566809654236
Epoch 4885, Loss: 1.2397691532969475, Final Batch Loss: 0.07684733718633652
Epoch 4886, Loss: 1.526261329650879, Final Batch Loss: 0.3180946409702301
Epoch 4887, Loss: 1.6051425337791443, Final Batch Loss: 0.29118695855140686
Epoch 4888, Loss: 

Epoch 4991, Loss: 1.6779500246047974, Final Batch Loss: 0.40050381422042847
Epoch 4992, Loss: 1.5949321389198303, Final Batch Loss: 0.3211595118045807
Epoch 4993, Loss: 1.3985361903905869, Final Batch Loss: 0.17456196248531342
Epoch 4994, Loss: 1.6181000769138336, Final Batch Loss: 0.3325694501399994
Epoch 4995, Loss: 1.4676940329372883, Final Batch Loss: 0.058357227593660355
Epoch 4996, Loss: 1.3499031066894531, Final Batch Loss: 0.08158093690872192
Epoch 4997, Loss: 1.4322989135980606, Final Batch Loss: 0.07290883362293243
Epoch 4998, Loss: 1.5558934211730957, Final Batch Loss: 0.34720054268836975
Epoch 4999, Loss: 1.3225022703409195, Final Batch Loss: 0.15284039080142975
Epoch 5000, Loss: 1.3336227014660835, Final Batch Loss: 0.0535629466176033


In [9]:
softmax = nn.Softmax(dim = 1)
model.train()
for batch in test_loader:
    features, labels = batch
    _, preds = torch.max(softmax(model(features.float())), dim = 1)
    print(metrics.confusion_matrix((labels).cpu(), preds.cpu()))
    print(metrics.classification_report((labels).cpu(), preds.cpu(), digits = 5))

[[16  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  9  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  5  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0 11  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0 10  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  3  0  0  1  0  0  0  0  0  2  0  0  0  0  0  1  0  0  1]
 [ 0  0  0  0  0  0 12  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0 11  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  4  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0 12  0  0  1  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  1 12  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  1 10  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 1  2  0  0  0  0  0  0  0  0  0  0  9  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0

# Train on Fake Test on Real

In [10]:
# 20,000 epochs DEFAULT PARAMETERS FOR THRESHOLDS

In [11]:
gen = Generator(z_dim = 111)
load_model(gen, "3 Label 8 Subject GAN Ablation_gen.param")
gen.eval()

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): Linear(in_features=111, out_features=80, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Linear(in_features=80, out_features=60, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): BatchNorm1d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Linear(in_features=60, out_features=50, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
    (3): Linear(in_features=50, out_features=32, bias=True)
    (4): Tanh()
  )
)

In [30]:
size = len(X_test)
latent_vectors = get_noise(size, 100)
act_vectors = get_act_matrix(size, 3)
usr_vectors = get_usr_matrix(size, 8)

fake_labels = []

for k in range(size):
    if act_vectors[0][k] == 0 and usr_vectors[0][k] == 0:
        fake_labels.append(0)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 0:
        fake_labels.append(1)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 0:
        fake_labels.append(2)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 1:
        fake_labels.append(3)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 1:
        fake_labels.append(4)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 1:
        fake_labels.append(5)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 2:
        fake_labels.append(6)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 2:
        fake_labels.append(7)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 2:
        fake_labels.append(8)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 3:
        fake_labels.append(9)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 3:
        fake_labels.append(10)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 3:
        fake_labels.append(11)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 4:
        fake_labels.append(12)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 4:
        fake_labels.append(13)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 4:
        fake_labels.append(14)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 5:
        fake_labels.append(15)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 5:
        fake_labels.append(16)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 5:
        fake_labels.append(17)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 6:
        fake_labels.append(18)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 6:
        fake_labels.append(19)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 6:
        fake_labels.append(20)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 7:
        fake_labels.append(21)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 7:
        fake_labels.append(22)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 7:
        fake_labels.append(23)
        
fake_labels = np.asarray(fake_labels)
to_gen = torch.cat((latent_vectors, act_vectors[1], usr_vectors[1]), 1)
fake_features = gen(to_gen).detach()
#fake_features = gen(to_gen).detach().numpy()

In [31]:
_, preds = torch.max(softmax(model(fake_features.float())), dim = 1)
print(metrics.confusion_matrix((fake_labels), preds.cpu()))
print(metrics.classification_report((fake_labels), preds.cpu(), digits = 5, zero_division = 0))

[[10  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  8  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0 11  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  7  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0 14  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  1  0  0  4  0  0  0  0  0  2  0  0  0  0  0  1  0  0  1]
 [ 0  0  0  0  0  0 13  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0]
 [ 0  0  0  0  0  0  0 10  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  8  0  0  0  0  0  0  0  0  0  0  0  1]
 [ 0  0  0  0  0  0  0  0  0  6  0  1  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  1 10  0  0  1  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  8  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  5  0  0  0  0  0]
 [ 0  2  0  0  0  0  0  0