In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder

device = "cuda" if torch.cuda.is_available() else "cpu"

# Move Working Directory

In [None]:
# MaryClare's
#os.chdir('/Users/maryclaremartin/Documents/jup/ExtraSensory')

# Josh's
os.chdir("/Users/jdeoliveira/REU2021-human-context-recognition/ExtraSensory_data")

# The Generator

In [None]:
#defines each generator layer
#input and output dimensions needed
def generator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace = True)
    )

#returns n_samples of z_dim (number of dimensions of latent space) noise
def get_noise(n_samples, z_dim):
    #torch.manual_seed(0)
    return torch.randn(n_samples, z_dim).to(device)

#defines generator class
class Generator(nn.Module):
    def __init__(self, z_dim = 10, feature_dim = 26, hidden_dim = 128):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            generator_block(z_dim, int(hidden_dim/2)),
            generator_block(int(hidden_dim/2), int(hidden_dim/4)),
            generator_block(int(hidden_dim/4), 30),
            generator_block(30, 28),
            nn.Linear(28, feature_dim)
        )
    def forward(self, noise):
        return self.gen(noise)

##calculates generator loss
#gen: generator
#disc: discriminator
#criterion1: loss function1
#criterion2: loss function2
#batch_size: batch size
#z_dim: number of dimensions in the latent space
def get_gen_loss(gen, disc, act, usr, criterion1, criterion2, batch_size, z_dim, activities, users):
    latent_vectors = get_noise(batch_size, z_dim)
    act_vectors = get_act_matrix(batch_size, activities)
    usr_vectors = get_usr_matrix(batch_size, users)
    
    to_gen = torch.cat((latent_vectors, act_vectors[1], usr_vectors[1]), 1)
    fake_features = gen(to_gen)
    
    pred_disc = disc(fake_features)
    pred_act = act(fake_features)
    pred_usr = usr(fake_features)
    
    gen_loss = criterion1(pred_disc, torch.ones_like(pred)) + criterion2(pred_act, act_vectors[0]) + criterion2(pred_usr, usr_vectors[0])
    return gen_loss

def get_act_matrix(batch_size, a_dim):
    indexes = np.random.randint(a_dim, size = batch_size)
    
    one_hot = np.zeros((indexes.size, indexes.max()+1))
    one_hot = one_hot[np.arange(indexes.size),indexes] = 1
    
    return torch.Tensor(indexes), torch.Tensor(one_hot)
    
def get_usr_matrix(batch_size, u_dim):
    indexes = np.random.randint(u_dim, size = batch_size)
    
    one_hot = np.zeros((indexes.size, indexes.max()+1))
    one_hot = one_hot[np.arange(indexes.size),indexes] = 1
    
    return torch.Tensor(indexes), torch.Tensor(one_hot)

# Create Fake Generated Samples

In [None]:
def get_fake_samples(gen, batch_size, z_dim):
    """
    Generates fake acceleration features given a batch size, latent vector dimension, and trained generator.
    
    """
    latent_vectors = get_noise(batch_size, z_dim) ### Retrieves a 2D tensor of noise
    fake_features = gen(latent_vectors.to(device))
    
    return fake_features ### Returns a 2D tensor of fake features of size batch_size x z_dim

# The Discriminator

In [None]:
def discriminator_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.LeakyReLU(0.05)
    )

#defines discriminator class
class Discriminator(nn.Module):
    def __init__(self, feature_dim = 26, hidden_dim = 16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            discriminator_block(feature_dim, hidden_dim),
            discriminator_block(hidden_dim, int(hidden_dim/2)),
            discriminator_block(int(hidden_dim/2), int(hidden_dim/4)),
            nn.Linear(int(hidden_dim/4), 1),
            nn.Sigmoid()                    
        )
    def forward(self, feature_vector):
        return self.disc(feature_vector)
    
def get_disc_loss(gen, disc, criterion, real_features, batch_size, z_dim, a_dim, u_dim):
    latent_vectors = get_noise(batch_size, z_dim)
    act_vectors = get_act_matrix(batch_size, a_dim)
    usr_vectors = get_usr_matrix(batch_size, u_dim)
    
    to_gen = torch.cat((latent_vectors, act_vectors[1], usr_vectors[1]), 1)
    fake_features = gen(to_gen)
    pred_fake = disc(fake_features.detach())
    
    ground_truth = torch.zeros_like(pred_fake)
    loss_fake = criterion(pred_fake, ground_truth)
    
    pred_real = disc(real_features)
    ground_truth = torch.ones_like(pred_real)
    loss_real = criterion(pred_real, ground_truth)
    
    disc_loss = (loss_fake + loss_real) / 2
    return disc_loss


