In [2]:
#!/usr/bin/env python
# coding: utf-8

'''
sample command: python T4_BT19_ae.py -k 0 -c 0 -r 1 --data_dir /home/ruihan/data
Individual training for BioTac data (full/partial data)
if -r=1, train with full data
if -r=2, train with half data
loss = classification loss + recon loss 
'''

# Import
import os,sys
import pickle
import argparse
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns

from vrae.vrae import VRAEC
from preprocess_data import get_TrainValTestLoader, get_TrainValTestDataset, get_TrainValTestData
from vrae.visual import plot_grad_flow, plot_stats

def train(args):
    # Set hyper params
    args_data_dir = args.data_dir
    kfold_number = args.kfold
    data_reduction_ratio = args.reduction
    shuffle = True # set to False for partial training
    sequence_length = 400
    number_of_features = 19

    hidden_size = args.h_s
    hidden_layer_depth = 1
    latent_length = 40
    batch_size = 32
    learning_rate = 0.001 # 0.0005
    n_epochs = 100
    dropout_rate = 0.2
    cuda = True # options: True, False
    header = None
    dataset = args.dataset
    if dataset == 'c50':
        num_class = 50
    else:
        num_class = 20

    # loss weightage
    w_r = args.w_r
    w_c = 1

    np.random.seed(1)
    torch.manual_seed(1)

    # Load data
    # data_dir = os.path.join(args_data_dir, "compiled_data/")
    logDir = 'models_and_stats/'
    if_plot = False

    # RNN block
    block = "phased_LSTM" # LSTM, GRU, phased_LSTM

    model_name = 'B_block_{}_data_{}_wrI_{}_wC_{}_hidden_{}_latent_{}_r_on_{}_p_max_{}'.format(block, dataset, w_r, w_c, str(hidden_size), str(latent_length), str(args.r_on), str(args.p_max))

    if torch.cuda.is_available() and cuda:
        device = torch.device("cuda:{}".format(args.cuda))
    else:
        device = torch.device('cpu')

    if args.reduction != 1:
        print("load {} kfold number, reduce data to {} folds, put to device: {}".format(args.kfold, args.reduction, device))
    else:
        print("load {} kfold number, train with full data, put to device: {}".format(args.kfold, device))

    prefix = ""
    dataset_dir = os.path.join(args_data_dir, dataset+"/") # TODO
    train_set, val_set, test_set = get_TrainValTestDataset(dataset_dir, k=0, prefix=prefix, seq_len=sequence_length)
    train_loader, val_loader, test_loader = get_TrainValTestLoader(dataset_dir, k=0, batch_size=batch_size,shuffle=shuffle, prefix=prefix,seq_len=sequence_length)
    X_train, X_val, X_test, Y_train, Y_val, Y_test = get_TrainValTestData(dataset_dir, k=0, prefix=prefix,seq_len=sequence_length)
    # Initialize models
    model = VRAEC(num_class=num_class,
                block=block,
                sequence_length=sequence_length, # TODO
                number_of_features = number_of_features,
                hidden_size = hidden_size, 
                hidden_layer_depth = hidden_layer_depth,
                latent_length = latent_length,
                batch_size = batch_size,
                learning_rate = learning_rate,
                n_epochs = n_epochs,
                dropout_rate = dropout_rate,
                cuda = cuda,
                model_name=model_name,
                header=header,
                device = device,
                ratio_on=args.r_on, period_init_max=args.p_max)
    model.to(device)

    # Initialize training settings
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    cl_loss_fn = nn.NLLLoss()
    recon_loss_fn = nn.MSELoss()

    # model.load_state_dict(torch.load('models_and_stats/model_phased_LSTM_B30.pt', map_location='cpu'))
    # saved_dicts = torch.load('models_and_stats/model_phased_LSTM_B.pt', map_location='cpu')
    # model.load_state_dict(saved_dicts['model_state_dict'])
    # optimizer.load_state_dict(saved_dicts['optimizer_state_dict'])

    training_start=datetime.now()
    # create empty lists to fill stats later
    epoch_train_loss = []
    epoch_train_acc = []
    epoch_val_loss = []
    epoch_val_acc = []
    max_val_acc = 0
    max_val_epoch = 0
    if block == "phased_LSTM":
        time = torch.Tensor(range(sequence_length))
        times = time.repeat(batch_size, 1)

