In [1]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn import metrics
import torch.optim as optim

In [2]:
sub_features = ['42 tGravityAcc-mean()-Y',
 '43 tGravityAcc-mean()-Z',
 '51 tGravityAcc-max()-Y',
 '52 tGravityAcc-max()-Z',
 '54 tGravityAcc-min()-Y',
 '55 tGravityAcc-min()-Z',
 '56 tGravityAcc-sma()',
 '58 tGravityAcc-energy()-Y',
 '59 tGravityAcc-energy()-Z',
 '475 fBodyGyro-bandsEnergy()-1,8',
 '483 fBodyGyro-bandsEnergy()-1,16',
 '559 angle(X,gravityMean)',
 '560 angle(Y,gravityMean)',
 '561 angle(Z,gravityMean)']

act_features = ['4 tBodyAcc-std()-X',
 '10 tBodyAcc-max()-X',
 '17 tBodyAcc-energy()-X',
 '90 tBodyAccJerk-max()-X',
 '202 tBodyAccMag-std()',
 '203 tBodyAccMag-mad()',
 '215 tGravityAccMag-std()',
 '216 tGravityAccMag-mad()',
 '266 fBodyAcc-mean()-X',
 '269 fBodyAcc-std()-X',
 '272 fBodyAcc-mad()-X',
 '282 fBodyAcc-energy()-X',
 '303 fBodyAcc-bandsEnergy()-1,8',
 '311 fBodyAcc-bandsEnergy()-1,16',
 '315 fBodyAcc-bandsEnergy()-1,24',
 '382 fBodyAccJerk-bandsEnergy()-1,8',
 '504 fBodyAccMag-std()',
 '505 fBodyAccMag-mad()',
 '509 fBodyAccMag-energy()']

input_shape = len(sub_features) + len(act_features)

In [3]:
def classifier_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.LeakyReLU(0.05)
    )

class Classifier(nn.Module):
    def __init__(self, feature_dim = input_shape):
        super(Classifier, self).__init__()
        self.network = nn.Sequential(
            classifier_block(feature_dim, 40),
            classifier_block(40, 30),
            classifier_block(30, 25),
            classifier_block(25, 20),
            nn.Linear(20, 21)
        )
    def forward(self, x):
        return self.network(x)

In [4]:
#defines each generator layer
#input and output dimensions needed
def generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace = True)
    )

#returns n_samples of z_dim (number of dimensions of latent space) noise
def get_noise(n_samples, z_dim):
    return torch.randn(n_samples, z_dim)

#defines generator class
class Generator(nn.Module):
    def __init__(self, z_dim = 10, feature_dim = input_shape, hidden_dim = 128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            generator_block(z_dim, 80),
            generator_block(80, 60),
            generator_block(60, 50),
            nn.Linear(50, feature_dim),
            nn.Tanh()
        )
    def forward(self, noise):
        return self.gen(noise)

def get_act_matrix(batch_size, a_dim):
    indexes = np.random.randint(a_dim, size = batch_size)
    
    one_hot = np.zeros((len(indexes), indexes.max()+1))
    one_hot[np.arange(len(indexes)),indexes] = 1
    return torch.Tensor(indexes).long(), torch.Tensor(one_hot)
    
def get_usr_matrix(batch_size, u_dim):
    indexes = np.random.randint(u_dim, size = batch_size)
    
    one_hot = np.zeros((indexes.size, indexes.max()+1))
    one_hot[np.arange(indexes.size),indexes] = 1
    return torch.Tensor(indexes).long(), torch.Tensor(one_hot)

def load_model(model, model_name):
    model.load_state_dict(torch.load(f'../../../saved_models/{model_name}'))

# Train on Real Test on Real

