# Training Notebook

##### Import Packages

In [77]:
import pandas as pd

#import GPUtil

import argparse
import json
import logging
import os

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
#import torchvision
#import torchvision.models
#import torchvision.transforms as transforms

from typing import List

## Set up logger to get details of errors
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

##### Import Data

In [78]:
coef_df_binary = pd.read_csv("https://raw.githubusercontent.com/jcox22/Sagemaker_practice_gan/main/rank_1_curves.csv")
coef_df_binary = coef_df_binary.drop(columns = ['Unnamed: 0'])

In [79]:
# Create Variance variables from data

variances = coef_df_binary.describe().loc[['std']]**2

real_var = torch.tensor(variances.to_numpy(), dtype = torch.float32, device = device)

##### Set Parameters

In [84]:
# k is for number of nodes in each hidden layer of NN
k = 10000

# For number of inputs (32 binary digits)
input_length = 32
output_length = input_length

# Model Parameters
epochs = 1000
batch_size = 256
lr = 0.001
momentum = 0.9

# Needed later on for save_model
model_dir = '/models'
data_dir = '/training'

##### Training Dataset

In [85]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_tensor = torch.tensor(coef_df_binary.to_numpy(), dtype = torch.float32, device = device)
train_ds = torch.utils.data.TensorDataset(train_tensor)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last = True)

##### Set Distribution of initial inputs

In [86]:
def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal(m.weight)

##### Create NN Classes

In [87]:
class Generator(nn.Module):
    def __init__(self, output_length: int):
        super(Generator, self).__init__()
        self.dense_layer = nn.Linear(output_length, k)
        self.dense_layer2 = nn.Linear(k, k)
        self.dense_layer3 = nn.Linear(k, k)
        self.dense_layer4 = nn.Linear(k, output_length)

    def forward(self, x):
        l1 = self.dense_layer(x)
        l2 = self.dense_layer2(F.relu(l1))
        l3 = self.dense_layer3(F.relu(l2))
        l4 = self.dense_layer4(F.relu(l3))
        return l4
    
class Discriminator(nn.Module):
    def __init__(self, input_length: int):
        super(Discriminator, self).__init__()
        self.dense_layer = nn.Linear(int(input_length), k)
        self.dense_layer2 = nn.Linear(k, k)
        self.dense_layer3 = nn.Linear(k, k)
        self.dense_layer4 = nn.Linear(k, 1)

    def forward(self, x):
        l1 = self.dense_layer(x)
        l2 = self.dense_layer2(F.relu(l1))
        l3 = self.dense_layer3(F.relu(l2))
        l4 = self.dense_layer4(F.relu(l3))
        return l4

##### Set up for training function

In [88]:
# Store on GPU else cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Device Type: {}".format(device))

# Call generator and discriminator
generator = Generator(output_length)
discriminator = Discriminator(input_length)

# Make sure it is on device
generator = generator.to(device)
discriminator = discriminator.to(device)
# Apply distrubution type
generator.apply(init_normal)
discriminator.apply(init_normal)

# Loss
loss = nn.BCEWithLogitsLoss().to(device)
MSE = torch.nn.MSELoss(reduction = 'sum').to(device)

# Choose optimizer
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
#optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=momentum)

dis_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

  nn.init.normal(m.weight)


##### Training Loop

In [62]:
%%time
for epoch in range(0, epochs):
    running_loss = 0.0
    average_td_gd_loss = 0.0
    running_loss_true_d = 0.0
    g_d_running_loss = 0.0
    for batch in enumerate(train_loader):
        noise = torch.randint(0, 2, size=(batch_size, output_length)).float()
        noise = noise.to(device)
    
        # Generate examples of data
        true_labels = [1] * batch_size
        true_labels = torch.tensor(true_labels).float()
        true_labels = true_labels.to(device).resize_((batch_size, 1))
            
        true_data = batch[1][0]

        # zero the parameter gradients
        gen_optimizer.zero_grad()

        # forward + backward + optimize
        #outputs = model(inputs)
        #G_of_noise = generator(noise)
        #loss = criterion(outputs, labels)
        #loss.backward()
        #optimizer.step()
        G_of_noise = generator(noise)
        D_of_G_of_noise = discriminator(G_of_noise)
        generator_loss = loss(D_of_G_of_noise, true_labels) + MSE(real_var, torch.var(G_of_noise, dim = 0))
        generator_loss.backward()
        gen_optimizer.step()
            
        # Train the discriminator on the true/generated data
        dis_optimizer.zero_grad()
        true_discriminator_out = discriminator(true_data)
        true_discriminator_loss = loss(true_discriminator_out, true_labels)

        # add .detach() here think about this
        generator_discriminator_out = discriminator(G_of_noise.detach()) # introduce new d_of_g_of_noise without gradient
        generator_discriminator_loss = loss(generator_discriminator_out, torch.zeros(batch_size).to(device).resize_((batch_size, 1)))
        discriminator_loss = (true_discriminator_loss*0.1 + generator_discriminator_loss) / 2
        discriminator_loss.backward()
        dis_optimizer.step()

        # print statistics
        running_loss += generator_loss.item() * len(batch)
        average_td_gd_loss += discriminator_loss.item() * len(batch)
        running_loss_true_d += true_discriminator_loss.item() * len(batch)
        g_d_running_loss += generator_discriminator_loss.item() * len(batch)
        #if batch % 2000 == 1999:  # print every 2000 mini-batches
            #print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 2000))
            #print(f"Loss is {generator_loss.item()}.  Running loss is {running_loss/2000}.  Discriminator loss is {discriminator_loss.item()}.  True Discriminator Loss is {true_discriminator_loss.item()}")
            #print(GPUtil.showUtilization())
            #running_loss = 0.0
            #print(torch.cuda.memory_summary(device))
            #print(torch.cuda.list_gpu_processes(device))
        # print(running_loss/10)
    print(f"epoch:{epoch} average generator loss {running_loss / len(train_loader.dataset)}")
    print(f"epoch:{epoch} average discriminator loss {average_td_gd_loss / len(train_loader.dataset)}")
    print(f"epoch:{epoch} average true discriminator loss {running_loss_true_d / len(train_loader.dataset)}")
    print(f"epoch:{epoch} average generator discriminator loss {g_d_running_loss / len(train_loader.dataset)}")

    print("Finished Epoch")

  return F.mse_loss(input, target, reduction=self.reduction)


epoch:0 average generator loss 2640287454927767.0
epoch:0 average discriminator loss 140.53275008053956
epoch:0 average true discriminator loss 9.573726243718792
epoch:0 average generator discriminator loss 280.10813805643664
Finished Epoch
epoch:1 average generator loss 245166732454396.25
epoch:1 average discriminator loss 5.026207713154748
epoch:1 average true discriminator loss 0.1306223243778101
epoch:1 average generator discriminator loss 10.039353193700098
Finished Epoch
epoch:2 average generator loss 104673092992938.7
epoch:2 average discriminator loss 0.00020101520152487301
epoch:2 average true discriminator loss 0.004020303900440295
epoch:2 average generator discriminator loss 0.0
Finished Epoch
epoch:3 average generator loss 57976259667725.414
epoch:3 average discriminator loss 2.293431045493983
epoch:3 average true discriminator loss 0.14363663058230283
epoch:3 average generator discriminator loss 4.5724984277589416
Finished Epoch
epoch:4 average generator loss 3637814907679

##### Saving the Model

In [53]:
torch.save({
        'generator_state_dict': generator.state_dict(),
        'optimizer_state_dict': gen_optimizer.state_dict()
    }, './gen9.pt')