#     for epoch in range(n_epochs):
#         # if epoch < 30:
#         #     continue
#         # TRAIN
#         model.train()
#         correct = 0
#         train_loss = 0
#         train_num = 0
#         for i, (XB,  y) in enumerate(train_loader):
#             if model.header == 'CNN':
#                 x = XI
#             else:
#                 x = XB
#             # x = x[:, :, 1::2]
#             x, y = x.to(device), y.long().to(device) # 32, 19, 400
#             if x.size()[0] != batch_size:
#                 break
            
#             # reduce data by data_reduction_ratio times
#             if i % data_reduction_ratio == 0:
#                 train_num += x.size(0)
#                 optimizer.zero_grad()
#                 if block == "phased_LSTM":
#                     x_decoded, latent, output = model(x, times)
#                 else:
#                     x_decoded, latent, output = model(x)

#                 # assert not torch.isnan(y).any(), "batch_num="+str(i)
#                 # print((output == 0).nonzero().size(0)==0)

#                 # assert (output == 0).nonzero().size(0)==0, 'output contain zero, batch_num'+str(i)+' indices:'+str((output == 0).nonzero())
#                 if (output == 0).nonzero().size(0) != 0:
#                     print('batch_num'+str(i)+' indices:'+str((output == 0).nonzero()))
#                     cl_loss = cl_loss_fn(output+1e-5, y) # avoid nan
#                 else:
#                     cl_loss = cl_loss_fn(output, y) 

#                 recon_loss = recon_loss_fn(x_decoded, x)
#                 loss = w_c*cl_loss + w_r *recon_loss
                
#                 # compute classification acc
#                 pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
#                 correct += pred.eq(y.data.view_as(pred)).long().cpu().sum().item()
#                 # accumulator
#                 train_loss += loss.item()
#                 start_bp = datetime.now()
#                 loss.backward()
#                 figname = logDir + model_name + "grad_flow_plot_epoch" +str(epoch)+".png"
#                 if i == 0: # and epoch%50 == 0:
#                     print('cl_loss:', cl_loss, 'recon_loss:', recon_loss)
#                     print("grad flow for epoch {}".format(epoch))
#                     plot_grad_flow(model.named_parameters(), figname, if_plot=False)
#                     # if epoch % 5 == 0:
#                     #     k = model.encoder.k_out
#                     #     k = k.squeeze()
#                     #     # print('k:',k[:, 0, 70])
#                     #     # tau = model.encoder.model.phased_cell.tau
#                     #     # print('tau:', tau)
#                     #     # phase = model.encoder.model.phased_cell.phase
#                     #     # print('phase:', phase)
#                     #     k = k[:,0,:].cpu().detach().numpy()
#                     #     x = x[0,:,:].permute(1,0)
#                     #     x = x.cpu().detach().numpy()
#                     #     fig, ax =plt.subplots(1,2)
#                     #     sns.heatmap(x, ax=ax[0])
#                     #     sns.heatmap(k, vmin=0, vmax=1, ax=ax[1])
#                     #     fig.savefig(logDir + model_name + "x_k_plot_epoch" +str(epoch)+".png")
#                     #     fig.clf()
#                 optimizer.step()
#                 # print('1 batch bp time:', datetime.now()-start_bp)

#         # if epoch == 0:
#         #     print('first epoch training time:', datetime.now()-training_start)
        
#         # if epoch < 20 or epoch%200 == 0:
#         # print("train last batch {} of {}: cl_loss {:.3f} recon_loss {:.3f}".format(i, len(train_loader), cl_loss, recon_loss))

#         # fill stats
#         train_accuracy = correct / train_num 
#         train_loss /= train_num
#         epoch_train_loss.append(train_loss)
#         epoch_train_acc.append(train_accuracy) 
        