# User Classifier

In [None]:
def classifier_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.LeakyReLU(0.05)
    )

class Classifier(nn.Module):
    def __init__(self, feature_dim = 26):
        super(Classifier, self).__init__()
        self.network = nn.Sequential(
            classifier_block(feature_dim, 20),
            classifier_block(20, 15),
            classifier_block(15, 10),
            nn.Linear(10, 3) 
        )
    def forward(self, x):
        return self.network(x)

# Activity Classifier

In [None]:
def classifier_block(input_dim, output_dim):
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.Dropout(0.1),
        nn.LeakyReLU(0.05)
    )

class Classifier(nn.Module):
    def __init__(self, feature_dim = 26):
        super(Classifier, self).__init__()
        self.network = nn.Sequential(
            classifier_block(feature_dim, 20),
            classifier_block(20, 15),
            classifier_block(15, 10),
            classifier_block(10, 5),
            nn.Linear(5, 3)
        )
    def forward(self, x):
        #softmax = nn.Softmax(dim = 1)
        return self.network(x)

# Interpolating acceleration columns with average values

In [None]:
#replaces Nan values with average values
#df: data frame of data to use
def interpolation(df):
    col_to_avg = list(df.columns) #Start with keeping all the columns as columns to use an average interpolation on
    for k in range(len(list(df.columns))):
        if list(df.columns)[k].startswith(('discrete', 'label')): #Remove label and discrete columns from col_to_avg
            col_to_avg.remove(list(df.columns)[k])
    
    df_with_avg = df[col_to_avg].fillna(df[col_to_avg].mean()) #Interpolate nan columns for all continuous-valued columns with average
    
    col_to_zero = list(df.columns)
    for k in range(len(list(df.columns))):
        if not list(df.columns)[k].startswith(('discrete', 'label')): #Remove all columns except label and discrete
            col_to_zero.remove(list(df.columns)[k])
    
    df_with_zero = df[col_to_zero].fillna(0) #Interpolate nan values for label and discrete columns with 0
    
    return pd.concat([df_with_avg, df_with_zero], axis = 1)

# Visualize Batches

In [None]:
##prints a plot of a generator batch
#gen: generator
#b_size: batch size
#epochs: current epoch (-1)
def visualize_gen_batch(gen, b_size, epochs = -1):
    #print(str(b_size))
    latent_vectors = get_noise(b_size, z_dim)
    #print(latent_vectors.shape)
    fake_features = gen(latent_vectors)
    #print(fake_features.shape)
    
    w_img = fake_features
    wmin = torch.min(w_img)
    wmax = torch.max(w_img)
    w_img = w_img.cpu()
    w_img = w_img.detach().numpy()
    c = plt.imshow(w_img, cmap ='Reds', vmin = wmin , vmax = wmax,
                        interpolation ='nearest', origin ='upper')
    plt.colorbar(c)
    plt.title('Generated Batch at Epoch ' + str(epochs), fontweight ="bold")
    plt.show()

##prints a plot of a batch of real data
#features: real data
def visualize_real_batch(features):
    w_img = features
    wmin = torch.min(w_img)
    wmax = torch.max(w_img)
    w_img = w_img.cpu()
    w_img = w_img.detach().numpy()
    c = plt.imshow(w_img, cmap ='Reds', vmin = wmin , vmax = wmax,
                        interpolation ='nearest', origin ='upper')
    plt.colorbar(c)
    plt.title('Real Batch of Data', fontweight ="bold")
    plt.show()

# Calculate Performance Statistics

