In [48]:
import json
import math
import os
import random
import sys
import time
import warnings
from collections import Counter
from datetime import datetime
from pprint import pprint

import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
from transformers import AdamW

from arg_parser import arg_parser
from build_matrices import (build_design_matrices_classification,
                            build_design_matrices_seq2seq)
from config import build_config
from dl_utils import Brain2enDataset, MyCollator
from models import PITOM, ConvNet10, MeNTAL, MeNTALmini
from train_eval import evaluate_roc, evaluate_topk, plot_training, train, valid
from vocab_builder import get_sp_vocab, get_std_vocab, get_vocab

# from train_eval import *

# datetime object containing current date and time
now = datetime.now()
dt_string = now.strftime("%A %d/%m/%Y %H:%M:%S")
print("Start Time: ", dt_string)
results_str = now.strftime("%Y-%m-%d-%H:%M")

args = arg_parser()
CONFIG = build_config(args, results_str)

# sys.stdout = open(CONFIG["LOG_FILE"], 'w')

# Model objectives
MODEL_OBJ = {
    "ConvNet10": "classifier",
    "PITOM": "classifier",
    "MeNTALmini": "classifier",
    "MeNTAL": "seq2seq"
}

# GPUs
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args.gpus = min(args.gpus, torch.cuda.device_count())

# Fix random seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

args.model = args.model.split("_")[0]
classify = False if (args.model in MODEL_OBJ
                     and MODEL_OBJ[args.model] == "seq2seq") else True

CONV_DIRS = CONFIG["CONV_DIRS"]
SAVE_DIR = CONFIG["SAVE_DIR"]
TRAIN_CONV = CONFIG["TRAIN_CONV"]
VALID_CONV = CONFIG["VALID_CONV"]

# Load train and validation datasets
# (if model is seq2seq, using speaker switching for sentence cutoff,
# and custom batching)
if classify:
    print("Building vocabulary")
    word2freq, vocab, n_classes, w2i, i2w = get_vocab(CONFIG)

    print("Loading training data")
    x_train, y_train = build_design_matrices_classification(
        'train', CONFIG, w2i, delimiter=" ", aug_shift_ms=[-1000])
    sys.stdout.flush()
    print("Loading validation data")
    x_valid, y_valid = build_design_matrices_classification('valid',
                                                            CONFIG,
                                                            w2i,
                                                            delimiter=" ",
                                                            aug_shift_ms=[])
    sys.stdout.flush()
    if args.model == "ConvNet10":
        x_train = x_train[:, np.newaxis, ...]
        x_valid = x_valid[:, np.newaxis, ...]

    # Shuffle labels if required
    if args.shuffle:
        print("Shuffling labels")
        np.random.shuffle(y_train)
        np.random.shuffle(y_valid)

    x_train = torch.from_numpy(x_train).float()
    print("Shape of training signals: ", x_train.size())
    y_train = torch.from_numpy(y_train)
    train_ds = data.TensorDataset(x_train, y_train)

    x_valid = torch.from_numpy(x_valid).float()
    print("Shape of validation signals: ", x_valid.size())
    y_valid = torch.from_numpy(y_valid)
    valid_ds = data.TensorDataset(x_valid, y_valid)

    # Create dataset and data generators
    print("Creating dataset and generators")
    sys.stdout.flush()
    train_dl = data.DataLoader(train_ds,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=CONFIG["num_cpus"])
    valid_dl = data.DataLoader(valid_ds,
                               batch_size=args.batch_size,
                               num_workers=CONFIG["num_cpus"])
else:
    print("Building vocabulary")
    if CONFIG["vocabulary"] == 'spm':
        vocab = get_sp_vocab(CONFIG, algo='unigram', vocab_size=500)
    elif CONFIG["vocabulary"] == 'std':
        word2freq, word_list, n_classes, vocab, i2w = get_std_vocab(
            CONFIG, classify)
    else:
        print("Such vocabulary doesn't exist")
    # print([(i, vocab.IdToPiece(i)) for i in range(len(vocab))])

    print("Loading training data")
    x_train, y_train = build_design_matrices_seq2seq(
        'train', CONFIG, vocab, delimiter=" ", aug_shift_ms=[-1000, -500])

    print("Loading validation data")
    x_valid, y_valid = build_design_matrices_seq2seq('valid',
                                                     CONFIG,
                                                     vocab,
                                                     delimiter=" ",
                                                     aug_shift_ms=[])
    sys.stdout.flush()
    # Shuffle labels if required
    if args.shuffle:
        print("Shuffling labels")
        np.random.shuffle(y_train)
        np.random.shuffle(y_valid)
    train_ds = Brain2enDataset(x_train, y_train)
    print("Number of training signals: ", len(train_ds))
    valid_ds = Brain2enDataset(x_valid, y_valid)
    print("Number of validation signals: ", len(valid_ds))
    my_collator = MyCollator(CONFIG, vocab)
    train_dl = data.DataLoader(train_ds,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=CONFIG["num_cpus"],
                               collate_fn=my_collator)
    valid_dl = data.DataLoader(valid_ds,
                               batch_size=args.batch_size,
                               num_workers=CONFIG["num_cpus"],
                               collate_fn=my_collator)


Start Time:  Friday 15/05/2020 00:47:57
Subject: 625
Training Data:: Number of Conversations is: 63
Validation Data:: Number of Conversations is: 13
Subject: 676
Training Data:: Number of Conversations is: 49
Validation Data:: Number of Conversations is: 24
Building vocabulary
# Conversations: 112
Vocabulary size (min_freq=10): 1393
Saving word counter
Loading training data
NY625_418_Part3_conversation1
Number of train samples is: 2777
Number of train samples is: 2777
Maximum Sequence Length is: 1028
Loading validation data
NY625_421_Part4_one_conversation2
Number of valid samples is: 699
Number of valid samples is: 699
Maximum Sequence Length is: 3799
Skipped 5 examples
Number of training signals:  2772
Skipped 5 examples
Number of validation signals:  694


In [61]:
train_ds[0][1].shape

torch.Size([4])

In [49]:
dataiter = iter(train_dl)

In [50]:
src, trg, trg_y, pos_mask, pad_mask = dataiter.next()

In [54]:
src.shape

torch.Size([48, 60, 128])

In [78]:
[train_ds[i][0].shape for i in range(48*21, 48*22)]

[torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48, 128]),
 torch.Size([48,

In [137]:
from torch.nn.utils.rnn import pad_sequence
src = pad_sequence([train_ds[i][0] for i in range(48, 96)],
                           batch_first=True,
                           padding_value=0.)
labels = pad_sequence([train_ds[i][1] for i in range(48, 96)],
                              batch_first=True,
                              padding_value=vocab['<pad>'])
trg = torch.zeros(labels.size(0), labels.size(1),
                          len(vocab)).scatter_(
                              2, labels.unsqueeze(-1), 1)
trg, trg_y = trg[:, :-1, :], labels[:, 1:]

In [140]:
def masks(labels):
        pos_mask = (torch.triu(torch.ones(labels.size(1),
                                          labels.size(1))) == 1).transpose(
                                              0, 1).unsqueeze(0)
        pos_mask = pos_mask.float().masked_fill(pos_mask == 0,
                                                float('-inf')).masked_fill(
                                                    pos_mask == 1, float(0.0))
        pad_mask = labels == vocab['<pad>']
        
        return pos_mask, pad_mask