#         # VALIDATION
#         model.eval()
#         correct = 0
#         val_loss = 0
#         val_num = 0
#         for i, (XB, y) in enumerate(val_loader):
#             if model.header == 'CNN':
#                 x = XI
#             else:
#                 x = XB
#             # x = x[:, :, 1::2]
#             x, y = x.to(device), y.long().to(device)
#             if x.size()[0] != batch_size:
#                 break
#             val_num += x.size(0)
#             if block == "phased_LSTM":
#                 x_decoded, latent, output = model(x, times)
#             else:
#                 x_decoded, latent, output = model(x)

#             # construct loss function
#             cl_loss = cl_loss_fn(output, y)
#             recon_loss = recon_loss_fn(x_decoded, x)
#             loss = w_c*cl_loss + w_r *recon_loss
            
#             # compute classification acc
#             pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
#             correct += pred.eq(y.data.view_as(pred)).long().cpu().sum().item()
#             # accumulator
#             val_loss += loss.item()
        
#         # fill stats
#         val_accuracy = correct / val_num
#         val_loss /= val_num
#         epoch_val_loss.append(val_loss)  # only save the last batch
#         epoch_val_acc.append(val_accuracy)
        
#         # if epoch < 20 or epoch%200 == 0:
#         print("train_num {}, val_num {}".format(train_num, val_num))
#         print('Epoch: {} Loss: train {:.3f}, valid {:.3f}. Accuracy: train: {:.3f}, valid {:.3f}'.format(epoch, train_loss, val_loss, train_accuracy, val_accuracy))
        
#         # choose model
#         if max_val_acc <= val_accuracy:
#             model_dir = logDir + model_name + str(epoch) + '.pt'
#             print('Saving model at {} epoch to {}'.format(epoch, model_dir))
#             max_val_acc = val_accuracy
#             max_val_epoch = epoch
#             torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, model_dir)
#         # model_dir = logDir + model_name + str(epoch) + '.pt'
#         # print('Saving model at {} epoch to {}'.format(epoch, model_dir))
#         # max_val_acc = val_accuracy
#         # torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, model_dir)
#     print('Best model at epoch {} with acc {:.3f}'.format(max_val_epoch, max_val_acc))
#     training_end =  datetime.now()
#     training_time = training_end -training_start 
#     print("training takes time {}".format(training_time))

#     model.is_fitted = True
#     model.eval()

#     # TEST at last epoch
#     correct = 0
#     test_num = 0
#     for i, (XB,  y) in enumerate(test_loader):
#         if model.header == 'CNN':
#             x = XI
#         else:
#             x = XB
#         # x = x[:, :, 1::2]
#         x, y = x.to(device), y.long().to(device)
        
#         if x.size(0) != batch_size:
#             print(" test batch {} size {} < {}, skip".format(i, x.size()[0], batch_size))
#             break
#         test_num += x.size(0)
#         if block == "phased_LSTM":
#             x_decoded, latent, output = model(x, times)
#         else:
#             x_decoded, latent, output = model(x)

#         # compute classification acc
#         pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
#         correct += pred.eq(y.data.view_as(pred)).long().cpu().sum().item()
        
#     test_acc1 = correct / test_num #len(test_loader.dataset)
#     print('last epoch Test accuracy for', str(kfold_number), ' fold : ', test_acc1)

    # TEST at the best model
    correct = 0
    test_num = 0
    
    saved_dicts = torch.load('models_and_stats/'+'B_block_phased_LSTM_data_c20_wrI_0.005_wC_1_hidden_90_latent_40_r_on_0.1_p_max_20060.pt', map_location='cpu')
    # saved_dicts = torch.load('models_and_stats/'+'model_name+str(max_val_epoch)+'.pt'', map_location='cpu')
    model.load_state_dict(saved_dicts['model_state_dict'])

    for i, (XB,  y) in enumerate(test_loader):
        if model.header == 'CNN':
            x = XI
        else:
            x = XB
        # x = x[:, :, 1::2]
        x, y = x.to(device), y.long().to(device)
        
        if x.size(0) != batch_size:
            print(" test batch {} size {} < {}, skip".format(i, x.size()[0], batch_size))
            break
        test_num += x.size(0)
        if block == "phased_LSTM":
            x_decoded, latent, output = model(x, times)
        else:
            x_decoded, latent, output = model(x)

        # compute classification acc
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        print(i, pred.eq(y.data.view_as(pred)).long().cpu().sum().item())
        correct += pred.eq(y.data.view_as(pred)).long().cpu().sum().item()
        
    print(correct, test_num)    
    test_acc2 = correct / test_num #len(test_loader.dataset)
    print('at the best model Test accuracy for', str(kfold_number), ' fold : ', test_acc2)