In [None]:
##calculates performance statistics for each epoch of training
#gen: generator
#disc: discriminator
#b_size: batch size
#z_dim: number of dimensions of the latent space
##returns accuracy, precision, recall, fpR, and f1 score
def performance_stats(gen, disc, b_size, z_dim, batch = None):
    tp = 0 #true positive
    fp = 0 #false positive
    tn = 0 #true negative
    fn = 0 #false negative

    with torch.no_grad():
        if batch is None:
            latent_vectors = get_noise(b_size, z_dim)
            fake_features = gen(latent_vectors)
            y_hat = torch.round(disc(fake_features))
            y_label = [0] * b_size
            y_label = torch.Tensor(y_label).to(device)
        else:
            latent_vectors = get_noise(int(b_size/2), z_dim)
            fake_features = gen(latent_vectors.to(device))
            y_hat = torch.round(disc(fake_features.to(device)))
            y_label = torch.Tensor([0] * int(b_size/2))
            
            real_y_hat = torch.round(disc(batch[:int(b_size/2)].to(device)))
            y_add = torch.Tensor([1] * int(b_size/2))
            y_label = torch.cat((y_label, y_add), dim = 0)
            #for i in range(0, int(b_size/2)):
            # y_label.append(1)
            y_hat = torch.cat((y_hat, real_y_hat), dim = 0).to(device)
            
            #print(y_hat)
            #print(y_label)
         
        
        for k in range(len(y_hat)):
            if y_label[k] == 1:
                if y_hat[k] == 1:
                    tp += 1
                else:
                    fn += 1
            elif y_hat[k] == 0:
                tn += 1
            elif y_hat[k] == 1:
                fp += 1
            else:
                print("Error")
                exit()
            
        class_acc = (tp + tn)/(tp + tn + fp + fn)
        
        if tp + fp == 0:
            precision = 0
        else:
            precision = tp / (tp + fp)
            
        if tp + fn == 0:
            recall = 0
        else:
            recall = tp / (tp + fn)
            
        if fp + tn == 0:
            fpR = 0
        else: 
            fpR = fp / (fp + tn)

        #print(f'Classification Accuracy: {class_acc:.2f}')
        #print(f'Precision: {precision:.2f}') #What percentage of a model's positive predictions were actually positive
        #print(f'Recall: {recall:.2f}') #What percent of the true positives were identified
        #print(f'F-1 Score: {2*(precision * recall / (precision + recall + 0.001)):.2f}')
        return class_acc, precision, recall, fpR, 2*(precision * recall / (precision + recall + 0.001))

# Create Density Curves

In [1]:
#create and plot density curves for mean, x, y, z acceleration
#reals: real data
#fakes: generated data
def density_curves(reals, fakes):
    plt.figure(figsize = (15, 15))
    subplot(2, 2, 1)
    sns.kdeplot(fakes.cpu().numpy()[:,0], color = 'r', shade = True, label = 'Fake Distribution')
    sns.kdeplot(reals[:,0], color = 'b', shade = True, label = 'Real Distribution')
    plt.xlabel('Mean Acceleration')
    plt.ylabel('Density')
    plt.legend()
    #plt.show()

    subplot(2, 2, 2)
    sns.kdeplot(fakes.cpu().numpy()[:,18], color = 'r', shade = True, label = 'Fake Distribution')
    sns.kdeplot(reals[:,18], color = 'b', shade = True, label = 'Real Distribution')
    plt.xlabel('Mean X-Acceleration')
    plt.ylabel('Density')
    plt.legend()
    #plt.show()

    subplot(2, 2, 3)
    sns.kdeplot(fakes.cpu().numpy()[:,19], color = 'r', shade = True, label = 'Fake Distribution')
    sns.kdeplot(reals[:,19], color = 'b', shade = True, label = 'Real Distribution')
    plt.xlabel('Mean Y-Acceleration')
    plt.ylabel('Density')
    plt.legend()
    #plt.show()

    subplot(2, 2, 4)
    sns.kdeplot(fakes.cpu().numpy()[:,20], color = 'r', shade = True, label = 'Fake Distribution')
    sns.kdeplot(reals[:,20], color = 'b', shade = True, label = 'Real Distribution')
    plt.xlabel('Mean Z-Acceleration')
    plt.ylabel('Density')
    plt.legend()
    plt.show()

# Calculate Wassertein distance for each dimension

In [None]:
##calculate Waaserstein distances for each dimension
#gen: generator
#z_dim: number of dimensions of the latent space
#feature_dim: number ofd dimensions in the feature space
#sample: sample of data
def all_Wasserstein_dists(gen, z_dim, feature_dim, sample):
    wasser_dim = []
    latent_vectors = get_noise(len(sample), z_dim)
    fake_features = gen(latent_vectors.to(device))
    for k in range(feature_dim):
        wasser_dim.append(wasserstein_distance(fake_features[:, k].cpu().detach().numpy(), sample[:, k].cpu().detach().numpy()))
    return torch.tensor(wasser_dim)

