In [25]:
from scipy.stats import loguniform, uniform
import numpy as np
import argparse
import os
import sys
import time
import json
import pandas as pd

from IPython import embed

def convert(o):
    if isinstance(o, np.int64): return int(o)  
    raise TypeError

def select_hyperparams(config, output_name, model, is_arc, score_key='f_macro'):
    ### make directories
    config_path, checkpoint_path, result_path = make_dirs(config)

    setup_params = ['tune_params', 'num_search_trials', 'dir_name']
    model_params = set()
    for p in config:
        if p in setup_params or ('range' in p or 'algo' in p or 'type' in p or p.startswith('CON')): continue
        model_params.add(p)
    print("[model params] {}".format(model_params))

    score_lst = []
    time_lst = []
    best_epoch_lst = []
    tn2vals = dict()
    for trial_num in range(int(config['num_search_trials'])):
        ### sample values
        print("[trial {}] Starting...".format(trial_num))
        print("[trial {}] sampling parameters in {}".format(trial_num, config['tune_params']))

        constraints_OK = False
        while not constraints_OK:
            p2v = sample_values(trial_num)
            constraints_OK = check_constraints(config, p2v)
        tn2vals[trial_num] = p2v

        ### construct the appropriate config file
        config_file_name = config_path + 'config-{}.txt'.format(trial_num)
        print("[trial {}] writing configuration to {}".format(trial_num, config_file_name))
        print("[trial {}] checkpoints to {}".format(trial_num, checkpoint_path))
        print("[trial {}] results to {}".format(trial_num, result_path))
        f = open(config_file_name, 'w')
        model_name = '{}_t{}'.format(config['name'], trial_num)
        f.write('name:{}\n'.format(model_name)) # include trial number in name
        f.write('ckp_path:{}\n'.format(checkpoint_path)) # checkpoint save location
        f.write('res_path:{}\n'.format(result_path)) # results save location
        for p in model_params:
            if p == 'name': continue
            f.write('{}:{}\n'.format(p, config[p]))
        for p in p2v:
            f.write('{}:{}\n'.format(p, p2v[p]))
        f.flush()

        ### run the script
        print("[trial {}] running cross validation".format(trial_num))
        start_time = time.time()
        if model == 'adv':
            os.system("./adv_train.sh 1 {} 0 {} > {}log_t{}.txt".format(config_file_name, score_key, result_path, trial_num))
        elif model == 'bicond':
            os.system("./bicond.sh {} {} > {}log_t{}.txt".format(config_file_name, score_key, result_path, trial_num))
        else:
            print("ERROR: model {} is not supported".format(model))
            sys.exit(1)
        script_time = (time.time() - start_time) / 60.
        print("[trial {}] running on ARC took {:.4f} minutes".format(trial_num, script_time))

        ### process the result and update information on best
        if model == 'adv':
            res_f = open('{}{}_t{}-{}.top5_{}.txt'.format(result_path, config['name'], trial_num, config['enc'], score_key), 'r')
        else:
            res_f = open('{}{}_t{}.top5_{}.txt'.format(result_path, config['name'], trial_num, score_key), 'r')
        res_lines = res_f.readlines()
        score_lst.append(res_lines[-2].strip().split(':')[1])
        time_lst.append(script_time)
        best_epoch_lst.append(res_lines[-3].strip().split(':')[1])

        print("[trial {}] Done.".format(trial_num))
        print()

    ### save the resulting scores and times, for calculating the expected validation f1
    data = []
    for ti in tn2vals:
        data.append([ti, score_lst[ti], time_lst[ti], best_epoch_lst[ti], json.dumps(tn2vals[ti], default=convert)])
    df = pd.DataFrame(data, columns=['trial_num', 'avg_score', 'time', 'best_epoch', 'param_vals'])
    df.to_csv('data/model_results/{}-{}trials/{}'.format(config['dir_name'], config['num_search_trials'],
                                                      output_name), index=False)
    print("results to {}".format(output_name))


def parse_config(fname):
    f = open(fname, 'r')
    lines = f.readlines()
    n2info = dict()
    for l in lines:
        n, info = l.strip().split(':')
        n2info[n] = info

    n2info['tune_params'] = n2info['tune_params'].split(',')
    for p in n2info['tune_params']:
        t = n2info['{}_type'.format(p)]
        n2info['{}_range'.format(p)] = list(map(lambda x: int(x) if t == 'int' else
                                                    float(x) if t == 'float' else x,
                                                    n2info['{}_range'.format(p)].split('-')))
    return n2info


def sample_values(trial_num):
    p2v = dict()
    for p in config['tune_params']:
        a = config['{}_algo'.format(p)]
        if a == 'selection':        #To select in order from a list of hyperparam values
            p2v[p] = config['{}_range'.format(p)][trial_num]
        elif a == 'choice':         #To randomly select any value from a list of hyperparam values
            p2v[p] = np.random.choice(config['{}_range'.format(p)])
        else:                       #To randomly select a value from a given range
            min_v, max_v = config['{}_range'.format(p)]
            if a == 'loguniform':
                p2v[p] = loguniform.rvs(min_v, max_v)
            elif a == 'uniform-integer':
                p2v[p] = np.random.randint(min_v, max_v + 1)
            elif a == 'uniform-float':
                p2v[p] = uniform.rvs(min_v, max_v)
            else:
                print("ERROR: sampling method specified as {}".format(a))

    return p2v


def check_constraints(n2info, p2v):
    constraints_OK = True
    for n in n2info:
        if not n.startswith('CON'): continue
        eq = n2info[n].split('#') # equations should be in format param1#symbol#param2
        if len(eq) == 3:
            con_res = parse_equation(p2v[eq[0]], eq[1], p2v[eq[2]])
        elif len(eq) == 4:
            if eq[0] in p2v:
                v1 = p2v[eq[0]]
                s = eq[1]
                v2 = float(eq[2]) * p2v[eq[3]]
            else:
                v1 = float(eq[0]) * p2v[eq[1]]
                s = eq[2]
                v2 = p2v[eq[3]]
            con_res = parse_equation(v1, s, v2)
        else:
            print("ERROR: equation not parsable {}".format(eq))
            sys.exit(1)
        constraints_OK = con_res and constraints_OK
    return constraints_OK


def parse_equation(v1, s, v2):
    if s == '<': return v1 < v2
    elif s == '<=': return v1 <= v2
    elif s == '=': return v1 == v2
    elif s == '!=': return v1 != v2
    elif s == '>': return v1 > v2
    elif s == '>=': return v1 >= v2
    else:
        print("ERROR: symbol {} not recognized".format(s))
        sys.exit(1)


def make_dirs(config):
    config_path = 'data/config/{}-{}trials/'.format(config['dir_name'],
                                                    config['num_search_trials'])
    checkpoint_path = 'data/checkpoints/{}-{}trials/'.format(config['dir_name'],
                                                             config['num_search_trials'])
    result_path = 'data/model_results/{}-{}trials/'.format(config['dir_name'],
                                             config['num_search_trials'])
    for p_name, p_path in [('config_path', config_path), ('ckp_path', checkpoint_path),
                           ('result_path', result_path)]:
        if not os.path.exists(p_path):
            os.makedirs(p_path)
        else:
            print("[{}] Directory {} already exists!".format(p_name, p_path))
            sys.exit(1)
    return config_path, checkpoint_path, result_path


