In [None]:
#!/usr/bin/env python
# coding: utf-8

'''
sample command: python T4_BT19_ae.py -k 0 -c 0 -r 1 --data_dir /home/ruihan/data
Individual training for BioTac data (full/partial data)
if -r=1, train with full data
if -r=2, train with half data
loss = classification loss + recon loss 
'''

# Import
import os,sys
import pickle
import argparse
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

from vrae.vrae import VRAEC
from preprocess_data import get_TrainValTestLoader, get_TrainValTestDataset, get_TrainValTestData
from vrae.visual import plot_grad_flow

# # Parse argument
# parser = argparse.ArgumentParser()
# parser.add_argument("-i", "--rep", type=int, default=0, help='index of running repetition')
# parser.add_argument('--data_dir', type=str, default='data', help="DIR set in 'gh_download.sh' to store compiled_data")
# parser.add_argument("-k", "--kfold", type=int, default=0, help="kfold_number for loading data")
# parser.add_argument("-r", "--reduction", type=int, default=1, help="data reduction ratio for partial training")
# parser.add_argument("-c", "--cuda", default=0, help="index of cuda gpu to use")
# args = parser.parse_args()

# dummy class to replace argparser, if running jupyter notebook
class Args:
    reduction = 0
    data_dir = 'data'
    kfold = 0
    cuda = '0'
    reduction = 1

args=Args()

# Set hyper params
args_data_dir = args.data_dir
kfold_number = args.kfold
data_reduction_ratio = args.reduction
shuffle = True # set to False for partial training
num_class = 20
sequence_length = 400
number_of_features = 19

hidden_size = 90
hidden_layer_depth = 1
latent_length = 40
batch_size = 32
learning_rate = 0.001 # 0.0005
n_epochs = 100
dropout_rate = 0.2
cuda = True # options: True, False
header = None

# loss weightage
w_r = 0.01
w_c = 1

np.random.seed(1)
torch.manual_seed(1)

# Load data
data_dir = os.path.join(args_data_dir, "compiled_data/")
logDir = 'models_and_stats/'
if_plot = False

# RNN block
block = "phased_LSTM" # LSTM, GRU, phased_LSTM

# model_name = 'BT19_ae_{}_wrI_{}_wC_{}_{}'.format(data_reduction_ratio, w_r, w_c, str(kfold_number))
model_name = "model_"+block+"_B"


if torch.cuda.is_available():
    device = torch.device("cuda:{}".format(args.cuda))
else:
    device = torch.device('cpu')

if args.reduction != 1:
    print("load {} kfold number, reduce data to {} folds, put to device: {}".format(args.kfold, args.reduction, device))
else:
    print("load {} kfold number, train with full data, put to devide: {}".format(args.kfold, device))

prefix = ""
dataset_dir = os.path.join(args_data_dir, "c20/") # TODO
train_set, val_set, test_set = get_TrainValTestDataset(dataset_dir, k=0, prefix=prefix, seq_len=sequence_length)
train_loader, val_loader, test_loader = get_TrainValTestLoader(dataset_dir, k=0, batch_size=batch_size,shuffle=shuffle, prefix=prefix,seq_len=sequence_length)
X_train, X_val, X_test, Y_train, Y_val, Y_test = get_TrainValTestData(dataset_dir, k=0, prefix=prefix,seq_len=sequence_length)
# Initialize models
model = VRAEC(num_class=num_class,
            block=block,
            sequence_length=sequence_length,
            number_of_features = number_of_features,
            hidden_size = hidden_size, 
            hidden_layer_depth = hidden_layer_depth,
            latent_length = latent_length,
            batch_size = batch_size,
            learning_rate = learning_rate,
            n_epochs = n_epochs,
            dropout_rate = dropout_rate,
            cuda = cuda,
            model_name=model_name,
            header=header,
            device = device)
model.to(device)

# Initialize training settings
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
cl_loss_fn = nn.NLLLoss()
recon_loss_fn = nn.MSELoss()

# model.load_state_dict(torch.load('models_and_stats/model_phased_LSTM_B30.pt', map_location='cpu'))
saved_dicts = torch.load('models_and_stats/model_phased_LSTM_B32.pt', map_location='cpu')
model.load_state_dict(saved_dicts['model_state_dict'])
optimizer.load_state_dict(saved_dicts['optimizer_state_dict'])

training_start=datetime.now()
# create empty lists to fill stats later
epoch_train_loss = []
epoch_train_acc = []
epoch_val_loss = []
epoch_val_acc = []
max_val_acc = 0

if block == "phased_LSTM":
    times = torch.ones(batch_size, sequence_length)

In [None]:
# B33 nan grads