# Visualizing Generation Quality

In [None]:
#creates and prints a plot of the generated vs real data
#data: data used
#gen: generator
#z_dim: number of dimensions of the latent space
def visualize_gen(data, gen, z_dim):
    #Number of datum to visualize
    sample_size = len(data)
    reals = data[0:sample_size, :]
    fakes = get_fake_samples(gen, sample_size, z_dim).detach()
    density_curves(reals, fakes)

# Initialize Training Environment

In [None]:
###initalize parameters that depend on training loop parameters
#X: acceleration data
#y: labels associated with X data (fake or real)
#z_dim: number of dimensions to the latent space
#disc_lr: discriminator learning rate
#gen_lr: generator learning rate
#DISCRIMINATOR: 1 to indicate if discriminator is training
#batch_size: batch size
#disc: initialized discrimiantor

def initialize_params(X, y, z_dim, disc_lr, gen_lr, DISCRIMINATOR, batch_size, disc):
    #initialize generator
    gen = Generator(z_dim + a_dim + u_dim).to(device)
    #indicate that discriminator is training
    to_train = DISCRIMINATOR
    #create training features
    train_features = torch.tensor(X)
    #create training labels
    train_labels = torch.tensor(y)
    #concatenate to create training data
    train_data = torch.utils.data.TensorDataset(train_features, train_labels)
    #create data loader for training data
    train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True)
    #initialize generator and discriminator optimizers
    opt_disc = optim.Adam(disc.parameters(), lr = disc_lr)
    opt_gen = optim.Adam(gen.parameters(), lr = gen_lr)
    
    return gen, to_train, train_features, train_labels, train_data, train_loader, opt_disc, opt_gen   

# Save / Load Models

In [None]:
# Change path and name of the Generator and Discriminator accordingly
def save_model(gen, disc, model_name):
    torch.save(gen.state_dict(), f"saved_models/{model_name}_gen")
    torch.save(disc.state_dict(), f"saved_models/{model_name}_disc")
    
def load_model(model, model_name):
    model.load_state_dict(torch.load(f'saved_models/{model_name}'))

# Training

In [None]:
######Training loop to train GAN

#Parameters to specifiy: 
    #X: starting accelerometer data
    #y: starting labels for X data (fake or real)
    
#Set parameters (do not change)
    #criterion: loss function (BCE)
    #dig: number of significant digits for printing (5)
    #feature_dim: Number of dimensions of output from generator (26)
    #GENERATOR: set generator to zero for training
    #DISCRIMINATOR: set discriminator to one for training
    #train_string: starting machine to train (DISC)
    #disc: initalize discriminator
    #rel_epochs: Epochs passed since last switch (constant training) (0)
    #rows: initialization of array to save data of each epoch to CSV file ([])
    #heading: array of column headings for table (["Epoch", "Machine Training", "Discriminator Loss", 
                    #"Generator Loss", "FPR", "Recall", "Median Wasserstein", "Mean Wasserstein"])
    #table: intialize a table as a pretty table to save epoch data
    #switch_count: number of switches in dynamic training (0)
    
#Set parameters (can change):
    #z_dim: number of dimensions of latent vector (100)
    #gen_lr: generator learning rate (.001)
    #disc_lr: discriminator learning rate (.001) (shoud be equal to gen_lr)
    #batch_size: batch size (75)
    #print_batches: Show model performance per batch (False)
    #n_epochs: number of epochs to train (100)
    #constant_train_flag: (False)
        #Set to true to train based on constant # of epochs per machine 
        #Set to false to train dynamically based on machine performance
        
    #Constant training approach:
        #disc_epochs: Number of consecutive epochs to train discriminator before epoch threshold (5)
        #gen_epochs: Number of consecutive epochs to train generator before epoch threshold (2)
        #epoch_threshold: Epoch number to change training epoch ratio (50)
        #disc_epochs_change: New number of consecutive epochs to train discriminator after epoch threshold is exceeded (1)
        #gen_epochs_change: New number of consecutive epochs to train generator after epoch threshold is exceeded (50)
    
    #Dynamic training approach:                        
        #static_threshold: Epoch number to change from static ratio to dynamic (18)
        #static_disc_epochs: Number of consecutive epochs to train discriminator before epoch threshold (4)
        #static_gen_epochs: Number of consecutive epochs to train generator before epoch threshold (2)
        #pull_threshold: Accuracy threshold for switching machine training when the generator is no longer competitive (0.4)
        #push_threshold: Accuracy threshold for switching machine training when the discriminator is no longer competitive (0.6)
        #recall_threshold: threshold for recall to switch machine training when discriminator is training well
        #switch_flag: indicates if we should switch our training machine (False)
        
