In [55]:
import argparse

### Parse custom arguments
''' Explanation of the Arguments:
model (string): Type of model to run from among (PITOM, ConvNet, MeNTALmini, MeNTAL
subject (list of strings): subject id's as a list
shift (integer): Amount by which the onset should be shifted
lr (float): learning rate
gpus (int): number of gpus for the model to run on
epochs (int):
batch-size (int):
window-size (int): window size to consider for the word in ms
bin-size (int): bin size in ms
init-model (string): 
'''
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='MeNTAL')
parser.add_argument('--subjects', nargs='*', default=['625', '676'])
parser.add_argument('--shift', type=int, default=0)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--gpus', type=int, default=16)
parser.add_argument('--epochs', type=int, default=60)
parser.add_argument('--batch-size', type=int, default=48)
parser.add_argument('--window-size', type=int, default=2000)
parser.add_argument('--bin-size', type=int, default=50)
parser.add_argument('--init-model', type=str, default=None)
parser.add_argument('--no-plot', action='store_false', default=False)
parser.add_argument('--electrodes', nargs='*', default=list(range(1,65)))
parser.add_argument('--vocab-min-freq', type=int, default=10)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--shuffle', action="store_true", default=False)
parser.add_argument('--no-eval', action="store_true", default=False)
parser.add_argument('--temp', type=float, default=0.995)
parser.add_argument('--tf-dmodel', type=int, default=64)
parser.add_argument('--tf-dff', type=int, default=256)
parser.add_argument('--tf-nhead', type=int, default=8)
parser.add_argument('--tf-nlayer', type=int, default=12)
parser.add_argument('--tf-dropout', type=float, default=0.05)
parser.add_argument('--weight-decay', type=float, default=0.35)
args = parser.parse_args([])

In [56]:
def createDict(*args):
     return dict(((k, eval(k)) for k in args))

In [57]:
import os

# Default Configuration
'''
exclude_words_class: words to be excluded from the classifier vocabulary
exclude_words: words to be excluded from the tranformer vocabulary
log_interval: 
'''
CONFIG = {
    "begin_token": "<s>",
    "datum_suffix": ("conversation_trimmed", "trimmed"),
    "electrodes": 64,
    "end_token": "</s>",
    "exclude_words_class": ["sp", "{lg}", "{ns}", "it", "a", "an", "and", "are",\
                      "as", "at", "be", "being", "by", "for", "from", "is",\
                      "of", "on", "that", "that's", "the", "there", "there's",\
                      "this", "to", "their", "them", "these", "he", "him",\
                      "his", "had", "have", "was", "were", "would"],
    "exclude_words": ["sp", "{lg}", "{ns}"],
    "log_interval": 32,
    "main_dir": "/scratch/gpfs/hgazula/brain2en",
    "data_dir": "/scratch/gpfs/hgazula",
    "num_cpus": 8,
    "oov_token": "<unk>",
    "pad_token": "<pad>",
    "print_pad": 120,
    "train_convs": '-train-convs.txt',
    "valid_convs": '-valid-convs.txt'
}

if len(args.subjects) == 1:
    if args.subjects[0] == '625':
        CONFIG["datum_suffix"] = [CONFIG["datum_suffix"][0]]
    elif args.subjects[0] == '676':
        CONFIG["datum_suffix"] = [CONFIG["datum_suffix"][1]]

### Model objectives
MODEL_OBJ = {
    "ConvNet10": "classifier",
    "PITOM": "classifier",
    "MeNTALmini": "classifier",
    "MeNTAL": "seq2seq"
}

In [58]:
import torch

# GPUs
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args.gpus = min(args.gpus, torch.cuda.device_count())

In [59]:
import sys

CONFIG.update(vars(args))
print("Script Configuration: ")
print(sorted(CONFIG.items()))
sys.stdout.flush()