#     # Save stats
#     results_dict = {"epoch_train_loss": epoch_train_loss,
#                     "epoch_train_acc": epoch_train_acc,
#                     "epoch_val_loss": epoch_val_loss,
#                     "epoch_val_acc": epoch_val_acc,
#                     "test_acc1": test_acc1,
#                     "test_acc2": test_acc2}

#     dict_name = model_name + '_stats_fold{}_{}.pkl'.format(str(kfold_number), args.rep)
#     pickle.dump(results_dict, open(logDir + dict_name, 'wb'))
#     print("dump results dict to {}".format(dict_name))

    # assert n_epochs == len(epoch_train_acc), "different epoch length {} {}".format(n_epochs, len(epoch_train_acc))
    # fig, ax = plt.subplots(figsize=(15, 7))
    # ax.plot(np.arange(n_epochs), epoch_train_acc, label="train acc")
    # ax.set_xlabel('epoch')
    # ax.set_ylabel('acc')
    # ax.grid(True)
    # plt.legend(loc='upper right')
    # figname = logDir + model_name +"_train_acc.png"
    # if if_plot:
    #     plt.show()

    # plot_stats(logDir + dict_name)

# # Parse argument
# parser = argparse.ArgumentParser()
# parser.add_argument("-i", "--rep", type=int, default=0, help='index of running repetition')
# parser.add_argument('--data_dir', type=str, default='data', help="DIR set in 'gh_download.sh' to store compiled_data")
# parser.add_argument("-k", "--kfold", type=int, default=0, help="kfold_number for loading data")
# parser.add_argument("-r", "--reduction", type=int, default=1, help="data reduction ratio for partial training")
# parser.add_argument("-c", "--cuda", default=0, help="index of cuda gpu to use")
# parser.add_argument("--r_on", default=0.1, help="ratio_on for phased lstm")
# parser.add_argument("--p_max", default=200, help="period_init_max for phased lstm")
# parser.add_argument("--w_r", default=0.005, type=float, help="weight of recon loss")
# parser.add_argument("--h_s", default=90, type=int, help="hidden size of rnn layers")
# parser.add_argument("--dataset", default='c20', type=str, help="name of dataset")
# args = parser.parse_args()

# dummy class to replace argparser, if running jupyter notebook
class Args:
    rep = 0
    data_dir = 'data'
    kfold = 0
    cuda = '0'
    reduction = 1
    r_on = 0.1
    p_max = 200
    w_r = 0.005
    h_s = 90
    dataset = 'c20'

args=Args()

# for ds in ['c20', 'c20new', 'c50']:
#     args.dataset = ds
#     print(args)
#     train(args)


    # 0.932

    # for r_on in [0.1, 0.2, 0.3, 0.4]:
    #     args.r_on = r_on
    #     print(args)
    #     train(args)

    # for p_max in [75.0, 60.0, 45.0, 30.0]:
    #     args.p_max = p_max
    #     print(args)
    #     train(args) 

    # for hidden_size in [60, 70, 80, 90, 100]:
    #     args.h_s = hidden_size
    #     print(args)
    #     train(args)

    # for w_r in [0, 0.0005, 0.001, 0.005, 0.01]:
    #     args.w_r = w_r
    #     print(args)
    #     train(args)

In [3]:
print(args)
train(args)

<__main__.Args object at 0x7ffa609b7490>
load 0 kfold number, train with full data, put to device: cpu
chop org data of length 400 into 1 segments, each of which is has length 400
chop org data of length 400 into 1 segments, each of which is has length 400
chop org data of length 400 into 1 segments, each of which is has length 400
0 25
1 27
2 28
3 26
4 30
5 29
 test batch 6 size 8 < 32, skip
165 192
at the best model Test accuracy for 0  fold :  0.859375