def remove_dirs(config):
    config_path = 'data/config/{}-{}trials/'.format(config['dir_name'],
                                                    config['num_search_trials'])
    checkpoint_path = 'data/checkpoints/{}-{}trials/'.format(config['dir_name'],
                                                             config['num_search_trials'])
    result_path = 'data/model_results/{}-{}trials/'.format(config['dir_name'],
                                             config['num_search_trials'])
    for p_name, p_path in [('config_path', config_path), ('ckp_path', checkpoint_path),
                           ('result_path', result_path)]:
        if not os.path.exists(p_path):
            print("[{}] directory {} doesn't exist".format(p_name, p_path))
            continue
        else:
            print("[{}] removing all files from {}".format(p_name, p_path))
            for fname in os.listdir(p_path):
                os.remove(os.path.join(p_path, fname))
            print("[{}] removing empty directory".format(p_name))
            os.rmdir(p_path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--mode', help='What to do', required=True)
    parser.add_argument('-s', '--settings', help='Name of the file containing hyperparam info', required=True)
    # model_name should be bert-text-level or adv or bicond currently and is to be specified when is_arc is True.
    parser.add_argument('-n', '--model', help='Name of the model to run', required=False, default='adv')
    parser.add_argument('-o', '--output', help='Name of the output file (full path)', required=False,
                        default='trial_results.csv')
    parser.add_argument('-k', '--score_key', help='Score key for optimization', required=False, default='f_macro')
    args = vars(parser.parse_args())

    config = parse_config(args['settings'])

    if args['mode'] == '1':
        ## run hyperparam search
        remove_dirs(config)
        select_hyperparams(config, args['output'], args['model'], is_arc=('arc' in args['settings'] or 'twitter' in args['settings']), score_key=args['score_key'])
    elif args['mode'] == '2':
        ## remove directories
        remove_dirs(config)
    else:
        print("ERROR. exiting")

usage: ipykernel_launcher.py [-h] -m MODE -s SETTINGS [-n MODEL] [-o OUTPUT]
                             [-k SCORE_KEY]
ipykernel_launcher.py: error: the following arguments are required: -m/--mode, -s/--settings


SystemExit: 2

In [None]:
python train_and_eval_model.py [-h] -m "train" -s SETTINGS [-n MODEL] [-o OUTPUT]
                             [-k SCORE_KEY]

In [26]:
python train_and_eval_model.py --mode "train" --config_file config_example_toad.txt --trn_data data/twitter_data_naacl/twitter_testA_seenval/train.csv --dev_data data/twitter_data_naacl/twitter_testA_seenval/validation.csv --score_key f_macro --topics_vocab data/resources/twitter-topic-TRN-semi-sup.vocab.pkl --mode train

SyntaxError: invalid syntax (289735962.py, line 1)

In [10]:
import torch, pickle, time, json, copy
from sklearn.metrics import f1_score, precision_score, recall_score
import numpy as np

class TorchModelHandler:
    '''
    Class that holds a model and provides the functionality to train it,
    save it, load it, and evaluate it. The model used here is assumed to be
    written in pytorch.
    '''
    # def __init__(self, model, loss_function, dataloader, optimizer, name, num_ckps=10,
    #              use_score='f_macro', device='cpu', use_last_batch=True):
    def __init__(self, num_ckps=10, use_score='f_macro', use_cuda=False,
                 checkpoint_path='data/checkpoints/',
                 result_path='data/', **params):
        super(TorchModelHandler, self).__init__()
        # data fields
        self.model = params['model']
        self.embed_model = params['embed_model']
        self.dataloader = params['dataloader']
        self.batching_fn = params['batching_fn']
        self.batching_kwargs = params['batching_kwargs']
        self.setup_fn = params['setup_fn']

        self.num_labels = self.model.num_labels
        self.name = params['name']

        # optimization fields
        self.loss_function = params['loss_function']
        self.optimizer = params['optimizer']
        self.fine_tune = params.get('fine_tune', False)

        # stats fields
        self.checkpoint_path = checkpoint_path
        self.checkpoint_num = 0
        self.num_ckps = num_ckps
        self.epoch = 0

        self.result_path = result_path

        # evaluation fields
        self.score_dict = dict()
        self.max_score = 0.
        self.max_lst = []  # to keep top 5 scores
        self.score_key = use_score
        self.blackout_start = params['blackout_start']
        self.blackout_stop = params['blackout_stop']

        # GPU support
        self.use_cuda = use_cuda
        if self.use_cuda:
            # move model and loss function to GPU, NOT the embedder
            self.model = self.model.to('cuda')
            self.loss_function = self.loss_function.to('cuda')

    def save_best(self, data=None, scores=None, data_name=None, class_wise=False):
        '''
        Evaluates the model on data and then updates the best scores and saves the best model.
        :param data: data to evaluate and update based on. Default (None) will evaluate on the internally
                        saved data. Otherwise, should be a DataSampler. Only used if scores is not None.
        :param scores: a dictionary of precomputed scores. Default (None) will compute a list of scores
                        using the given data, name and class_wise flag.
        :param data_name: the name of the data evaluating and updating on. Only used if scores is not None.
        :param class_wise: lag to determine whether to compute class-wise scores in
                            addition to macro-averaged scores. Only used if scores is not None.
        '''
        if scores is None:
            # evaluate and print
            scores = self.eval_and_print(data=data, data_name=data_name,
                                         class_wise=class_wise)
        scores = copy.deepcopy(scores)  # copy the scores, otherwise storing a pointer which won't track properly

        if self.epoch in range(self.blackout_start, self.blackout_stop):
            return
            # update list of top scores
        curr_score = scores[self.score_key]
        score_updated = False
        if len(self.max_lst) < 5:
            score_updated = True
            if len(self.max_lst) > 0:
                prev_max = self.max_lst[-1][0][self.score_key] # last thing in the list
            else:
                prev_max = curr_score
            self.max_lst.append((scores, self.epoch - 1))
        elif curr_score > self.max_lst[0][0][self.score_key]: # if bigger than the smallest score
            score_updated = True
            prev_max = self.max_lst[-1][0][self.score_key] # last thing in the list
            self.max_lst[0] = (scores, self.epoch - 1) #  replace smallest score

        # update best saved model and file with top scores
        if score_updated:
            # prev_max = self.max_lst[-1][0][self.score_key]
            # sort the scores
            self.max_lst = sorted(self.max_lst, key=lambda p: p[0][self.score_key])  # lowest first
            # write top 5 scores
            f = open('{}{}.top5_{}.txt'.format(self.result_path, self.name, self.score_key), 'w')  # overrides
            for p in self.max_lst:
                f.write('Epoch: {}\nScore: {}\nAll Scores: {}\n'.format(p[1], p[0][self.score_key],
                                                                            json.dumps(p[0])))
            # save best model step, if its this one
            print(curr_score, prev_max)
            if curr_score > prev_max or self.epoch == 1:
                self.save(num='BEST')

    def save(self, num=None):
        '''
        Saves the pytorch model in a checkpoint file.
        :param num: The number to associate with the checkpoint. By default uses
                    the internally tracked checkpoint number but this can be changed.
        '''
        if num is None:
            check_num = self.checkpoint_num
        else: check_num = num

        torch.save({
            'epoch': self.epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': self.loss
        }, '{}ckp-{}-{}.tar'.format(self.checkpoint_path, self.name, check_num))

        if num is None:
            self.checkpoint_num = (self.checkpoint_num + 1) % self.num_ckps

    def load(self, filename='data/checkpoints/ckp-[NAME]-FINAL.tar', use_cpu=False):
        '''
        Loads a saved pytorch model from a checkpoint file.
        :param filename: the name of the file to load from. By default uses
                        the final checkpoint for the model of this' name.
        '''
        filename = filename.replace('[NAME]', self.name)
        checkpoint = torch.load(filename)
        self.model.load_state_dict(checkpoint['model_state_dict'])

    def train_step(self):
        '''
        Runs one epoch of training on this model.
        '''
        print("[{}] epoch {}".format(self.name, self.epoch))
        self.model.train()
        self.loss = 0.  # clear the loss
        start_time = time.time()
        for i_batch, sample_batched in enumerate(self.dataloader):

            # zero gradients before EVERY optimizer step
            self.model.zero_grad()
            if self.tune_embeds:
                self.embed_model.zero_grad()

            y_pred, labels = self.get_pred_with_grad(sample_batched)

            label_tensor = torch.tensor(labels)
            if self.use_cuda:
                # move labels to cuda if necessary
                label_tensor = label_tensor.to('cuda')

            if self.dataloader.weighting:
                batch_loss = self.loss_function(y_pred, label_tensor)
                weight_lst = torch.tensor([self.dataloader.topic2c2w[b['ori_topic']][b['label']]
                                           for b in sample_batched])
                if self.use_cuda:
                    weight_lst = weight_lst.to('cuda')
                graph_loss = torch.mean(batch_loss * weight_lst)
            else:
                graph_loss = self.loss_function(y_pred, label_tensor)

            # self.loss = graph_loss.item()
            self.loss += graph_loss.item()  # update loss

            graph_loss.backward()

            self.optimizer.step()

        end_time = time.time()
        # self.dataloader.reset()
        print("   took: {:.1f} min".format((end_time - start_time)/60.))
        self.epoch += 1

    def compute_scores(self, score_fn, true_labels, pred_labels, class_wise, name):
        '''
        Computes scores using the given scoring function of the given name. The scores
        are stored in the internal score dictionary.
        :param score_fn: the scoring function to use.
        :param true_labels: the true labels.
        :param pred_labels: the predicted labels.
        :param class_wise: flag to determine whether to compute class-wise scores in
                            addition to macro-averaged scores.
        :param name: the name of this score function, to be used in storing the scores.
        '''
        labels = [i for i in range(2)]
        n = float(len(labels))

        vals = score_fn(true_labels, pred_labels, labels=labels, average=None)
        self.score_dict['{}_macro'.format(name)] = sum(vals) / n

        if class_wise:
            self.score_dict['{}_anti'.format(name)] = vals[0]
            self.score_dict['{}_pro'.format(name)] = vals[1]
            if n > 2:
                self.score_dict['{}_none'.format(name)] = vals[2]

    def eval_model(self, data=None, class_wise=False, data_name=None):
        '''
        Evaluates this model on the given data. Stores computed
        scores in the field "score_dict". Currently computes macro-averaged
        F1 scores, precision and recall. Can also compute scores on a class-wise basis.
        :param data: the data to use for evaluation. By default uses the internally stored data
                    (should be a DataSampler if passed as a parameter).
        :param class_wise: flag to determine whether to compute class-wise scores in
                            addition to macro-averaged scores.
        :return: a map from score names to values
        '''
        pred_labels, true_labels, t2pred, marks = self.predict(data)
        self.score(pred_labels, true_labels, class_wise, t2pred, marks)

        return self.score_dict

    def predict(self, data=None):
        all_y_pred = None
        all_labels = None
        all_marks = None
        all_tar_in_twe = None

        self.model.eval()
        self.loss = 0.

        if data is None:
            data = self.dataloader

        t2pred = dict()
        for sample_batched in data:
            with torch.no_grad():
                # print(sample_batched)
                y_pred, labels = self.get_pred_noupdate(sample_batched)

                label_tensor = torch.tensor(labels)
                if self.use_cuda:
                    # move labels to cuda if necessary
                    label_tensor = label_tensor.to('cuda')  # .cuda()
                self.loss += self.loss_function(y_pred, label_tensor).item()

                y_pred_arr = y_pred.detach().cpu().numpy()
                ls = np.array(labels)

                m = [b['seen'] for b in sample_batched]
                tar_in_twe = [b['target_in_tweet'] for b in sample_batched]

                for bi, b in enumerate(sample_batched):
                    t = b['ori_topic']
                    t2pred[t] = t2pred.get(t, ([], []))
                    t2pred[t][0].append(y_pred_arr[bi, :])
                    t2pred[t][1].append(ls[bi])

                if all_y_pred is None:
                    all_y_pred = y_pred_arr
                    all_labels = ls
                    all_marks = m
                    all_tar_in_twe = tar_in_twe
                else:
                    all_y_pred = np.concatenate((all_y_pred, y_pred_arr), 0)
                    all_labels = np.concatenate((all_labels, ls), 0)
                    all_marks = np.concatenate((all_marks, m), 0)
                    all_tar_in_twe = np.concatenate((all_tar_in_twe, tar_in_twe), 0)

        for t in t2pred:
            t2pred[t] = (np.argmax(t2pred[t][0], axis=1), t2pred[t][1])

        if None not in all_tar_in_twe:
            all_tar_in_twe = np.array(all_tar_in_twe)
            tar_in_twe_mask = np.column_stack((np.zeros(len(all_tar_in_twe)), np.zeros(len(all_tar_in_twe)), all_tar_in_twe))
            all_y_pred = np.where(tar_in_twe_mask == 1, -np.inf, all_y_pred)
        pred_labels = all_y_pred.argmax(axis=1)
        true_labels = all_labels
        return pred_labels, true_labels, t2pred, all_marks

    def eval_and_print(self, data=None, data_name=None, class_wise=False):
        '''
        Evaluates this model on the given data. Stores computed
        scores in the field "score_dict". Currently computes macro-averaged.
        Prints the results to the console.
        F1 scores, precision and recall. Can also compute scores on a class-wise basis.
        :param data: the data to use for evaluation. By default uses the internally stored data
                    (should be a DataSampler if passed as a parameter).
        :param data_name: the name of the data evaluating.
        :param class_wise: flag to determine whether to compute class-wise scores in
                            addition to macro-averaged scores.
        :return: a map from score names to values
        '''
        # Passing data_name to eval_model as evaluation of adv model on train and dev are different
        scores = self.eval_model(data=data, class_wise=class_wise, data_name=data_name)
        print("Evaling on \"{}\" data".format(data_name))
        for s_name, s_val in scores.items():
            print("{}: {}".format(s_name, s_val))
        return scores

    def score(self, pred_labels, true_labels, class_wise, t2pred, marks, topic_wise=False):
        '''
        Helper Function to compute scores. Stores updated scores in
        the field "score_dict".
        :param pred_labels: the predicted labels
        :param true_labels: the correct labels
        :param class_wise: flag to determine whether to compute class-wise scores in
                            addition to macro-averaged scores.
        '''
        # calculate class-wise and macro-averaged F1
        self.compute_scores(f1_score, true_labels, pred_labels, class_wise, 'f')
        # calculate class-wise and macro-average precision
        self.compute_scores(precision_score, true_labels, pred_labels, class_wise, 'p')
        # calculate class-wise and macro-average recall
        self.compute_scores(recall_score, true_labels, pred_labels, class_wise, 'r')

        for v in [1, 0]:
            tl_lst = []
            pl_lst = []
            for m, tl, pl in zip(marks, true_labels, pred_labels):
                if m != v: continue
                tl_lst.append(tl)
                pl_lst.append(pl)
            self.compute_scores(f1_score, tl_lst, pl_lst, class_wise, 'f-{}'.format(v))
            self.compute_scores(precision_score, tl_lst, pl_lst, class_wise, 'p-{}'.format(v))
            self.compute_scores(recall_score, tl_lst, pl_lst, class_wise, 'r-{}'.format(v))

        if topic_wise:
            for t in t2pred:
                self.compute_scores(f1_score, t2pred[t][1], t2pred[t][0], class_wise,
                                    '{}-f'.format(t))

    def get_pred_with_grad(self, sample_batched):
        '''
        Helper function for getting predictions while tracking gradients.
        Used for training the model.
        OVERRIDES: super method.
        :param sample_batched: the batch of data samples
        :return: the predictions for the batch (as a tensor) and the true
                    labels for the batch (as a numpy array)
        '''
        args = self.batching_fn(sample_batched, **self.batching_kwargs)

        if not self.fine_tune:
            # EMBEDDING
            embed_args = self.embed_model(**args)
            args.update(embed_args)

            # PREDICTION
            y_pred = self.model(*self.setup_fn(args, self.use_cuda))

        else:
            y_pred = self.model(**args)

        labels = args['labels']

        return y_pred, labels

    def get_pred_noupdate(self, sample_batched):
        '''
        Helper function for getting predictions without tracking gradients.
        Used for evaluating the model or getting predictions for other reasons.
        OVERRIDES: super method.
        :param sample_batched: the batch of data samples
        :return: the predictions for the batch (as a tensor) and the true labels
                    for the batch (as a numpy array)
        '''
        args = self.batching_fn(sample_batched, **self.batching_kwargs)

        with torch.no_grad():
            if not self.fine_tune:
                # EMBEDDING
                embed_args = self.embed_model(**args)
                args.update(embed_args)

                # PREDICTION
                y_pred = self.model(*self.setup_fn(args, self.use_cuda))
            else:
                y_pred = self.model(**args)

            labels = args['labels']

        return y_pred, labels


class AdvTorchModelHandler(TorchModelHandler):
    def __init__(self, num_ckps=10, use_score='f_macro', use_cuda=False, use_last_batch=True,
                 num_gpus=None, checkpoint_path='data/checkpoints/',
                 result_path='data/', opt_for='score_key', **params):
        TorchModelHandler.__init__(self, num_ckps=num_ckps, use_score=use_score, use_cuda=use_cuda,
                                   use_last_batch=use_last_batch, num_gpus=num_gpus,
                                   checkpoint_path=checkpoint_path, result_path=result_path,
                                   opt_for=opt_for,
                                   **params)
        self.adv_optimizer = params['adv_optimizer']
        self.tot_epochs = params['tot_epochs']
        self.initial_lr = params['initial_lr']
        self.alpha = params['alpha']
        self.beta = params['beta']
        self.num_constant_lr = params['num_constant_lr']
        self.batch_size = params['batch_size']

    def adjust_learning_rate(self, epoch):
        if epoch >= self.num_constant_lr:
            tot_epochs_for_calc = self.tot_epochs - self.num_constant_lr
            epoch_for_calc = epoch - self.num_constant_lr
            p = epoch_for_calc / tot_epochs_for_calc
            new_lr = self.initial_lr / ((1 + self.alpha * p) ** self.beta)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = new_lr
            for param_group in self.adv_optimizer.param_groups:
                param_group['lr'] = new_lr

    def get_learning_rate(self):
        for param_group in self.optimizer.param_groups:
            lr = param_group['lr']
            break
        return lr

    def train_step(self):
        '''
        Runs one epoch of training on this model.
        '''
        if self.epoch > 0:  # self.loss_function.use_adv:
            self.loss_function.update_param_using_p(self.epoch)  # update the adversarial parameter
        print("[{}] epoch {}".format(self.name, self.epoch))
        print("Adversarial parameter rho - {}".format(self.loss_function.adv_param))
        print("Learning rate - {}".format(self.get_learning_rate()))
        self.model.train()
        # clear the loss
        self.loss = 0.
        self.adv_loss = 0
        # TRAIN
        start_time = time.time()
        print(len(self.dataloader))
        for i_batch, sample_batched in enumerate(self.dataloader):
            print("Batch {} in epoch {} -".format(i_batch, self.epoch))
            # zero gradients before EVERY optimizer step
            self.model.zero_grad()

            pred_info, labels = self.get_pred_with_grad(sample_batched)

            label_tensor = torch.tensor(labels, device=('cuda' if self.use_cuda else 'cpu'))    #Getting stance labels
            topic_tensor = torch.tensor([b['topic_i'] for b in sample_batched],                 #Getting topic indices for train topics
                                        device=('cuda' if self.use_cuda else 'cpu'))
            pred_info['W'] = self.model.trans_layer.W
            pred_info['topic_i'] = topic_tensor         #Assigning topic indices to this dictionary element which is then used to calc adversarial loss on predicting train data topics

            # While training we want to compute adversarial loss.
            graph_loss_all, graph_loss_adv = self.loss_function(pred_info, label_tensor, compute_adv_loss=True)
            self.loss += graph_loss_all.item()
            self.adv_loss += graph_loss_adv.item()
            graph_loss_all.backward(retain_graph=True)  # NOT on adv. params
            # graph_loss_all.backward(retain_graph=self.loss_function.use_adv) # NOT on adv. params
            self.optimizer.step()

            print("Main loss", graph_loss_all.item())

            self.model.zero_grad()
            # if self.loss_function.use_adv:
            if True:  # self.loss_function.use_adv: - always do this, train adversary a bit first on it's own
                print("Adv loss", graph_loss_adv.item())
                graph_loss_adv.backward()
                self.adv_optimizer.step()
                # only on adv params

        end_time = time.time()
        # self.dataloader.reset()
        print("   took: {:.1f} min".format((end_time - start_time) / 60.))
        self.epoch += 1
        self.adjust_learning_rate(self.epoch)                # Adjusts the main and adversary optimizer learning rates using logic in base paper.

    def predict(self, data=None, data_name='DEV'):
        all_y_pred = None
        true_labels = None
        all_top_pred = None
        true_topics = None
        all_marks = None
        all_tar_in_twe = None

        self.model.eval()
        self.loss = 0.
        self.adv_loss = 0.

        if data is None:
            data = self.dataloader

        t2pred = dict()
        for sample_batched in data:
            with torch.no_grad():
                pred_info, labels = self.get_pred_noupdate(sample_batched)

                label_tensor = torch.tensor(labels, device=('cuda' if self.use_cuda else 'cpu'))
                pred_info['W'] = self.model.trans_layer.W

                if data_name == 'TRAIN':        #Predicting on train data the adversarial loss - irrespective of whether adv is included in main model or not
                    topics = [b['topic_i'] for b in sample_batched]
                    topic_tensor = torch.tensor(topics, device=('cuda' if self.use_cuda else 'cpu'))
                    pred_info['topic_i'] = topic_tensor
                    graph_loss_all, graph_loss_adv = self.loss_function(pred_info, label_tensor, compute_adv_loss=True)
                else:
                    # graph_loss_adv will be 0 - not calculated. graph_loss_all won't include adv loss.
                    graph_loss_all, graph_loss_adv = self.loss_function(pred_info, label_tensor, compute_adv_loss=False)

                self.loss += graph_loss_all.item()
                self.adv_loss += graph_loss_adv.item()

                y_pred_arr = pred_info['stance_pred'].detach().cpu().numpy()
                ls = np.array(labels)

                m = [b['seen'] for b in sample_batched]
                tar_in_twe = [b['target_in_tweet'] for b in sample_batched]

                if data_name == 'TRAIN':
                    top_pred_arr = pred_info['adv_pred'].detach().cpu().numpy()
                    tops = np.array(topics)

                for bi, b in enumerate(sample_batched):
                    t = b['ori_topic']
                    t2pred[t] = t2pred.get(t, ([], []))
                    t2pred[t][0].append(y_pred_arr[bi, :])
                    t2pred[t][1].append(ls[bi])

                if all_y_pred is None:
                    all_y_pred = y_pred_arr
                    true_labels = ls
                    all_marks = m
                    all_tar_in_twe = tar_in_twe
                    if data_name == 'TRAIN':
                        all_top_pred = top_pred_arr
                        true_topics = tops
                else:
                    all_y_pred = np.concatenate((all_y_pred, y_pred_arr), 0)
                    true_labels = np.concatenate((true_labels, ls), 0)
                    all_marks = np.concatenate((all_marks, m), 0)
                    all_tar_in_twe = np.concatenate((all_tar_in_twe, tar_in_twe), 0)
                    if data_name == 'TRAIN':
                        all_top_pred = np.concatenate((all_top_pred, top_pred_arr), 0)
                        true_topics = np.concatenate((true_topics, tops), 0)

        for t in t2pred:
            t2pred[t] = (np.argmax(t2pred[t][0], axis=1), t2pred[t][1])

        if None not in all_tar_in_twe:
            all_tar_in_twe = np.array(all_tar_in_twe)
            tar_in_twe_mask = np.column_stack((np.zeros(len(all_tar_in_twe)), np.zeros(len(all_tar_in_twe)), all_tar_in_twe))
            all_y_pred = np.where(tar_in_twe_mask == 1, -np.inf, all_y_pred)

        pred_labels = all_y_pred.argmax(axis=1)
        if data_name == 'TRAIN':
            pred_topics = all_top_pred.argmax(axis=1)
        else:
            pred_topics = None

        return pred_labels, true_labels, t2pred, pred_topics, true_topics, all_marks

    def eval_model(self, data=None, class_wise=False, data_name='DEV'):
        # pred_topics and true_topics will be none while evaluating on dev set
        pred_labels, true_labels, t2pred, pred_topics, true_topics, marks = self.predict(data, data_name)
        self.score(pred_labels, true_labels, class_wise, t2pred, marks)

        # compute score on topic prediction task - used to evaluate adversary performance on train dataset during training
        if data_name == 'TRAIN':
            self.compute_scores(f1_score, true_topics, pred_topics, class_wise, 'topic-f')

        return self.score_dict


In [11]:
import torch
import torch.nn as nn
import math
from IPython import embed

class ReconstructionLoss(torch.nn.Module):
    def __init__(self):
        super(ReconstructionLoss, self).__init__()
        self.tanh = nn.Tanh()

    def forward(self, ori_embeds, model_embeds, embed_l):
        # (B, L, E)
        temp = torch.norm(model_embeds - self.tanh(ori_embeds), dim=2) ** 2
        lrec = temp.sum(1) / embed_l
        return lrec.mean() # combine the loss across the batch


class TransformationLoss(torch.nn.Module):
    def __init__(self, dim, l, use_cuda=False):
        super(TransformationLoss, self).__init__()

        self.eye = torch.eye(dim, device='cuda' if use_cuda else 'cpu')
        self.l = l

    def forward(self, W):
        temp =  self.l * torch.norm(W - self.eye) ** 2
        return temp


class AdvBasicLoss(torch.nn.Module):
    def __init__(self, trans_dim, trans_param, num_no_adv=None, tot_epochs=20, rho_adv=False, gamma=10,
                 rec_weight=1, semi_sup=False, use_cuda=False):
        super(AdvBasicLoss, self).__init__()

        self.rec_loss = ReconstructionLoss()
        self.trans_loss = TransformationLoss(dim=trans_dim, l=trans_param, use_cuda=use_cuda)

        self.adv_param = 0. # start with the adversary weight set to 0


        self.semi_sup = semi_sup
        if self.semi_sup:
            self.stance_loss = nn.CrossEntropyLoss(ignore_index=3)
        else:
            self.stance_loss = nn.CrossEntropyLoss()
        self.topic_loss = nn.CrossEntropyLoss()
        #Adversary is not used for num_no_adv initial epochs
        self.use_adv = num_no_adv == 0
        self.num_no_adv = num_no_adv
        self.tot_epochs = tot_epochs
        self.rec_weight = rec_weight
        self.i = 0
        self.rho_adv = rho_adv
        self.gamma = gamma
        self.use_cuda = use_cuda

    def update_param_using_p(self, epoch):
        if epoch >= self.num_no_adv:
            self.use_adv = True
            tot_epochs_for_calc = self.tot_epochs - self.num_no_adv
            epoch_for_calc = epoch - self.num_no_adv
            p = epoch_for_calc/tot_epochs_for_calc

            self.adv_param = 2/(1 + math.exp(-self.gamma*p)) - 1
        else:
            self.use_adv = False

    def forward(self, pred_info, labels, compute_adv_loss=True, print_=False):
        lrec = self.rec_weight * self.rec_loss(ori_embeds=pred_info['text'], model_embeds=pred_info['recon_embeds'],
                         embed_l=pred_info['text_l'])
        lrec_topic = self.rec_weight * self.rec_loss(ori_embeds=pred_info['topic'], model_embeds=pred_info['topic_recon_embeds'],
                                                 embed_l=pred_info['topic_l'])

        ltrans = self.trans_loss(W=pred_info['W'])
        llabel = self.stance_loss(pred_info['stance_pred'], labels)
        ladv = torch.tensor(0)
        adversarial_loss = torch.tensor(0)
        if self.use_cuda:
            ladv = ladv.to('cuda')
            adversarial_loss = adversarial_loss.to('cuda')
        if compute_adv_loss:        #Ladv is computed only on the train dataset else it is left as 0.
            ladv = self.topic_loss(pred_info['adv_pred'], pred_info['topic_i'])
            if self.rho_adv:
                adversarial_loss = self.adv_param * self.topic_loss(pred_info['adv_pred_'], pred_info['topic_i'])
            else:
                adversarial_loss = self.topic_loss(pred_info['adv_pred_'], pred_info['topic_i'])

        if print_:
            print("lrec - {}, lrec_topic - {}, ltrans - {}, llabel - {}, ladv - {}".format(lrec, lrec_topic, ltrans, llabel, ladv))

        self.i += 1
        if self.use_adv:
            if self.i % 100 == 0:
                print("loss: {:.4f} + {:.4f} + {:.4f} - {:.4f}; adv: {:.4f}".format(lrec.item(), ltrans.item(), llabel.item(),
                                                   (self.adv_param * ladv).item(), ladv))
            return lrec + lrec_topic + ltrans + llabel - self.adv_param * ladv, adversarial_loss
        else:
            if self.i % 100 == 0:
                print("loss: {:.4f} +  {:.4f} + {:.4f}; adv: {:.4f}".format(lrec.item(), ltrans.item(), llabel.item(),
                                                     ladv))
            return lrec + lrec_topic + ltrans + llabel, adversarial_loss


In [12]:
import torch, pickle, json
from torch.utils.data import Dataset, DataLoader, Sampler
from transformers import BertTokenizer
import pandas as pd
from functools import reduce


class StanceData(Dataset):
    '''
    Holds the stance dataset.
    '''
    def __init__(self, data_name, vocab_name, topic_name=None, name='',
                 max_sen_len=10, max_tok_len=200, max_top_len=5, binary=False,
                 pad_val=0, is_bert=False, add_special_tokens=True, use_tar_in_twe=False):
        self.data_name = data_name
        self.data_file = pd.read_csv(data_name)
        if vocab_name is not None:
            self.word2i = pickle.load(open(vocab_name, 'rb'))
        self.name = name
        self.max_sen_len = max_sen_len
        self.max_tok_len = max_tok_len
        self.max_top_len = max_top_len
        self.binary = binary
        self.pad_value = pad_val
        self.topic2i = pickle.load(open(topic_name, 'rb')) if topic_name is not None else dict()
        self.is_bert = is_bert
        self.add_special_tokens = add_special_tokens
        self.tar_in_twe = ('target_in_tweet' in self.data_file.columns)
        self.use_tar_in_twe = use_tar_in_twe

        if self.is_bert:
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.preprocess_data()

        if self.is_bert:
            # filter unlabeled examples for Twitter
            self.data_file = self.data_file.loc[self.data_file['label'] != 3]
            self.data_file.reset_index(inplace=True)  # reset the index so we can access correctly later

    def process_bert(self):
        print("processing BERT")
        topic_str_lst = []
        text_str_lst = []
        for i in self.data_file.index:
            row = self.data_file.iloc[i]
            num_sens = 1
            if 'topic_str' in self.data_file.columns:
                ori_topic = row['topic_str']
            else:
                ori_topic = ' '.join(json.loads(row['topic']))
                topic_str_lst.append(ori_topic)

            if 'text_s' in self.data_file.columns:
                ori_text = row['text_s']
            else:
                ori_text = ' '.join(sum(json.loads(row['text']), []))
                text_str_lst.append(ori_text)

            text_topic = self.tokenizer(ori_text, ori_topic, padding='max_length', max_length=int(self.max_tok_len),
                                        return_token_type_ids=True,
                                        return_attention_mask=True)
            text = self.tokenizer(ori_text, add_special_tokens=self.add_special_tokens,
                                  max_length=int(self.max_tok_len), padding='max_length')
            topic = self.tokenizer(ori_topic, add_special_tokens=self.add_special_tokens,
                                   max_length=int(self.max_top_len), padding='max_length')
            self.data_file.at[i, 'text_idx'] = text['input_ids']
            self.data_file.at[i, 'ori_text'] = ori_text
            self.data_file.at[i, 'topic_idx'] = topic['input_ids']
            self.data_file.at[i, 'num_sens'] = num_sens
            self.data_file.at[i, 'text_topic_idx'] = text_topic['input_ids']
            self.data_file.at[i, 'token_type_ids'] = text_topic['token_type_ids']
            self.data_file.at[i, 'attention_mask'] = text_topic['attention_mask']
        print("...finished pre-processing for BERT")
        if 'topic_str' not in self.data_file.columns:
            self.data_file['topic_str'] = topic_str_lst

        if 'text_s' not in self.data_file.columns:
            self.data_file['text_s'] = text_str_lst
        return


    def process_nonbert(self):
        # Creating topic_string from tokenized topic column for twitter dataset
        if 'topic_str' not in self.data_file.columns:
            add_topic_string = True
        else:
            add_topic_string = False

        for i in self.data_file.index:
            row = self.data_file.iloc[i]

            # Tokenized text in the form of [[tokenized sentence 1],[tokenized sent 2],...].
            # In twitter data it is a 2 D array with [[tokenized text]].
            ori_text = json.loads(row['text'])
            # Tokenized topic array - 1D array with tokenized topic
            ori_topic = json.loads(row['topic'])

            # index text & topic
            text = [[self.get_index(w) for w in s] for s in ori_text]
            topic = [self.get_index(w) for w in ori_topic][:self.max_top_len]

            text = reduce(lambda x, y: x + y, text)
            text = text[:self.max_tok_len]
            text_lens = len(text)  # compute combined text len
            num_sens = 1
            text_mask = [1] * text_lens

            while len(text) < self.max_tok_len:
                text.append(self.pad_value)
                text_mask.append(0)

            # compute topic len
            topic_lens = len(topic)  # get len (before padding)
            topic_mask = [1] * topic_lens

            # pad topic
            while len(topic) < self.max_top_len:
                topic.append(self.pad_value)
                topic_mask.append(0)

            if 'text_s' in self.data_file.columns:
                ori_text_ = row['text_s']
            else:
                ori_text_ = ' '.join([' '.join(ti) for ti in ori_text])

            if add_topic_string:
                self.data_file.at[i, 'topic_str'] = ' '.join(ori_topic)

            self.data_file.at[i, 'text_idx'] = text
            self.data_file.at[i, 'topic_idx'] = topic
            self.data_file.at[i, 'text_l'] = text_lens
            self.data_file.at[i, 'topic_l'] = topic_lens
            self.data_file.at[i, 'ori_text'] = ori_text_
            self.data_file.at[i, 'num_sens'] = num_sens
            self.data_file.at[i, 'text_mask'] = text_mask
            self.data_file.at[i, 'topic_mask'] = topic_mask

    def preprocess_data(self):
        print('preprocessing data {} ...'.format(self.data_name))

        self.data_file['text_idx'] = [[] for _ in range(len(self.data_file))]
        self.data_file['topic_idx'] = [[] for _ in range(len(self.data_file))]
        self.data_file['text_topic_idx'] = [[] for _ in range(len(self.data_file))]
        self.data_file['token_type_ids'] = [[] for _ in range(len(self.data_file))]
        self.data_file['text_l'] = 0
        self.data_file['ori_text'] = ''
        self.data_file['topic_l'] = 0
        self.data_file['num_sens'] = 0
        self.data_file['text_mask'] = [[] for _ in range(len(self.data_file))]
        self.data_file['topic_mask'] = [[] for _ in range(len(self.data_file))]

        if self.is_bert:
            self.process_bert()
        else:
            self.process_nonbert()

        print("... finished preprocessing")

    def get_index(self, word):
        return self.word2i[word] if word in self.word2i else len(self.word2i)

    def __len__(self):
        return len(self.data_file)

    def __getitem__(self, idx, corpus=None):
        row = self.data_file.iloc[idx]

        l  = int(row['label'])

        if self.tar_in_twe and self.use_tar_in_twe:
            tar_in_twe_value = row['target_in_tweet']
        else:
            tar_in_twe_value = None

        sample = {'text': row['text_idx'], 'topic': row['topic_idx'],
                  'label': l,
                  'txt_l': row['text_l'], 'top_l': row['topic_l'],
                  'ori_topic': row['topic_str'],
                  'ori_text': row['ori_text'],
                  'text_mask': row['text_mask'],
                  'num_s': row['num_sens'],
                  'seen': row['seen?'],
                  }
        if self.is_bert and not self.add_special_tokens:
            sample['text_topic'] = row['text_topic_idx']
            sample['token_type_ids'] = row['token_type_ids']
            sample['attention_mask'] = row['attention_mask']
        else:
            sample['topic_i'] = self.topic2i.get(row['topic'], 0)
            sample['topic_mask'] = row['topic_mask']
            sample['target_in_tweet'] = tar_in_twe_value

        return sample

In [6]:
pip install transformers

Collecting transformers
  Using cached transformers-4.26.1-py3-none-any.whl (6.3 MB)
Collecting huggingface-hub<1.0,>=0.11.0
  Using cached huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-win_amd64.whl (3.3 MB)
     ---------------------------------------- 3.3/3.3 MB 1.1 MB/s eta 0:00:00
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.12.1 tokenizers-0.13.2 transformers-4.26.1
Note: you may need to restart the kernel to use updated packages.


In [13]:
import torch, random
import numpy as np


def load_vectors(vecfile, dim=300, unk_rand=True, seed=0):
    '''
    Loads saved vectors;
    :param vecfile: the name of the file to load the vectors from.
    :return: a numpy array of all the vectors.
    '''
    vecs = np.load(vecfile)
    np.random.seed(seed)

    if unk_rand:
        vecs = np.vstack((vecs, np.random.randn(dim))) # <unk> -> V-2
    else:
        vecs = np.vstack((vecs, np.zeros(dim))) # <unk> -> V - 2
    vecs = np.vstack((vecs, np.zeros(dim))) # pad -> V-1
    vecs = vecs.astype(float, copy=False)

    return vecs

def prepare_batch(sample_batched, **kwargs):
    '''
    Prepares a batch of data to be used in training or evaluation. Includes the text reversed.
    :param sample_batched: a list of dictionaries, where each is a sample
    :return: a dictionary containing:
            a tensor of all the text instances,
            a tensor of all topic instances,
            a list of labels for the text,topic instances
            a list of the text lengths
            a list of the topic lengths
            a list with the original texts
            a list with the original topics
            AND (depending on flags)
            a tensor of the inputs in the format CLS text SEP topic SEP (for Bert)
            a tensor of the token type ids (for Bert)
            a tensor with the generalized topic representations
    '''
    text_lens = np.array([b['txt_l'] for b in sample_batched])
    topic_batch = torch.tensor([b['topic'] for b in sample_batched])
    labels = [b['label'] for b in sample_batched]
    top_lens = [b['top_l'] for b in sample_batched]

    raw_text_batch = [b['ori_text'] for b in sample_batched]
    raw_top_batch = [b['ori_topic'] for b in sample_batched]

    text_batch = torch.tensor([b['text'] for b in sample_batched])

    args = {'text': text_batch, 'topic': topic_batch, 'labels': labels,
            'txt_l': text_lens, 'top_l': top_lens,
            'ori_text': raw_text_batch, 'ori_topic': raw_top_batch}

    if 'text_topic' in sample_batched[0]:
        args['text_topic_batch'] = torch.tensor([b['text_topic'] for b in sample_batched])
        args['token_type_ids'] = torch.tensor([b['token_type_ids'] for b in sample_batched])
        args['attention_mask'] = torch.tensor([b['attention_mask'] for b in sample_batched])

    if 'topic_rep_id' in sample_batched[0]:
        args['topic_rep_ids'] = torch.tensor([b['topic_rep_id'] for b in sample_batched])

    return args


def prepare_batch_adv(sample_batched, **kwargs):
    args = prepare_batch(sample_batched, **kwargs)

    txt_mask = [b['text_mask'] for b in sample_batched]
    args['txt_mask'] = txt_mask

    top_mask = [b['topic_mask'] for b in sample_batched]
    args['top_mask'] = top_mask

    return args


class DataSampler:
    '''
    A sampler for a dataset. Can get samples of differents sizes.
    Is iterable. By default shuffles the data each time all the data
    has been used through iteration.
    '''
    def __init__(self, data, batch_size, shuffle=True):
        self.data = data
        self.batch_size = batch_size
        self.shuffle = shuffle
        random.seed(0)

        self.indices = list(range(len(data)))
        if shuffle:
            random.shuffle(self.indices)
        self.batch_num = 0

    def __len__(self):
        return len(self.data)

    def num_batches(self):
        return len(self.data) / float(self.batch_size)

    def __iter__(self):
        self.indices = list(range(len(self.data)))
        if self.shuffle:
            random.shuffle(self.indices)
        return self

    def __next__(self):
        if self.indices != []:
            idxs = self.indices[:self.batch_size]
            batch = [self.data.__getitem__(i) for i in idxs]
            self.indices = self.indices[self.batch_size:]
            return batch
        else:
            raise StopIteration

    def get(self):
        self.reset()
        return self.__next__()

    def reset(self):
        self.indices = list(range(len(self.data)))
        if self.shuffle: random.shuffle(self.indices)


def setup_helper_bicond(args, use_cuda):
    if use_cuda:
        txt_E= args['txt_E'].to('cuda')  # (B,T,E)
        top_E = args['top_E'].to('cuda')  # (B,C,E)
        txt_l = torch.tensor(args['txt_l']).to('cuda')  # (B, S)
        top_l = torch.tensor(args['top_l']).to('cuda')  # (B)
    else:
        txt_E = args['txt_E']  # (B,T,E)
        top_E = args['top_E']  # (B,C,E)
        txt_l = torch.tensor(args['txt_l'])
        top_l = torch.tensor(args['top_l'])
    return txt_E, top_E, txt_l, top_l


def setup_helper_adv(args, use_cuda):
    if use_cuda:
        txt_E= args['txt_E'].to('cuda')  # (B,T,E)
        top_E = args['top_E'].to('cuda')  # (B,C,E)
    else:
        txt_E = args['txt_E']  # (B,T,E)
        top_E = args['top_E']  # (B,C,E)

    device = 'cuda' if use_cuda else 'cpu'

    txt_l = torch.tensor(args['txt_l'], device=device)  # (B, S)
    top_l = torch.tensor(args['top_l'], device=device) # (B)
    txt_mask = torch.tensor(args['txt_mask'], device=device) # (B, T)
    top_mask = torch.tensor(args['top_mask'], device=device) # (B, C)

    return txt_E, top_E, txt_l, top_l, txt_mask, top_mask


In [15]:
import torch, sys
import torch.nn as nn
import baseline_model_layers as bml
from transformers import  BertForSequenceClassification


class BiCondLSTMModel(torch.nn.Module):
    '''
    Bidirectional Coniditional Encoding LSTM (Augenstein et al, 2016, EMNLP)
    Single layer bidirectional LSTM where initial states are from the topic encoding.
    Topic is also with a bidirectional LSTM. Prediction done with a single layer FFNN with
    tanh then softmax, to use cross-entropy loss.
    '''

    def __init__(self, hidden_dim, embed_dim, input_dim, drop_prob=0, num_layers=1, num_labels=3,
                 use_cuda=False):
        super(BiCondLSTMModel, self).__init__()
        self.use_cuda = use_cuda
        self.num_labels = num_labels

        self.bilstm = bml.BiCondLSTMLayer(hidden_dim, embed_dim, input_dim, drop_prob, num_layers,
                                      use_cuda=use_cuda)
        self.dropout = nn.Dropout(p=drop_prob)  # so we can have dropouts on last layer
        self.pred_layer = bml.PredictionLayer(input_size=2 * num_layers * hidden_dim,
                                          output_size=self.num_labels,
                                          pred_fn=nn.Tanh(), use_cuda=use_cuda)  # This is BiCond specific


    def forward(self, text, topic, text_l, topic_l):

        text = text.transpose(0, 1)  # (T, B, E)
        topic = topic.transpose(0, 1)  # (C,B,E)

        _, combo_fb_hn, _, _ = self.bilstm(text, topic, topic_l, text_l)

        # dropout
        combo_fb_hn = self.dropout(combo_fb_hn)  # (B, H*N, dir*N_layers)

        y_pred = self.pred_layer(combo_fb_hn)  # (B, 2)
        return y_pred




class AdversarialBasic(torch.nn.Module):
    def __init__(self, enc_params, enc_type, stance_dim, topic_dim, num_labels, num_topics,
                 drop_prob=0.0, use_cuda=False):
        super(AdversarialBasic, self).__init__()
        self.enc_type = enc_type
        self.use_cuda = use_cuda
        self.hidden_dim = enc_params['h']
        self.embed_dim = enc_params['embed_dim']
        self.stance_dim = stance_dim
        self.num_labels = num_labels
        self.num_topics = num_topics

        if self.enc_type == 'bicond':
            self.enc = bml.BiCondLSTMLayer(hidden_dim=self.hidden_dim, embed_dim=self.embed_dim, input_dim=self.embed_dim,
                                           drop_prob=enc_params['drop_prob'], num_layers=1, use_cuda=use_cuda)
            self.att_layer = bml.ScaledDotProductAttention(input_dim=2*self.hidden_dim, use_cuda=self.use_cuda)
        else:
            print("ERROR: invalid encoder type. exiting")
            sys.exit(1)
        self.in_dropout = nn.Dropout(p=drop_prob)
        self.out_dropout = nn.Dropout(p=drop_prob)

        self.recon_layer = bml.ReconstructionLayer(hidden_dim=self.hidden_dim, embed_dim=self.embed_dim,
                                                   use_cuda=self.use_cuda)

        self.topic_recon_layer = bml.ReconstructionLayer(hidden_dim=self.hidden_dim, embed_dim=self.embed_dim, use_cuda=self.use_cuda)
        self.trans_layer = bml.TransformationLayer(input_size=2*self.hidden_dim)

        multiplier = 4
        self.stance_classifier = bml.TwoLayerFFNNLayer(input_dim=multiplier*self.hidden_dim, hidden_dim=stance_dim,
                                                       out_dim=self.num_labels, nonlinear_fn=nn.ReLU())
        self.topic_classifier = bml.TwoLayerFFNNLayer(input_dim=2*self.hidden_dim, hidden_dim=topic_dim,
                                                       out_dim=self.num_topics, nonlinear_fn=nn.ReLU())

    def forward(self, text, topic, text_l, topic_l, text_mask=None, topic_mask=None):
        # text: (B, T, E), topic: (B, C, E), text_l: (B), topic_l: (B), text_mask: (B, T), topic_mask: (B, C)

        # apply dropout on the input
        dropped_text = self.in_dropout(text)

        # encode the text
        if self.enc_type == 'bicond':
            output, _, last_top_hn, topic_output = self.enc(dropped_text.transpose(0, 1),
                                              topic.transpose(0, 1),
                                              topic_l, text_l)
            output = output.transpose(0, 1)     #output represents the token level text encodings of size (B,T,2*H)
            topic_output = topic_output.transpose(0, 1)   #Token levek topic embeddings of size (B, C, 2*H)
            last_top_hn = last_top_hn.transpose(0, 1).reshape(-1, 2*self.hidden_dim)        #(B, 2*H)
            att_vecs = self.att_layer(output, last_top_hn)      #(B, 2H)


        # reconstruct the original embeddings
        recon_embeds = self.recon_layer(output, text_mask) #(B, L, E)
        # reconstruct topic embeddings
        topic_recon_embeds = self.topic_recon_layer(topic_output, topic_mask)

        # transform the representation
        trans_reps = self.trans_layer(att_vecs) #(B, 2H)

        trans_reps = self.out_dropout(trans_reps)  # adding dropout
        last_top_hn = self.out_dropout(last_top_hn)

        # stance prediction
        # added topic input to stance classifier
        stance_input = torch.cat((trans_reps, last_top_hn), 1)      #(B, 4H)
        stance_preds = self.stance_classifier(stance_input)

        # topic prediction
        topic_preds = self.topic_classifier(trans_reps)
        topic_preds_ = self.topic_classifier(trans_reps.detach())

        pred_info = {'text': text, 'text_l': text_l,
                     'topic': topic, 'topic_l': topic_l,
                     'adv_pred': topic_preds, 'adv_pred_':topic_preds_, 'stance_pred': stance_preds,
                     'topic_recon_embeds': topic_recon_embeds, 'recon_embeds': recon_embeds}

        return pred_info


class JointSeqBERTLayer(torch.nn.Module):
    def __init__(self, num_labels=3, use_cuda=False):
        super(JointSeqBERTLayer, self).__init__()

        self.num_labels = num_labels
        self.use_cuda = use_cuda
        self.bert_layer = BertForSequenceClassification.from_pretrained('bert-base-uncased')

        self.dim = 768
        if self.use_cuda:
            self.bert_layer = self.bert_layer.to('cuda')

    def forward(self, **kwargs):
        output = self.bert_layer(input_ids=kwargs['text_topic_batch'].to('cuda'),
                                 token_type_ids=kwargs['token_type_ids'].to('cuda'),
                                 attention_mask=kwargs['attention_mask'].to('cuda'))
        return output[0]


class WordEmbedLayer(torch.nn.Module):
    def __init__(self, vecs, static_embeds=True, use_cuda=False):
        super(WordEmbedLayer, self).__init__()
        vec_tensor = torch.tensor(vecs)

        self.embeds = nn.Embedding.from_pretrained(vec_tensor, freeze=static_embeds)

        self.dim = vecs.shape[1]
        print("Input layer embedding size -  ", self.dim)
        self.vocab_size = float(vecs.shape[0])
        self.use_cuda = use_cuda

    def forward(self, **kwargs):
        embed_args = {'txt_E': self.embeds(kwargs['text']).type(torch.FloatTensor),  # (B, T, E)
                      'top_E': self.embeds(kwargs['topic']).type(torch.FloatTensor)}  # (B, C, E)
        return embed_args

In [16]:
import torch, math
import torch.nn as nn
import torch.nn.utils.rnn as rnn


class TwoLayerFFNNLayer(torch.nn.Module):
    '''
    2-layer FFNN with specified nonlinear function
    must be followed with some kind of prediction layer for actual prediction
    '''
    def __init__(self, input_dim, hidden_dim, out_dim, nonlinear_fn):
        super(TwoLayerFFNNLayer, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim

        self.model = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                                   nonlinear_fn,
                                   nn.Linear(hidden_dim, out_dim))

    def forward(self, input):
        return self.model(input)


class PredictionLayer(torch.nn.Module):
    '''
    Predicition layer. linear projection followed by the specified functions
    ex: pass pred_fn=nn.Tanh()
    '''
    def __init__(self, input_size, output_size, pred_fn, use_cuda=False):
        super(PredictionLayer, self).__init__()

        self.use_cuda = use_cuda

        self.input_dim = input_size
        self.output_dim = output_size
        self.pred_fn = pred_fn

        self.model = nn.Sequential(nn.Linear(self.input_dim, self.output_dim, bias=False))

        if self.use_cuda:
            self.model = self.model.to('cuda')#cuda()

    def forward(self, input_data):
        return self.model(input_data)


class BiCondLSTMLayer(torch.nn.Module):
    '''
    Bidirection Conditional Encoding (Augenstein et al. 2016 EMNLP).
    Bidirectional LSTM with initial states from topic encoding.
    Topic encoding is also a bidirectional LSTM.
    '''
    def __init__(self, hidden_dim, embed_dim, input_dim, drop_prob=0, num_layers=1, use_cuda=False):
        super(BiCondLSTMLayer, self).__init__()

        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.use_cuda = use_cuda

        self.topic_lstm = nn.LSTM(input_dim, self.hidden_dim, bidirectional=True)
        self.text_lstm = nn.LSTM(self.embed_dim, self.hidden_dim, bidirectional=True)

    def forward(self, txt_e, top_e, top_l, txt_l):
        ####################
        # txt_e = (Lx, B, E), top_e = (Lt, B, E), top_l=(B), txt_l=(B)
        ########################
        p_top_embeds = rnn.pack_padded_sequence(top_e, top_l, enforce_sorted=False)

        self.topic_lstm.flatten_parameters()

        # feed topic
        topic_output, last_top_hn_cn = self.topic_lstm(p_top_embeds)  # (seq_ln, B, 2*H),((2, B, H), (2, B, H))
        last_top_hn = last_top_hn_cn[0]  # LSTM
        padded_topic_output, _ = rnn.pad_packed_sequence(topic_output,total_length=top_e.shape[0])

        p_text_embeds = rnn.pack_padded_sequence(txt_e, txt_l, enforce_sorted=False)
        self.text_lstm.flatten_parameters()

        # feed text conditioned on topic
        output, (txt_last_hn, _)  = self.text_lstm(p_text_embeds, last_top_hn_cn) # (2, B, H)
        txt_fw_bw_hn = txt_last_hn.transpose(0, 1).reshape((-1, 2 * self.hidden_dim))
        padded_output, _ = rnn.pad_packed_sequence(output, total_length=txt_e.shape[0])
        return padded_output, txt_fw_bw_hn, last_top_hn, padded_topic_output


class ScaledDotProductAttention(torch.nn.Module):
    def __init__(self, input_dim, use_cuda=False):
        super(ScaledDotProductAttention, self).__init__()
        self.input_dim = input_dim

        self.scale = math.sqrt(2 * self.input_dim)

    def forward(self, inputs, query):
        # inputs = (B, L, 2*H), query = (B, 2*H), last_hidden=(B, 2*H)
        sim = torch.einsum('blh,bh->bl', inputs, query) / self.scale  # (B, L)
        att_weights = nn.functional.softmax(sim, dim=1)  # (B, L)
        context_vec = torch.einsum('blh,bl->bh', inputs, att_weights)  # (B, 2*H)
        return context_vec
    

class TransformationLayer(torch.nn.Module):
    '''
    Linear transformation layer
    '''
    def __init__(self, input_size):
        super(TransformationLayer, self).__init__()

        self.dim = input_size

        self.W = torch.empty((self.dim, self.dim))
        self.W = nn.Parameter(nn.init.xavier_normal_(self.W)) # (D, D)

    def forward(self, text):
        # text: (B, D)
        return torch.einsum('bd,dd->bd', text, self.W)


class ReconstructionLayer(torch.nn.Module):
    '''
    Embedding reconstruction layer
    '''
    def __init__(self, hidden_dim, embed_dim, use_cuda=False):
        super(ReconstructionLayer, self).__init__()

        self.hidden_dim = hidden_dim
        self.embed_dim=embed_dim
        self.use_cuda = use_cuda

        self.recon_W = torch.empty((2 * self.hidden_dim, self.embed_dim),
                                   device=('cuda' if self.use_cuda else 'cpu'))
        self.recon_w = nn.Parameter(nn.init.xavier_normal_(self.recon_W))
        self.recon_b = torch.empty((self.embed_dim, 1), device=('cuda' if self.use_cuda else 'cpu'))
        self.recon_b = nn.Parameter(nn.init.xavier_normal_(self.recon_b)).squeeze(1)
        self.tanh = nn.Tanh()

    def forward(self, text_output, text_mask):
        # text_output: (B, T, H), text_mask: (B, T)
        recon_embeds = self.tanh(torch.einsum('blh,he->ble', text_output, self.recon_w) + self.recon_b)  # (B,L,E)
        recon_embeds = torch.einsum('ble,bl->ble', recon_embeds, text_mask)

        return recon_embeds

In [17]:
import torch, sys, os, argparse, time
sys.path.append('./modeling')
import baseline_models as bm
import data_utils, model_utils, datasets
import loss_fn as lf
import torch.optim as optim
import torch.nn as nn
from itertools import chain
import pandas as pd
import copy
from transformers import get_linear_schedule_with_warmup

SEED = 0
LOCAL = True
use_cuda = torch.cuda.is_available()


def train(model_handler, num_epochs, verbose=True, dev_data=None, num_warm=0, phases=False, is_adv=True):
    '''
    Trains the given model using the given data for the specified
    number of epochs. Prints training loss and evaluation starting
    after 10 epochs. Saves at most 10 checkpoints plus a final one.
    :param model_handler: a holder with a model and data to be trained.
                            Assuming the model is a pytorch model.
    :param num_epochs: the number of epochs to train the model for.
    :param verbose: whether or not to print train results while training.
                    Default (True): do print intermediate results.
    '''
    trn_scores_dict = {}
    dev_scores_dict = {}
    for epoch in range(num_epochs):
        if is_adv:
            learning_rate = model_handler.get_learning_rate()
        if phases:
            model_handler.train_step_phases()
        else:
            model_handler.train_step()

        if epoch >= num_warm:
            if verbose:
                # print training loss and training (& dev) scores, ignores the first few epochs
                print("training loss: {}".format(model_handler.loss))
                # eval model on training data
                trn_scores = eval_helper(model_handler, data_name='TRAIN')
                trn_scores_dict[epoch] = copy.deepcopy(trn_scores)
                if is_adv:
                    trn_scores_dict[epoch].update({'lr': copy.deepcopy(learning_rate),
                                                   'rho': copy.deepcopy(model_handler.loss_function.adv_param)})
                # update best scores
                if dev_data is not None:
                    dev_scores = eval_helper(model_handler, data_name='DEV',
                                             data=dev_data)
                    dev_scores_dict[epoch] = copy.deepcopy(dev_scores)
                    model_handler.save_best(scores=dev_scores)
                else:
                    model_handler.save_best(scores=trn_scores)

    print("TRAINED for {} epochs".format(epoch))

    # save final checkpoint
    model_handler.save(num="FINAL")

    # print final training (& dev) scores
    eval_helper(model_handler, data_name='TRAIN')
    if dev_data is not None:
        eval_helper(model_handler,  data_name='DEV', data=dev_data)
    # Can uncomment to save epoch_level_scores
    #save_epoch_level_results_to_csv(trn_scores_dict, dev_scores_dict, model_handler.result_path, model_handler.name, is_adv)


def save_epoch_level_results_to_csv(trn_scores_dict, dev_scores_dict, output_path, name, is_adv):
    '''
    Saves the results from the current epoch to a CSV file
    :param trn_scores_dict: a dictionary containing training scores
    :param dev_scores_dict: a dictionary containing dev set scores
    :param output_path: the path for where to save the scores
    :param name: the prefix for the file name
    :param is_adv: whether or not the scores are from the adversarial  model
    '''
    dev_fscore_overall_list = []
    dev_fscore_seen_list = []
    dev_fscore_unseen_list = []
    train_fscore_overall_list = []
    train_fscore_seen_list = []
    train_fscore_unseen_list = []
    topic_fscore_list = []
    learning_rate_list = []
    rho_list = []
    epochs = []
    for key in trn_scores_dict.keys():
        epochs.append(key)
        if is_adv:
            learning_rate_list.append(trn_scores_dict[key]['lr'])
            rho_list.append(trn_scores_dict[key]['rho'])
            topic_fscore_list.append(trn_scores_dict[key]['topic-f_macro'])
        dev_fscore_overall_list.append(dev_scores_dict[key]['f_macro'])
        dev_fscore_seen_list.append(dev_scores_dict[key]['f-1_macro'])
        dev_fscore_unseen_list.append(dev_scores_dict[key]['f-0_macro'])
        train_fscore_overall_list.append(trn_scores_dict[key]['f_macro'])
        train_fscore_seen_list.append(trn_scores_dict[key]['f-1_macro'])
        train_fscore_unseen_list.append(trn_scores_dict[key]['f-0_macro'])
        
    if is_adv:
        df = pd.DataFrame(list(zip(epochs, learning_rate_list, rho_list, dev_fscore_overall_list, dev_fscore_seen_list,
                                   dev_fscore_unseen_list, topic_fscore_list, train_fscore_overall_list,
                                   train_fscore_seen_list,train_fscore_unseen_list)),
                      columns=['Epoch', 'Learning Rate', 'Rho', 'Dev Fscore overall', 'Dev Fscore seen',
                               'Dev Fscore unseen', 'Topic Fscore', 'Train Fscore overall', 'Train Fscore seen',
                               'Train Fscore unseen'])
    else:
        df = pd.DataFrame(list(zip(epochs, dev_fscore_overall_list, dev_fscore_seen_list, dev_fscore_unseen_list,
                               train_fscore_overall_list, train_fscore_seen_list, train_fscore_unseen_list)),
                      columns=['Epoch', 'Dev Fscore overall', 'Dev Fscore seen', 'Dev Fscore unseen',
                               'Train Fscore overall', 'Train Fscore seen', 'Train Fscore unseen'])
    df.to_csv("{}{}_epoch_level_scores.csv".format(output_path, name), index=False)


def eval_helper(model_handler, data_name, data=None):
    '''
    Helper function for evaluating the model during training.
    Can evaluate on all the data or just a subset of corpora.
    :param model_handler: the holder for the model
    :return: the scores from running on all the data
    '''
    # eval on full corpus
    scores = model_handler.eval_and_print(data=data, data_name=data_name, class_wise=True)
    return scores


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', dest='mode', help='What to do', required=True)
    parser.add_argument('--config_file', dest='config_file', help='Name of the cofig data file', required=False)
    parser.add_argument('--trn_data', dest='trn_data', help='Name of the training data file', required=False)
    parser.add_argument('--dev_data', dest='dev_data', help='Name of the dev data file', default=None, required=False)
    parser.add_argument('--name', dest='name', help='something to add to the saved model name',
                        required=False, default='')
    parser.add_argument('-p', '--num_warm', help='Number of warm-up epochs', required=False,
                        type=int, default=0)
    parser.add_argument('--topics_vocab', dest='topics_vocab', help='Name of the topic file', required=False,
                        type=str, default='twitter-topic.vocab.pkl')
    parser.add_argument('--score_key', dest='score_key', help='Score key for optimization', required=False,
                        default='f_macro')
    parser.add_argument('--saved_model_file_name', dest='saved_model_file_name', required=False, default=None)
    args = parser.parse_args()

    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

    ####################
    # load config file #
    ####################
    with open(args.config_file, 'r') as f:
        config = dict()
        for l in f.readlines():
            config[l.strip().split(":")[0]] = l.strip().split(":")[1]

    ################
    # load vectors #
    ################
    if not LOCAL:
        vec_path = 'resources'
    else:
        vec_path = 'data/resources'     #Need to set path to vectors here

    if 'bert' not in config['name']:

        vec_name = config['vec_name']
        vec_dim = int(config['vec_dim'])

        vecs = data_utils.load_vectors('{}/{}.vectorsF.npy'.format(vec_path, vec_name),
                                       dim=vec_dim, seed=SEED)

    #############
    # LOAD DATA #
    #############
    # load training data
    vocab_name = '{}/{}.vocabF.pkl'.format(vec_path, vec_name)
    data = datasets.StanceData(args.trn_data, vocab_name, topic_name='{}/{}'.format(vec_path, args.topics_vocab),
                           pad_val=len(vecs) - 1,
                           max_tok_len=int(config.get('max_tok_len', '200')),
                           max_sen_len=int(config.get('max_sen_len', '10')),
                           max_top_len=int(config.get('max_top_len', '5')))
    
    dataloader = data_utils.DataSampler(data,  batch_size=int(config['b']))

    # load dev data if specified
    if args.dev_data is not None:
        dev_data = datasets.StanceData(args.dev_data, vocab_name, topic_name=None,
                                       pad_val=len(vecs) - 1,
                                       max_tok_len=int(config.get('max_tok_len', '200')),
                                       max_sen_len=int(config.get('max_sen_len', '10')),
                                       max_top_len=int(config.get('max_top_len', '5')),
                                       use_tar_in_twe=('use_tar_in_twe' in config))

        dev_dataloader = data_utils.DataSampler(dev_data, batch_size=int(config['b']), shuffle=False)

    else:
        dev_dataloader = None

    # set the optimizer
    if 'optimizer' not in config:
        optim_fn = optim.Adam
    else:
        if config['optimizer'] == 'adamw':
            optim_fn = optim.AdamW
        elif config['optimizer'] == 'sgd':
            optim_fn = optim.SGD
        else:
            print("ERROR with optimizer")
            sys.exit(1)

    lr = float(config.get('lr', '0.001'))
    nl = 3
    adv = False

    # RUN
    print("Using cuda?: {}".format(use_cuda))

    if 'bert' in config['name']:
        batch_args = {'keep_sen': False}
        setup_fn = data_utils.setup_helper_bert_ffnn
        loss_fn = nn.CrossEntropyLoss()

        input_layer = None
        model = bm.JointSeqBERTLayer(nl, use_cuda=use_cuda)

        optimizer = optim.AdamW(model.parameters(), lr=lr)

        num_training_steps = len(data) * int(config['epochs'])
        scheduler = get_linear_schedule_with_warmup(optimizer,
                                                    num_warmup_steps=0.1 * num_training_steps,
                                                    num_training_steps=num_training_steps)

        kwargs = {'model': model, 'embed_model': input_layer, 'dataloader': dataloader,
                  'batching_fn': data_utils.prepare_batch,
                  'batching_kwargs': batch_args, 'name': config['name'],
                  'loss_function': loss_fn,
                  'optimizer': optimizer,
                  'scheduler': scheduler,
                  'setup_fn': setup_fn,
                  'fine_tune': (config.get('fine-tune', 'no') == 'yes')}

        model_handler = model_utils.TorchModelHandler(use_cuda=use_cuda,
                                                      checkpoint_path=config.get('ckp_path', 'data/checkpoints/'),
                                                      result_path=config.get('res_path', 'data/gen-stance/'),
                                                      use_score=args.score_key, save_ckp=args.save_ckp,
                                                      **kwargs)

    elif 'BiCond' in config['name']:
        batch_args = {}
        input_layer = bm.WordEmbedLayer(vecs=vecs, use_cuda=use_cuda)

        setup_fn = data_utils.setup_helper_bicond

        loss_fn = nn.CrossEntropyLoss()

        model = bm.BiCondLSTMModel(int(config['h']), embed_dim=input_layer.dim,
                                   input_dim=(int(config['in_dim']) if 'in_dim' in config['name'] else input_layer.dim),
                                   drop_prob=float(config['dropout']), use_cuda=use_cuda,
                                   num_labels=nl)
        o = optim_fn(model.parameters(), lr=lr)

        bf = data_utils.prepare_batch

        kwargs = {'model': model, 'embed_model': input_layer, 'dataloader': dataloader,
                  'batching_fn': bf,
                  'batching_kwargs': batch_args, 'name': config['name'] + args.name,
                  'loss_function': loss_fn,
                  'optimizer': o,
                  'setup_fn': setup_fn,
                  'blackout_start': int(config['blackout_start']),
                  'blackout_stop': int(config['blackout_stop'])}

        model_handler = model_utils.TorchModelHandler(use_cuda=use_cuda,
                                                      checkpoint_path=config.get('ckp_path', 'data/checkpoints/'),
                                                      result_path=config.get('res_path','data/gen-stance/'),
                                                      **kwargs)

    elif 'BasicAdv' in config['name']:
        batch_args = {}
        input_layer = bm.WordEmbedLayer(vecs=vecs, use_cuda=use_cuda)
        setup_fn = data_utils.setup_helper_adv

        loss_fn = lf.AdvBasicLoss(trans_dim=2*int(config['h']), trans_param=float(config['trans_w']),
                                  num_no_adv=float(config['num_na']),
                                  tot_epochs=int(config['epochs']),
                                  rho_adv=('rho_adv' in config),
                                  gamma=float(config.get('gamma', 10.0)),
                                  semi_sup=('semi_sup' in config),
                                  use_cuda=use_cuda)

        enc_params = {'h': int(config['h']), 'embed_dim': input_layer.dim, 'drop_prob' : float(config['dropout'])}

        model = bm.AdversarialBasic(enc_params=enc_params, enc_type=config['enc'],
                                    stance_dim=int(config['sd']), topic_dim=int(config['td']),
                                    num_labels=nl, num_topics=int(config['num_top']),
                                    drop_prob=float(config['dropout']),
                                    use_cuda=use_cuda)
        
        if 'optimizer' not in config:
            #Adam optimizer
            o_main = optim_fn(chain(model.enc.parameters(),
                                model.recon_layer.parameters(),
                                model.topic_recon_layer.parameters(),
                                model.trans_layer.parameters(),
                                model.stance_classifier.parameters()),
                          lr=lr,
                          weight_decay=float(config.get('l2_main', '0')))
            o_adv = optim_fn(model.topic_classifier.parameters(),
                             lr=lr,
                             weight_decay=float(config.get('l2_adv', '0')))
        elif config['optimizer'] == 'sgd':
            #SGD optimizer
            o_main = optim_fn(chain(model.enc.parameters(),
                                    model.recon_layer.parameters(),
                                    model.topic_recon_layer.parameters(),
                                    model.trans_layer.parameters(),
                                    model.stance_classifier.parameters()),
                              lr=lr,
                              weight_decay=float(config.get('l2_main', '0')),
                              momentum=0.9,
                              nesterov=True)
            o_adv = optim_fn(model.topic_classifier.parameters(),
                             lr=lr,
                             weight_decay=float(config.get('l2_adv', '0')),
                             momentum=0.9,
                             nesterov=True)

        kwargs = {'model': model, 'embed_model': input_layer, 'dataloader': dataloader,
                  'batching_fn': data_utils.prepare_batch_adv,
                  'batching_kwargs': batch_args, 'name': config['name'] + '-{}'.format(config['enc']) + args.name,
                  'loss_function': loss_fn,
                  'optimizer': o_main,
                  'adv_optimizer': o_adv,
                  'setup_fn': setup_fn,
                  'tot_epochs': int(config['epochs']),
                  'initial_lr': lr,
                  'alpha': float(config.get('alpha', 10.0)),
                  'beta': float(config.get('beta', 0.75)),
                  'num_constant_lr': float(config['num_constant_lr']),
                  'batch_size': int(config['b']),
                  'blackout_start': int(config['blackout_start']),
                  'blackout_stop': int(config['blackout_stop'])}

        model_handler = model_utils.AdvTorchModelHandler(use_score=args.score_key, use_cuda=use_cuda,
                                                         checkpoint_path=config.get('ckp_path', 'data/checkpoints/'),
                                                         result_path=config.get('res_path', 'data/gen-stance/'),
                                                         opt_for=config.get('opt', 'score_key'),
                                                         **kwargs)

    if args.mode == 'train':
        # Train model
        start_time = time.time()
        train(model_handler, int(config['epochs']), dev_data=dev_dataloader,
             num_warm=args.num_warm, phases=('phases' in config),is_adv=adv)
        print("[{}] total runtime: {:.2f} minutes".format(config['name'], (time.time() - start_time)/60.))

    elif args.mode == 'eval':
        # Evaluate saved model
        model_handler.load(filename=args.saved_model_file_name)
        eval_helper(model_handler,'DEV',data=dev_dataloader)

usage: ipykernel_launcher.py [-h] --mode MODE [--config_file CONFIG_FILE]
                             [--trn_data TRN_DATA] [--dev_data DEV_DATA]
                             [--name NAME] [-p NUM_WARM]
                             [--topics_vocab TOPICS_VOCAB]
                             [--score_key SCORE_KEY]
                             [--saved_model_file_name SAVED_MODEL_FILE_NAME]
ipykernel_launcher.py: error: the following arguments are required: --mode


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
