In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# BERT


# Arguments and Hyperparameters

In [2]:
import torch

data_folder = "/content/drive/My Drive/NLP"

# Set the parameters
models_save_dir = data_folder + "/models_save"
snapshot_dir = data_folder + "/snapshot/results.txt"
is_cuda = torch.cuda.is_available()

args = {'train':True, 
        'test':False, 
        'cuda' : is_cuda,
        'dataset' : 'news_group',
        'class_balance':False, # False default,
        'debug_mode' : False,
        'init_lr':0.001, 
        'epochs':4, # default is 50
        'batch_size':64,
        'patience':5, 
        'models_save_dir': models_save_dir, 
        'model_path':'model.pt', 
        'results_path' : snapshot_dir, # where to dump model config and epoch stats
        'model':'TextCNN',
        'embedding_dim' : 768,
        'hidden_dims':100, 
        'num_layers':1, 
        'dropout':0.05, 
        'weight_decay':5e-06,
        'filter_num':100, 
        'filters':[3, 4, 5], 
        'num_class':20, 
        'emb_dims':196,
        'tuning_metric':'loss', 
        'num_workers':4, 
        'objective':'cross_entropy', 
        'get_rationales':True, 
        'continuity_lambda': 0.0, 
        'selection_lambda': 0.001,
        # gumbel 
        'gumbel_decay' : 1e-5,
        'gumbel_temprature' : 1.0,  # Start temprature for gumbel softmax
        'model_form':'cnn',
        'use_as_tagger':False, 
        'hidden_dim' : 100,
        'nb_texts': 1000  # Put a number >= 9051 for using the whole News_Group database 
        }

print("cuda on", is_cuda)
#'gumbel_temprature':1, 'gumbel_decay':1e-5,'tag_lambda':.5

cuda on True


In [3]:
!pip install transformers
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased').cuda()

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d8/b2/57495b5309f09fa501866e225c84532d1fd89536ea62406b2181933fb418/transformers-4.5.1-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 19.1MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 49.9MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 50.8MB/s 
Installing collected packages: sacremoses, tokenizers, transformers
Successfully installed sacremoses-0.0.45 tokenizers-0.10.2 transformers-4.5.1


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




# Load data set

In [4]:
# import abstract
from abc import ABCMeta, abstractmethod, abstractproperty
import torch.utils.data as data
# import News_Group
import gzip
import re
import tqdm
from sklearn.datasets import fetch_20newsgroups
import random
random.seed(0)



## Utils 

In [5]:
# util abstract data set

TRAIN_ONLY_ERR_MSG = "{} only supported for train dataset! Instead saw {}"

class AbstractDataset(data.Dataset):
    __metaclass__ = ABCMeta

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self,index):
        sample = self.dataset[index]
        return sample

# util get_indices_tensor

import torch

def get_indices_tensor(text_arr, word_to_indx, max_length):
    '''
    -text_arr: array of word tokens
    -word_to_indx: mapping of word -> index
    -max length of return tokens

    returns tensor of same size as text with each words corresponding
    index
    '''
    nil_indx = 0
    text_indx = [ word_to_indx[x] if x in word_to_indx else nil_indx for x in text_arr][:max_length]
    if len(text_indx) < max_length:
        text_indx.extend( [nil_indx for _ in range(max_length - len(text_indx))])

    x =  torch.LongTensor([text_indx])

    return x

# util factory

NO_DATASET_ERR = "Dataset {} not in DATASET_REGISTRY! Available datasets are {}"

DATASET_REGISTRY = {}


def RegisterDataset(dataset_name):
    """Registers a dataset."""

    def decorator(f):
        DATASET_REGISTRY[dataset_name] = f
        return f

    return decorator


# Depending on arg, build dataset
def get_dataset(args, tokenizer, bert_model, truncate_train=False):
    if args["dataset"] not in DATASET_REGISTRY:
        raise Exception(
            NO_DATASET_ERR.format(args["dataset"], DATASET_REGISTRY.keys()))

    if args["dataset"] in DATASET_REGISTRY:
        train = DATASET_REGISTRY[args["dataset"]](args, tokenizer, bert_model, 'train')
        dev = DATASET_REGISTRY[args["dataset"]](args, tokenizer, bert_model, 'dev')
        test = DATASET_REGISTRY[args["dataset"]](args, tokenizer, bert_model, 'test')

    return train, dev, test

## Define dataset loader