Script Configuration: 
[('batch_size', 48), ('begin_token', '<s>'), ('bin_size', 50), ('data_dir', '/scratch/gpfs/hgazula'), ('datum_suffix', ('conversation_trimmed', 'trimmed')), ('electrodes', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64]), ('end_token', '</s>'), ('epochs', 60), ('exclude_words', ['sp', '{lg}', '{ns}']), ('exclude_words_class', ['sp', '{lg}', '{ns}', 'it', 'a', 'an', 'and', 'are', 'as', 'at', 'be', 'being', 'by', 'for', 'from', 'is', 'of', 'on', 'that', "that's", 'the', 'there', "there's", 'this', 'to', 'their', 'them', 'these', 'he', 'him', 'his', 'had', 'have', 'was', 'were', 'would']), ('gpus', 0), ('init_model', None), ('log_interval', 32), ('lr', 0.0001), ('main_dir', '/scratch/gpfs/hgazula/brain2en'), ('model', 'MeNTAL'), ('no_eval', False), ('no_plot', False), (

In [60]:
import random
import numpy as np

# Fix random seed
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

In [61]:
# Format directory logistics
CONV_DIRS = [
    CONFIG["data_dir"] + '/%s-conversations/' % i for i in args.subjects
]
META_DIRS = [CONFIG["data_dir"] + '/%s-metadata/' % i for i in args.subjects]
SAVE_DIR = './ipynbresults/%s/' % (args.model)

In [62]:
import os

if not os.path.isdir(SAVE_DIR):
    os.makedirs(SAVE_DIR)
args.model = args.model.split("_")[0]
classify = False if (args.model in MODEL_OBJ
                     and MODEL_OBJ[args.model] == "seq2seq") else True

In [63]:
from data_util import read_file

# Conversation splits
TRAIN_CONV, VALID_CONV = [], []
for meta, subject in zip(META_DIRS, args.subjects):
    TRAIN_CONV.append(
        read_file("%s%s%s" % (meta, subject, CONFIG["train_convs"])))
    VALID_CONV.append(
        read_file("%s%s%s" % (meta, subject, CONFIG["valid_convs"])))

Number of Conversations is: 63
Number of Conversations is: 13
Number of Conversations is: 49
Number of Conversations is: 24


In [72]:
conv_dirs = CONV_DIRS
subjects = args.subjects
conversations = TRAIN_CONV
algo='unigram'
vocab_size=1000
exclude_words=['sp', '{lg}', '{ns}']
datum_suffix=CONFIG["datum_suffix"]
oov_tok=CONFIG["oov_token"]
begin_tok=CONFIG["begin_token"]
end_tok=CONFIG["end_token"]
pad_tok=CONFIG["pad_token"]
min_freq=10

In [73]:
from collections import Counter
import glob 
import pandas as pd
# Generating vocaulary from reading the datums
exclude_words = set(exclude_words)
word2freq = Counter()
columns = ["word", "onset", "offset", "accuracy", "speaker"]
# files = [
#     f for conv_dir, subject, ds in zip(conv_dirs, subjects, datum_suffix)
#     for f in glob.glob(conv_dir + f'NY{subject}*/misc/*datum_{ds}.txt')
# ]

convs = [(conv_dir + conv_name, '/misc/*datum_%s.txt' % ds, idx)\
         for idx, (conv_dir, convs, ds) in enumerate(zip(conv_dirs, conversations, datum_suffix))\
         for conv_name in convs]

print(len(convs))

conv_count = 0
for conversation, suffix, idx in convs:
    
    # Check if files exists, if it doesn't go to next
    datum_fn = glob.glob(conversation + suffix)[0]
    if not datum_fn:
        print('File DNE: ', conversation + suffix)
        continue
        
    conv_count += 1
    with open(datum_fn, 'r') as fin:
        lines = map(lambda x: x.split(), fin)
        examples = map(lambda x: (" ".join([z for y in x[0:-4] if (z := y.lower().strip().replace('"', '')) not in exclude_words])), lines)
        examples = filter(lambda x: len(x) > 0, examples)
        examples = list(map(lambda x: x.split(), examples))
    word2freq.update(word for example in examples for word in example)

    
if min_freq > 1:
    word2freq = {
        word: freq
        for word, freq in word2freq.items() if freq >= min_freq
    }
vocab = sorted(word2freq.keys())
n_classes = len(vocab)
w2i = {word: i for i, word in enumerate(vocab)}
i2w = {i: word for word, i in w2i.items()}
print("# Conversations:", conv_count)
print("Vocabulary size (min_freq=%d): %d" % (min_freq, len(word2freq)))

112
# Conversations: 112
Vocabulary size (min_freq=10): 1389


In [52]:
conv_dirs = CONV_DIRS
subjects = args.subjects
conversations = TRAIN_CONV

In [75]:
import sentencepiece as spm
import glob
import pandas as pd

# Assign tokens
oov_token = oov_tok
begin_token = begin_tok
end_token = end_tok
pad_token = pad_tok

# set of exclude words
exclude_words = set(exclude_words)
columns = ["word", "onset", "offset", "accuracy", "speaker"]
convs = [(conv_dir + conv_name, '/misc/*datum_%s.txt' % ds, idx)\
         for idx, (conv_dir, convs, ds) in enumerate(zip(conv_dirs, conversations, datum_suffix))\
         for conv_name in convs]

print(len(convs))

words = []

conv_count = 0
for conversation, suffix, idx in convs:
    
    # Check if files exists, if it doesn't go to next
    datum_fn = glob.glob(conversation + suffix)[0]
    if not datum_fn:
        print('File DNE: ', conversation + suffix)
        continue

    df = pd.read_csv(datum_fn,
                     delimiter=' ',
                     header=None,
                     names=columns)
    df.word = df.word.str.lower()
    df = df[df.speaker == "Speaker1"]
    df.word = df[~df.word.str.lower().isin(exclude_words)]
    words.append(df.word.dropna().tolist())
    
wordsl = [item for sublist in words for item in sublist]
with open("vocab_temp.txt", "w") as fh:
    fh.writelines("%s\n" % str(place) for place in wordsl)

spm.SentencePieceTrainer.Train('--input=vocab_temp.txt --model_prefix=BrainTransformer --model_type=%s --vocab_size=%d --bos_id=0 --eos_id=1 --unk_id=2 --unk_surface=%s --pad_id=3' % (algo, vocab_size, oov_token))

sys.stdout.flush()
vocab = spm.SentencePieceProcessor()
vocab.Load("BrainTransformer.model")

print("# Conversations:", len(files))
print("Vocabulary size (%s): %d" % (algo, vocab_size))


112
# Conversations: 155
Vocabulary size (unigram): 1000


In [76]:
print([(i, vocab.IdToPiece(i)) for i in range(len(vocab))])

[(0, '<s>'), (1, '</s>'), (2, '<unk>'), (3, '<pad>'), (4, '▁i'), (5, "'"), (6, 's'), (7, '▁like'), (8, '▁you'), (9, 't'), (10, '▁yeah'), (11, '▁it'), (12, '▁a'), (13, '▁the'), (14, '▁to'), (15, '▁and'), (16, '▁know'), (17, '▁that'), (18, 'm'), (19, '▁so'), (20, 'ing'), (21, '▁just'), (22, 'n'), (23, '▁was'), (24, '▁they'), (25, '▁'), (26, 'y'), (27, '▁my'), (28, '▁me'), (29, 'e'), (30, '▁don'), (31, '▁of'), (32, 're'), (33, '▁in'), (34, 'ed'), (35, '▁but'), (36, '▁what'), (37, '▁have'), (38, '▁she'), (39, '▁no'), (40, 'd'), (41, 'ay'), (42, '▁be'), (43, '▁for'), (44, '▁um'), (45, '▁this'), (46, '▁oh'), (47, '▁on'), (48, '▁ok'), (49, '▁do'), (50, '▁not'), (51, 'a'), (52, '▁he'), (53, 'll'), (54, '▁go'), (55, '▁m'), (56, '▁right'), (57, '▁if'), (58, '▁did'), (59, '▁we'), (60, '▁there'), (61, '▁with'), (62, '▁can'), (63, '▁gonna'), (64, '▁good'), (65, 've'), (66, 'h'), (67, 'o'), (68, 'ly'), (69, '▁really'), (70, 'er'), (71, '▁one'), (72, 'r'), (73, '▁or'), (74, 'u'), (75, '▁when'), (76, 

==========================================================================================================

==========================================================================================================

==========================================================================================================

==========================================================================================================

In [30]:
### Get electrode date helper
def get_electrode(elec_id):
    conversation, electrode = elec_id
    search_str = conversation + f'/preprocessed/*_{electrode}.mat'
    mat_fn = glob.glob(search_str)
    if len(mat_fn) == 0:
        print(f'[WARNING] electrode {electrode} DNE in {search_str}')
        return None
    return loadmat(mat_fn[0])['p1st'].squeeze().astype(np.float32)

In [31]:
# Function input arguments
conv_dirs = CONV_DIRS
subjects = args.subjects
conversations = TRAIN_CONV
delimiter=" "
bin_ms=args.bin_size
shift_ms=args.shift
window_ms=args.window_size
electrodes=args.electrodes
datum_suffix=CONFIG["datum_suffix"]
exclude_words=['sp', '{lg}', '{ns}']
aug_shift_ms=[-1000, -500]
fs = 512

# extra stuff that happens inside
oov_token=CONFIG["oov_token"]
begin_token=CONFIG["begin_token"]
end_token=CONFIG["end_token"]
pad_token=CONFIG["pad_token"]

In [32]:
import glob
from multiprocessing import Pool
from scipy.io import loadmat
import math
from pprint import pprint
import pandas as pd
# import seaborn as sns
# import matplotlib.pyplot as plt

# from matplotlib import rc
# rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
# ## for Palatino and other serif fonts use:
# #rc('font',**{'family':'serif','serif':['Palatino']})
# rc('text', usetex=True)

In [33]:
signals, labels = [], []
bin_fs = int(bin_ms / 1000 * fs)
shift_fs = int(shift_ms / 1000 * fs)
window_fs = int(window_ms / 1000 * fs)
half_window = window_fs // 2
start_offset = - half_window + shift_fs
end_offset = half_window + shift_fs
aug_shift_fs = [int(s / 1000 * fs) for s in aug_shift_ms]

signal_param_dict = createDict('bin_fs',
                               'shift_fs',
                               'window_fs',
                               'half_window',
                               'start_offset',
                               'end_offset')

In [34]:
convs = [(conv_dir + conv_name, '/misc/*datum_%s.txt' % ds, idx)\
         for idx, (conv_dir, convs, ds) in enumerate(zip(conv_dirs, conversations, datum_suffix))\
         for conv_name in convs]

In [35]:
len(convs)

112

In [77]:
def return_electrode_array(conv, suf):
    # Read signals
    elec_ids = ((conversation, electrode) for electrode in electrodes)
    with Pool() as pool:
        ecogs = list(filter(lambda x: x is not None,
                            pool.map(get_electrode, elec_ids)))

    ecogs = np.asarray(ecogs)
    ecogs = (ecogs - ecogs.mean(axis=1).reshape(ecogs.shape[0], 1)) / ecogs.std(axis=1).reshape(ecogs.shape[0], 1)
    ecogs = ecogs.T
    assert(ecogs.ndim == 2 and ecogs.shape[1] == len(electrodes))
    return ecogs

In [37]:
# def generate_w2i(string, wtoi, oov_tok):
#     index_list = [wtoi[x] if x in wtoi else wtoi[oov_tok] for x in string.split(' ')]
#     return index_list


# def return_examples(file, delim, word2i, oov_tok, vocab):
#     columns = ["word", "onset", "offset", "accuracy", "speaker"]  # columns in the datum
#     df = pd.read_csv(file, delimiter=' ', header=None, names=columns)

#     df.drop(columns=['offset', 'accuracy'], inplace=True)  # dropping offset and accuracy
#     df.word = df.word.str.lower()  # converting words to lower case
#     df = df[df.word != 'sp']  # deleting rows with 'sp' in word column
#     df.sort_values(by=['onset'], inplace=True)  # sorting the datum based on onset
#     df.speaker = df.speaker.str.strip() # stripping \n from Speaker
        
# #     # concatenate strings where the onset is the same for the same speaker
# #     df['word'] = df.groupby(['onset', 'speaker'])['word'].transform(lambda x: ' '.join(x))
#     df.drop_duplicates()
    
#     '''# Example for the above line
#     a = [['hello', 'Speaker 1', 6445],
#          ['zaid', 'Speaker 1', 6445]
#          ['harsha', 'Speaker 2', 6445],
#          ['yeah', 'Speaker 1', 7345]
#         ] 
#     df_a = pd.DataFrame(a, columns=['word', 'speaker', 'onset'])
#     '''
#     print(df.word.tolist()[:50])
#     df.speaker = df.speaker == 'Speaker1'  # check if speaker 1
#     df.word = df.word.apply(lambda x: generate_w2i(x, word2i, oov_tok))
#     print(df.word.tolist()[:50])

#     df = df[['word', 'speaker', 'onset']]
#     print(df.head(50))
#     df_vec = df.onset.diff() <= 0
#     plt.figure()
#     plt.plot(df.onset)
    
#     plt.figure()
#     plt.plot(df_vec[1:])
  
#     examples = list(df.to_records(index=False))
    
#     ####### part 2 ######

        
#     return examples

### This following block of code extracts  bi-grams and their corresponding neural signals

In [78]:
def return_examples_std(file, delim, vocabulary, ex_words):
    with open(file, 'r') as fin:
        lines = map(lambda x: x.split(delim), fin)
        examples = map(lambda x: (" ".join([z for y in x[0:-4] if (z := y.lower().strip().replace('"', '')) not in ex_words]),
                                  x[-1].strip() == "Speaker1",
                                  x[-4], x[-3]), lines)
        examples = filter(lambda x: len(x[0]) > 0, examples)
        examples = map(lambda x: ([vocabulary[x] for x in x[0].split()], x[1], int(float(x[2])), int(float(x[3]))), examples)
        return list(examples)
    
    
def return_examples_spm(file, delim, vocabulary, ex_words):
    
    with open(file, 'r') as fin:
        lines = map(lambda x: x.split(delim), fin)
        examples = map(lambda x: (" ".join([z for y in x[0:-4] if (z := y.lower().strip().replace('"', '')) not in ex_words]),
                                  x[-1].strip() == "Speaker1",
                                  x[-4], x[-3]), lines)
        examples = filter(lambda x: len(x[0]) > 0, examples)
        examples = map(lambda x: (vocabulary.EncodeAsIds(x[0]), x[1], int(float(x[2])), int(float(x[3]))), examples)
        return list(examples)

    
def generate_wordpairs_ghub(examples):
    '''if the first set already has two words and is speaker 1
        if the second set already has two words and is speaker 1
        the onset of the first word is earlier than the second word
    '''
    my_grams = []
    for first, second in zip(examples, examples[1:]):
        len1, len2 = len(first[0]), len(second[0])
        if first[1] and len1 == 2:
            my_grams.append(first)
        if second[1] and len2 == 2:
            my_grams.append(second)
        if ((first[1] and second[1]) and (len1 == 1 and len2 == 1)
                and (first[2] < second[2])):
            ak = (first[0] + second[0], True, first[2], second[3])
            my_grams.append(ak)
    return my_grams


def generate_wordpairs(examples):
    my_grams = []
    for first, second in zip(examples, examples[1:]):
        len1, len2 = len(first[0]), len(second[0])
        if first[1] and len1 == 2:  # if the first set already has two words and is speaker 1
            my_grams.append(first)
        if second[1] and len2 == 2:  # if the second set already has two words and is speaker 1
            my_grams.append(second)
        if first[1] and second[1]:
            if len1 == 1 and len2 == 1:
                if first[2] < second[2]:  # the onset of the first word is earlier than the second word
                    ak = (first[0] + second[0], True, first[2], second[3])
                    my_grams.append(ak)
    return my_grams


def remove_duplicates(grams):
    df = pd.DataFrame(grams)
    df[['fw', 'sw']] = pd.DataFrame(df[0].tolist()) 
    df = df.drop(columns=[0]).drop_duplicates()
    df[0] = df[['fw', 'sw']].values.tolist()
    df = df.drop(columns=['fw', 'sw'])
    df = df[sorted(df.columns)]
    return list(df.to_records(index=False))


def add_begin_end_tokens(word_pair, vocabulary, start_tok, stop_tok):
    word_pair.insert(0, vocabulary[start_tok])  # Add start token
    word_pair.append(vocabulary[stop_tok])  # Add end token
    return word_pair


def test_for_bad_window(start, stop, shape, window):
    return (start < 0 or  # if the window_begin is less than 0 or
        start > shape[0] or  # check if onset is within limits
        stop < 0 or  # if the window_end is less than 0 or
        stop > shape[0] or  # if the window_end is outside the signal 
        stop - start < window)  # if there are not enough frames in the window


def calculate_windows_params(gram, param_dict):
    seq_length = gram[3] - gram[2]
    begin_window = bigram[2] + param_dict['start_offset']
    end_window = bigram[3] + param_dict['end_offset']
    bin_size = int(math.ceil((end_window - begin_window) / param_dict['bin_fs']))  # calculate number of bins

    return seq_length, begin_window, end_window, bin_size


def remove_oovs(grams, vocabulary, data_tag='train'):
    if data_tag == 'train':
        grams = filter(lambda x: vocabulary['<unk>'] not in x[0], grams)
    else:
        grams = filter(lambda x: x[0] != [vocabulary['<unk>']] * 2, grams)
    return list(grams)

#### Generating Training Data Set

In [80]:
delimiter

' '

In [87]:
train_seq_lengths, signals, labels = [], [], []
for conversation, suffix, idx in convs:
    
    # Check if files exists, if it doesn't go to next
    datum_fn = glob.glob(conversation + suffix)[0]
    if not datum_fn:
        print('File DNE: ', conversation + suffix)
        continue
        
    # Extract electrode data
    ecogs = return_electrode_array(conversation, suffix)
    if not ecogs.size:
        print(f'Skipping bad conversation: {conversation}')
        continue

    examples = return_examples_spm(datum_fn, delimiter, vocab, exclude_words)  # for spm
    bigrams = generate_wordpairs(examples)
    bigrams_ghub = generate_wordpairs_ghub(examples)
    if not bigrams:
        print(f'Skipping bad conversation: {conversation}')
        continue
    bigrams = remove_duplicates(bigrams)
    
    for bigram in bigrams:
        seq_length, start_onset, end_onset, n_bins = calculate_windows_params(bigram, signal_param_dict)
        
        if seq_length <= 0:
            continue
            
        train_seq_lengths.append(seq_length)
        
        if test_for_bad_window(start_onset, end_onset, ecogs.shape, window_fs):
            continue

        labels.append(add_begin_end_tokens(bigram[0], vocab, begin_tok, end_tok))  # put the sentence in the label vector/list
        word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32) 
        for i, f in enumerate(np.array_split(ecogs[start_onset:end_onset,:], n_bins, axis=0)):
            word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)

        #TODO: Data Augmentation
        signals.append(word_signal)
        
print('final')
assert len(labels) == len(signals), "Bad Shape for Lengths"
x_train = signals
y_train = labels

final


In [88]:
len(labels)

65037

#### Generating validation data set

In [92]:
conversations = VALID_CONV
convs = [(conv_dir + conv_name, '/misc/*datum_%s.txt' % ds, idx)\
         for idx, (conv_dir, convs, ds) in enumerate(zip(conv_dirs, conversations, datum_suffix))\
         for conv_name in convs]

signals, labels = [], []
for (conversation, suffix, idx) in convs:
    # Check if files exists, if it doesn't go to next
    datum_fn = glob.glob(conversation + suffix)[0]
    if not datum_fn:
        print('File DNE: ', conversation + suffix)
        continue
        
    # Extract electrode data
    ecogs = return_electrode_array(conversation, suffix)
    if not ecogs.size:
        print(f'Skipping bad conversation: {conversation}')
        continue

    examples = return_examples_spm(datum_fn, delimiter, vocab, exclude_words)
    bigrams = generate_wordpairs(examples)
    if not bigrams:
        print(f'Skipping no bigrams: {conversation}')
        continue
    bigrams = remove_duplicates(bigrams)
    
    valid_seq_lengths = 0
    for bigram in bigrams:
        seq_length, start_onset, end_onset, n_bins = calculate_windows_params(bigram, signal_param_dict)
        
        if seq_length <=0:
            continue
            
        if test_for_bad_window(start_onset, end_onset, ecogs.shape, window_fs):
            continue

        labels.append(add_begin_end_tokens(bigram[0], vocab, begin_tok, end_tok))  # put the sentence in the label vector/list
        word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32) 
        for i, f in enumerate(np.array_split(ecogs[start_onset:end_onset,:], n_bins, axis=0)):
            word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)

        #TODO: Data Augmentation
        signals.append(word_signal)
        