def training_loop(X, y, act, usr, criterion1 = nn.BCELoss(), criterion2 = nn.CrossEntropyLoss(), gan_id = "Mod Test Gan", dig = 5, feature_dim = 26, 
                  GENERATOR = 0, DISCRIMINATOR = 1, train_string = "DISC", disc = Discriminator(), z_dim = 100, a_dim = 3, u_dim = 3, 
                  gen_lr =  0.001, disc_lr = 0.001, batch_size = 100, constant_train_flag = False, disc_epochs = 5,
                  gen_epochs = 2, epoch_threshold = 50, disc_epochs_change = 5, gen_epochs_change = 2, rel_epochs = 0,
                 static_threshold = 18, static_disc_epochs = 5, static_gen_epochs = 2, pull_threshold = 0.2,
                 push_threshold = 0.8, recall_threshold = 0.75, print_batches = False, n_epochs = 1000, rows = [],
                 heading = ["Epoch", "Machine Training", "Discriminator Loss", "Generator Loss", "Accuracy", "FPR", "Precision", "Recall", "F1", "Median Wasserstein", "Mean Wasserstein"],
                 table = PrettyTable(), switch_flag = False, switch_count = 0, last_real_features = []):
    
    disc.to(device)
    #returns generator, sets discriminator training, creates training tensor, loads data, and initializes optimizers
    gen, to_train, train_features, train_labels, train_data, train_loader, opt_disc, opt_gen = initialize_params(X, y, z_dim, a_dim, u_dim, disc_lr, gen_lr, DISCRIMINATOR, batch_size, disc)

    #set pretty table field names
    table.field_names = heading
    
    visualize_gen(X, gen, z_dim)

    gen_epochs = 0
    
    last_D_loss = -1.0
    last_G_loss = -1.0
    
    mean_mean = []
    mean_median = []
    
    for epoch in range(n_epochs):  
        if constant_train_flag:
            if to_train == DISCRIMINATOR and rel_epochs >= disc_epochs:
                rel_epochs = 0
                to_train = GENERATOR
                train_string = "GEN"

            elif to_train == GENERATOR and rel_epochs >= gen_epochs:
                rel_epochs = 0
                to_train = DISCRIMINATOR
                train_string = "DISC"

            # Change epoch ratio after intial 'leveling out'
            if epoch == epoch_threshold:
                rel_epochs = 0
                to_train = GENERATOR
                train_string = "GENERATOR"

                old_ratio = gen_epochs / disc_epochs
                gen_epochs = gen_epochs_change
                disc_epochs = disc_epochs_change
                new_ratio = gen_epochs / disc_epochs
                print(f'\n\nTraining ratio of G/D switched from {old_ratio:.{dig}f} to {new_ratio:.{dig}f}\n\n')
        else:
            if epoch < static_threshold:
                if to_train == DISCRIMINATOR and rel_epochs >= static_disc_epochs:
                    rel_epochs = 0
                    to_train = GENERATOR
                    train_string = "GEN"

                elif to_train == GENERATOR and rel_epochs >= static_gen_epochs:
                    rel_epochs = 0
                    to_train = DISCRIMINATOR
                    train_string = "DISC"

            else:
                #to_train = DISCRIMINATOR
                #train_string = "DISC"
                if not switch_flag:
                    print("\nSwitching to Dynamic Training\n")
                    switch_flag = True
                if to_train == DISCRIMINATOR and fpR <= pull_threshold and R >= recall_threshold:
                    to_train = GENERATOR
                    train_string = "GEN"
                    print("\nPull Generator\n")
                    switch_count += 1
                if to_train == GENERATOR and fpR >= push_threshold:
                    to_train = DISCRIMINATOR
                    train_string = "DISC"
                    print("\nPush Generator\n")
                    switch_count += 1
        print(f'Epoch [{epoch + 1}/{n_epochs}] Training: {train_string} ', end ='')
        for batch_idx, (real_features, _) in enumerate(train_loader):
            #batch_size = len(real_features)
            
            if print_batches:
                    print(f'\n\tBatch [{batch_idx + 1}/{len(train_loader)}] |', end ='')

            if to_train == DISCRIMINATOR:
                ### Training Discriminator
                #visualize_real_batch(real_features.float())
                opt_disc.zero_grad()
                disc_loss = get_disc_loss(gen, disc, criterion1, real_features.float(), batch_size, z_dim, a_dim, u_dim)
                disc_loss.backward(retain_graph = True)
                opt_disc.step()
                acc, P, R, fpR, F1 = performance_stats(gen, disc, batch_size, z_dim, batch = real_features.float())
                w_dist = all_Wasserstein_dists(gen, z_dim, feature_dim, real_features.float())
                median_w_dist = torch.median(w_dist)
                mean_w_dist = torch.mean(w_dist)
                
                mean_mean.append(mean_w_dist)
                mean_median.append(median_w_dist)

                last_D_loss = disc_loss.item()
                
                if last_G_loss == -1.0:
                    last_G_loss = get_gen_loss(gen, disc, act, usr, criterion1, criterion2, batch_size, z_dim, a_dim, u_dim)
                
                if print_batches:
                    print(f'Loss D: {last_D_loss:.{dig}f}, Loss G: {last_G_loss:.{dig}f} | Acc: {acc:.{dig}f} | fpR: {fpR:.{dig}f} P: {P:.{dig}f} | R: {R:.{dig}f} | F1: {F1:.{dig}f} | Median Wasserstein: {median_w_dist:.{dig}f} | Mean Wasserstein: {mean_w_dist:.{dig}f}')
            else:
                ### Training Generator
                opt_gen.zero_grad()
                gen_loss = get_gen_loss(gen, disc, act, usr, criterion1, criterion2, batch_size, z_dim, a_dim, u_dim)
                gen_loss.backward()
                opt_gen.step()
                acc, P, R, fpR, F1 = performance_stats(gen, disc, batch_size, z_dim, batch = real_features.float())
                w_dist = all_Wasserstein_dists(gen, z_dim, feature_dim, real_features.float())
                median_w_dist = torch.median(w_dist)
                mean_w_dist = torch.mean(w_dist)
                
                mean_mean.append(mean_w_dist)
                mean_median.append(median_w_dist)
                
                last_G_loss = gen_loss.item()
                
                if last_D_loss == -1.0:
                    last_D_loss = get_disc_loss(gen, disc, criterion1, real_features.float(), batch_size, z_dim, a_dim, u_dim)
                
                if print_batches:
                    print(f'Loss D: {last_D_loss:.{dig}f}, Loss G: {last_G_loss:.{dig}f} | Acc: {acc:.{dig}f} | fpR: {fpR:.{dig}f} P: {P:.{dig}f} | R: {R:.{dig}f} | F1: {F1:.{dig}f} | Median Wasserstein: {median_w_dist:.{dig}f} | Mean Wasserstein: {mean_w_dist:.{dig}f}')

        if not print_batches:
            
            mean_mean_w = torch.mean(torch.Tensor(mean_mean)) 
            
            mean_median_w = torch.mean(torch.Tensor(mean_median))
            
            if to_train == DISCRIMINATOR:
                ### Currently doesn't print Median/Mean Wasserstein --> Change if needed
                print(f'| Loss D: {last_D_loss:.{dig}f}, Loss G: {last_G_loss:.{dig}f} | Acc: {acc:.{dig}f} | fpR: {fpR:.{dig}f} | P: {P:.{dig}f} | R: {R:.{dig}f} | F1: {F1:.{dig}f} | Median W: {mean_median_w:.{dig}f} | Mean W: {mean_mean_w:.{dig}f}')
                row_to_add = [f"{epoch + 1}", "Discriminator", f"{last_D_loss:.{dig}f}", f"{last_G_loss:.{dig}f}", f"{acc:.{dig}f}", f"{fpR:.{dig}f}", f"{P:.{dig}f}", f"{R:.{dig}f}", f"{F1:.{dig}f}", f"{mean_median_w:.{dig}f}", f"{mean_mean_w:.{dig}f}"]
                table.add_row(row_to_add)
                rows.append(row_to_add)
            else:
                print(f'| Loss D: {last_D_loss:.{dig}f}, Loss G: {last_G_loss:.{dig}f} | Acc: {acc:.{dig}f} | fpR: {fpR:.{dig}f} | P: {P:.{dig}f} | R: {R:.{dig}f} | F1: {F1:.{dig}f} | Median W: {mean_median_w:.{dig}f} | Mean W: {mean_mean_w:.{dig}f}')
                row_to_add = [f"{epoch + 1}", "Generator", f"{last_D_loss:.{dig}f}", f"{last_G_loss:.{dig}f}", f"{acc:.{dig}f}", f"{fpR:.{dig}f}", f"{P:.{dig}f}", f"{R:.{dig}f}", f"{F1:.{dig}f}", f"{mean_median_w:.{dig}f}", f"{mean_mean_w:.{dig}f}"]
                table.add_row(row_to_add)
                rows.append(row_to_add)
                gen_epochs += 1
        mean_mean.clear()
        mean_median.clear()
        rel_epochs += 1
    print("\n\nTraining Session Finished")
    print(f"Encountered {switch_count} non-trivial training swaps")
    percent = gen_epochs / n_epochs
    print(f"Trained Generator {gen_epochs} out of {n_epochs} ({percent:.3f})")
    f = open("model_outputs/" + gan_id + ".txt", "w")
    f.write(table.get_string())
    f.close()
    print("Model Results Sucessfully Saved to \"model_outputs/" + gan_id + ".txt\"")

    with open("model_outputs/" + gan_id + ".csv", "w") as csvfile: 
        # creating a csv writer object 
        csvwriter = csv.writer(csvfile) 
        # writing the fields 
        csvwriter.writerow(heading)
        # writing the data rows 
        csvwriter.writerows(rows)
    print("Model Results Sucessfully Saved to \"model_outputs/" + gan_id + ".csv\"")
    save_model(gen, disc, gan_id)
    model_output = pd.read_csv("model_outputs/" + gan_id + ".csv")
    visualize_gen(X, gen, z_dim)
    
    # Change path and name of the Generator and Discriminator accordingly
    save_model(gen, disc, gan_id)
    
    return model_output