In [6]:
SMALL_TRAIN_SIZE = 800
CATEGORIES = ['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

def preprocess_data(data):
    processed_data = []
    for indx, sample in enumerate(data['data']):
        text, label = sample, data['target'][indx]
        label_name = data['target_names'][label]
        text = re.sub('\W+', ' ', text).lower().strip().replace('_',' ')
        processed_data.append( (text, label, label_name) )
    return processed_data


@RegisterDataset('news_group')
class NewsGroupDataset(AbstractDataset): # MODIF

    def __init__(self, args, tokenizer, bert_model, name, max_length=80):
        self.args = args
        self.args["num_class"] = 20
        self.name = name
        self.dataset = []
        self.max_length = max_length
        self.class_balance = {}
        self.tokenizer = tokenizer
        self.bert_model = bert_model

        if name in ['train', 'dev']:
            data = preprocess_data(fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'), categories=CATEGORIES))
            random.shuffle(data)
            num_train = int(len(data)*.8)
            if name == 'train':
                data = data[:num_train]
            else:
                data = data[num_train:]
        else:
            data = preprocess_data(fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'), categories=CATEGORIES))
        
        for indx, _sample in tqdm.tqdm(enumerate(data)):
          if(indx in range(args['nb_texts'])):                     
            sample = self.processLine(_sample)

            if not sample['y'] in self.class_balance:
                self.class_balance[ sample['y'] ] = 0
            self.class_balance[ sample['y'] ] += 1
            self.dataset.append(sample)

        print ("Class balance", self.class_balance)

        if args["class_balance"]:
            raise NotImplementedError("NewsGroup dataset doesn't support balanced sampling")
        if args["objective"] == 'mse':
            raise NotImplementedError("News Group does not support Regression objective")

    ## Convert one line from beer dataset to {Text, Tensor, Labels}
    def processLine(self, row):
        text, label, label_name = row
        text = " ".join(text.split()[:self.max_length])
        sample = {'text':text, 'y':label, 'y_name': label_name}
        return sample

## Get dataset

In [7]:
# Loading the dataset
train_data, dev_data, test_data = get_dataset(args, tokenizer, bert_model)

# Printing 1 datapoints
for datapoint in train_data[:1]:
    print(datapoint)

Downloading 20news dataset. This may take a few minutes.
Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)
9051it [00:00, 410087.77it/s]


Class balance {14: 48, 4: 44, 8: 37, 13: 61, 19: 35, 12: 50, 3: 63, 10: 59, 0: 51, 11: 47, 1: 48, 7: 55, 15: 51, 2: 54, 18: 43, 5: 44, 17: 50, 9: 58, 16: 47, 6: 55}


2263it [00:00, 150083.17it/s]


Class balance {14: 55, 7: 45, 15: 55, 17: 48, 8: 39, 5: 58, 0: 39, 13: 51, 3: 70, 18: 48, 11: 56, 19: 29, 9: 47, 4: 55, 10: 58, 2: 60, 12: 51, 6: 45, 1: 56, 16: 35}


7532it [00:00, 383117.64it/s]

Class balance {7: 53, 5: 60, 0: 37, 17: 50, 19: 33, 13: 45, 15: 56, 1: 40, 2: 52, 8: 61, 4: 62, 6: 39, 16: 47, 14: 55, 3: 54, 11: 43, 9: 58, 10: 69, 18: 38, 12: 48}
{'text': 'thanks again one final question the name gehrels wasn t known to me before this thread came up but the may issue of scientific american has an article about the inconstant cosmos with a photo of neil gehrels project scientist for nasa s compton gamma ray observatory same person mark brader softquad inc toronto information we want information utzoo sq msb msb sq com the prisoner', 'y': 14, 'y_name': 'sci.space'}





# Model

In [8]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.nn.functional as F
import pdb

# imports for util
import numpy as np
import torch.utils.data as data

## Utils

In [9]:
# utils.learn

def get_train_loader(train_data, args):
    if args['class_balance']:
        sampler = data.sampler.WeightedRandomSampler(
                weights=train_data.weights,
                num_samples=len(train_data),
                replacement=True)
        train_loader = data.DataLoader(
                train_data,
                num_workers=args['num_workers'],
                sampler=sampler,
                batch_size=args['batch_size'])
    else:
        train_loader = data.DataLoader(
            train_data,
            batch_size=args['batch_size'],
            shuffle=True,
            num_workers=args['num_workers'],
            drop_last=False)

    return train_loader

def get_rationales(mask, text):
    if mask is None:
        return text
    masked_text = []
    for i, t in enumerate(text):
        sample_mask = list(mask.data[i])
        original_words = t.split()
        words = [ w if m  > .5 else "_" for w,m in zip(original_words, sample_mask) ]
        masked_sample = " ".join(words)
        masked_text.append(masked_sample)
    return masked_text


def get_dev_loader(dev_data, args):
    dev_loader = data.DataLoader(
        dev_data,
        batch_size=args['batch_size'],
        shuffle=False,
        num_workers=args['num_workers'],
        drop_last=False)
    return dev_loader

def get_optimizer(models, args):
    '''
        -models: List of models (such as Generator, classif, memory, etc)
        -args: experiment level config

        returns: torch optimizer over models
    '''
    params = []
    for model in models:
        params.extend([param for param in model.parameters() if param.requires_grad])
    return torch.optim.Adam(params, lr=args['lr'],  weight_decay=args['weight_decay'])


def get_x_indx(batch, args, eval_model):
    x_indx = autograd.Variable(batch['x'], volatile=eval_model)
    return x_indx


def get_hard_mask(z, return_ind=False):
    '''
        -z: torch Tensor where each element probablity of element
        being selected
        -args: experiment level config

        returns: A torch variable that is binary mask of z >= .5
    '''
    max_z, ind = torch.max(z, dim=-1)
    if return_ind:
        del z
        return ind
    masked = torch.ge(z, max_z.unsqueeze(-1)).float()
    del z
    return masked


def get_gen_path(model_path):
    '''
        -model_path: path of encoder model

        returns: path of generator
    '''
    return '{}.gen'.format(model_path)

def one_hot(label, num_class):
    vec = torch.zeros( (1, num_class) )
    vec[0][label] = 1
    return vec


def gumbel_softmax(input, temperature, cuda):
    noise = torch.rand(input.size())
    noise.add_(1e-9).log_().neg_()
    noise.add_(1e-9).log_().neg_()
    noise = autograd.Variable(noise)
    if cuda:
        noise = noise.cuda()
    x = (input + noise) / temperature
    x = F.softmax(x.view(-1,  x.size()[-1]), dim=-1)
    return x.view_as(input)


def bert_embeddings(text,tokenizer,bert_model,args):
    embeddings = torch.zeros(1,args['emb_dims'],args['embedding_dim'])
    tokens = torch.tensor([tokenizer.encode(text,add_special_tokens=True)]).cuda()
    source = bert_model(tokens)[0].cpu()
    embeddings[:,:source.shape[1],:] = source
    return(embeddings)

## Define CNN

In [10]:
class CNN(nn.Module):

    def __init__(self, args, max_pool_over_time=False):
        super(CNN, self).__init__()

        self.args = args
        self.layers = []
        for layer in range(args['num_layers']):
            convs = []
            for filt in args['filters']:
                in_channels =  args["embedding_dim"] if layer == 0 else args.filter_num * len( args["filters"])
                kernel_size = filt
                new_conv = nn.Conv1d(in_channels=in_channels, out_channels=args["filter_num"], kernel_size=kernel_size)
                self.add_module( 'layer_'+str(layer)+'_conv_'+str(filt), new_conv)
                convs.append(new_conv)

            self.layers.append(convs)

        self.max_pool = max_pool_over_time



    def _conv(self, x):
        layer_activ = x
        for layer in self.layers:
            next_activ = []
            for conv in layer:
                left_pad = conv.kernel_size[0] - 1
                pad_tensor_size = [d for d in layer_activ.size()]
                pad_tensor_size[2] = left_pad
                left_pad_tensor =autograd.Variable( torch.zeros( pad_tensor_size ) )
                if self.args['cuda']:
                    left_pad_tensor = left_pad_tensor.cuda()
                padded_activ = torch.cat( (left_pad_tensor, layer_activ), dim=2)
                next_activ.append( conv(padded_activ) )

            # concat across channels
            layer_activ = F.relu( torch.cat(next_activ, 1) )

        return layer_activ


    def _pool(self, relu):
        pool = F.max_pool1d(relu, relu.size(2)).squeeze(-1)
        return pool


    def forward(self, x):
        activ = self._conv(x)
        if self.max_pool:
            activ =  self._pool(activ)
        return activ

## Define generator


In [11]:
'''
    The generator selects a rationale z from a document x that should be sufficient
    for the encoder to make it's prediction.

    Several forms of Generator are supported. Namely CNN with arbitary number of layers, and @taolei's FastKNN
'''
class Generator(nn.Module):

    def __init__(self, args, tokenizer, bert_model):
        super(Generator, self).__init__()
        self.args = args
        if args['model_form'] == 'cnn':
            self.cnn = CNN(args, max_pool_over_time = False)

        self.z_dim = 2

        self.hidden = nn.Linear((len(args["filters"])* args["filter_num"]), self.z_dim)
        self.dropout = nn.Dropout(args["dropout"])


    def  __z_forward(self, activ):
        '''
            Returns prob of each token being selected
        '''
        activ = activ.transpose(1,2)
        logits = self.hidden(activ)
        probs = gumbel_softmax(logits, self.args["gumbel_temprature"], self.args["cuda"])
        z = probs[:,:,1]
        return z


    def forward(self, text):
        '''
            Given input x_indx of dim (batch, length), return z (batch, length) such that z
            can act as element-wise mask on x
        '''
        if self.args["model_form"] == 'cnn':
            x_list = []
            for t in text:
                x = bert_embeddings(t,tokenizer,bert_model,self.args)
                if self.args["cuda"]:
                    x = x.cuda()
                x = torch.transpose(x, 1, 2) # Switch X to (Batch, Embed, Length)
                x_list.append(x)
            activ = self.cnn(torch.cat(x_list))
        else:
            raise NotImplementedError("Model form {} not yet supported for generator!".format(args["model_form"]))

        z = self.__z_forward(F.relu(activ))
        mask = self.sample(z)
        return mask, z


    def sample(self, z):
        '''
            Get mask from probablites at each token. Use gumbel
            softmax at train time, hard mask at test time
        '''
        mask = z
        if self.training:
            mask = z
        else:
            ## pointwise set <.5 to 0 >=.5 to 1
            mask = get_hard_mask(z)
        return mask


    def loss(self, mask):
        '''
            Compute the generator specific costs, i.e selection cost, continuity cost, and global vocab cost
        '''
        selection_cost = torch.mean( torch.sum(mask, dim=1) )
        l_padded_mask =  torch.cat( [mask[:,0].unsqueeze(1), mask] , dim=1)
        r_padded_mask =  torch.cat( [mask, mask[:,-1].unsqueeze(1)] , dim=1)
        continuity_cost = torch.mean( torch.sum( torch.abs( l_padded_mask - r_padded_mask ) , dim=1) )
        return selection_cost, continuity_cost

## Define encodor

In [12]:
class Encoder(nn.Module):

    def __init__(self, args, tokenizer, bert_model):
        super(Encoder, self).__init__()
        ### Encoder
        self.args = args
        hidden_dim = args["embedding_dim"]
        self.hidden_dim = hidden_dim
        self.embedding_fc = nn.Linear( hidden_dim, hidden_dim ) 
        self.embedding_bn = nn.BatchNorm1d( hidden_dim)

        if args["model_form"] == 'cnn':
            self.cnn = CNN(args, max_pool_over_time=(not args["use_as_tagger"]))
            self.fc = nn.Linear( len(args["filters"])*args["filter_num"],  args["hidden_dim"])
        else:
            raise NotImplementedError("Model form {} not yet supported for encoder!".format(args["model_form"]))

        self.dropout = nn.Dropout(args["dropout"])
        self.hidden = nn.Linear(args["hidden_dim"], args["num_class"])

    def forward(self, text, mask=None):
        '''
            x_indx:  batch of word indices
            mask: Mask to apply over embeddings for tao ratioanles
        '''

        x_list = []
        for t in text:
            x = bert_embeddings(t,tokenizer,bert_model,self.args)
            x_list.append(x)
        x = torch.cat(x_list)
        
        if self.args["cuda"]:
                x = x.cuda()
        if not mask is None:
            x = x * mask.unsqueeze(-1)
        
        x = F.relu( self.embedding_fc(x))
        x = self.dropout(x)

        
        if self.args["model_form"] == 'cnn':
            x = torch.transpose(x, 1, 2) # Switch X to (Batch, Embed, Length)
            hidden = self.cnn(x)
            hidden = F.relu( self.fc(hidden) )
        else:
            raise Exception("Model form {} not yet supported for encoder!".format(args["model_form"]))

        hidden = self.dropout(hidden)
        logit = self.hidden(hidden)
        return logit, hidden


## Get joint model

In [13]:
gen   = Generator(args,tokenizer,bert_model)
model = Encoder(args,tokenizer,bert_model)

# Train

In [14]:
import os
import sys
import torch
import torch.autograd as autograd
import torch.nn.functional as F

import tqdm
import numpy as np
import pdb
import sklearn.metrics

## Utils

In [15]:
# utils learn already define above

# utils generic (with corrected bug)
def tensor_to_numpy(tensor):
    return tensor.cpu().detach().numpy()

# utils metrics
import sklearn.metrics

def collate_epoch_stat(stat_dict, epoch_details, mode, args):
    '''
        Update stat_dict with details from epoch_details and create
        log statement

        - stat_dict: a dictionary of statistics lists to update
        - epoch_details: list of statistics for a given epoch
        - mode: train, dev or test
        - args: model run configuration

        returns:
        -stat_dict: updated stat_dict with epoch details
        -log_statement: log statement sumarizing new epoch

    '''
    log_statement_details = ''
    for metric in epoch_details:
        loss = epoch_details[metric]
        stat_dict['{}_{}'.format(mode, metric)].append(loss)

        log_statement_details += ' -{}: {}'.format(metric, loss)

    log_statement = '\n {} - {}\n--'.format(
        args["objective"], log_statement_details )

    return stat_dict, log_statement

def get_metrics(preds, golds, args):
    metrics = {}

    if args["objective"]  in ['cross_entropy', 'margin']:
        metrics['accuracy'] = sklearn.metrics.accuracy_score(y_true=golds, y_pred=preds)
        metrics['confusion_matrix'] = sklearn.metrics.confusion_matrix(y_true=golds,y_pred=preds)
        metrics['precision'] = sklearn.metrics.precision_score(y_true=golds, y_pred=preds, average="weighted")
        metrics['recall'] = sklearn.metrics.recall_score(y_true=golds,y_pred=preds, average="weighted")
        metrics['f1'] = sklearn.metrics.f1_score(y_true=golds,y_pred=preds, average="weighted")

        metrics['mse'] = "NA"

    elif args["objective"] == 'mse':
        metrics['mse'] = sklearn.metrics.mean_squared_error(y_true=golds, y_pred=preds)
        metrics['confusion_matrix'] = "NA"
        metrics['accuracy'] = "NA"
        metrics['precision'] = "NA"
        metrics['recall'] = "NA"
        metrics['f1'] = 'NA'

    return metrics

def init_metrics_dictionary(modes):
    '''
    Create dictionary with empty array for each metric in each mode
    '''
    epoch_stats = {}
    metrics = [
        'loss', 'obj_loss', 'k_selection_loss',
        'k_continuity_loss', 'accuracy', 'precision', 'recall', 'f1', 'confusion_matrix', 'mse']
    for metric in metrics:
        for mode in modes:
            key = "{}_{}".format(mode, metric)
            epoch_stats[key] = []

    return epoch_stats

## Define train and co

In [16]:
def train_model(train_data, dev_data, model, gen, args):
    '''
    Train model and tune on dev set. If model doesn't improve dev performance within args.patience
    epochs, then halve the learning rate, restore the model to best and continue training.

    At the end of training, the function will restore the model to best dev version.

    returns epoch_stats: a dictionary of epoch level metrics for train and test
    returns model : best model from this call to train
    '''

    if args["cuda"]:
        model = model.cuda()
        gen = gen.cuda()

    args["lr"] = args["init_lr"]
    optimizer = get_optimizer([model, gen], args)

    num_epoch_sans_improvement = 0
    epoch_stats = init_metrics_dictionary(modes=['train', 'dev'])
    step = 0
    tuning_key = "dev_{}".format(args["tuning_metric"])
    best_epoch_func = min if tuning_key == 'loss' else max

    train_loader = get_train_loader(train_data, args)
    dev_loader = get_dev_loader(dev_data, args)


    for epoch in range(1, args["epochs"] + 1):

        print("-------------\nEpoch {}:\n".format(epoch))
        for mode, dataset, loader in [('Train', train_data, train_loader), ('Dev', dev_data, dev_loader)]:
            train_model = mode == 'Train'
            print('{}'.format(mode))
            key_prefix = mode.lower()
            epoch_details, step, _, _, _, _ = run_epoch(
                data_loader=loader,
                train_model=train_model,
                model=model,
                gen=gen,
                optimizer=optimizer,
                step=step,
                args=args)

            epoch_stats, log_statement = collate_epoch_stat(epoch_stats, epoch_details, key_prefix, args)

            # Log  performance
            print(log_statement)


        # Save model if beats best dev
        best_func = min if args["tuning_metric"] == 'loss' else max
        if best_func(epoch_stats[tuning_key]) == epoch_stats[tuning_key][-1]:
            num_epoch_sans_improvement = 0
            if not os.path.isdir(args["models_save_dir"]):
                os.makedirs(args["models_save_dir"])
            # Subtract one because epoch is 1-indexed and arr is 0-indexed
            epoch_stats['best_epoch'] = epoch - 1
            torch.save(model, args["models_save_dir"] + "/encod_model.pt")
            torch.save(gen, args["models_save_dir"] + "/gen_model.pt")
        else:
            num_epoch_sans_improvement += 1

        if not train_model:
            print('---- Best Dev {} is {:.4f} at epoch {}'.format(
                args["tuning_metric"],
                epoch_stats[tuning_key][epoch_stats['best_epoch']],
                epoch_stats['best_epoch'] + 1))

        if num_epoch_sans_improvement >= args["patience"]:
            print("Reducing learning rate")
            num_epoch_sans_improvement = 0
            model.cpu()
            gen.cpu()
            model = torch.load(args["models_save_dir"] + "/encod_model.pt")
            gen = torch.load(args["models_save_dir"] + "/gen_model.pt")

            if args["cuda"]:
                model = model.cuda()
                gen   = gen.cuda()
            args["lr"] *= .5
            optimizer = get_optimizer([model, gen], args)

    # Restore model to best dev performance
    if os.path.exists(args["model_path"]):
        model.cpu()
        model = torch.load(args["models_save_dir"] + "/encod_model.pt")
        gen.cpu()
        gen = torch.load(args["models_save_dir"] + "/gen_model.pt")

    return epoch_stats, model, gen


def run_epoch(data_loader, train_model, model, gen, optimizer, step, args):
    '''
    Train model for one pass of train data, and return loss, acccuracy
    '''
    eval_model = not train_model
    data_iter = data_loader.__iter__()

    losses = []
    obj_losses = []
    k_selection_losses = []
    k_continuity_losses = []
    preds = []
    golds = []
    losses = []
    texts = []
    rationales = []

    if train_model:
        model.train()
        gen.train()
    else:
        gen.eval()
        model.eval()

    num_batches_per_epoch = len(data_iter)
    if train_model:
        num_batches_per_epoch = min(len(data_iter), 10000)

    for _ in tqdm.tqdm(range(num_batches_per_epoch)):
        batch = data_iter.next()
        if train_model:
            step += 1
            if  step % 100 == 0 or args["debug_mode"]:
                args["gumbel_temprature"] = max( np.exp((step+1) *-1* args["gumbel_decay"]), .05)

        text = batch['text']
        y = autograd.Variable(batch['y'], volatile=eval_model)

        if args["cuda"]:
            y = y.cuda()

        if train_model:
            optimizer.zero_grad()

        if args["get_rationales"]:
            mask, z = gen(text)
        else:
            mask = None

        logit, _ = model(text, mask=mask)

        if args["use_as_tagger"]:
            logit = logit.view(-1, 2)
            y = y.view(-1)

        loss = get_loss(logit, y, args)
        obj_loss = loss

        if args["get_rationales"]:
            selection_cost, continuity_cost = gen.loss(mask)

            loss += args["selection_lambda"] * selection_cost
            loss += args["continuity_lambda"] * continuity_cost

        if train_model:
            loss.backward()
            optimizer.step()

        if args["get_rationales"]:
            k_selection_losses.append(tensor_to_numpy(selection_cost))
            k_continuity_losses.append(tensor_to_numpy(continuity_cost))

        obj_losses.append(tensor_to_numpy(obj_loss))
        losses.append(tensor_to_numpy(loss) )
        batch_softmax = F.softmax(logit, dim=-1).cpu()
        preds.extend(torch.max(batch_softmax, 1)[1].view(y.size()).data.numpy())

        texts.extend(text)
        rationales.extend(get_rationales(mask, text))

        if args["use_as_tagger"]:
            golds.extend(batch['y'].view(-1).numpy())
        else:
            golds.extend(batch['y'].numpy())



    epoch_metrics = get_metrics(preds, golds, args)

    epoch_stat = {
        'loss' : np.mean(losses),
        'obj_loss': np.mean(obj_losses)
    }

    for metric_k in epoch_metrics.keys():
        epoch_stat[metric_k] = epoch_metrics[metric_k]

    if args["get_rationales"]:
        epoch_stat['k_selection_loss'] = np.mean(k_selection_losses)
        epoch_stat['k_continuity_loss'] = np.mean(k_continuity_losses)

    return epoch_stat, step, losses, preds, golds, rationales


def get_loss(logit,y, args):
    if args["objective"] == 'cross_entropy':
        if args["use_as_tagger"]:
            loss = F.cross_entropy(logit, y, reduce=False)
            neg_loss = torch.sum(loss * (y == 0).float()) / torch.sum(y == 0).float()
            pos_loss = torch.sum(loss * (y == 1).float()) / torch.sum(y == 1).float()
            loss = args["tag_lambda"] * neg_loss + (1 - args["tag_lambda"]) * pos_loss
        else:
            loss = F.cross_entropy(logit, y)
    elif args["objective"] == 'mse':
        loss = F.mse_loss(logit, y.float())
    else:
        raise Exception(
            "Objective {} not supported!".format(args["objective"]))
    return loss

## Do training

In [17]:
import pickle

epoch_stats, model, gen = train_model(train_data, dev_data, model, gen, args)
args["epoch_stats"] = epoch_stats
save_path_train_results = data_folder + "/train_results"
print("Save train/dev results to", save_path_train_results)
# save
pickle.dump(args, open(save_path_train_results + "/train_result",'wb') ) 

-------------
Epoch 1:

Train


  cpuset_checked))