print('final')
assert len(labels) == len(signals), "Bad Shape for Lengths"
x_valid = signals
y_valid = labels

Skipping no bigrams: /scratch/gpfs/hgazula/676-conversations/NY676_618_Part8_conversation1
final


#### Converting train and validation data to Loader objects

In [None]:
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

### Pytorch Dataset wrapper
class Brain2enDataset(Dataset):
    """Brainwave-to-English Dataset."""

    def __init__(self, signals, labels):
        """
        Args:
            signals (list): brainwave examples.
            labels (list): english examples.
        """
        global oov_token, vocab

        assert(len(signals) == len(labels))
        indices = [(i, len(signals[i]), len(labels[i]))\
                   for i in range(len(signals))]
        indices.sort(key=lambda x: (x[1], x[2], x[0]))
        self.examples = []
        self.max_seq_len = 0
        self.max_sent_len = 0
        self.train_freq = Counter()
        c = 0
        for i in indices:
            if i[1] > 384 or i[2] < 4 or i[2] > 128:
                c += 1
                continue
            lab = labels[i[0]]
            self.train_freq.update(lab)
            lab = torch.tensor(lab).long()
            self.examples.append((torch.from_numpy(signals[i[0]]).float(), lab))
            self.max_seq_len = max(self.max_seq_len, i[1])
            self.max_sent_len = max(self.max_sent_len, len(lab))
        print("Skipped", c, "examples")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

### Create a mask for subsequent positions and a mask for padding
def masks(labels):
    global pad_token, vocab
    pos_mask = (torch.triu(torch.ones(labels.size(1), labels.size(1))) == 1).transpose(0, 1).unsqueeze(0)
    pos_mask = pos_mask.float().masked_fill(pos_mask == 0,
        float('-inf')).masked_fill(pos_mask == 1, float(0.0))
    pad_mask = labels == vocab[pad_token]
    return pos_mask, pad_mask


### Batch padding method
def pad_batch(batch):
    global end_token, pad_token, vocab
    src = pad_sequence([batch[i][0] for i in range(len(batch))], batch_first=True, padding_value=0.)
    labels = pad_sequence([batch[i][1] for i in range(len(batch))], batch_first=True, padding_value=vocab[pad_token])
    trg = torch.zeros(labels.size(0), labels.size(1), len(vocab)).scatter_(2, labels.unsqueeze(-1), 1)
    trg, trg_y = trg[:,:-1,:], labels[:,1:]
    pos_mask, pad_mask = masks(trg_y)
    return src, trg, trg_y, pos_mask, pad_mask

In [None]:
from collections import Counter
import torch.utils.data as data

### Shuffle labels if required
if args.shuffle:
    print("Shuffling labels")
    np.random.shuffle(y_train)
    np.random.shuffle(y_valid)
train_ds = Brain2enDataset(x_train, y_train)
print("Number of training signals: ", len(train_ds))
valid_ds = Brain2enDataset(x_valid, y_valid)
print("Number of validation signals: ", len(valid_ds))
train_dl = data.DataLoader(train_ds, batch_size=args.batch_size,
                           shuffle=True, num_workers=CONFIG["num_cpus"],
                           collate_fn=pad_batch)
valid_dl = data.DataLoader(valid_ds, batch_size=args.batch_size,
                           num_workers=CONFIG["num_cpus"],
                           collate_fn=pad_batch)

#### Default models and parameters

In [None]:
### Default models and parameters
DEFAULT_MODELS = {
    "ConvNet10": (len(vocab),),
    "PITOM": (len(vocab), len(args.electrodes)*len(args.subjects)),
    "MeNTALmini": (len(args.electrodes)*len(args.subjects), len(vocab),\
                        args.tf_dmodel, args.tf_nhead, args.tf_nlayer,\
                        args.tf_dff, args.tf_dropout),
    "MeNTAL": (len(args.electrodes)*len(args.subjects), len(vocab),
                         args.tf_dmodel, args.tf_nhead, args.tf_nlayer,\
                         args.tf_dff, args.tf_dropout)
}

#### Creating a Model

In [None]:
from models import *

if args.init_model is None:
    if args.model in DEFAULT_MODELS:
        print("Building default model: %s" % args.model, end="")
        model_class = globals()[args.model]
        model = model_class(*(DEFAULT_MODELS[args.model]))
    else:
        print("Building custom model: %s" % args.model, end="")
        sys.exit(1)
else:
    model_name = "%s%s.pt" % (SAVE_DIR, args.model)
    if os.path.isfile(model_name):
        model = torch.load(model_name)
        model = model.module if hasattr(model, 'module') else model
        print("Loaded initial model: %s " % args.model)
    else:
        print("No models found in: ", SAVE_DIR)
        sys.exit(1)
print(" with %d trainable parameters"
    % sum([p.numel() for p in model.parameters() if p.requires_grad]))
sys.stdout.flush()

#### Initialize loss and optimizer

In [None]:
from transformers import AdamW, get_cosine_schedule_with_warmup

# Initialize loss and optimizer
# weights = torch.ones(n_classes)
# max_freq = -1.
# for i in range(n_classes):
#     max_freq = max(max_freq, word2freq[vocab[i]])
#     weights[i] = 1./float(word2freq[vocab[i]])
# weights = weights*max_freq
# print(sorted([(vocab[i], round(float(weights[i]),1)) for i in range(n_classes)],
#              key=lambda x: x[1]))
criterion = nn.CrossEntropyLoss()
step_size = int(math.ceil(len(train_ds)/args.batch_size))
# optimizer = optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98),
#                         eps=1e-9, weight_decay=args.weight_decay)
# optimizer = optim.AdamW(model.parameters(), lr=args.lr,
#                         weight_decay=args.weight_decay)
optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# optimizer = NoamOpt(args.tf_dmodel, 0.2, 5*step_size,
#             optim.Adam(model.parameters(), lr=0., betas=(0.9, 0.98), eps=1e-9))
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
#                                            milestones=[10*step_size,20*step_size, 40*step_size],
#                                            gamma=0.2)
# scheduler = get_cosine_schedule_with_warmup(optimizer, 10*step_size,
#                                             args.epochs*step_size, num_cycles=2.5)
scheduler = None

In [None]:
# Move model and loss to GPUs
if args.gpus:
    model.cuda()
    criterion.cuda()
    if args.gpus > 1:
        model = nn.DataParallel(model)


# Batch chunk size to send to single GPU
# import math
# chunk_size = int(math.ceil(float(args.batch_size)/max(1,args.gpus)))

In [None]:
import time
################################################################################
#
# Brain2En > Training and Evaluation Utilities
#
################################################################################

### Libraries
from collections import Counter
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import re
from sklearn.metrics import auc, confusion_matrix, roc_curve
import sys
import time
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


################################################################################
#
# Optimization Classes and Methods
#
################################################################################


CLIP_NORM = 1.0
REGEX = re.compile('[^a-zA-Z]')

### NOAM Optimizer
class NoamOpt:
    "Optimizer wrapper implementing learning scheme"
    def __init__(self, d_model, prefactor, warmup, optimizer):
        self.d_model = d_model
        self.optimizer = optimizer
        self.warmup = warmup
        self.prefactor = prefactor
        self._step = 0
        self._rate = 0

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step = None):
        "Implement learning rate warmup scheme"
        if step is None:
            step = self._step
        return self.prefactor * (self.d_model ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))

### Regularization by Label Smoothing
class LabelSmoothing(nn.Module):
    "Implements label smoothing on a multiclass target."
    def __init__(self, criterion, size, pad_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = criterion
        self.pad_idx = pad_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None

    def forward(self, x, target):
        assert(x.size(1) == self.size)
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1,
            target.data.unsqueeze(1).long(),
            self.confidence)
        true_dist[:, self.pad_idx] = 0
        mask = torch.nonzero(target.data == self.pad_idx)
        if mask.sum() > 0 and len(mask) > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

### Single GPU Loss Computation
class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, criterion, opt=None, scheduler=None):
        self.criterion = criterion
        self.opt = opt
        self.scheduler = scheduler

    def __call__(self, x, y, val=False):
        loss = self.criterion(x.view(-1, x.size(-1)),\
                              y.view(-1))
        if not val:
            loss.backward()
            if self.opt is not None:
                self.opt.step()
            if self.scheduler is not None:
                self.scheduler.step()
        return loss.data.item()

### Multi GPU Loss Computation
class MultiGPULossCompute:
    "A multi-gpu loss compute and train function."
    def __init__(self, criterion, devices, opt=None, scheduler=None, chunk_size=5):
        # Send out to different gpus.
        self.criterion = nn.parallel.replicate(criterion,
                                               devices=devices)
        self.opt = opt
        self.scheduler = scheduler
        self.devices = devices
        self.chunk_size = chunk_size

    def __call__(self, x, y, val=False):
        total = 0.0
        out_scatter = nn.parallel.scatter(out,
                                          target_gpus=self.devices)
        out_grad = [[] for _ in out_scatter]
        targets = nn.parallel.scatter(targets,
                                      target_gpus=self.devices)

        # Divide generating into chunks.
        chunk_size = self.chunk_size
        for i in range(0, out_scatter[0].size(1), chunk_size):
            # Predict distributions
            out_column = [[Variable(o[:, i:i+chunk_size].data,
                                    requires_grad=self.opt is not None)]
                           for o in out_scatter]
            gen = nn.parallel.parallel_apply(generator, out_column)

            # Compute loss.
            y = [(g.contiguous().view(-1, g.size(-1)),
                  t[:, i:i+chunk_size].contiguous().view(-1))
                 for g, t in zip(gen, targets)]
            loss = nn.parallel.parallel_apply(self.criterion, y)

            # Sum and normalize loss
            l = nn.parallel.gather(loss,
                                   target_device=self.devices[0])
            l = l.sum()[0]
            total += l.data[0]

            # Backprop loss to output of transformer
            if not val and self.opt is not None:
                l.backward()
                for j, l in enumerate(loss):
                    out_grad[j].append(out_column[j][0].grad.data.clone())
        if not val:
            if self.opt is not None:
                out_grad = [Variable(torch.cat(og, dim=1)) for og in out_grad]
                o1 = out
                o2 = nn.parallel.gather(out_grad,
                                        target_device=self.devices[0])
                o1.backward(gradient=o2)
                self.opt.step()
            if self.scheduler is not None:
                self.scheduler.step()
        return total