In [None]:
for epoch in range(n_epochs):
    if epoch < 33:
        continue
    # TRAIN
    model.train()
    correct = 0
    train_loss = 0
    train_num = 0
    for i, (XB,  y) in enumerate(train_loader):
        if model.header == 'CNN':
            x = XI
        else:
            x = XB
        x, y = x.to(device), y.long().to(device)
        if x.size()[0] != batch_size:
            break
        
        # reduce data by data_reduction_ratio times
        if i % data_reduction_ratio == 0:
            train_num += x.size(0)
            optimizer.zero_grad()
            if block == "phased_LSTM":
                x_decoded, latent, output = model(x, times)
            else:
                x_decoded, latent, output = model(x)

            # assert not torch.isnan(y).any(), "batch_num="+str(i)
            # print((output == 0).nonzero().size(0)==0)

            assert (output == 0).nonzero().size(0)==0, 'output contain zero, batch_num'+str(i)+' indices:'+str((output == 0).nonzero())
            
            cl_loss = cl_loss_fn(output, y)
            recon_loss = recon_loss_fn(x_decoded, x)
            loss = w_c*cl_loss + w_r *recon_loss
            
            # compute classification acc
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(y.data.view_as(pred)).long().cpu().sum().item()
            # accumulator
            train_loss += loss.item()
            start_bp = datetime.now()
            loss.backward()
            figname = logDir + model_name + "grad_flow_plot_epoch" +str(epoch)+".png"
            if i == 0: # and epoch%50 == 0:
                for n, p in model.named_parameters():
                    if p.requires_grad:
                        print(n,p.grad)
                print("grad flow for epoch {}".format(epoch))
                plot_grad_flow(model.named_parameters(), figname, if_plot=False)
            optimizer.step()
            # print('1 batch bp time:', datetime.now()-start_bp)

    # if epoch == 0:
    #     print('first epoch training time:', datetime.now()-training_start)
    
    # if epoch < 20 or epoch%200 == 0:
    # print("train last batch {} of {}: cl_loss {:.3f} recon_loss {:.3f}".format(i, len(train_loader), cl_loss, recon_loss))

    # fill stats
    train_accuracy = correct / train_num 
    train_loss /= train_num
    epoch_train_loss.append(train_loss)
    epoch_train_acc.append(train_accuracy) 
    
    # VALIDATION
    model.eval()
    correct = 0
    val_loss = 0
    val_num = 0
    for i, (XB, y) in enumerate(val_loader):
        if model.header == 'CNN':
            x = XI
        else:
            x = XB
        x, y = x.to(device), y.long().to(device)
        if x.size()[0] != batch_size:
            break
        val_num += x.size(0)
        if block == "phased_LSTM":
            x_decoded, latent, output = model(x, times)
        else:
            x_decoded, latent, output = model(x)

        # construct loss function
        cl_loss = cl_loss_fn(output, y)
        recon_loss = recon_loss_fn(x_decoded, x)
        loss = w_c*cl_loss + w_r *recon_loss
        
        # compute classification acc
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(y.data.view_as(pred)).long().cpu().sum().item()
        # accumulator
        val_loss += loss.item()
    
    # fill stats
    val_accuracy = correct / val_num
    val_loss /= val_num
    epoch_val_loss.append(val_loss)  # only save the last batch
    epoch_val_acc.append(val_accuracy)
    
    # if epoch < 20 or epoch%200 == 0:
    print("train_num {}, val_num {}".format(train_num, val_num))
    print('Epoch: {} Loss: train {:.3f}, valid {:.3f}. Accuracy: train: {:.3f}, valid {:.3f}'.format(epoch, train_loss, val_loss, train_accuracy, val_accuracy))
    
    

In [None]:
# choose model
    # if max_val_acc <= val_accuracy:
    #     model_dir = logDir + model_name + str(epoch) + '.pt'
    #     print('Saving model at {} epoch to {}'.format(epoch, model_dir))
    #     max_val_acc = val_accuracy
    #     torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, model_dir)
    model_dir = logDir + model_name + str(epoch) + '.pt'
    print('Saving model at {} epoch to {}'.format(epoch, model_dir))
    max_val_acc = val_accuracy
    torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, model_dir)

training_end =  datetime.now()
training_time = training_end -training_start 
print("training takes time {}".format(training_time))

model.is_fitted = True
model.eval()

# TEST
correct = 0
test_num = 0
for i, (XB,  y) in enumerate(test_loader):
    if model.header == 'CNN':
        x = XI
    else:
        x = XB
    x, y = x.to(device), y.long().to(device)
    
    if x.size(0) != batch_size:
        print(" test batch {} size {} < {}, skip".format(i, x.size()[0], batch_size))
        break
    test_num += x.size(0)
    if block == "phased_LSTM":
        x_decoded, latent, output = model(x, times)
    else:
        x_decoded, latent, output = model(x)

    # compute classification acc
    pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
    correct += pred.eq(y.data.view_as(pred)).long().cpu().sum().item()
    
test_acc = correct / test_num #len(test_loader.dataset)
print('Test accuracy for', str(kfold_number), ' fold : ', test_acc)

# Save stats
results_dict = {"epoch_train_loss": epoch_train_loss,
                "epoch_train_acc": epoch_train_acc,
                "epoch_val_loss": epoch_val_loss,
                "epoch_val_acc": epoch_val_acc,
                "test_acc": test_acc}

dict_name = model_name + '_stats_fold{}_{}.pkl'.format(str(kfold_number), args.rep)
pickle.dump(results_dict, open(logDir + dict_name, 'wb'))
print("dump results dict to {}".format(dict_name))

assert n_epochs == len(epoch_train_acc), "different epoch length {} {}".format(n_epochs, len(epoch_train_acc))
fig, ax = plt.subplots(figsize=(15, 7))
ax.plot(np.arange(n_epochs), epoch_train_acc, label="train acc")
ax.set_xlabel('epoch')
ax.set_ylabel('acc')
ax.grid(True)
plt.legend(loc='upper right')
figname = logDir + model_name +"_train_acc.png"
if if_plot:
    plt.show()