100%|██████████| 16/16 [01:06<00:00,  4.16s/it]


 cross_entropy -  -loss: 3.0565123558044434 -obj_loss: 3.0565123558044434 -accuracy: 0.067 -confusion_matrix: [[10  0  5  0  0  4  0  0  0  0  0  0  0 32  0  0  0  0  0  0]
 [ 3  0 15  0  1  5  0  0  0  0  0  0  0 24  0  0  0  0  0  0]
 [ 2  0 16  0  0  4  0  0  0  0  0  0  0 32  0  0  0  0  0  0]
 [ 6  0 13  0  0 13  0  0  0  0  0  0  0 31  0  0  0  0  0  0]
 [ 4  0 11  0  0  5  0  0  0  0  0  0  0 24  0  0  0  0  0  0]
 [ 3  0 15  0  0  3  0  0  0  0  0  0  0 23  0  0  0  0  0  0]
 [ 7  0  9  0  1  8  0  0  0  0  1  0  0 29  0  0  0  0  0  0]
 [ 6  0  9  0  1  4  0  0  0  0  0  0  0 35  0  0  0  0  0  0]
 [ 1  0  4  0  1  5  0  0  0  0  0  0  0 26  0  0  0  0  0  0]
 [ 2  0 10  0  1  3  0  0  0  1  0  0  0 41  0  0  0  0  0  0]
 [ 5  0 12  0  0 10  0  0  0  1  0  0  0 31  0  0  0  0  0  0]
 [ 3  0 11  0  0  4  0  0  0  0  0  0  0 29  0  0  0  0  0  0]
 [ 2  0  5  0  0  4  0  0  0  0  0  0  0 39  0  0  0  0  0  0]
 [ 7  0 12  0  1  4  0  0  0  0  0  0  0 37  0  0  0  0  0  0]
 [ 1  0


  _warn_prf(average, modifier, msg_start, len(result))
100%|██████████| 16/16 [00:31<00:00,  1.98s/it]



 cross_entropy -  -loss: 2.9924750328063965 -obj_loss: 2.9924750328063965 -accuracy: 0.051 -confusion_matrix: [[ 0  0  0  0  0  0  0  0  0  0  0  0  0 39  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 56  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 60  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 70  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 55  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 58  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  1  0  0  0 44  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 45  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 39  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 47  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 58  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 56  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 51  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0 51  0  0  0  0  0  0]
 [ 0  0

100%|██████████| 16/16 [01:05<00:00,  4.10s/it]


 cross_entropy -  -loss: 3.0049099922180176 -obj_loss: 3.0049099922180176 -accuracy: 0.081 -confusion_matrix: [[ 1  0  8  0  0  0  0  0  0  3  3  0  0 36  0  0  0  0  0  0]
 [ 0  0 30  0  0  0  2  0  0  1  0  0  0 15  0  0  0  0  0  0]
 [ 0  0 29  0  0  0  0  0  0  5  0  0  0 20  0  0  0  0  0  0]
 [ 0  0 36  0  0  0  2  0  0  4  2  0  0 19  0  0  0  0  0  0]
 [ 0  0 29  0  0  0  2  0  0  1  1  0  0 11  0  0  0  0  0  0]
 [ 0  0 35  0  0  0  1  0  0  2  2  0  0  4  0  0  0  0  0  0]
 [ 0  0 30  0  0  0  5  0  0  5  0  0  0 15  0  0  0  0  0  0]
 [ 0  0 10  0  0  0  0  0  0  6  7  0  0 32  0  0  0  0  0  0]
 [ 0  0  6  0  0  0  0  0  0  4  2  0  0 25  0  0  0  0  0  0]
 [ 0  0 13  0  0  0  2  0  0 15  2  0  0 26  0  0  0  0  0  0]
 [ 0  0 15  0  0  0  0  0  0 15  2  0  0 27  0  0  0  0  0  0]
 [ 0  0 19  0  0  0  2  0  0  5  5  0  0 16  0  0  0  0  0  0]
 [ 0  0 19  0  0  0  5  0  0  3  3  0  0 20  0  0  0  0  0  0]
 [ 1  0 12  0  0  0  2  0  0 13  4  0  0 29  0  0  0  0  0  0]
 [ 0  0


100%|██████████| 16/16 [00:31<00:00,  1.96s/it]



 cross_entropy -  -loss: 2.9463820457458496 -obj_loss: 2.9463820457458496 -accuracy: 0.07 -confusion_matrix: [[ 0  0  0  0  0  0  0  0  0  0 38  0  0  1  0  0  0  0  0  0]
 [ 0  0  7  0  0  0  9  0  0  1 39  0  0  0  0  0  0  0  0  0]
 [ 0  0  7  0  0  0  7  0  0  0 46  0  0  0  0  0  0  0  0  0]
 [ 0  0  1  0  0  0  6  0  0  1 62  0  0  0  0  0  0  0  0  0]
 [ 0  0  2  0  0  0  1  0  0  0 52  0  0  0  0  0  0  0  0  0]
 [ 0  0 11  0  0  0  3  0  0  0 44  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  6  0  0  2 37  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  1 44  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0 38  0  0  1  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  6 41  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  7 51  0  0  0  0  0  0  0  0  0]
 [ 0  0  2  0  0  1  2  0  0  2 48  0  0  1  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  3  0  0  0 48  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  1 50  0  0  0  0  0  0  0  0  0]
 [ 0  0 

100%|██████████| 16/16 [01:05<00:00,  4.10s/it]


 cross_entropy -  -loss: 2.664330244064331 -obj_loss: 2.664330244064331 -accuracy: 0.182 -confusion_matrix: [[ 0  0  1  3  0  0  2  0  0  4  0  0  0 29  0  3  1  8  0  0]
 [ 0  1 12  5  2 13  8  0  0  2  0  0  0  4  0  0  0  1  0  0]
 [ 0  0 14  7  0 14 13  0  0  4  0  0  0  2  0  0  0  0  0  0]
 [ 0  0 12 21  4  7 13  0  0  4  0  0  0  2  0  0  0  0  0  0]
 [ 0  0  3 19  0  4 10  0  0  3  0  0  0  3  0  0  0  2  0  0]
 [ 0  1 17  4  1 18  2  0  0  0  0  0  0  1  0  0  0  0  0  0]
 [ 0  0  2  9  2  0 31  0  0  6  0  0  0  5  0  0  0  0  0  0]
 [ 0  0  0  7  0  0  3  3  0 11  0  0  0 28  0  0  0  3  0  0]
 [ 0  0  0  3  1  1  3  1  0  9  0  0  0 16  0  0  0  3  0  0]
 [ 0  0  0  1  0  0  3  0  0 43  0  0  0  9  0  1  0  1  0  0]
 [ 0  0  1  0  0  0  0  0  0 49  0  0  0  9  0  0  0  0  0  0]
 [ 0  0  0 10  2  7  3  0  0  2  0  0  0 14  0  1  0  8  0  0]
 [ 0  1  2 15  2  0  6  1  0  8  0  0  0 15  0  0  0  0  0  0]
 [ 0  0  1  5  1  0  2  1  0  9  0  0  0 37  0  1  0  4  0  0]
 [ 0  0  


100%|██████████| 16/16 [00:31<00:00,  1.95s/it]



 cross_entropy -  -loss: 2.4353792667388916 -obj_loss: 2.4353792667388916 -accuracy: 0.204 -confusion_matrix: [[10  0  0  0  3  0  0  0  0  2  3  6  0  6  0  9  0  0  0  0]
 [ 5  0  0  0  2 44  4  0  0  0  0  1  0  0  0  0  0  0  0  0]
 [ 3  2  0  0  4 51  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 2  0  0  0 18 47  2  0  0  0  0  1  0  0  0  0  0  0  0  0]
 [ 5  1  0  0 13 31  4  0  0  0  1  0  0  0  0  0  0  0  0  0]
 [ 4  0  0  0  0 53  1  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 3  0  0  0  3  6 32  0  0  1  0  0  0  0  0  0  0  0  0  0]
 [ 6  1  0  0  8  0 16  1  0  3  1  0  0  9  0  0  0  0  0  0]
 [ 3  0  0  0 14  0  6  1  0  4  0  0  0 10  0  1  0  0  0  0]
 [ 4  0  0  0  1  0  1  0  1 38  1  1  0  0  0  0  0  0  0  0]
 [ 4  0  0  0  1  2  1  0  0 46  3  0  0  1  0  0  0  0  0  0]
 [ 5  0  0  0 12 19  1  0  0  2  1  6  0  8  0  2  0  0  0  0]
 [ 4  0  0  0 27 12  6  0  0  2  0  0  0  0  0  0  0  0  0  0]
 [ 2  0  0  0 13  1  0  0  0  6  0  1  0 26  0  2  0  0  0  0]
 [ 3  1

100%|██████████| 16/16 [01:05<00:00,  4.09s/it]


 cross_entropy -  -loss: 2.2426042556762695 -obj_loss: 2.2426042556762695 -accuracy: 0.282 -confusion_matrix: [[ 2  0  1  0  0  0  0  5  0  2  3  2  0  5  0 17  3  9  2  0]
 [ 0  8  9  4  1 13  6  1  0  0  0  0  1  3  1  1  0  0  0  0]
 [ 0  8 12 10  1 11  3  3  0  0  1  2  1  0  1  0  0  0  1  0]
 [ 0  5  9 19 10  4 10  3  1  0  0  1  0  1  0  0  0  0  0  0]
 [ 0  1  7 12  8  2  6  4  0  0  0  2  0  1  0  0  0  1  0  0]
 [ 0  4 14  1  1 20  0  0  0  0  0  1  0  2  0  1  0  0  0  0]
 [ 0  0  1  2  3  0 38  6  0  1  1  1  0  1  0  1  0  0  0  0]
 [ 2  0  0  2  2  0  7 25  0  1  1  0  0 12  0  1  2  0  0  0]
 [ 3  0  0  0  0  0  2 11  1  2  1  1  0  8  1  4  0  2  1  0]
 [ 0  0  0  0  0  0  0  4  0 32 15  1  2  2  0  1  0  1  0  0]
 [ 0  0  0  0  0  0  0  1  0 23 33  0  0  0  0  1  0  1  0  0]
 [ 3  3  3  4  4  4  0  4  0  0  1  8  1  1  0  5  0  4  2  0]
 [ 0  1  3  5  3  1  6 15  0  4  1  2  3  4  2  0  0  0  0  0]
 [ 4  0  0  1  1  1  0  4  0  0  0  4  0 30  0  7  4  3  2  0]
 [ 4  0


100%|██████████| 16/16 [00:31<00:00,  1.96s/it]



 cross_entropy -  -loss: 2.118983507156372 -obj_loss: 2.118983507156372 -accuracy: 0.292 -confusion_matrix: [[12  0  0  0  0  0  0  0  0  0  1  6  0  4  2  0  0  1 13  0]
 [ 2 18  0  2  1 28  1  0  0  0  0  2  0  0  2  0  0  0  0  0]
 [ 4 16  0  2  7 31  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  9  0 15 22 18  0  1  0  0  0  2  0  0  3  0  0  0  0  0]
 [ 1 11  0 15 11 10  2  1  0  0  0  1  0  0  2  0  0  0  1  0]
 [ 0 10  0  1  1 44  0  0  0  0  0  0  0  0  2  0  0  0  0  0]
 [ 1  5  0  4  1  1 29  1  0  0  0  0  0  0  3  0  0  0  0  0]
 [ 1  2  0  2  3  0 10 19  0  0  0  1  0  1  5  0  0  0  1  0]
 [ 4  0  0  0  7  0  5 10  0  0  0  0  0  2 10  0  0  0  1  0]
 [ 2  1  0  0  0  0  0  0  1  7 29  1  0  0  6  0  0  0  0  0]
 [ 3  2  0  0  0  0  0  0  0  2 46  0  0  2  3  0  0  0  0  0]
 [ 0  1  0  1  6 12  0  3  0  0  0 16  0  6  3  0  0  0  8  0]
 [ 2 12  0 12 11  1  3  6  0  0  0  0  0  0  4  0  0  0  0  0]
 [ 3  2  0  0  0  0  0  6  0  0  1  1  0 27  8  0  0  0  3  0]
 [ 3  5  

# Test

## Define test

In [18]:
def test_model(test_data, model, gen, args):
    '''
    Run model on test data, and return loss, accuracy.
    '''
    if args["cuda"]:
        model = model.cuda()
        gen = gen.cuda()

    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=args["batch_size"],
        shuffle=False,
        num_workers=args["num_workers"],
        drop_last=False)

    test_stats = init_metrics_dictionary(modes=['test'])

    mode = 'Test'
    train_model = False
    key_prefix = mode.lower()
    print("-------------\nTest")
    epoch_details, _, losses, preds, golds, rationales = run_epoch(
        data_loader=test_loader,
        train_model=train_model,
        model=model,
        gen=gen,
        optimizer=None,
        step=None,
        args=args)

    test_stats, log_statement = collate_epoch_stat(test_stats, epoch_details, 'test', args)
    test_stats['losses'] = losses
    test_stats['preds'] = preds
    test_stats['golds'] = golds
    test_stats['rationales'] = rationales

    print(log_statement)

    return test_stats

# Do testing

In [19]:
## bon enfait ça fait presque comme la cellule juste au dessus mais tant pis
import pickle 

test_stats = test_model(test_data, model, gen, args)
args["test_stats"] = test_stats
args["train_data"] = train_data
args["test_data"] = test_data

save_path_test_results = data_folder + "/test_results/results"
print("Save test results to", save_path_test_results)
pickle.dump(args, open(save_path_test_results,'wb') )

-------------
Test


  cpuset_checked))
100%|██████████| 16/16 [00:31<00:00,  1.96s/it]
  _warn_prf(average, modifier, msg_start, len(result))



 cross_entropy -  -loss: 2.1484084129333496 -obj_loss: 2.1484084129333496 -accuracy: 0.276 -confusion_matrix: [[ 9  0  0  0  0  1  0  2  0  0  1  6  0  4  5  0  0  0  9  0]
 [ 0 19  0  0  2 14  1  1  0  0  0  2  0  1  0  0  0  0  0  0]
 [ 0  9  0  0  2 38  0  1  0  0  0  1  0  0  0  0  0  1  0  0]
 [ 0 11  0 11 10 19  0  1  0  0  0  0  0  1  1  0  0  0  0  0]
 [ 1 10  0 18 15 10  3  1  0  0  0  0  0  0  4  0  0  0  0  0]
 [ 1 10  0  1  1 46  0  0  0  0  0  0  0  0  1  0  0  0  0  0]
 [ 2  3  0 13  1  0 17  1  0  0  0  0  0  0  2  0  0  0  0  0]
 [ 4  1  0  1  4  0 14 25  0  0  0  0  0  1  2  0  0  0  1  0]
 [ 4  3  0  1  4  2  4 21  0  0  1  1  0  4 13  0  0  0  3  0]
 [ 1  4  0  0  0  1  0  0  0 12 32  0  0  0  7  0  0  0  1  0]
 [ 1  2  0  0  0  0  1  0  0  5 54  0  0  0  6  0  0  0  0  0]
 [ 3  3  0  0  5  4  0  1  0  0  0 10  0  2  5  0  0  0 10  0]
 [ 0  8  0  4 15  2  9  6  0  1  0  0  0  1  2  0  0  0  0  0]
 [ 2  1  0  0  4  0  2  6  0  0  0  0  0 22  7  0  0  0  1  0]
 [ 6  5

In [20]:
test_stats

{'golds': [7,
  5,
  0,
  17,
  19,
  13,
  15,
  15,
  5,
  1,
  2,
  5,
  17,
  8,
  0,
  2,
  4,
  1,
  6,
  16,
  1,
  6,
  17,
  14,
  3,
  13,
  11,
  7,
  7,
  3,
  5,
  5,
  4,
  3,
  14,
  1,
  9,
  4,
  6,
  1,
  17,
  2,
  8,
  1,
  11,
  1,
  14,
  3,
  11,
  0,
  8,
  8,
  13,
  9,
  1,
  2,
  10,
  17,
  16,
  14,
  8,
  10,
  17,
  19,
  2,
  18,
  18,
  13,
  0,
  9,
  8,
  18,
  10,
  19,
  3,
  7,
  18,
  7,
  18,
  9,
  6,
  18,
  17,
  4,
  12,
  10,
  16,
  15,
  8,
  4,
  14,
  9,
  16,
  15,
  16,
  0,
  9,
  3,
  3,
  16,
  2,
  10,
  14,
  15,
  3,
  16,
  10,
  4,
  14,
  12,
  8,
  3,
  17,
  3,
  8,
  14,
  9,
  5,
  9,
  17,
  12,
  4,
  4,
  5,
  9,
  13,
  18,
  8,
  3,
  1,
  16,
  11,
  6,
  13,
  10,
  5,
  5,
  1,
  3,
  12,
  10,
  14,
  7,
  7,
  10,
  5,
  10,
  12,
  0,
  13,
  14,
  4,
  15,
  4,
  6,
  14,
  18,
  2,
  10,
  4,
  11,
  17,
  9,
  5,
  3,
  8,
  2,
  16,
  6,
  7,
  1,
  7,
  7,
  8,
  4,
  12,
  10,
  0,
  10,
  18,
  4,
  9,
  