################################################################################
#
# Training Methods
#
################################################################################


### Training loop
def train(data_iter, model, criterion, devices, opt,
          scheduler=None, seq2seq=False, pad_idx=-1):
    model.train()
    start_time = time.time()
    total_loss = 0.
    total_acc = 0.
    count, batch_count = 0, 0
    for i, batch in enumerate(data_iter):
        # Prevent gradient accumulation
        model.zero_grad()
        src = batch[0].cuda()
        trg = batch[1].long().to(src.device)
        if seq2seq:
            # trg_y = batch[2].long().to(src.device)
            # trg_pos_mask, trg_pad_mask = batch[3].to(src.device), batch[4].to(src.device)
            # out, trg_y = model.forward(src, trg, trg_y, trg_pos_mask, trg_pad_mask)
            # # Fix asymmetrical load on single GPU by computing loss in parallel
            # out_scatter = nn.parallel.scatter(out, target_gpus=devices)
            # out_grad = [[] for _ in out_scatter]
            # trg_y_scatter = nn.parallel.scatter(trg_y, target_gpus=devices)
            # chunk_size = int(math.ceil(out_scatter[0].size(0)/len(devices)))
            # for i in range(0, out_scatter[0].size(0), chunk_size):
            #     out_scatter_chunk = [Variable(o[i:end], requires_grad=True)\
            #                          for o in out_scatter if (end := min(i+chunk_size,len(o))) > i]
            #     trg_y_scatter_chunk = [t[i:end] for t in trg_y_scatter\
            #                            if (end := min(i+chunk_size, len(t))) > i]
            #     idx_scatter = [(t != pad_idx).nonzero(as_tuple=True) for t in trg_y_scatter_chunk]
            #     y = [(o.contiguous().view(-1, o.size(-1)), t.contiguous().view(-1)) for o, t in zip(out_scatter_chunk, trg_y_scatter_chunk)]
            #     print(len(criterion), len(y))
            #     print([(o.size(), t.size()) for o, t in y])
            #     loss = nn.parallel.parallel_apply(criterion[:len(y)], y)
            #     loss = [l.unsqueeze(-1) for l in loss]
            #     num_dev = len(loss)
            #     # print(num_dev)
            #     loss = nn.parallel.gather(loss, target_device=devices[0])
            #     loss = loss.sum()
            #     total_loss += float(loss.item() / num_dev)
            #     loss.backward()
            #     for j in range(num_dev):
            #         out_grad[j].append(out_scatter_chunk[j].grad.data.clone())
            #     out_scatter_chunk = [torch.argmax(o[idx], dim=1) for idx, o in zip(idx_scatter, out_scatter_chunk)]
            #     trg_y_scatter_chunk = [t[idx] for idx, t in zip(idx_scatter, trg_y_scatter_chunk)]
            #     total_acc += sum([float((o == t).sum()) for o, t in zip(out_scatter_chunk, trg_y_scatter_chunk)])
            #     count += sum(int(o.size(0)) for o in out_scatter_chunk)
            #     # del out_scatter_chunk, trg_y_scatter_chunk, idx_scatter, y, loss
            #     print(total_loss, total_acc / count)
            #     sys.stdout.flush()
            # print("Here now")
            # out_grad = [Variable(torch.cat(og, dim=0)) for og in out_grad]
            # print(out.size(), trg_y.size())
            # print([o.size() for o in out_grad])
            # sys.stdout.flush()
            # out.backward(gradient=nn.parallel.gather(out_grad,
            #                                          target_device=devices[0]))
            # del src, trg, trg_y, trg_pos_mask, trg_pad_mask, out, out_grad
            # print("brah")
            # loss = criterion(out.view(-1, out.size(-1)), trg_y.view(-1))
            # loss.backward()
            trg_y = batch[2].long().to(src.device)
            trg_pos_mask, trg_pad_mask = batch[3].to(src.device), batch[4].to(src.device)
            # Perform loss computation during forward pass for parallelism
            out, trg_y, loss = model.forward(src, trg, trg_pos_mask, trg_pad_mask, trg_y, criterion)
            idx = (trg_y != pad_idx).nonzero(as_tuple=True)
            total_loss += loss.data.item()
            out = out[idx]
            trg_y = trg_y[idx]
            out = torch.argmax(out, dim=1)
            total_acc += float((out == trg_y).sum())
            opt.step()
            if scheduler is not None:
                scheduler.step()
            # total_loss += loss.data.item()
            # out = out[idx]
            # trg_y = trg_y[idx]
            # out = torch.argmax(out, dim=1)
            # total_acc += float((out == trg_y).sum())
            # print("hereo")
            # sys.stdout.flush()
        else:
            out = model.forward(src)
            loss = criterion(out.view(-1, out.size(-1)), trg.view(-1))
            loss.backward()
            if opt is not None:
                opt.step()
            if scheduler is not None:
                scheduler.step()
            total_loss += loss.data.item()
            total_acc += float((torch.argmax(out, dim=1) == trg).sum())
            # count += int(out.size(0))
        # Prevent gradient blowup
        nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
        count += int(out.size(0))
        batch_count += 1
    total_loss /= batch_count
    total_acc /= count
    elapsed = (time.time() - start_time) * 1000. / batch_count
    perplexity = float('inf')
    try:
        perplexity = math.exp(total_loss)
    except:
        pass
    print('loss {:5.3f} | accuracy {:5.3f} | perplexity {:3.2f} | ms/batch {:5.2f}'.format(total_loss, total_acc, perplexity, elapsed), end='')
    return total_loss, total_acc

### Validation loop
def valid(data_iter, model, criterion,
          temperature=1.0, n_samples=10, seq2seq=False, pad_idx=-1):
    model.eval()
    total_loss = 0.
    total_acc = 0.
    total_sample_rank_acc = 0.
    batch_count, count = 0, 0
    for i, batch in enumerate(data_iter):
        src = batch[0].cuda()
        trg = batch[1].long().to(src.device)
        if seq2seq:
            trg_y = batch[2].long().to(src.device)
            trg_pos_mask, trg_pad_mask = batch[3].to(src.device), batch[4].to(src.device)
            out, trg_y, loss = model.forward(src, trg, trg_pos_mask, trg_pad_mask, trg_y, criterion)
            idx = (trg_y != pad_idx).nonzero(as_tuple=True)
            total_loss += loss.data.item()
            out = out[idx]
            trg_y = trg_y[idx]
            out_top1 = torch.argmax(out, dim=1)
            total_acc += float((out_top1 == trg_y).sum())
            out = F.softmax(out/temperature, dim=1)
            samples = torch.multinomial(out, n_samples)
            pred = torch.zeros(samples.size(0)).cuda()
            for j in range(len(pred)):
                pred[j] = samples[j,torch.argmax(out[j,samples[j]])]
            total_sample_rank_acc += float((pred == trg_y).sum())
        else:
            out = model.forward(src)
            loss = criterion(out.view(-1, out.size(-1)), trg.view(-1))
            total_loss += loss.data.item()
            total_acc += float((torch.argmax(out, dim=1) == trg).sum())
            out = F.softmax(out/temperature, dim=1)
            samples = torch.multinomial(out, n_samples)
            pred = torch.zeros(samples.size(0)).cuda()
            for j in range(len(pred)):
                pred[j] = samples[j,torch.argmax(out[j,samples[j]])]
            total_sample_rank_acc += float((pred == trg).sum())
        count += int(out.size(0))
        batch_count += 1
    total_loss /= batch_count
    total_acc /= count
    total_sample_rank_acc /= count
    perplexity = float('inf')
    try:
        perplexity = math.exp(total_loss)
    except:
        pass
    print('loss {:5.3f} | accuracy {:5.3f} | sample-rank acc {:5.3f} | perplexity {:3.2f}'.format(total_loss, total_acc, total_sample_rank_acc, perplexity))
    return total_loss, total_acc

### Plot train/val loss and accuracy and save figures
def plot_training(history, save_dir, title='', val=True):
    plt.plot(history["train_loss"])
    if val:
        plt.plot(history["valid_loss"])
    plt.title('Model loss: %s' % title)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.savefig(save_dir + 'loss.png')
    plt.clf()
    plt.plot(history["train_acc"])
    if val:
        plt.plot(history["valid_acc"])
    plt.title('Model accuracy: %s' % title)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.savefig(save_dir + 'accuracy.png')


################################################################################
#
# Evaluation Methods
#
################################################################################


### Choose point of minimum distance to an ideal point,
### (For ROC: (0,1); for PR: (1,1)).
def best_threshold(X, Y, T, best_x=0., best_y=1.):
    min_d, min_i = np.inf, 0
    for i, (x, y) in enumerate(zip(X, Y)):
        d = np.sqrt((best_x-x)**2 + (best_y-y)**2)
        if d < min_d:
            min_d, min_i = d, i
    return X[min_i], Y[min_i], T[min_i]