# Plot Metrics

In [None]:
#plot metrics based on data (csv)
def plot_metrics(data, vanilla = True):
    if vanilla:
        sns.set(style = 'whitegrid', context = 'talk', palette = 'rainbow')
    
        plt.figure(figsize = (15, 15))
        subplot(2, 2, 1)
        sns.scatterplot(x = 'Epoch', y = 'FPR', data = data).set(xlim = (0, None))
        sns.despine()
        
        subplot(2, 2, 2)
        sns.scatterplot(x = 'Epoch', y = 'Recall', data = data).set(xlim = (0, None))
        sns.despine()
        
        subplot(2, 2, 3)
        sns.regplot(x = 'Epoch', y = 'Median Wasserstein', data = data, line_kws = {'color': 'orange'}).set(xlim = (0, None))
        sns.despine()
        
        subplot(2, 2, 4)
        sns.regplot(x = 'Epoch', y = 'Mean Wasserstein', data = data, line_kws = {'color': 'orange'}).set(xlim = (0, None))
        sns.despine()
        plt.show()
    else:
        sns.set(style = 'whitegrid', context = 'talk', palette = 'rainbow')
        plt.figure(figsize = (15, 8))
        
        subplot(1, 2, 1)
        sns.regplot(x = 'Epoch', y = 'Median Wasserstein', data = data, line_kws = {'color': 'orange'}).set(xlim = (0, None))
        sns.despine()
        
        subplot(1, 2, 2)
        sns.regplot(x = 'Epoch', y = 'Mean Wasserstein', data = data, line_kws = {'color': 'orange'}).set(xlim = (0, None))
        sns.despine()
        
        plt.show()

# Run Training

In [None]:
#X, y = start_data("aggregated_data/aggregated_data.csv", "label:SITTING")
X, y = start_data("raw_data/0A986513-7828-4D53-AA1F-E02D6DF9561B.features_labels.csv", "label:SITTING") 
model_output = training_loop(X,y, activity_classifier, user_classifier, gan_id="10", batch_size = 200, gen_lr=.0001, disc_lr =.0001, n_epochs=10, dig=5, constant_train_flag=True)
plot_metrics(model_output, True)