In [5]:
#label is a list of integers specifying which labels to filter by
#users is a list of integers specifying which users to filter by
#y_label is a string, either "Activity" or "Subject" depending on what y output needs to be returned
def start_data(label, users, y_label, sub_features, act_features):
    #get the dataframe column names
    name_dataframe = pd.read_csv('../../../data/features.txt', delimiter = '\n', header = None)
    names = name_dataframe.values.tolist()
    names = [k for row in names for k in row] #List of column names

    data = pd.read_csv('../../../data/X_train.txt', delim_whitespace = True, header = None) #Read in dataframe
    data.columns = names #Setting column names
    
    X_train_1 = data[sub_features]
    X_train_2 = data[act_features]
    X_train = pd.concat([X_train_1, X_train_2], axis = 1)
    
    y_train_activity = pd.read_csv('../../../data/y_train.txt', header = None)
    y_train_activity.columns = ['Activity']
    
    y_train_subject = pd.read_csv('../../../data/subject_train.txt', header = None)
    y_train_subject.columns = ['Subject']
    
    GAN_data = pd.concat([X_train, y_train_activity, y_train_subject], axis = 1)
    GAN_data = GAN_data[GAN_data['Activity'].isin(label)]
    GAN_data = GAN_data[GAN_data['Subject'].isin(users)]
    
    X_1 = GAN_data[(GAN_data['Subject'] == 1) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_2 = GAN_data[(GAN_data['Subject'] == 1) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_3 = GAN_data[(GAN_data['Subject'] == 1) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_4 = GAN_data[(GAN_data['Subject'] == 3) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_5 = GAN_data[(GAN_data['Subject'] == 3) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_6 = GAN_data[(GAN_data['Subject'] == 3) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_7 = GAN_data[(GAN_data['Subject'] == 5) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_8 = GAN_data[(GAN_data['Subject'] == 5) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_9 = GAN_data[(GAN_data['Subject'] == 5) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_10 = GAN_data[(GAN_data['Subject'] == 7) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_11 = GAN_data[(GAN_data['Subject'] == 7) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_12 = GAN_data[(GAN_data['Subject'] == 7) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_13 = GAN_data[(GAN_data['Subject'] == 8) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_14 = GAN_data[(GAN_data['Subject'] == 8) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_15 = GAN_data[(GAN_data['Subject'] == 8) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_16 = GAN_data[(GAN_data['Subject'] == 11) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_17 = GAN_data[(GAN_data['Subject'] == 11) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_18 = GAN_data[(GAN_data['Subject'] == 11) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    X_19 = GAN_data[(GAN_data['Subject'] == 14) & (GAN_data['Activity'] == 1)].iloc[:,:-2].values
    X_20 = GAN_data[(GAN_data['Subject'] == 14) & (GAN_data['Activity'] == 3)].iloc[:,:-2].values
    X_21 = GAN_data[(GAN_data['Subject'] == 14) & (GAN_data['Activity'] == 4)].iloc[:,:-2].values
    
    X_train = np.concatenate((X_1, X_2, X_3, X_4, X_5, X_6, X_7, X_8, X_9, X_10, X_11, X_12, X_13, X_14, X_15, X_16, X_17, X_18, X_19, X_20, X_21))
    y_train = [0] * len(X_1) + [1] * len(X_2) + [2] * len(X_3) + [3] * len(X_4) + [4] * len(X_5) + [5] * len(X_6) + [6] * len(X_7) + [7] * len(X_8) + [8] * len(X_9) + [9] * len(X_10) + [10] * len(X_11) + [11] * len(X_12) + [12] * len(X_13) + [13] * len(X_14) + [14] * len(X_15) + [15] * len(X_16) + [16] * len(X_17) + [17] * len(X_18) + [18] * len(X_19) + [19] * len(X_20) + [20] * len(X_21)
    
    return X_train, np.asarray(y_train)

In [6]:
activities = [1, 3, 4]
users = [1, 3, 5, 7, 8, 11, 14]

X, y = start_data(activities, users, "Activity", sub_features, act_features)

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, shuffle = True)

model = Classifier()
lr = 0.001
n_epochs = 5000
batch_size = 250

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

train_features = torch.tensor(X_train)
train_labels = torch.tensor(y_train)
test_features = torch.tensor(X_test)
test_labels = torch.tensor(y_test)

train_data = torch.utils.data.TensorDataset(train_features, train_labels)
test_data = torch.utils.data.TensorDataset(test_features, test_labels)

train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = len(test_labels), shuffle = True)

In [8]:
for epoch in range(n_epochs):
    total_loss = 0
    for batch in train_loader:
        features, labels = batch
        
        optimizer.zero_grad()
        preds = model(features.float())
        
        loss = criterion(preds, labels.long()) 
        loss.backward()
        
        optimizer.step()
        total_loss += loss.item()
        
    print(f'Epoch {epoch + 1}, Loss: {total_loss}, Final Batch Loss: {loss.item()}')

Epoch 1, Loss: 12.233332395553589, Final Batch Loss: 3.0618832111358643
Epoch 2, Loss: 12.213953971862793, Final Batch Loss: 3.0347023010253906
Epoch 3, Loss: 12.213301181793213, Final Batch Loss: 3.047750949859619
Epoch 4, Loss: 12.19971489906311, Final Batch Loss: 3.0395970344543457
Epoch 5, Loss: 12.189426183700562, Final Batch Loss: 3.0559213161468506
Epoch 6, Loss: 12.165900945663452, Final Batch Loss: 3.048717737197876
Epoch 7, Loss: 12.13656210899353, Final Batch Loss: 3.037327527999878
Epoch 8, Loss: 12.075984716415405, Final Batch Loss: 2.9997315406799316
Epoch 9, Loss: 12.004512310028076, Final Batch Loss: 2.9855034351348877
Epoch 10, Loss: 11.918658971786499, Final Batch Loss: 2.9668917655944824
Epoch 11, Loss: 11.808464050292969, Final Batch Loss: 2.948636054992676
Epoch 12, Loss: 11.671109437942505, Final Batch Loss: 2.9222023487091064
Epoch 13, Loss: 11.510560750961304, Final Batch Loss: 2.8589224815368652
Epoch 14, Loss: 11.307159900665283, Final Batch Loss: 2.8106188774

Epoch 118, Loss: 4.698841691017151, Final Batch Loss: 1.1599957942962646
Epoch 119, Loss: 4.677254557609558, Final Batch Loss: 1.112173080444336
Epoch 120, Loss: 4.6663841009140015, Final Batch Loss: 1.154649019241333
Epoch 121, Loss: 4.712216258049011, Final Batch Loss: 1.2512235641479492
Epoch 122, Loss: 4.568700432777405, Final Batch Loss: 1.09285569190979
Epoch 123, Loss: 4.531087875366211, Final Batch Loss: 1.0430294275283813
Epoch 124, Loss: 4.609476923942566, Final Batch Loss: 1.1612225770950317
Epoch 125, Loss: 4.297021567821503, Final Batch Loss: 0.9713675379753113
Epoch 126, Loss: 4.484549283981323, Final Batch Loss: 1.1157035827636719
Epoch 127, Loss: 4.626560211181641, Final Batch Loss: 1.2059450149536133
Epoch 128, Loss: 4.420087456703186, Final Batch Loss: 1.0262075662612915
Epoch 129, Loss: 4.590576887130737, Final Batch Loss: 1.086307406425476
Epoch 130, Loss: 4.46967077255249, Final Batch Loss: 1.009466528892517
Epoch 131, Loss: 4.458341002464294, Final Batch Loss: 1.1

Epoch 236, Loss: 3.5752620100975037, Final Batch Loss: 0.834503710269928
Epoch 237, Loss: 3.4958959221839905, Final Batch Loss: 0.8297495245933533
Epoch 238, Loss: 3.5615339279174805, Final Batch Loss: 0.9187162518501282
Epoch 239, Loss: 3.4417473673820496, Final Batch Loss: 0.7822977304458618
Epoch 240, Loss: 3.5136115550994873, Final Batch Loss: 0.9147074818611145
Epoch 241, Loss: 3.5589849948883057, Final Batch Loss: 0.87595134973526
Epoch 242, Loss: 3.582333207130432, Final Batch Loss: 0.9950692057609558
Epoch 243, Loss: 3.5375319123268127, Final Batch Loss: 0.926461398601532
Epoch 244, Loss: 3.4872779846191406, Final Batch Loss: 0.9068548679351807
Epoch 245, Loss: 3.4987675547599792, Final Batch Loss: 0.8405829668045044
Epoch 246, Loss: 3.425139904022217, Final Batch Loss: 0.7193805575370789
Epoch 247, Loss: 3.609040379524231, Final Batch Loss: 1.044350266456604
Epoch 248, Loss: 3.4547483921051025, Final Batch Loss: 0.8686450719833374
Epoch 249, Loss: 3.3499684929847717, Final Bat

Epoch 359, Loss: 2.912096679210663, Final Batch Loss: 0.6039761900901794
Epoch 360, Loss: 2.9687530994415283, Final Batch Loss: 0.8059918880462646
Epoch 361, Loss: 3.047074556350708, Final Batch Loss: 0.7167794704437256
Epoch 362, Loss: 2.916670560836792, Final Batch Loss: 0.6941214203834534
Epoch 363, Loss: 2.9280834197998047, Final Batch Loss: 0.8540814518928528
Epoch 364, Loss: 2.8668729662895203, Final Batch Loss: 0.7752936482429504
Epoch 365, Loss: 2.8994336128234863, Final Batch Loss: 0.7381933927536011
Epoch 366, Loss: 2.894223213195801, Final Batch Loss: 0.7058090567588806
Epoch 367, Loss: 2.9976886510849, Final Batch Loss: 0.8636389374732971
Epoch 368, Loss: 2.8940106630325317, Final Batch Loss: 0.6122897267341614
Epoch 369, Loss: 2.823961079120636, Final Batch Loss: 0.7077443599700928
Epoch 370, Loss: 2.779483735561371, Final Batch Loss: 0.633307695388794
Epoch 371, Loss: 2.952241361141205, Final Batch Loss: 0.7157540321350098
Epoch 372, Loss: 2.797391951084137, Final Batch L

Epoch 480, Loss: 2.604950726032257, Final Batch Loss: 0.682703971862793
Epoch 481, Loss: 2.528316915035248, Final Batch Loss: 0.747729480266571
Epoch 482, Loss: 2.6874207258224487, Final Batch Loss: 0.6960917115211487
Epoch 483, Loss: 2.5010933876037598, Final Batch Loss: 0.7156448364257812
Epoch 484, Loss: 2.51448518037796, Final Batch Loss: 0.5594602823257446
Epoch 485, Loss: 2.573279321193695, Final Batch Loss: 0.5329327583312988
Epoch 486, Loss: 2.396823674440384, Final Batch Loss: 0.4650556147098541
Epoch 487, Loss: 2.5239654779434204, Final Batch Loss: 0.5148017406463623
Epoch 488, Loss: 2.596328854560852, Final Batch Loss: 0.8449313640594482
Epoch 489, Loss: 2.5615434646606445, Final Batch Loss: 0.599432110786438
Epoch 490, Loss: 2.5551674365997314, Final Batch Loss: 0.7122249007225037
Epoch 491, Loss: 2.4480600357055664, Final Batch Loss: 0.5687416791915894
Epoch 492, Loss: 2.400842010974884, Final Batch Loss: 0.6103988289833069
Epoch 493, Loss: 2.3949485421180725, Final Batch 

Epoch 593, Loss: 2.3362648487091064, Final Batch Loss: 0.6479887366294861
Epoch 594, Loss: 2.2089000940322876, Final Batch Loss: 0.5803834795951843
Epoch 595, Loss: 2.2883410155773163, Final Batch Loss: 0.6052349209785461
Epoch 596, Loss: 2.3689993619918823, Final Batch Loss: 0.6247851848602295
Epoch 597, Loss: 2.1770834624767303, Final Batch Loss: 0.4610605537891388
Epoch 598, Loss: 2.237056791782379, Final Batch Loss: 0.532462477684021
Epoch 599, Loss: 2.2794530987739563, Final Batch Loss: 0.5516117811203003
Epoch 600, Loss: 2.2316310703754425, Final Batch Loss: 0.4930417239665985
Epoch 601, Loss: 2.2294393181800842, Final Batch Loss: 0.5573912262916565
Epoch 602, Loss: 2.330112099647522, Final Batch Loss: 0.6500775814056396
Epoch 603, Loss: 2.2140761017799377, Final Batch Loss: 0.5752606987953186
Epoch 604, Loss: 2.236133873462677, Final Batch Loss: 0.5846093893051147
Epoch 605, Loss: 2.317681133747101, Final Batch Loss: 0.6823776960372925
Epoch 606, Loss: 2.3837989270687103, Final 

Epoch 709, Loss: 2.1349990367889404, Final Batch Loss: 0.5629603862762451
Epoch 710, Loss: 2.089241921901703, Final Batch Loss: 0.5360115170478821
Epoch 711, Loss: 2.0413854718208313, Final Batch Loss: 0.4893907308578491
Epoch 712, Loss: 2.1089271903038025, Final Batch Loss: 0.42911219596862793
Epoch 713, Loss: 2.0622985661029816, Final Batch Loss: 0.49544963240623474
Epoch 714, Loss: 1.9340402781963348, Final Batch Loss: 0.49511653184890747
Epoch 715, Loss: 2.0824490785598755, Final Batch Loss: 0.628005862236023
Epoch 716, Loss: 2.017549604177475, Final Batch Loss: 0.4751424491405487
Epoch 717, Loss: 1.951129287481308, Final Batch Loss: 0.4926178753376007
Epoch 718, Loss: 2.019603908061981, Final Batch Loss: 0.4563901126384735
Epoch 719, Loss: 2.14053612947464, Final Batch Loss: 0.6069933772087097
Epoch 720, Loss: 2.146846652030945, Final Batch Loss: 0.5611493587493896
Epoch 721, Loss: 2.0314164757728577, Final Batch Loss: 0.43914562463760376
Epoch 722, Loss: 2.135922372341156, Final 

Epoch 829, Loss: 1.9563463628292084, Final Batch Loss: 0.4543167054653168
Epoch 830, Loss: 1.8073405027389526, Final Batch Loss: 0.40083256363868713
Epoch 831, Loss: 1.8923964500427246, Final Batch Loss: 0.42097803950309753
Epoch 832, Loss: 1.830840289592743, Final Batch Loss: 0.47074002027511597
Epoch 833, Loss: 1.8360697627067566, Final Batch Loss: 0.46356329321861267
Epoch 834, Loss: 1.9790136218070984, Final Batch Loss: 0.4506000578403473
Epoch 835, Loss: 1.9132872819900513, Final Batch Loss: 0.46302345395088196
Epoch 836, Loss: 1.957951843738556, Final Batch Loss: 0.42697903513908386
Epoch 837, Loss: 1.934652417898178, Final Batch Loss: 0.4619993567466736
Epoch 838, Loss: 2.022989183664322, Final Batch Loss: 0.6305065751075745
Epoch 839, Loss: 1.9012122750282288, Final Batch Loss: 0.5561623573303223
Epoch 840, Loss: 1.9716505408287048, Final Batch Loss: 0.5752500295639038
Epoch 841, Loss: 1.8620086908340454, Final Batch Loss: 0.4136582612991333
Epoch 842, Loss: 1.9435869455337524,

Epoch 939, Loss: 1.7091864049434662, Final Batch Loss: 0.44734933972358704
Epoch 940, Loss: 1.8075524866580963, Final Batch Loss: 0.4466835856437683
Epoch 941, Loss: 1.9134099185466766, Final Batch Loss: 0.5314103364944458
Epoch 942, Loss: 1.7138963639736176, Final Batch Loss: 0.45292210578918457
Epoch 943, Loss: 1.7971899509429932, Final Batch Loss: 0.49270927906036377
Epoch 944, Loss: 1.8190406560897827, Final Batch Loss: 0.3603323698043823
Epoch 945, Loss: 1.7624720633029938, Final Batch Loss: 0.4366745948791504
Epoch 946, Loss: 1.8043701648712158, Final Batch Loss: 0.4545115828514099
Epoch 947, Loss: 1.9932920932769775, Final Batch Loss: 0.5708891153335571
Epoch 948, Loss: 1.7026067674160004, Final Batch Loss: 0.3895443081855774
Epoch 949, Loss: 1.8155658543109894, Final Batch Loss: 0.3674766421318054
Epoch 950, Loss: 1.7306508123874664, Final Batch Loss: 0.5097714066505432
Epoch 951, Loss: 1.8594937026500702, Final Batch Loss: 0.45180198550224304
Epoch 952, Loss: 1.750029623508453

Epoch 1054, Loss: 1.631069928407669, Final Batch Loss: 0.37769171595573425
Epoch 1055, Loss: 1.6989712119102478, Final Batch Loss: 0.3936980366706848
Epoch 1056, Loss: 1.654195874929428, Final Batch Loss: 0.4031747877597809
Epoch 1057, Loss: 1.5774421393871307, Final Batch Loss: 0.3942711353302002
Epoch 1058, Loss: 1.6944229900836945, Final Batch Loss: 0.40940892696380615
Epoch 1059, Loss: 1.7330299317836761, Final Batch Loss: 0.4498690962791443
Epoch 1060, Loss: 1.5354251861572266, Final Batch Loss: 0.36089393496513367
Epoch 1061, Loss: 1.7667548060417175, Final Batch Loss: 0.4260311424732208
Epoch 1062, Loss: 1.6723062694072723, Final Batch Loss: 0.384891152381897
Epoch 1063, Loss: 1.7443055510520935, Final Batch Loss: 0.4568955898284912
Epoch 1064, Loss: 1.7091373205184937, Final Batch Loss: 0.5090287923812866
Epoch 1065, Loss: 1.7265967428684235, Final Batch Loss: 0.378781259059906
Epoch 1066, Loss: 1.6970050930976868, Final Batch Loss: 0.37631258368492126
Epoch 1067, Loss: 1.68193

Epoch 1163, Loss: 1.5790298283100128, Final Batch Loss: 0.41334107518196106
Epoch 1164, Loss: 1.5244792103767395, Final Batch Loss: 0.3078263998031616
Epoch 1165, Loss: 1.6363479495048523, Final Batch Loss: 0.28856998682022095
Epoch 1166, Loss: 1.6462036967277527, Final Batch Loss: 0.5127888917922974
Epoch 1167, Loss: 1.5766706466674805, Final Batch Loss: 0.36233553290367126
Epoch 1168, Loss: 1.566380649805069, Final Batch Loss: 0.3895244300365448
Epoch 1169, Loss: 1.5973795056343079, Final Batch Loss: 0.4182776212692261
Epoch 1170, Loss: 1.5747813284397125, Final Batch Loss: 0.33371037244796753
Epoch 1171, Loss: 1.4617098420858383, Final Batch Loss: 0.2499818354845047
Epoch 1172, Loss: 1.5906071662902832, Final Batch Loss: 0.3963920474052429
Epoch 1173, Loss: 1.4628554582595825, Final Batch Loss: 0.3986073136329651
Epoch 1174, Loss: 1.645782232284546, Final Batch Loss: 0.40276971459388733
Epoch 1175, Loss: 1.5535263419151306, Final Batch Loss: 0.4161882996559143
Epoch 1176, Loss: 1.51

Epoch 1278, Loss: 1.508245974779129, Final Batch Loss: 0.3825157582759857
Epoch 1279, Loss: 1.499863177537918, Final Batch Loss: 0.4503103494644165
Epoch 1280, Loss: 1.6022680699825287, Final Batch Loss: 0.4393494725227356
Epoch 1281, Loss: 1.5988104939460754, Final Batch Loss: 0.4998704195022583
Epoch 1282, Loss: 1.419845312833786, Final Batch Loss: 0.2525591552257538
Epoch 1283, Loss: 1.4665622115135193, Final Batch Loss: 0.36172500252723694
Epoch 1284, Loss: 1.4329343438148499, Final Batch Loss: 0.3431859016418457
Epoch 1285, Loss: 1.46715047955513, Final Batch Loss: 0.33599844574928284
Epoch 1286, Loss: 1.452584445476532, Final Batch Loss: 0.3860221803188324
Epoch 1287, Loss: 1.419712781906128, Final Batch Loss: 0.3601854145526886
Epoch 1288, Loss: 1.5402754545211792, Final Batch Loss: 0.4438112676143646
Epoch 1289, Loss: 1.4384129047393799, Final Batch Loss: 0.3371511995792389
Epoch 1290, Loss: 1.649865835905075, Final Batch Loss: 0.3537713289260864
Epoch 1291, Loss: 1.47658896446

Epoch 1393, Loss: 1.2579872906208038, Final Batch Loss: 0.22807058691978455
Epoch 1394, Loss: 1.5054311752319336, Final Batch Loss: 0.3511565327644348
Epoch 1395, Loss: 1.3808437585830688, Final Batch Loss: 0.3548218309879303
Epoch 1396, Loss: 1.3963981568813324, Final Batch Loss: 0.3151536285877228
Epoch 1397, Loss: 1.4271415770053864, Final Batch Loss: 0.4462703466415405
Epoch 1398, Loss: 1.3550719618797302, Final Batch Loss: 0.32501667737960815
Epoch 1399, Loss: 1.375457912683487, Final Batch Loss: 0.30814090371131897
Epoch 1400, Loss: 1.4177883565425873, Final Batch Loss: 0.4195784628391266
Epoch 1401, Loss: 1.3363745510578156, Final Batch Loss: 0.3733515739440918
Epoch 1402, Loss: 1.4081531167030334, Final Batch Loss: 0.3626898527145386
Epoch 1403, Loss: 1.2037738263607025, Final Batch Loss: 0.2673141360282898
Epoch 1404, Loss: 1.4174632728099823, Final Batch Loss: 0.29176247119903564
Epoch 1405, Loss: 1.5487779676914215, Final Batch Loss: 0.4239804148674011
Epoch 1406, Loss: 1.39

Epoch 1509, Loss: 1.2983070015907288, Final Batch Loss: 0.4194644093513489
Epoch 1510, Loss: 1.3332290947437286, Final Batch Loss: 0.34394219517707825
Epoch 1511, Loss: 1.3262234032154083, Final Batch Loss: 0.2826007008552551
Epoch 1512, Loss: 1.3966149389743805, Final Batch Loss: 0.4560718834400177
Epoch 1513, Loss: 1.3668961226940155, Final Batch Loss: 0.3781610429286957
Epoch 1514, Loss: 1.2158215045928955, Final Batch Loss: 0.2698648273944855
Epoch 1515, Loss: 1.3535643815994263, Final Batch Loss: 0.2899298369884491
Epoch 1516, Loss: 1.2856312096118927, Final Batch Loss: 0.29099422693252563
Epoch 1517, Loss: 1.2907030880451202, Final Batch Loss: 0.3585217595100403
Epoch 1518, Loss: 1.279444932937622, Final Batch Loss: 0.2703173756599426
Epoch 1519, Loss: 1.279944658279419, Final Batch Loss: 0.3528183400630951
Epoch 1520, Loss: 1.5512376725673676, Final Batch Loss: 0.5709919929504395
Epoch 1521, Loss: 1.3553167879581451, Final Batch Loss: 0.3050765097141266
Epoch 1522, Loss: 1.30868

Epoch 1618, Loss: 1.276725858449936, Final Batch Loss: 0.2455584704875946
Epoch 1619, Loss: 1.3467244803905487, Final Batch Loss: 0.33981382846832275
Epoch 1620, Loss: 1.3494274318218231, Final Batch Loss: 0.2950887978076935
Epoch 1621, Loss: 1.3734923601150513, Final Batch Loss: 0.2844175398349762
Epoch 1622, Loss: 1.2469312846660614, Final Batch Loss: 0.2512040138244629
Epoch 1623, Loss: 1.260559618473053, Final Batch Loss: 0.2840365171432495
Epoch 1624, Loss: 1.33988156914711, Final Batch Loss: 0.36100447177886963
Epoch 1625, Loss: 1.1631300300359726, Final Batch Loss: 0.3296755850315094
Epoch 1626, Loss: 1.2147355675697327, Final Batch Loss: 0.3015711307525635
Epoch 1627, Loss: 1.1688214838504791, Final Batch Loss: 0.2721986174583435
Epoch 1628, Loss: 1.1833476722240448, Final Batch Loss: 0.2931167483329773
Epoch 1629, Loss: 1.357517033815384, Final Batch Loss: 0.4023972749710083
Epoch 1630, Loss: 1.2605552673339844, Final Batch Loss: 0.3401653468608856
Epoch 1631, Loss: 1.26894862

Epoch 1730, Loss: 1.21982941031456, Final Batch Loss: 0.31324657797813416
Epoch 1731, Loss: 1.1622167229652405, Final Batch Loss: 0.25214719772338867
Epoch 1732, Loss: 1.206113576889038, Final Batch Loss: 0.29707059264183044
Epoch 1733, Loss: 1.1315299570560455, Final Batch Loss: 0.22443026304244995
Epoch 1734, Loss: 1.1912576407194138, Final Batch Loss: 0.2672812342643738
Epoch 1735, Loss: 1.2040081322193146, Final Batch Loss: 0.25254884362220764
Epoch 1736, Loss: 1.2466617822647095, Final Batch Loss: 0.298519104719162
Epoch 1737, Loss: 1.2413937747478485, Final Batch Loss: 0.2926304042339325
Epoch 1738, Loss: 1.1836809515953064, Final Batch Loss: 0.3449168801307678
Epoch 1739, Loss: 1.2419263422489166, Final Batch Loss: 0.3994089961051941
Epoch 1740, Loss: 1.249837338924408, Final Batch Loss: 0.4224082827568054
Epoch 1741, Loss: 1.0774898529052734, Final Batch Loss: 0.19332224130630493
Epoch 1742, Loss: 1.1775627434253693, Final Batch Loss: 0.1907193958759308
Epoch 1743, Loss: 1.2681

Epoch 1847, Loss: 1.171823799610138, Final Batch Loss: 0.3658939599990845
Epoch 1848, Loss: 1.1947399079799652, Final Batch Loss: 0.2983073890209198
Epoch 1849, Loss: 1.2211551666259766, Final Batch Loss: 0.3222426772117615
Epoch 1850, Loss: 1.1979619562625885, Final Batch Loss: 0.3098447024822235
Epoch 1851, Loss: 1.1901240646839142, Final Batch Loss: 0.29101648926734924
Epoch 1852, Loss: 1.190134733915329, Final Batch Loss: 0.3245391547679901
Epoch 1853, Loss: 1.2843464016914368, Final Batch Loss: 0.42695602774620056
Epoch 1854, Loss: 1.133267879486084, Final Batch Loss: 0.26344218850135803
Epoch 1855, Loss: 1.1688359379768372, Final Batch Loss: 0.33006373047828674
Epoch 1856, Loss: 1.207933485507965, Final Batch Loss: 0.32548975944519043
Epoch 1857, Loss: 1.080199882388115, Final Batch Loss: 0.2325538545846939
Epoch 1858, Loss: 1.097073271870613, Final Batch Loss: 0.29440414905548096
Epoch 1859, Loss: 1.1782735586166382, Final Batch Loss: 0.26429012417793274
Epoch 1860, Loss: 1.1993

Epoch 1965, Loss: 1.1935808658599854, Final Batch Loss: 0.3457079529762268
Epoch 1966, Loss: 1.2991037666797638, Final Batch Loss: 0.3302418887615204
Epoch 1967, Loss: 1.306849867105484, Final Batch Loss: 0.3722410798072815
Epoch 1968, Loss: 1.1372565031051636, Final Batch Loss: 0.2405708134174347
Epoch 1969, Loss: 1.191614419221878, Final Batch Loss: 0.3034266233444214
Epoch 1970, Loss: 1.1032797545194626, Final Batch Loss: 0.31303372979164124
Epoch 1971, Loss: 1.1939961165189743, Final Batch Loss: 0.2223469763994217
Epoch 1972, Loss: 1.2051593661308289, Final Batch Loss: 0.33999332785606384
Epoch 1973, Loss: 1.1986853629350662, Final Batch Loss: 0.18210159242153168
Epoch 1974, Loss: 1.189076989889145, Final Batch Loss: 0.3333881199359894
Epoch 1975, Loss: 1.169198453426361, Final Batch Loss: 0.4005577862262726
Epoch 1976, Loss: 1.1886871457099915, Final Batch Loss: 0.34502413868904114
Epoch 1977, Loss: 1.086362898349762, Final Batch Loss: 0.27396097779273987
Epoch 1978, Loss: 1.05007

Epoch 2077, Loss: 1.1064829677343369, Final Batch Loss: 0.37560218572616577
Epoch 2078, Loss: 1.1302912384271622, Final Batch Loss: 0.3802802860736847
Epoch 2079, Loss: 1.0999396443367004, Final Batch Loss: 0.2293320596218109
Epoch 2080, Loss: 1.1192714720964432, Final Batch Loss: 0.31315916776657104
Epoch 2081, Loss: 0.9955030679702759, Final Batch Loss: 0.17505046725273132
Epoch 2082, Loss: 1.1776880323886871, Final Batch Loss: 0.4107125401496887
Epoch 2083, Loss: 1.0054108202457428, Final Batch Loss: 0.3070022761821747
Epoch 2084, Loss: 1.1563032269477844, Final Batch Loss: 0.3196943700313568
Epoch 2085, Loss: 1.1499724984169006, Final Batch Loss: 0.3370051980018616
Epoch 2086, Loss: 1.1664460003376007, Final Batch Loss: 0.2367645800113678
Epoch 2087, Loss: 1.069944128394127, Final Batch Loss: 0.26610639691352844
Epoch 2088, Loss: 0.9661027193069458, Final Batch Loss: 0.17123812437057495
Epoch 2089, Loss: 1.083542361855507, Final Batch Loss: 0.27809762954711914
Epoch 2090, Loss: 0.9

Epoch 2192, Loss: 0.9810749739408493, Final Batch Loss: 0.22734931111335754
Epoch 2193, Loss: 1.2452434301376343, Final Batch Loss: 0.28788653016090393
Epoch 2194, Loss: 1.1435818821191788, Final Batch Loss: 0.353736937046051
Epoch 2195, Loss: 0.9977386146783829, Final Batch Loss: 0.23077447712421417
Epoch 2196, Loss: 1.0390779674053192, Final Batch Loss: 0.2860211431980133
Epoch 2197, Loss: 1.0800763219594955, Final Batch Loss: 0.24847327172756195
Epoch 2198, Loss: 1.080582395195961, Final Batch Loss: 0.1903351992368698
Epoch 2199, Loss: 1.0377936959266663, Final Batch Loss: 0.26209187507629395
Epoch 2200, Loss: 1.1033525466918945, Final Batch Loss: 0.20766830444335938
Epoch 2201, Loss: 1.246446281671524, Final Batch Loss: 0.36092761158943176
Epoch 2202, Loss: 1.0645992159843445, Final Batch Loss: 0.22056931257247925
Epoch 2203, Loss: 1.0433247834444046, Final Batch Loss: 0.25373589992523193
Epoch 2204, Loss: 0.9289336055517197, Final Batch Loss: 0.2097027450799942
Epoch 2205, Loss: 1

Epoch 2305, Loss: 1.0804753452539444, Final Batch Loss: 0.22718383371829987
Epoch 2306, Loss: 1.1249172687530518, Final Batch Loss: 0.2672739028930664
Epoch 2307, Loss: 1.0572274923324585, Final Batch Loss: 0.28759488463401794
Epoch 2308, Loss: 1.110117107629776, Final Batch Loss: 0.3221757113933563
Epoch 2309, Loss: 1.033198818564415, Final Batch Loss: 0.255679190158844
Epoch 2310, Loss: 1.1401866525411606, Final Batch Loss: 0.35843101143836975
Epoch 2311, Loss: 0.9821776896715164, Final Batch Loss: 0.2224532663822174
Epoch 2312, Loss: 0.9477530270814896, Final Batch Loss: 0.20894832909107208
Epoch 2313, Loss: 1.053358256816864, Final Batch Loss: 0.2577294707298279
Epoch 2314, Loss: 0.9691120386123657, Final Batch Loss: 0.22571566700935364
Epoch 2315, Loss: 1.2404529750347137, Final Batch Loss: 0.4351038932800293
Epoch 2316, Loss: 1.1009010076522827, Final Batch Loss: 0.2624824643135071
Epoch 2317, Loss: 1.0754258632659912, Final Batch Loss: 0.29650208353996277
Epoch 2318, Loss: 1.103

Epoch 2414, Loss: 1.156296581029892, Final Batch Loss: 0.37432584166526794
Epoch 2415, Loss: 1.0664803832769394, Final Batch Loss: 0.2837996780872345
Epoch 2416, Loss: 1.17837955057621, Final Batch Loss: 0.34393179416656494
Epoch 2417, Loss: 1.0686412751674652, Final Batch Loss: 0.2823784649372101
Epoch 2418, Loss: 1.021787017583847, Final Batch Loss: 0.31066012382507324
Epoch 2419, Loss: 1.0962984263896942, Final Batch Loss: 0.262249231338501
Epoch 2420, Loss: 0.9342750310897827, Final Batch Loss: 0.2458827644586563
Epoch 2421, Loss: 0.956953689455986, Final Batch Loss: 0.2395033836364746
Epoch 2422, Loss: 1.0084394663572311, Final Batch Loss: 0.22182568907737732
Epoch 2423, Loss: 0.9479959905147552, Final Batch Loss: 0.21979868412017822
Epoch 2424, Loss: 1.0536170601844788, Final Batch Loss: 0.3506234288215637
Epoch 2425, Loss: 0.9806940853595734, Final Batch Loss: 0.24929489195346832
Epoch 2426, Loss: 0.9994959980249405, Final Batch Loss: 0.2652088701725006
Epoch 2427, Loss: 1.01417

Epoch 2532, Loss: 1.0031526535749435, Final Batch Loss: 0.1851927936077118
Epoch 2533, Loss: 1.0283146649599075, Final Batch Loss: 0.27462175488471985
Epoch 2534, Loss: 1.0098763406276703, Final Batch Loss: 0.26635095477104187
Epoch 2535, Loss: 0.9941925704479218, Final Batch Loss: 0.24454206228256226
Epoch 2536, Loss: 1.0714965909719467, Final Batch Loss: 0.2556731700897217
Epoch 2537, Loss: 0.8838133066892624, Final Batch Loss: 0.20082035660743713
Epoch 2538, Loss: 1.0350473821163177, Final Batch Loss: 0.24604225158691406
Epoch 2539, Loss: 1.0228653401136398, Final Batch Loss: 0.34235629439353943
Epoch 2540, Loss: 0.9825363308191299, Final Batch Loss: 0.2284199297428131
Epoch 2541, Loss: 0.905736654996872, Final Batch Loss: 0.14771920442581177
Epoch 2542, Loss: 0.9387829750776291, Final Batch Loss: 0.21825754642486572
Epoch 2543, Loss: 0.9207461029291153, Final Batch Loss: 0.17371848225593567
Epoch 2544, Loss: 1.1113253235816956, Final Batch Loss: 0.33201003074645996
Epoch 2545, Loss

Epoch 2641, Loss: 1.0205335468053818, Final Batch Loss: 0.2564835548400879
Epoch 2642, Loss: 0.8474623262882233, Final Batch Loss: 0.1675814837217331
Epoch 2643, Loss: 0.8338806480169296, Final Batch Loss: 0.2061227411031723
Epoch 2644, Loss: 1.0179252475500107, Final Batch Loss: 0.28315839171409607
Epoch 2645, Loss: 1.0369476228952408, Final Batch Loss: 0.3036796450614929
Epoch 2646, Loss: 1.010381504893303, Final Batch Loss: 0.30329960584640503
Epoch 2647, Loss: 0.8482538014650345, Final Batch Loss: 0.15477143228054047
Epoch 2648, Loss: 1.0181952714920044, Final Batch Loss: 0.27454930543899536
Epoch 2649, Loss: 1.0851696729660034, Final Batch Loss: 0.30978432297706604
Epoch 2650, Loss: 0.9032420963048935, Final Batch Loss: 0.17990007996559143
Epoch 2651, Loss: 0.9335588067770004, Final Batch Loss: 0.1921030730009079
Epoch 2652, Loss: 0.9385126531124115, Final Batch Loss: 0.23217949271202087
Epoch 2653, Loss: 1.00566066801548, Final Batch Loss: 0.1999441236257553
Epoch 2654, Loss: 1.0

Epoch 2752, Loss: 0.961842492222786, Final Batch Loss: 0.16703610122203827
Epoch 2753, Loss: 1.000541478395462, Final Batch Loss: 0.26771292090415955
Epoch 2754, Loss: 0.9270349442958832, Final Batch Loss: 0.19766438007354736
Epoch 2755, Loss: 0.9048165529966354, Final Batch Loss: 0.16552163660526276
Epoch 2756, Loss: 0.9195852428674698, Final Batch Loss: 0.2428138107061386
Epoch 2757, Loss: 1.0203864127397537, Final Batch Loss: 0.23665013909339905
Epoch 2758, Loss: 0.9676743894815445, Final Batch Loss: 0.190299391746521
Epoch 2759, Loss: 0.8400699570775032, Final Batch Loss: 0.12251298874616623
Epoch 2760, Loss: 0.9682050943374634, Final Batch Loss: 0.27545586228370667
Epoch 2761, Loss: 1.0314185619354248, Final Batch Loss: 0.2936912178993225
Epoch 2762, Loss: 1.0172606706619263, Final Batch Loss: 0.26151975989341736
Epoch 2763, Loss: 0.9084202796220779, Final Batch Loss: 0.203753262758255
Epoch 2764, Loss: 1.0036988705396652, Final Batch Loss: 0.27402934432029724
Epoch 2765, Loss: 0.

Epoch 2865, Loss: 0.9055959284305573, Final Batch Loss: 0.13343991339206696
Epoch 2866, Loss: 0.9931622743606567, Final Batch Loss: 0.3164735734462738
Epoch 2867, Loss: 0.8983224630355835, Final Batch Loss: 0.18775013089179993
Epoch 2868, Loss: 0.9568829834461212, Final Batch Loss: 0.2752327620983124
Epoch 2869, Loss: 0.9392874836921692, Final Batch Loss: 0.22745919227600098
Epoch 2870, Loss: 0.9801304489374161, Final Batch Loss: 0.23960807919502258
Epoch 2871, Loss: 1.0058914870023727, Final Batch Loss: 0.2669244110584259
Epoch 2872, Loss: 0.8935917913913727, Final Batch Loss: 0.22078463435173035
Epoch 2873, Loss: 0.9920122921466827, Final Batch Loss: 0.24189019203186035
Epoch 2874, Loss: 0.8814467191696167, Final Batch Loss: 0.1956697255373001
Epoch 2875, Loss: 0.9443814009428024, Final Batch Loss: 0.19495034217834473
Epoch 2876, Loss: 0.9031165540218353, Final Batch Loss: 0.23237143456935883
Epoch 2877, Loss: 0.8768322020769119, Final Batch Loss: 0.21620458364486694
Epoch 2878, Loss

Epoch 2975, Loss: 0.8537309169769287, Final Batch Loss: 0.2516128420829773
Epoch 2976, Loss: 0.8748435527086258, Final Batch Loss: 0.238599494099617
Epoch 2977, Loss: 1.0264575332403183, Final Batch Loss: 0.27620601654052734
Epoch 2978, Loss: 0.973309263586998, Final Batch Loss: 0.3271159529685974
Epoch 2979, Loss: 0.9900061637163162, Final Batch Loss: 0.23348796367645264
Epoch 2980, Loss: 0.9210378080606461, Final Batch Loss: 0.23311066627502441
Epoch 2981, Loss: 0.9328132271766663, Final Batch Loss: 0.26626476645469666
Epoch 2982, Loss: 0.9811854958534241, Final Batch Loss: 0.19253148138523102
Epoch 2983, Loss: 1.0487482398748398, Final Batch Loss: 0.3233921527862549
Epoch 2984, Loss: 1.0270491242408752, Final Batch Loss: 0.2738848924636841
Epoch 2985, Loss: 1.002372071146965, Final Batch Loss: 0.2866996228694916
Epoch 2986, Loss: 0.8944871872663498, Final Batch Loss: 0.13861674070358276
Epoch 2987, Loss: 0.9007477164268494, Final Batch Loss: 0.2129938155412674
Epoch 2988, Loss: 1.00

Epoch 3089, Loss: 0.9382287263870239, Final Batch Loss: 0.3362310528755188
Epoch 3090, Loss: 0.7770169526338577, Final Batch Loss: 0.17837920784950256
Epoch 3091, Loss: 0.809233158826828, Final Batch Loss: 0.15714094042778015
Epoch 3092, Loss: 0.9218316674232483, Final Batch Loss: 0.1927669793367386
Epoch 3093, Loss: 0.8360482454299927, Final Batch Loss: 0.22847945988178253
Epoch 3094, Loss: 0.9612899869680405, Final Batch Loss: 0.29611143469810486
Epoch 3095, Loss: 1.0101988315582275, Final Batch Loss: 0.2645364999771118
Epoch 3096, Loss: 0.9515356123447418, Final Batch Loss: 0.23127329349517822
Epoch 3097, Loss: 1.0120864659547806, Final Batch Loss: 0.26980745792388916
Epoch 3098, Loss: 0.8713834136724472, Final Batch Loss: 0.2458084374666214
Epoch 3099, Loss: 0.9954452514648438, Final Batch Loss: 0.23088958859443665
Epoch 3100, Loss: 0.8778524100780487, Final Batch Loss: 0.24722544848918915
Epoch 3101, Loss: 0.8648380041122437, Final Batch Loss: 0.2507557272911072
Epoch 3102, Loss: 

Epoch 3199, Loss: 0.8714417666196823, Final Batch Loss: 0.25560060143470764
Epoch 3200, Loss: 0.8787247389554977, Final Batch Loss: 0.2882094383239746
Epoch 3201, Loss: 0.942981019616127, Final Batch Loss: 0.168968066573143
Epoch 3202, Loss: 0.8696235418319702, Final Batch Loss: 0.18061432242393494
Epoch 3203, Loss: 0.9039217382669449, Final Batch Loss: 0.21739287674427032
Epoch 3204, Loss: 0.9971814155578613, Final Batch Loss: 0.3782045245170593
Epoch 3205, Loss: 0.9225217401981354, Final Batch Loss: 0.20449933409690857
Epoch 3206, Loss: 0.8790398240089417, Final Batch Loss: 0.23388099670410156
Epoch 3207, Loss: 0.8566240817308426, Final Batch Loss: 0.17363256216049194
Epoch 3208, Loss: 0.9428450465202332, Final Batch Loss: 0.1973314881324768
Epoch 3209, Loss: 1.0266814082860947, Final Batch Loss: 0.3044261634349823
Epoch 3210, Loss: 0.9105624705553055, Final Batch Loss: 0.21598397195339203
Epoch 3211, Loss: 0.9268667995929718, Final Batch Loss: 0.23139216005802155
Epoch 3212, Loss: 0

Epoch 3312, Loss: 0.8827892243862152, Final Batch Loss: 0.2221081405878067
Epoch 3313, Loss: 0.9325339049100876, Final Batch Loss: 0.1819418966770172
Epoch 3314, Loss: 0.9762844145298004, Final Batch Loss: 0.21587298810482025
Epoch 3315, Loss: 0.9602854996919632, Final Batch Loss: 0.2558448910713196
Epoch 3316, Loss: 0.8843145370483398, Final Batch Loss: 0.19010475277900696
Epoch 3317, Loss: 0.8142568022012711, Final Batch Loss: 0.1871432363986969
Epoch 3318, Loss: 0.7717925012111664, Final Batch Loss: 0.16891326010227203
Epoch 3319, Loss: 0.9032381772994995, Final Batch Loss: 0.2528320252895355
Epoch 3320, Loss: 0.8465839326381683, Final Batch Loss: 0.22838804125785828
Epoch 3321, Loss: 0.7981197237968445, Final Batch Loss: 0.23389855027198792
Epoch 3322, Loss: 0.7895005494356155, Final Batch Loss: 0.17471390962600708
Epoch 3323, Loss: 1.0131622403860092, Final Batch Loss: 0.3693508505821228
Epoch 3324, Loss: 0.8779487907886505, Final Batch Loss: 0.25418350100517273
Epoch 3325, Loss: 

Epoch 3430, Loss: 0.8675270080566406, Final Batch Loss: 0.26067110896110535
Epoch 3431, Loss: 0.9047722518444061, Final Batch Loss: 0.2828562557697296
Epoch 3432, Loss: 0.9588387757539749, Final Batch Loss: 0.3023928701877594
Epoch 3433, Loss: 1.0864993631839752, Final Batch Loss: 0.26413851976394653
Epoch 3434, Loss: 0.8914002776145935, Final Batch Loss: 0.16598904132843018
Epoch 3435, Loss: 0.7932939380407333, Final Batch Loss: 0.1920429766178131
Epoch 3436, Loss: 0.8551091998815536, Final Batch Loss: 0.1944247931241989
Epoch 3437, Loss: 0.8274282515048981, Final Batch Loss: 0.16757741570472717
Epoch 3438, Loss: 0.8996100127696991, Final Batch Loss: 0.26577192544937134
Epoch 3439, Loss: 0.9214640408754349, Final Batch Loss: 0.19529563188552856
Epoch 3440, Loss: 0.9100984185934067, Final Batch Loss: 0.3106110990047455
Epoch 3441, Loss: 0.9264786541461945, Final Batch Loss: 0.2781352400779724
Epoch 3442, Loss: 0.7644638419151306, Final Batch Loss: 0.19262102246284485
Epoch 3443, Loss: 

Epoch 3544, Loss: 0.8633039146661758, Final Batch Loss: 0.24733562767505646
Epoch 3545, Loss: 0.8593568354845047, Final Batch Loss: 0.18108543753623962
Epoch 3546, Loss: 0.9682154655456543, Final Batch Loss: 0.27566036581993103
Epoch 3547, Loss: 0.8655034601688385, Final Batch Loss: 0.21717746555805206
Epoch 3548, Loss: 0.8834343701601028, Final Batch Loss: 0.17850351333618164
Epoch 3549, Loss: 0.8677487671375275, Final Batch Loss: 0.21774698793888092
Epoch 3550, Loss: 0.9615660607814789, Final Batch Loss: 0.2755320966243744
Epoch 3551, Loss: 0.8708124607801437, Final Batch Loss: 0.16007782518863678
Epoch 3552, Loss: 0.8703029751777649, Final Batch Loss: 0.2574625015258789
Epoch 3553, Loss: 0.8232491314411163, Final Batch Loss: 0.20101334154605865
Epoch 3554, Loss: 0.856653556227684, Final Batch Loss: 0.196630597114563
Epoch 3555, Loss: 0.8243339657783508, Final Batch Loss: 0.16729506850242615
Epoch 3556, Loss: 0.8336522728204727, Final Batch Loss: 0.21727077662944794
Epoch 3557, Loss:

Epoch 3658, Loss: 0.9523391127586365, Final Batch Loss: 0.3515715003013611
Epoch 3659, Loss: 0.8558930307626724, Final Batch Loss: 0.19012145698070526
Epoch 3660, Loss: 0.8438646644353867, Final Batch Loss: 0.206222802400589
Epoch 3661, Loss: 0.8588914275169373, Final Batch Loss: 0.24512025713920593
Epoch 3662, Loss: 0.8238535523414612, Final Batch Loss: 0.2354126274585724
Epoch 3663, Loss: 0.8427722007036209, Final Batch Loss: 0.17165647447109222
Epoch 3664, Loss: 0.8225376904010773, Final Batch Loss: 0.164991095662117
Epoch 3665, Loss: 0.8586521446704865, Final Batch Loss: 0.2174040675163269
Epoch 3666, Loss: 0.9005732089281082, Final Batch Loss: 0.2654038965702057
Epoch 3667, Loss: 0.873676672577858, Final Batch Loss: 0.21757477521896362
Epoch 3668, Loss: 0.8395347744226456, Final Batch Loss: 0.1447208821773529
Epoch 3669, Loss: 0.8800843507051468, Final Batch Loss: 0.18139863014221191
Epoch 3670, Loss: 0.8256434202194214, Final Batch Loss: 0.1794939935207367
Epoch 3671, Loss: 0.879

Epoch 3770, Loss: 0.7859028279781342, Final Batch Loss: 0.16339178383350372
Epoch 3771, Loss: 0.6907720416784286, Final Batch Loss: 0.15424570441246033
Epoch 3772, Loss: 0.821938082575798, Final Batch Loss: 0.1154937744140625
Epoch 3773, Loss: 0.7599291652441025, Final Batch Loss: 0.1781480461359024
Epoch 3774, Loss: 0.8677796870470047, Final Batch Loss: 0.264166921377182
Epoch 3775, Loss: 0.9158768206834793, Final Batch Loss: 0.24001401662826538
Epoch 3776, Loss: 0.9259315878152847, Final Batch Loss: 0.25520116090774536
Epoch 3777, Loss: 0.8919269740581512, Final Batch Loss: 0.2294810712337494
Epoch 3778, Loss: 0.8445201218128204, Final Batch Loss: 0.20746523141860962
Epoch 3779, Loss: 1.0116876661777496, Final Batch Loss: 0.3115551769733429
Epoch 3780, Loss: 0.9221636801958084, Final Batch Loss: 0.2060009390115738
Epoch 3781, Loss: 0.8489518761634827, Final Batch Loss: 0.2960943281650543
Epoch 3782, Loss: 0.796099528670311, Final Batch Loss: 0.19745390117168427
Epoch 3783, Loss: 0.82

Epoch 3881, Loss: 0.896800309419632, Final Batch Loss: 0.2657497525215149
Epoch 3882, Loss: 0.8567282557487488, Final Batch Loss: 0.12625610828399658
Epoch 3883, Loss: 0.8357811272144318, Final Batch Loss: 0.30096718668937683
Epoch 3884, Loss: 0.844030112028122, Final Batch Loss: 0.3409058153629303
Epoch 3885, Loss: 0.8543970584869385, Final Batch Loss: 0.24715454876422882
Epoch 3886, Loss: 1.0335042923688889, Final Batch Loss: 0.2345544546842575
Epoch 3887, Loss: 0.7742043435573578, Final Batch Loss: 0.14979639649391174
Epoch 3888, Loss: 0.8654814809560776, Final Batch Loss: 0.29246222972869873
Epoch 3889, Loss: 0.8387381881475449, Final Batch Loss: 0.22228385508060455
Epoch 3890, Loss: 0.7735909521579742, Final Batch Loss: 0.17870967090129852
Epoch 3891, Loss: 0.8218897581100464, Final Batch Loss: 0.18885496258735657
Epoch 3892, Loss: 0.9299386143684387, Final Batch Loss: 0.24468664824962616
Epoch 3893, Loss: 0.7747306227684021, Final Batch Loss: 0.15565316379070282
Epoch 3894, Loss:

Epoch 3994, Loss: 0.7598345577716827, Final Batch Loss: 0.18983183801174164
Epoch 3995, Loss: 0.7846847325563431, Final Batch Loss: 0.2330498993396759
Epoch 3996, Loss: 0.8780447393655777, Final Batch Loss: 0.19049760699272156
Epoch 3997, Loss: 0.8731092065572739, Final Batch Loss: 0.2882844805717468
Epoch 3998, Loss: 0.8772238343954086, Final Batch Loss: 0.21754583716392517
Epoch 3999, Loss: 0.8453424572944641, Final Batch Loss: 0.23597724735736847
Epoch 4000, Loss: 0.7159246206283569, Final Batch Loss: 0.14148490130901337
Epoch 4001, Loss: 0.7678609490394592, Final Batch Loss: 0.1484544426202774
Epoch 4002, Loss: 0.8928678631782532, Final Batch Loss: 0.2873627841472626
Epoch 4003, Loss: 0.7623073905706406, Final Batch Loss: 0.1850063055753708
Epoch 4004, Loss: 0.8425126373767853, Final Batch Loss: 0.20047307014465332
Epoch 4005, Loss: 0.9514031708240509, Final Batch Loss: 0.21586570143699646
Epoch 4006, Loss: 0.8864817768335342, Final Batch Loss: 0.30153346061706543
Epoch 4007, Loss:

Epoch 4109, Loss: 0.8765981644392014, Final Batch Loss: 0.24731014668941498
Epoch 4110, Loss: 0.844117060303688, Final Batch Loss: 0.20585288107395172
Epoch 4111, Loss: 0.8191174417734146, Final Batch Loss: 0.1696373075246811
Epoch 4112, Loss: 0.8198423683643341, Final Batch Loss: 0.19764205813407898
Epoch 4113, Loss: 0.8531225472688675, Final Batch Loss: 0.20862430334091187
Epoch 4114, Loss: 0.8185591101646423, Final Batch Loss: 0.1694781333208084
Epoch 4115, Loss: 0.86611607670784, Final Batch Loss: 0.15750077366828918
Epoch 4116, Loss: 0.7410265356302261, Final Batch Loss: 0.1484515219926834
Epoch 4117, Loss: 0.8413120657205582, Final Batch Loss: 0.1906139850616455
Epoch 4118, Loss: 0.7674901485443115, Final Batch Loss: 0.13664644956588745
Epoch 4119, Loss: 0.8567418605089188, Final Batch Loss: 0.1911351978778839
Epoch 4120, Loss: 0.9184966534376144, Final Batch Loss: 0.22734509408473969
Epoch 4121, Loss: 0.7968709915876389, Final Batch Loss: 0.21732200682163239
Epoch 4122, Loss: 0.

Epoch 4219, Loss: 0.8275857269763947, Final Batch Loss: 0.2031540870666504
Epoch 4220, Loss: 0.7345279008150101, Final Batch Loss: 0.16439147293567657
Epoch 4221, Loss: 0.9087780117988586, Final Batch Loss: 0.2840431332588196
Epoch 4222, Loss: 0.9451910853385925, Final Batch Loss: 0.2385772466659546
Epoch 4223, Loss: 0.8692936897277832, Final Batch Loss: 0.2292487472295761
Epoch 4224, Loss: 0.9754549413919449, Final Batch Loss: 0.2671952545642853
Epoch 4225, Loss: 0.883508488535881, Final Batch Loss: 0.2856799364089966
Epoch 4226, Loss: 0.7509662210941315, Final Batch Loss: 0.23092901706695557
Epoch 4227, Loss: 0.8086834847927094, Final Batch Loss: 0.222713902592659
Epoch 4228, Loss: 0.8999847918748856, Final Batch Loss: 0.29463887214660645
Epoch 4229, Loss: 0.8649639338254929, Final Batch Loss: 0.2606908679008484
Epoch 4230, Loss: 0.8056792169809341, Final Batch Loss: 0.18238092958927155
Epoch 4231, Loss: 0.8632761240005493, Final Batch Loss: 0.20735090970993042
Epoch 4232, Loss: 0.84

Epoch 4334, Loss: 0.813982829451561, Final Batch Loss: 0.2547842264175415
Epoch 4335, Loss: 0.8067512959241867, Final Batch Loss: 0.1724250316619873
Epoch 4336, Loss: 0.779552772641182, Final Batch Loss: 0.14031462371349335
Epoch 4337, Loss: 0.7582770884037018, Final Batch Loss: 0.21113531291484833
Epoch 4338, Loss: 0.7556063532829285, Final Batch Loss: 0.18892641365528107
Epoch 4339, Loss: 0.7910087704658508, Final Batch Loss: 0.16042682528495789
Epoch 4340, Loss: 0.8096951693296432, Final Batch Loss: 0.20165924727916718
Epoch 4341, Loss: 0.7862327247858047, Final Batch Loss: 0.19687420129776
Epoch 4342, Loss: 0.7667343020439148, Final Batch Loss: 0.1429566591978073
Epoch 4343, Loss: 0.7516976296901703, Final Batch Loss: 0.12819485366344452
Epoch 4344, Loss: 0.8876600414514542, Final Batch Loss: 0.329197496175766
Epoch 4345, Loss: 0.7775493413209915, Final Batch Loss: 0.1532488912343979
Epoch 4346, Loss: 0.8178679347038269, Final Batch Loss: 0.17985749244689941
Epoch 4347, Loss: 0.844

Epoch 4447, Loss: 0.8742592036724091, Final Batch Loss: 0.22309380769729614
Epoch 4448, Loss: 0.8138376921415329, Final Batch Loss: 0.1811475157737732
Epoch 4449, Loss: 0.9606448113918304, Final Batch Loss: 0.20933376252651215
Epoch 4450, Loss: 0.7613934725522995, Final Batch Loss: 0.14825142920017242
Epoch 4451, Loss: 0.794733002781868, Final Batch Loss: 0.23942941427230835
Epoch 4452, Loss: 0.757598340511322, Final Batch Loss: 0.13675615191459656
Epoch 4453, Loss: 0.9409362822771072, Final Batch Loss: 0.3826952278614044
Epoch 4454, Loss: 0.8342868387699127, Final Batch Loss: 0.20676352083683014
Epoch 4455, Loss: 0.8686467409133911, Final Batch Loss: 0.2737444043159485
Epoch 4456, Loss: 0.8858194202184677, Final Batch Loss: 0.30078810453414917
Epoch 4457, Loss: 0.9184727072715759, Final Batch Loss: 0.28276902437210083
Epoch 4458, Loss: 0.8688316494226456, Final Batch Loss: 0.2450047731399536
Epoch 4459, Loss: 0.8876108229160309, Final Batch Loss: 0.25547078251838684
Epoch 4460, Loss: 

Epoch 4565, Loss: 0.8197595477104187, Final Batch Loss: 0.25639650225639343
Epoch 4566, Loss: 0.8221640735864639, Final Batch Loss: 0.17306913435459137
Epoch 4567, Loss: 0.7961457669734955, Final Batch Loss: 0.21120429039001465
Epoch 4568, Loss: 0.7474959641695023, Final Batch Loss: 0.1809263676404953
Epoch 4569, Loss: 0.7962273806333542, Final Batch Loss: 0.16799749433994293
Epoch 4570, Loss: 0.8190481960773468, Final Batch Loss: 0.16000892221927643
Epoch 4571, Loss: 0.7006601542234421, Final Batch Loss: 0.23649248480796814
Epoch 4572, Loss: 0.8473779857158661, Final Batch Loss: 0.21555456519126892
Epoch 4573, Loss: 0.8595103025436401, Final Batch Loss: 0.23863893747329712
Epoch 4574, Loss: 0.9058763980865479, Final Batch Loss: 0.2066964954137802
Epoch 4575, Loss: 0.7539674639701843, Final Batch Loss: 0.17412042617797852
Epoch 4576, Loss: 0.7101278007030487, Final Batch Loss: 0.1776084005832672
Epoch 4577, Loss: 0.8429161310195923, Final Batch Loss: 0.20391975343227386
Epoch 4578, Los

Epoch 4673, Loss: 0.6858913898468018, Final Batch Loss: 0.13335703313350677
Epoch 4674, Loss: 0.728084072470665, Final Batch Loss: 0.21144968271255493
Epoch 4675, Loss: 0.6807893812656403, Final Batch Loss: 0.15831395983695984
Epoch 4676, Loss: 0.8370793163776398, Final Batch Loss: 0.24066224694252014
Epoch 4677, Loss: 0.841180294752121, Final Batch Loss: 0.2238043248653412
Epoch 4678, Loss: 0.7054988592863083, Final Batch Loss: 0.1433410495519638
Epoch 4679, Loss: 0.8925461620092392, Final Batch Loss: 0.19752137362957
Epoch 4680, Loss: 0.7727579772472382, Final Batch Loss: 0.1681729406118393
Epoch 4681, Loss: 0.8347740918397903, Final Batch Loss: 0.2667687237262726
Epoch 4682, Loss: 0.7885138541460037, Final Batch Loss: 0.2254902720451355
Epoch 4683, Loss: 0.7892804741859436, Final Batch Loss: 0.1823621243238449
Epoch 4684, Loss: 0.7072395980358124, Final Batch Loss: 0.16574683785438538
Epoch 4685, Loss: 0.8663592636585236, Final Batch Loss: 0.17516566812992096
Epoch 4686, Loss: 0.828

Epoch 4786, Loss: 0.7130503803491592, Final Batch Loss: 0.15755538642406464
Epoch 4787, Loss: 0.820206418633461, Final Batch Loss: 0.18659652769565582
Epoch 4788, Loss: 0.7470470815896988, Final Batch Loss: 0.1487952321767807
Epoch 4789, Loss: 0.6370075345039368, Final Batch Loss: 0.12558525800704956
Epoch 4790, Loss: 0.6706331893801689, Final Batch Loss: 0.11607753485441208
Epoch 4791, Loss: 0.883887991309166, Final Batch Loss: 0.30394428968429565
Epoch 4792, Loss: 0.7617763429880142, Final Batch Loss: 0.12748165428638458
Epoch 4793, Loss: 0.7847517132759094, Final Batch Loss: 0.18554653227329254
Epoch 4794, Loss: 0.7364758998155594, Final Batch Loss: 0.1372622847557068
Epoch 4795, Loss: 0.794413149356842, Final Batch Loss: 0.2283088117837906
Epoch 4796, Loss: 0.8554303795099258, Final Batch Loss: 0.169374480843544
Epoch 4797, Loss: 0.7444017827510834, Final Batch Loss: 0.18277089297771454
Epoch 4798, Loss: 0.6836745589971542, Final Batch Loss: 0.20513896644115448
Epoch 4799, Loss: 0.

Epoch 4897, Loss: 0.6965188533067703, Final Batch Loss: 0.17931023240089417
Epoch 4898, Loss: 0.8298775553703308, Final Batch Loss: 0.24044086039066315
Epoch 4899, Loss: 0.6869909763336182, Final Batch Loss: 0.16733165085315704
Epoch 4900, Loss: 0.8044253438711166, Final Batch Loss: 0.18309861421585083
Epoch 4901, Loss: 0.8681163191795349, Final Batch Loss: 0.16053734719753265
Epoch 4902, Loss: 0.7759074717760086, Final Batch Loss: 0.29072487354278564
Epoch 4903, Loss: 0.7508832961320877, Final Batch Loss: 0.20749405026435852
Epoch 4904, Loss: 0.7851029485464096, Final Batch Loss: 0.19196248054504395
Epoch 4905, Loss: 0.7341313511133194, Final Batch Loss: 0.17210440337657928
Epoch 4906, Loss: 0.7424214035272598, Final Batch Loss: 0.21976277232170105
Epoch 4907, Loss: 0.9466035664081573, Final Batch Loss: 0.2737135589122772
Epoch 4908, Loss: 0.7988851219415665, Final Batch Loss: 0.215914785861969
Epoch 4909, Loss: 0.7831306755542755, Final Batch Loss: 0.18221580982208252
Epoch 4910, Los

In [34]:
softmax = nn.Softmax(dim = 1)
model.train()
for batch in test_loader:
    features, labels = batch
    _, preds = torch.max(softmax(model(features.float())), dim = 1)
    print(metrics.confusion_matrix((labels).cpu(), preds.cpu()))
    print(metrics.classification_report((labels).cpu(), preds.cpu(), digits = 5))

[[21  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  8  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0]
 [ 0  0  0 15  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  7  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0]
 [ 0  0  1  0  0  5  0  0  2  0  0  0  0  0  4  0  0  1  0  0  1]
 [ 0  0  0  0  0  0 10  1  0  0  1  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  8  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  4  0  0  0  0  0  1  0  0  2  0  0  1]
 [ 0  0  0  0  0  0  0  0  0 13  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0 11  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  5  0  0  0  0  0  0  0  0  3]
 [ 1  0  0  0  1  0  0  0  0  0  0  0  4  0  0  0  0  0  1  0  0]
 [ 0  1  0  0  0  0  0  0  0  0  0  0  0  6  0  0  0  0  0  0  0]
 [ 0  0  3  0  0  0  0  0  0  0  0  1  0  0  3  0  0  0  0  0  0]
 [ 0  0  0

# Train on Fake Test on Real

In [10]:
# 20,000 epochs DEFAULT PARAMETERS FOR THRESHOLDS

In [11]:
gen = Generator(z_dim = 110)
load_model(gen, "3 Label 7 Subject GAN Ablation_gen.param")
gen.eval()

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): Linear(in_features=110, out_features=80, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Linear(in_features=80, out_features=60, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): BatchNorm1d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Linear(in_features=60, out_features=50, bias=True)
      (1): Dropout(p=0.1, inplace=False)
      (2): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
    (3): Linear(in_features=50, out_features=33, bias=True)
    (4): Tanh()
  )
)

In [17]:
size = len(X_test)
latent_vectors = get_noise(size, 100)
act_vectors = get_act_matrix(size, 3)
usr_vectors = get_usr_matrix(size, 7)

fake_labels = []

for k in range(size):
    if act_vectors[0][k] == 0 and usr_vectors[0][k] == 0:
        fake_labels.append(0)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 0:
        fake_labels.append(1)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 0:
        fake_labels.append(2)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 1:
        fake_labels.append(3)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 1:
        fake_labels.append(4)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 1:
        fake_labels.append(5)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 2:
        fake_labels.append(6)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 2:
        fake_labels.append(7)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 2:
        fake_labels.append(8)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 3:
        fake_labels.append(9)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 3:
        fake_labels.append(10)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 3:
        fake_labels.append(11)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 4:
        fake_labels.append(12)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 4:
        fake_labels.append(13)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 4:
        fake_labels.append(14)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 5:
        fake_labels.append(15)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 5:
        fake_labels.append(16)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 5:
        fake_labels.append(17)
    elif act_vectors[0][k] == 0 and usr_vectors[0][k] == 6:
        fake_labels.append(18)
    elif act_vectors[0][k] == 1 and usr_vectors[0][k] == 6:
        fake_labels.append(19)
    elif act_vectors[0][k] == 2 and usr_vectors[0][k] == 6:
        fake_labels.append(20)
        
fake_labels = np.asarray(fake_labels)
to_gen = torch.cat((latent_vectors, act_vectors[1], usr_vectors[1]), 1)
fake_features = gen(to_gen).detach()
#fake_features = gen(to_gen).detach().numpy()

In [18]:
_, preds = torch.max(softmax(model(fake_features.float())), dim = 1)
print(metrics.confusion_matrix((fake_labels), preds.cpu()))
print(metrics.classification_report((fake_labels), preds.cpu(), digits = 5, zero_division = 0))

[[ 5  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0 17  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0 17  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0 11  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0 11  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0 12  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  9  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  8  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  9  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0 10  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  8  0  0  4  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0 11  0  0  0  0  0  0  0  0  0]
 [10  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  9  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  9  0  0  0  0  0  0]
 [ 0  0  0