### Evaluate ROC performance of the model
### (predictions, labels of shape (n_examples, n_classes))
def evaluate_roc(predictions, labels, i2w, train_freqs, save_dir, do_plot,
                 given_thresholds=None, title='', suffix='', min_train=10,
                 tokens_to_remove=[]):
    assert(predictions.shape == labels.shape)
    lines, scores, word_freqs = [], [], []
    n_examples, n_classes = predictions.shape
    thresholds = np.full(n_classes, np.nan)
    rocs, fprs, tprs = {}, [], []

    # Create directory for plots if required
    if do_plot:
        roc_dir = save_dir + 'rocs/'
        if not os.path.isdir(roc_dir): os.mkdir(roc_dir)

    # Go over each class and compute AUC
    for i in range(n_classes):
        if i2w[i] in tokens_to_remove:
            continue
        train_count = train_freqs[i]
        n_true = np.count_nonzero(labels[:,i])
        if train_count < 1 or n_true == 0: continue
        word = i2w[i]
        probs = predictions[:,i]
        c_labels = labels[:,i]
        fpr, tpr, thresh = roc_curve(c_labels, probs)
        if given_thresholds is None:
            x, y, threshold = best_threshold(fpr, tpr, thresh)
        else:
            x, y, threshold = 0, 0, given_thresholds[i]
        thresholds[i] = threshold
        score = auc(fpr, tpr)
        scores.append(score)
        word_freqs.append(train_count)
        rocs[word] = score
        fprs.append(fpr)
        tprs.append(tpr)
        y_pred = probs >= threshold
        tn, fp, fn, tp = confusion_matrix(c_labels, y_pred).ravel()
        lines.append('%s\t%3d\t%3d\t%.3f\t%d\t%d\t%d\t%d\n' \
                % (word, n_true, train_count, score, tp, fp, fn, tn))
        if do_plot:
            fig, axes = plt.subplots(1,2, figsize=(16,6))
            axes[0].plot(fpr, tpr, color='darkorange', lw=2, marker='.')
            axes[0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
            axes[0].plot(x, y, marker='o', color='blue')
            axes[0].set_xlim([0.0, 1.0])
            axes[0].set_ylim([0.0, 1.05])
            axes[0].set_xlabel('False Positive Rate')
            axes[0].set_ylabel('True Positive Rate')
            h1 = probs[c_labels == 1].reshape(-1)
            h2 = probs[c_labels == 0].reshape(-1)
            axes[1].hist(h2, bins=20, color='orange',
                             alpha=0.5, label='Neg. Examples')
            #axes[1].twinx().hist(h1, bins=50, alpha=0.5, label='Pos. Examples')
            axes[1].hist(h1, bins=50, alpha=0.5, label='Pos. Examples')
            axes[1].axvline(threshold, color='k')
            axes[1].set_xlabel('Activation')
            axes[1].set_ylabel('Frequency')
            axes[1].legend()
            axes[1].set_title('%d TP | %d FP | %d FN | %d TN'\
                              % (tp, fp, fn, tn))
            fig.suptitle('ROC Curve | %s | AUC = %.3f | N = %d'\
                         % (word, score, n_true))
            plt.savefig(roc_dir + '%s.png' % word)
            fig.clear()
            plt.close(fig)

    # Compute statistics
    scores, word_freqs = np.array(scores), np.array(word_freqs)
    normed_freqs = word_freqs / word_freqs.sum()
    avg_auc = scores.mean()
    weighted_avg = (scores * normed_freqs).sum()
    print('Avg AUC: %d\t%.6f' % (scores.size, avg_auc))
    print('Weighted Avg AUC: %d\t%.6f' % (scores.size, weighted_avg))

    # Write to file
    with open(save_dir + 'aucs%s.txt' % suffix, 'w') as fout:
        for line in lines:
            fout.write(line)

    # Plot histogram and AUC as a function of num of examples
    _, ax = plt.subplots(1,1)
    ax.scatter(word_freqs, scores, marker='.')
    ax.set_xlabel('# examples')
    ax.set_ylabel('AUC')
    ax.set_title('%s | avg: %.3f | N = %d' % (title, weighted_avg, scores.size))
    ax.set_yticks(np.arange(0., 1.1, 0.1))
    ax.grid()
    plt.savefig(save_dir + 'roc-auc-examples.png', bbox_inches='tight')

    _, ax = plt.subplots(1,1)
    ax.hist(scores, bins=20)
    ax.set_xlabel('AUC')
    ax.set_ylabel('# labels')
    ax.set_title('%s | avg: %.3f | N = %d' % (title, weighted_avg, scores.size))
    ax.set_xticks(np.arange(0., 1., 0.1))
    plt.savefig(save_dir + 'roc-auc.png', bbox_inches='tight')

    _, ax = plt.subplots(1,1)
    for fpr, tpr in zip(fprs, tprs):
        ax.plot(fpr, tpr, lw=1)
    ax.plot([0, 1], [0, 1], color='navy', linestyle='--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('%s | avg: %.3f | N = %d' % (title, weighted_avg, scores.size))
    plt.savefig(save_dir + 'roc-auc-all.png', bbox_inches='tight')

    return {
        'rocauc_avg': avg_auc,
        'rocauc_stddev': scores.std(),
        'rocauc_w_avg': weighted_avg,
        'rocauc_n': scores.size,
        'rocs': rocs
    }

### Evaluate top-k performance of the model. (assumes activations can be
### interpreted as probabilities).
### (predictions, labels of shape (n_examples, n_classes))
def evaluate_topk(predictions, labels, i2w, train_freqs, save_dir,
                  min_train=10, prefix='', suffix='', tokens_to_remove=[]):
    ranks = []
    n_examples, n_classes = predictions.shape
    fid = open(save_dir + 'guesses%s.csv' % suffix, 'w')
    top1_uw, top5_uw, top10_uw = set(), set(), set()
    accs, sizes = {}, {}
    total_freqs = float(sum(train_freqs.values()))

    # Go through each example and calculate its rank and top-k
    for i in range(n_examples):
        y_true_idx = labels[i]

        if train_freqs[y_true_idx] < 1:
            continue

        word = i2w[y_true_idx]
        if word in tokens_to_remove:
            continue

        # Get example predictions
        ex_preds = np.argsort(predictions[i])[::-1]
        rank = np.where(y_true_idx == ex_preds)[0][0]
        ranks.append(rank)

        fid.write('%s,%d,' % (word, rank))
        fid.write(','.join(i2w[j] for j in ex_preds[:10]))
        fid.write('\n')

        if rank == 0:
            top1_uw.add(ex_preds[0])
        elif rank < 5:
            top5_uw.update(ex_preds[:5])
        elif rank < 10:
            top10_uw.update(ex_preds[:10])

        if word not in accs:
            accs[word] = float(rank == 0)
            sizes[y_true_idx] = 1.
        else:
            accs[word] += float(rank == 0)
            sizes[y_true_idx] += 1.
    for idx in sizes:
        word = i2w[idx]
        chance_acc = float(train_freqs[idx]) / total_freqs * 100.
        if sizes[idx] > 0:
            rounded_acc = round(accs[word] / sizes[idx] * 100, 3)
            accs[word] = (rounded_acc, chance_acc, rounded_acc - chance_acc)
        else:
            accs[word] = (0., chance_acc, -chance_acc)
    accs = sorted(accs.items(), key=lambda x: -x[1][2])

    fid.close()
    print('Top1 #Unique:', len(top1_uw))
    print('Top5 #Unique:', len(top5_uw))
    print('Top10 #Unique:', len(top10_uw))

    n_examples = len(ranks)
    ranks = np.array(ranks)
    top1 = sum(ranks == 0) / (1e-12 + len(ranks)) * 100
    top5 = sum(ranks < 5) / (1e-12 + len(ranks)) * 100
    top10 = sum(ranks < 10) / (1e-12 + len(ranks)) * 100

    # Calculate chance levels based on training word frequencies
    freqs = Counter(labels)
    freqs = np.array([freqs[i] for i,_ in train_freqs.most_common()])
    freqs = freqs[freqs > 0]
    chances = (freqs / freqs.sum()).cumsum() * 100

    # Print and write to file
    if suffix is not None:
        with open(save_dir + 'topk%s.txt' % suffix, 'w') as fout:
            line = 'n_classes: %d\nn_examples: %d' % (n_classes, n_examples)
            print(line)
            fout.write(line + '\n')
            line = 'Top-1\t%.4f %% (%.2f %%)' % (top1, chances[0])
            print(line)
            fout.write(line + '\n')
            line = 'Top-5\t%.4f %% (%.2f %%)' % (top5, chances[4])
            print(line)
            fout.write(line + '\n')
            line = 'Top-10\t%.4f %% (%.2f %%)' % (top10, chances[9])
            print(line)
            fout.write(line + '\n')

    return {
        prefix + 'top1': top1,
        prefix + 'top5': top5,
        prefix + 'top10': top10,
        prefix + 'top1_chance':  chances[0],
        prefix + 'top5_chance':  chances[4],
        prefix + 'top10_chance': chances[9],
        prefix + 'top1_above':  (top1 - chances[0]) / chances[0],
        prefix + 'top5_above':  (top5 - chances[4]) / chances[4],
        prefix + 'top10_above': (top10 - chances[9]) / chances[9],
        prefix + 'top1_n_uniq_correct': len(top1_uw),
        prefix + 'top5_n_uniq_correct': len(top5_uw),
        prefix + 'top10_n_uniq_correct': len(top10_uw),
        prefix + 'word_accuracies': accs
    }

In [None]:


### Training and evaluation script
# if __name__ == "__main__":

print("Training on %d GPU(s) with batch_size %d for %d epochs"\
    % (args.gpus, args.batch_size, args.epochs))
print("=" * CONFIG["print_pad"])
sys.stdout.flush()

best_val_loss = float("inf")
best_model = model
history = {'train_loss': [],
           'train_acc': [],
           'valid_loss': [],
           'valid_acc': []}

# train_loss_compute = SimpleLossCompute(criterion,
#                                        opt=optimizer, scheduler=scheduler)
# valid_loss_compute = SimpleLossCompute(criterion, opt=None, scheduler=None)

epoch = 0
model_name = "%s%s.pt" % (SAVE_DIR, args.model)

# totalfreq = float(sum(train_ds.train_freq.values()))
# print(sorted(((i2w[l],f/totalfreq) for l, f in train_ds.train_freq.most_common()), key=lambda x: -x[1]))

# Run training and validation for args.epochs epochs
lr = args.lr
for epoch in range(1, args.epochs+1):
    epoch_start_time = time.time()
    print('| train | epoch %d | ' % epoch, end='')
    train_loss, train_acc = train(train_dl, model,
                                  criterion, list(range(args.gpus)),
                                  optimizer, scheduler=scheduler,
                                  seq2seq=not classify,
                                  pad_idx=vocab[CONFIG["pad_token"]] if not classify else -1)
    for param_group in optimizer.param_groups:
        if 'lr' in param_group:
            print(' | lr {:1.2E}'.format(param_group['lr']))
            break
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    print('| valid | epoch %d | ' % epoch, end='')
    with torch.no_grad():
        valid_loss, valid_acc = valid(valid_dl, model,
                                      criterion, temperature=args.temp,
                                      seq2seq=not classify,
                                      pad_idx=vocab[CONFIG["pad_token"]] if not classify else -1)
    history['valid_loss'].append(valid_loss)
    history['valid_acc'].append(valid_acc)
    print('|' + '-'*(CONFIG["print_pad"]-2) + '|')
    # Store best model so far
    if valid_loss < best_val_loss:
        best_model, best_val_loss = model, valid_loss
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model_to_save = best_model.module\
                if hasattr(best_model, 'module') else best_model
            torch.save(model_to_save, model_name)
        sys.stdout.flush()

    # if epoch > 10 and valid_loss > max(history['valid_loss'][-3:]):
    #     lr /= 2.
    #     for param_group in optimizer.param_groups:
    #         param_group['lr'] = lr

# Plot loss,accuracy vs. time and save figures
plot_training(history, SAVE_DIR, title="%s_lr%s" % (args.model, args.lr))

# Save best model found
# print("Saving best model as %s.pt" % args.model)
# sys.stdout.flush()

if not args.no_eval and classify:

    print("Evaluating predictions on test set")
    # Load best model
    model = torch.load(model_name)
    if args.gpus:
        if args.gpus > 1: model = nn.DataParallel(model)
        model.to(DEVICE)

    start, end = 0, 0
    softmax = nn.Softmax(dim=1)
    all_preds = np.zeros((x_valid.size(0), n_classes), dtype=np.float32)
    print('Allocating', np.prod(all_preds.shape)*5/1e9,'GB')

    # Calculate all predictions on test set
    with torch.no_grad():
        for batch in valid_dl:
            src, trg = batch[0].to(DEVICE), batch[1].to(DEVICE, dtype=torch.long)
            end = start + src.size(0)
            out = softmax(model(src))
            all_preds[start:end, :] = out.cpu()
            start = end

    print("Calculated predictions")

    # Make categorical
    n_examples = y_valid.shape[0]
    categorical = np.zeros((n_examples, n_classes), dtype=np.float32)
    categorical[np.arange(n_examples), y_valid] = 1

    train_freq = Counter(y_train.tolist())

    # Evaluate top-k
    print("Evaluating top-k")
    sys.stdout.flush()
    res = evaluate_topk(all_preds, y_valid.numpy(), i2w, train_freq,
                        SAVE_DIR, suffix='-val',
                        min_train=args.vocab_min_freq)

    # Evaluate ROC-AUC
    print("Evaluating ROC-AUC")
    sys.stdout.flush()
    res.update(evaluate_roc(all_preds, categorical, i2w, train_freq,
                            SAVE_DIR, do_plot=not args.no_plot,
                            min_train=args.vocab_min_freq))
    pprint(res.items())
    print("Saving results")
    with open(SAVE_DIR + "results.json", "w") as fp:
        json.dump(res, fp, indent=4)

In [None]:
if not args.no_eval and not classify:

    print("Evaluating predictions on test set")
    # Load best model
    model = torch.load(model_name)
    if args.gpus:
        model.cuda()

    all_preds, categorical, all_labs = [], [], []
    softmax = nn.Softmax(dim=1)

    # Calculate all predictions on test set
    with torch.no_grad():
        model.eval()
        for batch in valid_dl:
            src, trg_y = batch[0].cuda(), batch[2].long().cuda()
            trg_pos_mask, trg_pad_mask = batch[3].cuda().squeeze(), batch[4].cuda()
            memory = model.encode(src)
            y = torch.zeros(src.size(0), 1, len(vocab)).long().cuda()
            y_sr = torch.zeros(src.size(0), 1, len(vocab)).long().cuda()
            probs = torch.zeros(src.size(0), 1, len(vocab)).long().cuda()
            y[:,:,vocab[CONFIG["begin_token"]]] = 1
            y_sr[:,:,vocab[CONFIG["begin_token"]]] = 1
            for i in range(trg_y.size(1)):
                out = model.decode(memory, y,
                                   trg_pos_mask[:y.size(1),:y.size(1)],
                                   trg_pad_mask[:,:y.size(1)])[:,-1,:]
                out = softmax(out/args.temp)
                temp = torch.zeros(src.size(0), len(vocab)).long().cuda()
                temp = temp.scatter_(1, torch.argmax(out, dim=1).unsqueeze(-1), 1)
                y = torch.cat([y, temp.unsqueeze(1)], dim=1)
                # probs = torch.cat([probs, out.unsqueeze(1)], dim=1)
                samples = torch.multinomial(out, 20)
                pred = torch.zeros(out.size(0)).long().cuda()
                for j in range(len(samples)):
                    pred[j] = samples[j,torch.argmax(out[j,samples[j]])]
                temp = torch.zeros(pred.size(0), len(vocab)).long().cuda()
                pred = temp.scatter_(1, pred.unsqueeze(-1), 1).unsqueeze(1)
                y_sr = torch.cat([y_sr, pred], dim=1)
            y, y_sr = y[:,1:,:], y_sr[:,1:,:]
            idx = (trg_y != vocab[CONFIG["pad_token"]]).nonzero(as_tuple=True)
            lab = trg_y[idx]
            cat = torch.zeros((lab.size(0), len(vocab)), dtype=torch.long).to(lab.device)
            cat = cat.scatter_(1, lab.unsqueeze(-1), 1)
            all_preds.extend(y[idx].cpu().numpy())
            categorical.extend(cat.cpu().numpy())
            all_labs.extend(lab.cpu().numpy())
            print("Output: ", vocab.DecodeIds(torch.argmax(y[0], dim=1).tolist()))
            print("Output_sr: ", vocab.DecodeIds(torch.argmax(y_sr[0], dim=1).tolist()))
            print("Target: ", vocab.DecodeIds(trg_y[0].tolist()))
            print()
            print("Output: ", vocab.DecodeIds(torch.argmax(y[-1], dim=1).tolist()))
            print("Output_sr: ", vocab.DecodeIds(torch.argmax(y_sr[-1], dim=1).tolist()))
            print("Target: ", vocab.DecodeIds(trg_y[-1].tolist()))
            # print("BLEU: ", sentence_bleu([target_sent], predicted_sent))

    all_preds = np.array(all_preds)
    categorical = np.array(categorical)
    all_labs = np.array(all_labs)
    print("Calculated predictions")

    train_freq = train_ds.train_freq
    i2w = {i:vocab.IdToPiece(i) for i in range(len(vocab))}
    markers = [CONFIG["begin_token"], CONFIG["end_token"],\
               CONFIG["oov_token"], CONFIG["pad_token"]]

    # Evaluate top-k
    print("Evaluating top-k")
    sys.stdout.flush()
    res = evaluate_topk(all_preds, all_labs, i2w, train_freq,
                        SAVE_DIR, suffix='-val',
                        min_train=args.vocab_min_freq,
                        tokens_to_remove=markers)

    # Evaluate ROC-AUC
    print("Evaluating ROC-AUC")
    sys.stdout.flush()
    res.update(evaluate_roc(all_preds, categorical, i2w, train_freq,
                            SAVE_DIR, do_plot=not args.no_plot,
                            min_train=args.vocab_min_freq,
                            tokens_to_remove=markers))
    pprint(res.items())
    print("Saving results")
    with open(SAVE_DIR + "results.json", "w") as fp:
        json.dump(res, fp, indent=4)

    print("Done!")

In [None]:
# for conversation, suffix, idx in convs[:1]:
    
#     # Check if files exists, if it doesn't go to next
#     datum_fn = glob.glob(conversation + suffix)[0]
#     if not datum_fn:
#         print('File DNE: ', conversation + suffix)
#         continue
        
#     # Extract electrode data
# #     ecogs = return_electrode_array(conversation, suffix)
#     ecogs = np.load('ecogs.npy')
#     if not ecogs.size:
#         print(f'Skipping bad conversation: {conversation}')
#         continue

#     # Read conversations and form examples
#     old_size = len(signals)
#     max_len = 0
    
#     examples = return_examples(datum_fn, delimiter, vocab)
#     my_grams = []
#     for first, second in zip(shorty, shorty[1:]):
#         len1, len2 = len(first[0]), len(second[0])
#         if first[1] and len1 == 2:  # if the first set already has two words and is speaker 1
#             my_grams.append(first)
#         if second[1] and len2 == 2:  # if the second set already has two words and is speaker 1
#             my_grams.append(second)
#         if first[1] and second[1]:
#             if len1 == 1 and len2 == 1:
#                 ak = (first[0] + second[0], True, first[2], second[3])
#                 my_grams.append(ak)

#     print(my_grams)
#     df = pd.DataFrame(my_grams)
#     df[['fw', 'sw']] = pd.DataFrame(df[0].tolist(), index= df.index) 
#     df.drop(columns=[0], inplace=True) 
#     df.drop_duplicates(inplace=True)
#     df[0] = df[['fw', 'sw']].values.tolist()
#     df.drop(columns=['fw', 'sw'], inplace=True)
#     df = df[[0, 1, 2, 3]]
#     my_exams = list(df.to_records(index=False))
#     print(my_exams)
    
#     signals, labels = [], []
#     speaker = False
#     cur_sentence, start_onset, end_onset = [vocab[begin_token]], 0, 0
#     ###################################################################
# #     for example in examples:  # loop over each example
# #         if example[1]:  # check if it is speaker 1
# #             '''If this is the beginning of the sentence'''
# #             if len(cur_sentence) == 1:  # if sentence length is 1 (begin token)
# #                 onset, offset = float(x[2]), float(x[3])  # grab onset and offset of current word
# #                 start_onset = onset + start_window  # calculate the window_begin (its before the onset) 
# #                 end_onset = offset + end_window  # calculate the window_end (its after the offset)
# #                 if start_onset < 0 or start_onset > ecogs.shape[0]:  # check if onset is within limits
# #                     continue  # move on to the next example
# #                 if (end_onset < 0 or  # if the window_end is less than 0 or
# #                     end_onset > ecogs.shape[0]  # if the window_end is outside the signal 
# #                     or end_onset - start_onset < window_fs):  # if there are not enough frames in the window
# #                     continue  # move on to the next example
# #                 true_start_onset = start_onset
# #                 true_end_onset = end_onset  # save onset for next word 
# #                 cur_sentence.extend(x[0])  # add the word to the cur_sentence
                
# #         n_bins = int(math.ceil((end_onset - start_onset) / bin_fs))  # calculate number of bins
        
# #         '''If this is the end of the sentence (perform augmentation here)'''
# #         if ((not x[1] and  # if not speaker 1 (speaker switched) and
# #              speaker and  # speaker switched and
# #              len(cur_sentence) >= 2)  # the sentence length is >= 2
# #             and n_bins):  # number of bins is non-zero
# #             cur_sentence.append(vocab[end_token])  # end the current sentence
# #             labels.append(cur_sentence)  # put the sentence in the label vector/list
# #             word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32) 
# #             for i, f in enumerate(np.array_split(ecogs[start_onset:end_onset,:], n_bins, axis=0)):
# #                 word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)
# #             signals.append(word_signal)
           
# #             '''
# #             # Data augmentation by shifts
# #             for i, s in enumerate(aug_shift_fs):
# #                 aug_start = start_onset + s
# #                 if aug_start < 0 or aug_start > ecogs.shape[0]:
# #                     continue
# #                 aug_end = end_onset + s
# #                 if aug_end < 0 or aug_end > ecogs.shape[0]:
# #                     continue
# #                 n_bins = int(math.ceil((aug_end - aug_start) / bin_fs))
# #                 if n_bins > 0:
# #                     word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32)
# #                     for i, f in enumerate(np.array_split(ecogs[aug_start:aug_end,:], n_bins, axis=0)):
# #                         word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)
# #                     signals.append(word_signal)
# #                     labels.append(cur_sentence)
# #             cur_sentence = [vocab[begin_token]]
# #             '''
# #         speaker = example[1]
            
#     ###################################################################    

#     for x in examples:  # loop over each example
#         if x[1]:  # Check if it is speaker 1
#             if len(cur_sentence) == 1:  # if sentence length is 1 (that's the begin token)
#                 start_onset = int(float(x[2])) + start_offset  # take the onset of that word
#                 if start_onset < 0 or start_onset > ecogs.shape[0]:  # check if onset is within limits
#                     continue  # move on to the next example
#             new_end_onset = int(float(x[2])) + end_offset
#             if (new_end_onset < 0 or  # if the window_end is less than 0 or 
#                 new_end_onset > ecogs.shape[0] or  # if the window_end is outside the signal 
#                 new_end_onset - start_onset < window_fs):  # if there are not enough frames in the window
#                 continue  # move on to the next example
#             end_onset = new_end_onset
#             cur_sentence.extend(x[0])
        
#         n_bins = int(math.ceil((end_onset - start_onset) / bin_fs))  # calculate number of bins
        
#         if ((not x[1] and  # if not speaker 1 (speaker switched) and
#              speaker and  # speaker switched and
#              len(cur_sentence) >= 2)  # the sentence length is >= 2 and 
#             and (n_bins > 0)):  # number of bins is non-zero
#             cur_sentence.append(vocab[end_token])  # end the current sentence
#             labels.append(cur_sentence)  # put the sentence in the label vector/list
#             max_len = max(max_len, len(cur_sentence))  # this is not really used anywhere
#             word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32) 
#             for i, f in enumerate(np.array_split(ecogs[start_onset:end_onset,:], n_bins, axis=0)):
#                 word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)
#             signals.append(word_signal)
          
#             # Data augmentation by shifts
#             for i, s in enumerate(aug_shift_fs):
#                 aug_start = start_onset + s
#                 if aug_start < 0 or aug_start > ecogs.shape[0]:
#                     continue
#                 aug_end = end_onset + s
#                 if aug_end < 0 or aug_end > ecogs.shape[0]:
#                     continue
#                 n_bins = int(math.ceil((aug_end - aug_start) / bin_fs))
#                 if n_bins > 0:
#                     word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32)
#                     for i, f in enumerate(np.array_split(ecogs[aug_start:aug_end,:], n_bins, axis=0)):
#                         word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)
#                     signals.append(word_signal)
#                     labels.append(cur_sentence)
#             cur_sentence = [vocab[begin_token]]
#         speaker = x[1]
    
#     # what is this block of code doing here
#     if speaker and len(cur_sentence) >= 2 and n_bins > 0:
#         cur_sentence.append(vocab[end_token])
#         labels.append(cur_sentence)
#         n_bins = int(math.ceil((end_onset - start_onset) / bin_fs))
#         word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32)
#         for i, f in enumerate(np.array_split(ecogs[start_onset:end_onset,:], n_bins, axis=0)):
#             word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)
#         signals.append(word_signal)

# if len(signals) == old_size:
#     print(f'[WARNING] no examples built for {conversation}')

# if len(signals) == 0:
#     print('[ERROR] signals is empty')
#     sys.exit(1)

# print('final')
# x = signals
# y = labels

In [None]:
# for conversation, suffix, idx in convs[0:1]:
    
#     # Check if files exists, if it doesn't go to next
#     datum_fn = glob.glob(conversation + suffix)[0]
#     if not datum_fn:
#         print('File DNE: ', conversation + suffix)
#         continue

#     # Extract electrode data
#     ecogs = return_electrode_array(conversation, suffix)
#     if not ecogs.size:
#         print(f'Skipping bad conversation: {conversation}')
#         continue

#     # Read conversations and form examples
#     old_size = len(signals)
#     max_len = 0
#     speaker = False
#     cur_sentence, start_onset, end_onset = [w2i[begin_token]], 0, 0
    
#     examples = return_examples(datum_fn, delimiter, w2i, oov_token)
#     examples1 = return_examples1(datum_fn, delimiter, w2i, oov_token)
        
#     for x in examples:
#         if x[1]:  # first example in the list
#             if len(cur_sentence) == 1:
#                 start_onset = int(float(x[2])) + start_offset
#                 if start_onset < 0 or start_onset > ecogs.shape[0]:
#                     continue
#             new_end_onset = int(float(x[2])) + end_offset
#             if new_end_onset < 0 or new_end_onset > ecogs.shape[0]\
#                or new_end_onset - start_onset < window_fs:
#                 continue
#             end_onset = new_end_onset
#             cur_sentence.append(x[0])
        
#         n_bins = int(math.ceil((end_onset - start_onset) / bin_fs))
        
#         if (((not x[1]) and speaker and len(cur_sentence) >= 2) or (x[1] and len(cur_sentence) >= 3)) \
#            and n_bins > 0:
#             cur_sentence.append(w2i[end_token])
#             labels.append(cur_sentence)
#             max_len = max(max_len, len(cur_sentence))
#             word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32)
#             for i, f in enumerate(np.array_split(ecogs[start_onset:end_onset,:], n_bins, axis=0)):
#                 word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)
#             signals.append(word_signal)

#             # Data augmentation by shifts
#             for i, s in enumerate(aug_shift_fs):
#                 aug_start = start_onset + s
#                 if aug_start < 0 or aug_start > ecogs.shape[0]:
#                     continue
#                 aug_end = end_onset + s
#                 if aug_end < 0 or aug_end > ecogs.shape[0]:
#                     continue
#                 n_bins = int(math.ceil((aug_end - aug_start) / bin_fs))
#                 if n_bins > 0:
#                     word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32)
#                     for i, f in enumerate(np.array_split(ecogs[aug_start:aug_end,:], n_bins, axis=0)):
#                         word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)
#                     signals.append(word_signal)
#                     labels.append(cur_sentence)
#             cur_sentence = [w2i[begin_token]]
#         speaker = x[1]
    
#     if speaker and len(cur_sentence) >= 2 and n_bins > 0:
#         cur_sentence.append(w2i[end_token])
#         labels.append(cur_sentence)
#         n_bins = int(math.ceil((end_onset - start_onset) / bin_fs))
#         word_signal = np.zeros((n_bins, len(electrodes)*len(subjects)), np.float32)
#         for i, f in enumerate(np.array_split(ecogs[start_onset:end_onset,:], n_bins, axis=0)):
#             word_signal[i,idx*len(electrodes):(idx+1)*len(electrodes)] = f.mean(axis=0)
#         signals.append(word_signal)

#     if len(signals) == old_size:
#         print(f'[WARNING] no examples built for {conversation}')

# if len(signals) == 0:
#     print('[ERROR] signals is empty')
#     sys.exit(1)

# print('final')
# x = signals
# y = labels

==========================================================================================================

==========================================================================================================

==========================================================================================================

==========================================================================================================

In [None]:
# # Default models and parameters
# DEFAULT_MODELS = {
#     "ConvNet10": (n_classes, ),
#     "PITOM": (n_classes, len(args.electrodes) * len(args.subjects)),
#     "BrainClassifier":
#     (len(args.electrodes) * len(args.subjects), n_classes, args.tf_dmodel,
#      args.tf_nhead, args.tf_nlayer, args.tf_dff, args.tf_dropout),
#     "BrainTransformer":
#     (len(args.electrodes) * len(args.subjects), n_classes, args.tf_dmodel,
#      args.tf_nhead, args.tf_nlayer, args.tf_dff, args.tf_dropout)
# }

In [None]:
# from models import *

# # Create model
# if args.init_model is None:
#     if args.model in DEFAULT_MODELS:
#         print("Building default model: %s" % args.model, end="")
#         model_class = globals()[args.model]
#         model = model_class(*(DEFAULT_MODELS[args.model]))
#     else:
#         print("Building custom model: %s" % args.model, end="")
#         sys.exit(1)
# else:
#     model_name = "%s%s.pt" % (SAVE_DIR, args.model)
#     if os.path.isfile(model_name):
#         model = torch.load(model_name)
#         model = model.module if hasattr(model, 'module') else model
#         print("Loaded initial model: %s " % args.model)
#     else:
#         print("No models found in: ", SAVE_DIR)
#         sys.exit(1)
# print(" with %d trainable parameters" %
#       sum([p.numel() for p in model.parameters() if p.requires_grad]))

# sys.stdout.flush()

In [None]:
# from torchsummary import summary

# print(model)
# for param in model.parameters():
#     print(param.shape)
# summary(model, x_train.shape[1:])

In [None]:
# from transformers import AdamW

# # Initialize loss and optimizer
# # weights = torch.ones(n_classes)
# # max_freq = -1.
# # for i in range(n_classes):
# #     max_freq = max(max_freq, word2freq[vocab[i]])
# #     weights[i] = 1./float(word2freq[vocab[i]])
# # weights = weights*max_freq
# # print(sorted([(vocab[i], round(float(weights[i]),1))
# #               for i in range(n_classes)], key=lambda x: x[1]))
# criterion = nn.CrossEntropyLoss()
# # step_size = int(math.ceil(len(train_ds)/args.batch_size))
# # optimizer = optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98),
# #                         eps=1e-9, weight_decay=args.weight_decay)
# # optimizer = optim.AdamW(model.parameters(), lr=args.lr,
# #                         weight_decay=args.weight_decay)
# optimizer = AdamW(model.parameters(),
#                   lr=args.lr,
#                   weight_decay=args.weight_decay)
# # optimizer = NoamOpt(args.tf_dmodel, 1., 2000,
# #                     optim.Adam(model.parameters(),
# #                                lr=0.,
# #                                betas=(0.9, 0.98),
# #                                eps=1e-9))
# # scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
# #                                            milestones=[20*step_size,40*step_size],
# #                                            gamma=0.2)
# # scheduler = get_cosine_schedule_with_warmup(optimizer, 10*step_size,
# #                                             args.epochs*step_size,
# # num_cycles=0.5)
# scheduler = None

In [None]:
# # Move model and loss to GPUs
# if args.gpus:
#     if args.gpus > 1:
#         model = nn.DataParallel(model)
#     model.to(DEVICE)
#     criterion.to(DEVICE)

In [None]:
# print("Training on %d GPU(s) with batch_size %d for %d epochs" %
#       (args.gpus, args.batch_size, args.epochs))
# print("=" * CONFIG["print_pad"])
# sys.stdout.flush()

# best_val_loss = float("inf")
# best_model = model
# history = {
#     'train_loss': [],
#     'train_acc': [],
#     'valid_loss': [],
#     'valid_acc': []
# }

In [None]:
# from train_eval import SimpleLossCompute

# train_loss_compute = SimpleLossCompute(criterion,
#                                        opt=optimizer,
#                                        scheduler=scheduler)
# valid_loss_compute = SimpleLossCompute(criterion, opt=None, scheduler=None)

# epoch = 0
# model_name = "%s%s.pt" % (SAVE_DIR, args.model)
# # Run training and validation for args.epochs epochs
# lr = args.lr

In [None]:
# import time
# import warnings
# from train_eval import train, valid

# for epoch in range(1, args.epochs + 1):
#     epoch_start_time = time.time()
#     print('| train | epoch %d | ' % epoch, end='')
#     train_loss, train_acc = train(train_dl, model, DEVICE,
#                                   train_loss_compute, epoch, i2w)
#     history['train_loss'].append(train_loss)
#     history['train_acc'].append(train_acc)
#     print('| valid | epoch %d | ' % epoch, end='')
#     valid_loss, valid_acc = valid(valid_dl, model, DEVICE,
#                                   valid_loss_compute, epoch)
#     history['valid_loss'].append(valid_loss)
#     history['valid_acc'].append(valid_acc)
#     print('|' + '-' * (CONFIG["print_pad"] - 2) + '|')
#     # Store best model so far
#     if valid_loss < best_val_loss:
#         best_model, best_val_loss = model, valid_loss
#         with warnings.catch_warnings():
#             warnings.simplefilter("ignore")
#             model_to_save = best_model.module\
#                 if hasattr(best_model, 'module') else best_model
#             torch.save(model_to_save, model_name)
#     sys.stdout.flush()

In [None]:
# from train_eval import plot_training

# # Plot loss,accuracy vs. time and save figures
# plot_training(history, SAVE_DIR, title="%s_lr%s" % (args.model, args.lr))

In [None]:
# print("Evaluating predictions on test set")
# # Load best model
# model = torch.load(model_name)
# if args.gpus:
#     if args.gpus > 1:
#         model = nn.DataParallel(model)
#     model.to(DEVICE)

In [None]:
# start, end = 0, 0
# softmax = nn.Softmax(dim=1)
# all_preds = np.zeros((x_valid.size(0), n_classes), dtype=np.float32)
# print('Allocating', np.prod(all_preds.shape) * 5 / 1e9, 'GB')

In [None]:
# all_preds.shape

In [None]:
# # Calculate all predictions on test set
# with torch.no_grad():
#     for batch in valid_dl:
#         src, trg = batch[0].to(DEVICE), batch[1].to(DEVICE,
#                                                     dtype=torch.long)
#         end = start + src.size(0)
#         out = softmax(model(src))
#         all_preds[start:end, :] = out.cpu()
#         start = end

# print("Calculated predictions")
# sys.stdout.flush()

In [None]:
# from collections import Counter

# # Make categorical
# n_examples = y_valid.shape[0]
# categorical = np.zeros((n_examples, n_classes), dtype=np.float32)
# categorical[np.arange(n_examples), y_valid] = 1

# train_freq = Counter(y_train.tolist())

In [None]:
# x_valid.shape

In [None]:
# train_freq

In [None]:
# predictions = all_preds
# labels = y_valid.numpy()
# train_freqs = train_freq
# min_train=args.vocab_min_freq
# suffix='-val'
# prefix = ''
# save_dir = SAVE_DIR

In [None]:
# ranks = []
# n_examples, n_classes = predictions.shape
# fid = open(save_dir + 'guesses%s.csv' % suffix, 'w')
# top1_uw, top5_uw, top10_uw = set(), set(), set()
# accs, sizes = {}, [0.] * n_classes
# total_freqs = float(sum(train_freqs.values()))

In [None]:
# # Go through each example and calculate its rank and top-k
# for i in range(n_examples):
#     y_true_idx = labels[i]
#     word = i2w[y_true_idx]

#     if train_freqs[y_true_idx] < min_train:
#         continue

#     # Get example predictions
#     ex_preds = np.argsort(predictions[i])[::-1]
#     rank = np.where(y_true_idx == ex_preds)[0][0]
#     ranks.append(rank)

#     if i == 0:
#         raise Exception('Done')
#     fid.write('%s,%d,' % (word, rank))
#     fid.write(','.join(i2w[j] for j in ex_preds[:10]))
#     fid.write('\n')

#     if rank == 0:
#         top1_uw.add(ex_preds[0])
#     elif rank < 5:
#         top5_uw.update(ex_preds[:5])
#     elif rank < 10:
#         top10_uw.update(ex_preds[:10])

#     # Counts of how many times a word has been ranked the top
#     accs[i2w[y_true_idx]] = accs.get(i2w[y_true_idx], 0.) + (rank == 0)
    
#     # size of validation set excluding the samples with < min_freq
#     sizes[y_true_idx] += 1.

# print(accs, sizes, n_examples, sum(sizes))

In [None]:
# for i in range(n_classes):
#     chance_acc = train_freqs[i] / total_freqs * 100.
#     if sizes[i] > 0:
#         rounded_acc = round(accs[i2w[i]] / sizes[i] * 100, 3)
#         accs[i2w[i]] = (rounded_acc, chance_acc, rounded_acc - chance_acc)
#     else:
#         accs[i2w[i]] = (0., chance_acc, -chance_acc)
        
# accs = sorted(accs.items(), key=lambda x: -x[1][2])
# fid.close()
# print('Top1 #Unique:', len(top1_uw))
# print('Top5 #Unique:', len(top5_uw))
# print('Top10 #Unique:', len(top10_uw))

In [None]:
# n_examples = len(ranks)
# ranks = np.array(ranks)
# top1 = sum(ranks == 0) / (1e-12 + len(ranks)) * 100
# top5 = sum(ranks < 5) / (1e-12 + len(ranks)) * 100
# top10 = sum(ranks < 10) / (1e-12 + len(ranks)) * 100

# print(top1, top5, top10)
# # Calculate chance levels based on training word frequencies
# freqs = Counter(labels)
# print(freqs, len(labels))
# freqs = np.array([freqs[i] for i, _ in train_freqs.most_common()])
# print(freqs)
# freqs = freqs[freqs != 0]
# print(freqs)
# chances = (freqs / freqs.sum()).cumsum() * 100

# # Print and write to file
# if suffix is not None:
#     with open(save_dir + 'topk%s.txt' % suffix, 'w') as fout:
#         line = 'n_classes: %d\nn_examples: %d' % (n_classes, n_examples)
#         print(line)
#         fout.write(line + '\n')
#         line = 'Top-1\t%.4f %% (%.2f %%)' % (top1, chances[0])
#         print(line)
#         fout.write(line + '\n')
#         line = 'Top-5\t%.4f %% (%.2f %%)' % (top5, chances[4])
#         print(line)
#         fout.write(line + '\n')
#         line = 'Top-10\t%.4f %% (%.2f %%)' % (top10, chances[9])
#         print(line)
#         fout.write(line + '\n')
#         line = 'Token Accuracies\t%s' % str(accs)
#         print(line)
#         fout.write(line + '\n')

# final_dict = {
#     prefix + 'top1': top1,
#     prefix + 'top5': top5,
#     prefix + 'top10': top10,
#     prefix + 'top1_chance': chances[0],
#     prefix + 'top5_chance': chances[4],
#     prefix + 'top10_chance': chances[9],
#     prefix + 'top1_above': (top1 - chances[0]) / chances[0],
#     prefix + 'top5_above': (top5 - chances[4]) / chances[4],
#     prefix + 'top10_above': (top10 - chances[9]) / chances[9],
#     prefix + 'top1_n_uniq_correct': len(top1_uw),
#     prefix + 'top5_n_uniq_correct': len(top5_uw),
#     prefix + 'top10_n_uniq_correct': len(top10_uw),
#     prefix + 'word_accuracies': accs
# }

In [None]:
# # Evaluate ROC-AUC
# print("Evaluating ROC-AUC")
# sys.stdout.flush()
# res.update(
#     evaluate_roc(all_preds,
#                  categorical,
#                  i2w,
#                  train_freq,
#                  SAVE_DIR,
#                  do_plot=not args.no_plot,
#                  min_train=args.vocab_min_freq))
# pprint(res.items())
# print("Saving results")
# with open(SAVE_DIR + "results.json", "w") as fp:
#     json.dump(res, fp, indent=4)
# print("Done!")

In [None]:
# # Evaluate ROC performance of the model
# # (predictions, labels of shape (n_examples, n_classes))
# predictions = all_preds
# labels = categorical
# train_freqs = train_freq
# save_dir = SAVE_DIR
# do_plot = not args.no_plot
# given_thresholds=None
# title=''
# suffix=''
# min_train=args.vocab_min_freq

In [None]:
# assert (predictions.shape == labels.shape)
# lines, scores, word_freqs = [], [], []
# n_examples, n_classes = predictions.shape
# thresholds = np.full(n_classes, np.nan)
# rocs, fprs, tprs = {}, [], []

In [None]:
# # Create directory for plots if required
# if do_plot:
#     roc_dir = save_dir + 'rocs/'
#     if not os.path.isdir(roc_dir):
#         os.mkdir(roc_dir)

In [None]:
# from sklearn.metrics import auc, confusion_matrix, roc_curve
# from train_eval import best_threshold
# import matplotlib.pyplot as plt

# # Go over each class and compute AUC
# for i in range(n_classes):
#     train_count = train_freqs[i]
#     n_true = np.count_nonzero(labels[:, i])
#     if train_count < min_train or n_true == 0:
#         continue
#     word = i2w[i]
#     probs = predictions[:, i]
#     c_labels = labels[:, i]
#     fpr, tpr, thresh = roc_curve(c_labels, probs)
#     if given_thresholds is None:
#         x, y, threshold = best_threshold(fpr, tpr, thresh)
#     else:
#         x, y, threshold = 0, 0, given_thresholds[i]
#     thresholds[i] = threshold
#     score = auc(fpr, tpr)
#     scores.append(score)
#     word_freqs.append(train_count)
#     rocs[word] = score
#     fprs.append(fpr)
#     tprs.append(tpr)
#     y_pred = probs >= threshold
#     tn, fp, fn, tp = confusion_matrix(c_labels, y_pred).ravel()
#     lines.append('%s\t%3d\t%3d\t%.3f\t%d\t%d\t%d\t%d\n' %
#                  (word, n_true, train_count, score, tp, fp, fn, tn))
#     if do_plot:
#         fig, axes = plt.subplots(1, 2, figsize=(16, 6))
#         axes[0].plot(fpr, tpr, color='darkorange', lw=2, marker='.')
#         axes[0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
#         axes[0].plot(x, y, marker='o', color='blue')
#         axes[0].set_xlim([0.0, 1.0])
#         axes[0].set_ylim([0.0, 1.05])
#         axes[0].set_xlabel('False Positive Rate')
#         axes[0].set_ylabel('True Positive Rate')
#         h1 = probs[c_labels == 1].reshape(-1)
#         h2 = probs[c_labels == 0].reshape(-1)
#         axes[1].hist(h2,
#                      bins=20,
#                      color='orange',
#                      alpha=0.5,
#                      label='Neg. Examples')
#         # axes[1].twinx().hist(h1, bins=50, alpha=0.5,
#         # label='Pos. Examples')
#         axes[1].hist(h1, bins=50, alpha=0.5, label='Pos. Examples')
#         axes[1].axvline(threshold, color='k')
#         axes[1].set_xlabel('Activation')
#         axes[1].set_ylabel('Frequency')
#         axes[1].legend()
#         axes[1].set_title('%d TP | %d FP | %d FN | %d TN' %
#                           (tp, fp, fn, tn))
#         fig.suptitle('ROC Curve | %s | AUC = %.3f | N = %d' %
#                      (word, score, n_true))
#         plt.savefig(roc_dir + '%s.png' % word)
#         fig.clear()
#         plt.close(fig)

In [None]:
# # Compute statistics
# scores, word_freqs = np.array(scores), np.array(word_freqs)
# normed_freqs = word_freqs / word_freqs.sum()
# avg_auc = scores.mean()
# weighted_avg = (scores * normed_freqs).sum()
# print('Avg AUC: %d\t%.6f' % (scores.size, avg_auc))
# print('Weighted Avg AUC: %d\t%.6f' % (scores.size, weighted_avg))

# # Write to file
# with open(save_dir + 'aucs%s.txt' % suffix, 'w') as fout:
#     for line in lines:
#         fout.write(line)

# # Plot histogram and AUC as a function of num of examples
# _, ax = plt.subplots(1, 1)
# ax.scatter(word_freqs, scores, marker='.')
# ax.set_xlabel('# examples')
# ax.set_ylabel('AUC')
# ax.set_title('%s | avg: %.3f | N = %d' %
#              (title, weighted_avg, scores.size))
# ax.set_yticks(np.arange(0., 1.1, 0.1))
# ax.grid()
# plt.savefig(save_dir + 'roc-auc-examples.png', bbox_inches='tight')

# _, ax = plt.subplots(1, 1)
# ax.hist(scores, bins=20)
# ax.set_xlabel('AUC')
# ax.set_ylabel('# labels')
# ax.set_title('%s | avg: %.3f | N = %d' %
#              (title, weighted_avg, scores.size))
# ax.set_xticks(np.arange(0., 1., 0.1))
# plt.savefig(save_dir + 'roc-auc.png', bbox_inches='tight')

# _, ax = plt.subplots(1, 1)
# for fpr, tpr in zip(fprs, tprs):
#     ax.plot(fpr, tpr, lw=1)
# ax.plot([0, 1], [0, 1], color='navy', linestyle='--')
# ax.set_xlim([0.0, 1.0])
# ax.set_ylim([0.0, 1.05])
# ax.set_xlabel('False Positive Rate')
# ax.set_ylabel('True Positive Rate')
# ax.set_title('%s | avg: %.3f | N = %d' %
#              (title, weighted_avg, scores.size))
# plt.savefig(save_dir + 'roc-auc-all.png', bbox_inches='tight')

# final_dict.update({
#     'rocauc_avg': avg_auc,
#     'rocauc_stddev': scores.std(),
#     'rocauc_w_avg': weighted_avg,
#     'rocauc_n': scores.size,
#     'rocs': rocs
# })