In [3]:
import torch
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup
import numpy as np
import gc
from joblib import Parallel, delayed
import random
from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler
from torch.utils.data import Subset
import pickle as p
from tqdm import tqdm
import os
from transformers import BertPreTrainedModel, BertModel
from transformers.modeling_bert import BertOnlyMLMHead
from torch import nn
import sys


from collections import Counter
from nltk.util import ngrams 
from itertools import chain
from nltk.corpus import stopwords
max_length = 200
positive_label = [3]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class LOTClassModel(BertPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
        # MLM head is not trained
        for param in self.cls.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, pred_mode, attention_mask=None, token_type_ids=None, 
                position_ids=None, head_mask=None, inputs_embeds=None):
        bert_outputs = self.bert(input_ids,
                                 attention_mask=attention_mask,
                                 token_type_ids=token_type_ids,
                                 position_ids=position_ids,
                                 head_mask=head_mask,
                                 inputs_embeds=inputs_embeds)
        last_hidden_states = bert_outputs[0]
        if pred_mode == "classification":
            trans_states = self.dense(last_hidden_states)
            trans_states = self.activation(trans_states)
            trans_states = self.dropout(trans_states)
            logits = self.classifier(trans_states)
        elif pred_mode == "mlm":
            logits = self.cls(last_hidden_states)
        else:
            sys.exit("Wrong pred_mode!")
        return logits

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
vocab = tokenizer.get_vocab()
inv_vocab = {k:v for v, k in vocab.items()}

In [None]:
data_vocab = torch.load("category_vocab.pt")
label_data = torch.load("label_name_data.pt")
train_data = torch.load('train.pt')

In [None]:
corpus = open('train.txt', encoding="utf-8")
true_labels = open('train_labels.txt', encoding="utf-8")
docs_labels = [doc.strip() for doc in true_labels.readlines()]
dict_label = {0:[], 1:[], 2:[],3:[]}
list_label = [int(label) for label in docs_labels]
for i, label in enumerate(docs_labels):
    if label in list(dict_label.keys()):
        
        dict_label[int(label)].append(i)
    else:
        dict_label[int(label)] = [i]
docs = [doc.strip() for doc in corpus.readlines()]

In [None]:
def test(model, number = 512, test_batch_size = 32,docs = docs, all = False, true_label = positive_label):
    model.eval()
    true_negative = 0
    true_positive = 0
    false_positive = 0
    false_negative = 0
    correct_pred = 0
    negative = 0
    divider = number
    if all:
        test_list = list(range(len(docs)))
        divider = len(docs)
    else:
        test_list = random.sample(list(range(len(docs))), k = number)
    inputs = torch.stack([encode(docs[i])[0].squeeze() for i in test_list])
    attention_mask = torch.stack([encode(docs[i])[1].squeeze() for i in test_list])
    true_labels = torch.stack([torch.tensor(int(list_label[i] in true_label)) for i in test_list])
    test_dataset = TensorDataset(inputs, attention_mask, true_labels)
    test_dataloader = DataLoader(test_dataset, batch_size = test_batch_size)
    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            inputs_test, attention_test, labels_test = batch
            logits = model(inputs_test.to(device),attention_mask=attention_test.to(device), pred_mode='classification')
            logits_cls = logits[:,0]
            prediction = torch.argmax(logits_cls, -1)
            
            true_positive += (prediction.cpu()*labels_test).sum().item()
            true_negative += ((1-prediction.cpu())*(1-labels_test)).sum().item()
            false_positive += ((prediction.cpu())*(1-labels_test)).sum().item()
            false_negative += ((1-prediction.cpu())*(labels_test)).sum().item()
            correct_pred += (labels_test == prediction.cpu()).sum().item()
            assert (correct_pred == (true_positive + true_negative))
        assert(true_positive+true_negative+false_positive+false_negative == divider)
        accuracy = correct_pred / divider
        
    if (true_positive+false_positive) > 0:
        precision = true_positive / (true_positive+false_positive)
        print('Precision', precision)
    else : 
        precision = None
        print("Precision Undefined")
    if (true_positive+false_negative) > 0 :
        recall = true_positive/(true_positive+false_negative)
        print('Recall', recall)
    else :
        recall = None
        print("Recall Undefined")
    if recall is not None and precision is not None:
        f1_score = 2*(recall*precision)/(recall+precision)
        print("F1_score", f1_score)
    else:
        print("F1_score Undefined")
    print("Accuracy ", accuracy)
    model.train()
    return accuracy
        
    
    
def encode(docs, tokenizer = tokenizer, max_length = max_length):
    encoded_dict = tokenizer.encode_plus(docs, add_special_tokens=True, max_length=max_length, padding='max_length',
                                                    return_attention_mask=True, truncation=True, return_tensors='pt')
    input_ids = encoded_dict['input_ids']
    attention_masks = encoded_dict['attention_mask']
    return input_ids, attention_masks

In [None]:
category_vocab = []
for k in data_vocab.keys():
    category_vocab += list(data_vocab[k])

In [None]:
# Creer la liste de mots positifs
list_pos_keyword = []
for w in category_vocab:
    list_pos_keyword.append(inv_vocab[w])

In [None]:
from tqdm import tqdm
negative_doc=[]
negative_doc_label = []
for k, doc in tqdm(enumerate(docs)):
    tokenized_doc = tokenizer.tokenize(doc)
    new_doc = []
    wordpcs = []
    label_idx = -1 * torch.ones(512, dtype=torch.long)
    for idx, wordpc in enumerate(tokenized_doc):
        wordpcs.append(wordpc[2:] if wordpc.startswith("##") else wordpc)
        if idx >= 512 - 1: # last index will be [SEP] token
            break
        if idx == len(doc) - 1 or not doc[idx+1].startswith("##"):
            word = ''.join(wordpcs)
            if word in list_pos_keyword:
                label_idx[idx] = 0
                break
            new_word = ''.join(wordpcs)
            if new_word != tokenizer.unk_token:
                idx += len(wordpcs)
                new_doc.append(new_word)
            wordpcs = []
    if (label_idx>=0).any():
        continue
    else:
        negative_doc_label.append(list_label[k])
        negative_doc.append(doc)
    

In [None]:
print("Negative pre-set", len(negative_doc))
print("Precision pre-set, ", len([k for k in negative_doc_label if k not in positive_label])/len(negative_doc_label))

In [None]:
inputs_list = []
masks_list = []
for doc in tqdm(negative_doc):
    input_ids, input_mask = encode(doc)
    inputs_list.append(input_ids)
    masks_list.append(input_mask)

In [None]:
input_tensor = torch.stack(inputs_list).squeeze()
mask_tensor = torch.stack(masks_list).squeeze()
label_tensor = torch.stack([torch.tensor(i).unsqueeze(0) for i in negative_doc_label])
dataset = torch.utils.data.TensorDataset(input_tensor,mask_tensor, label_tensor)
dataloader = torch.utils.data.DataLoader(dataset, shuffle = False, batch_size = 8)

In [None]:
def intersect_tensor(t1, t2, device = 'cuda', mask = None):    
    indices = torch.zeros_like(t1, dtype = torch.uint8, device = device)
    for elem in t2:
        indices = indices | (t1 == elem) 
        indices = indices.to(bool)
        
    if mask is not None:
        indices = indices * mask 
    intersection = t1[indices]  
    return intersection, indices


In [None]:
def count_similar_words(batch, category_vocab = category_vocab):
    prediction= batch[0]
    input_mask = batch[1]
    masked_pred = prediction[:input_mask.sum().item(),:]
    _ , words = torch.topk(masked_pred, 8, -1)
    counter = 0
    for word in words.squeeze():
        counter += int(len(np.intersect1d(word.numpy(), category_vocab))>0)
        intersect_time = time() - intersect_time_start
        if counter > 0:
            print('break')
            return False
            break
    return True
            
def occurences(word, vocab = category_vocab):
    return len(np.intersect1d(word.cpu().numpy(), vocab))

In [None]:
# #### Get negative set
# from time import time
# verified_negative = []
# correct_label = 0
# verbose = False
# topk = 10
# vocab = torch.tensor(category_vocab).to(device)
# min_similar_words = 0
# max_category_word = 0
# num_cpus = 8
# with torch.no_grad():
#     for k, batch in tqdm(enumerate(dataloader)):

#         input_ids, input_mask, label_id = batch
#         predictions = model(input_ids.to(device),
#                         pred_mode="mlm",
#                         token_type_ids=None, 
#                         attention_mask=input_mask.to(device))

        
        
#     ########### GPU    
# #         intersection, indices = intersect_tensor(torch.topk(predictions,topk,-1)[1],vocab, 
# #                                                 mask = input_mask.unsqueeze(2).repeat(1,1,topk).to(device))
# #         counts_similar_word_per_word = indices.sum(-1) 
        
# #         counts = (counts_similar_word_per_word>min_similar_words).sum(-1)
# #         end_intersection_time = time()
        
# #         indices_count = torch.where(counts<=0)[0]
# #         for j in indices_count:
# #             i = j.item()

#     ################## CPU ####################""
# #         counts = Parallel(n_jobs=num_cpus)(delayed(count_similar_words)(batch) for batch in zip(predictions.cpu(), input_mask))
    
    
    
#         for i, doc in enumerate(predictions.cpu()):
#             masked_pred = doc[:input_mask[i].sum().item(),:]
#             _ , words = torch.topk(masked_pred, topk, -1)
#             counter = 0
            
# #             counts = Parallel(n_jobs=num_cpus)(delayed(occurences)(word) 
# #                                                    for word in words.squeeze())
# #             counter += len(np.where(np.array(counts)>min_similar_words)[0])
#             for word in words.squeeze():
# #                 counter += int(len(intersect_tensor(word, vocab))>0)
#                 counter += int(len(np.intersect1d(word.cpu().numpy(), category_vocab))>min_similar_words)
#                 if counter > max_category_word:
#                     break

# #             j = i.item()
# #         for i in np.where(np.array(counts))[0]:

#             if counter <= max_category_word:             
#                 verified_negative.append(k*4+i)
#                 if label_id[i] not in positive_label:
#                     correct_label += 1 
             
#         if k%100 == 0:
#             if len(verified_negative)>0:
#                 print('accuracy :', correct_label/len(verified_negative))
#                 print('number of elements retrieved', len(verified_negative))
# #         if verbose:
# #             print('Prediction time', end_prediction_time-start_time) 
# # #             print('bascule cpu', start_loop-end_prediction_time)
# #             print('Intersection time', end_intersection_time-end_prediction_time)
# # #             print('topk time', topk_time-start_loop)
# #             print('counting time', end_counting_time-end_intersection_time)
# #         del predictions
# #         gc.collect()
# #         torch.cuda.empty_cache()
        
        
        
    

In [None]:
# import pickle as p

# p.dump(verified_negative, open('verified_negative_politics.p','wb'))
# p.dump(dataloader, open('dataloader_politics.p','wb'))

In [None]:
######################### TRAINING PART ###############################################

In [None]:
#### DATASET CONSTRUCTION ####

In [None]:
def decode(ids, tokenizer=tokenizer):
    strings = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return strings


In [None]:
# new_verified_negative = p.load(open('verified_negative_politics.p','rb'))
# new_dataloader = p.load(open('dataloader_politics.p','rb'))

In [4]:
mcp_data = torch.load('athlete/mcp_train.pt')
label = torch.LongTensor([1]).repeat(len(mcp_data['labels']))
mcp_data['labels'] = label

In [7]:
positive_dataset = torch.utils.data.TensorDataset(mcp_data['input_ids'], mcp_data['attention_masks'], mcp_data['labels'])

In [13]:
len(positive_dataset[0])

3

In [None]:
#### Statistiques des deux Sets ####
# negative_dataset = Subset(new_dataloader.dataset, new_verified_negative)
# positive_dataset = torch.utils.data.TensorDataset(mcp_data['input_ids'], mcp_data['attention_masks'], mcp_data['labels'])

## TO DO ##


In [None]:
### STATISTICS ON CORPUS
# corpus = open('train.txt', encoding="utf-8")
# stopwords_vocab = stopwords.words('english')
# lines = corpus.readlines()
# words = chain.from_iterable(line.lower().split() for line in lines)
# count = Counter(word for word in words if word not in stopwords_vocab)
# count.most_common(55)

In [6]:
# negative_dataset = Subset(new_dataloader.dataset, new_verified_negative)
positive_dataset = torch.utils.data.TensorDataset(mcp_data['input_ids'], mcp_data['attention_masks'], mcp_data['labels'])
negative_dataset = dataset

NameError: name 'dataset' is not defined

In [None]:
#### Start of the training

In [None]:
# TEST without further refinement


In [None]:
target = np.hstack((np.zeros(int(len(negative_dataset)), dtype=np.int32),
                    np.ones(int(len(positive_dataset)), dtype=np.int32)))

class_sample_count = np.array(
    [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
target = torch.from_numpy(target).long()

In [None]:

data = torch.stack([negative_data[0][:max_length] for negative_data in negative_dataset] + 
            [positive_data[0][:max_length] for positive_data in positive_dataset])

In [None]:
mask = torch.stack([negative_data[1][:max_length] for negative_data in negative_dataset] + 
            [positive_data[1][:max_length] for positive_data in positive_dataset])

In [None]:
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

In [None]:
train_dataset = torch.utils.data.TensorDataset(data,mask, target)

In [None]:
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size = batch_size, sampler=sampler)

In [None]:
### TRAINING LOOPS

In [None]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

def onehot(indexes, N=None, ignore_index=None):
    """
    Creates a one-representation of indexes with N possible entries
    if N is not specified, it will suit the maximum index appearing.
    indexes is a long-tensor of indexes
    ignore_index will be zero in onehot representation
    """
    if N is None:
        N = indexes.max() + 1
    sz = list(indexes.size())
    output = indexes.new().byte().resize_(*sz, N).zero_()
    output.scatter_(-1, indexes.unsqueeze(-1), 1)
    if ignore_index is not None and ignore_index >= 0:
        output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0)
    return output


def _is_long(x):
    if hasattr(x, 'data'):
        x = x.data
    return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor)

def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction='mean',
                  smooth_eps=None, smooth_dist=None, from_logits=True):
    """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567"""
    smooth_eps = smooth_eps or 0

    # ordinary log-liklihood - use cross_entropy from nn
    if _is_long(target) and smooth_eps == 0:
        if from_logits:
            return F.cross_entropy(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
        else:
            return F.nll_loss(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)

    if from_logits:
        # log-softmax of inputs
        lsm = F.log_softmax(inputs, dim=-1)
    else:
        lsm = inputs

    masked_indices = None
    num_classes = inputs.size(-1)

    if _is_long(target) and ignore_index >= 0:
        masked_indices = target.eq(ignore_index)

    if smooth_eps > 0 and smooth_dist is not None:
        if _is_long(target):
            target = onehot(target, num_classes).type_as(inputs)
        if smooth_dist.dim() < target.dim():
            smooth_dist = smooth_dist.unsqueeze(0)
        target.lerp_(smooth_dist, smooth_eps)

    if weight is not None:
        lsm = lsm * weight.unsqueeze(0)

    if _is_long(target):
        eps_sum = smooth_eps / num_classes
        eps_nll = 1. - eps_sum - smooth_eps
        likelihood = lsm.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
        loss = -(eps_nll * likelihood + eps_sum * lsm.sum(-1))
    else:
        loss = -(target * lsm).sum(-1)

    if masked_indices is not None:
        loss.masked_fill_(masked_indices, 0)

    if reduction == 'sum':
        loss = loss.sum()
    elif reduction == 'mean':
        if masked_indices is None:
            loss = loss.mean()
        else:
            loss = loss.sum() / float(loss.size(0) - masked_indices.sum())

    return loss

class CrossEntropyLoss(nn.CrossEntropyLoss):
    """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing"""

    def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None, from_logits=True):
        super(CrossEntropyLoss, self).__init__(weight=weight,
                                               ignore_index=ignore_index, reduction=reduction)
        self.smooth_eps = smooth_eps
        self.smooth_dist = smooth_dist
        self.from_logits = from_logits

    def forward(self, input, target, smooth_dist=None):
        if smooth_dist is None:
            smooth_dist = self.smooth_dist
        return cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index,
                             reduction=self.reduction, smooth_eps=self.smooth_eps,
                             smooth_dist=smooth_dist, from_logits=self.from_logits)

In [None]:
model = LOTClassModel.from_pretrained('bert-base-uncased',
                                           output_attentions=False,
                                           output_hidden_states=False,
                                           num_labels=2).to(device)
accum_steps = 8
model.train()
epochs = 1
smooth_eps = 1e-2
train_loss = CrossEntropyLoss(smooth_eps=smooth_eps)
total_steps = len(train_loader) * epochs / accum_steps
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)
losses_track = []
try:
    for i in range(epochs):
        model.train()
        total_train_loss = 0
        model.zero_grad()
        for j, batch in enumerate(train_loader):
            input_ids = batch[0].to(device)
            input_mask = batch[1].to(device)
            labels = batch[2].to(device)


            ### RANDOM MASKING
            random_masking = random.choices(list(range(max_length-1)),k=batch_size)
            for i, mask_pos in enumerate(random_masking):
                input_ids[i,mask_pos+1] = tokenizer.get_vocab()[tokenizer.mask_token]
            
            ### PREDICTION
            logits = model(input_ids, 
                           pred_mode="classification",
                           token_type_ids=None, 
                           attention_mask=input_mask)
            ### LOSS
            logits_cls = logits[:,0]
            loss = train_loss(logits_cls.view(-1, 2), labels.view(-1)) / accum_steps            
            total_train_loss += loss.item()
            loss.backward()
            if (j+1) % accum_steps == 0:
                # Clip the norm of the gradients to 1.0.
                losses_track.append(loss*accum_steps)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
            if (j+1) % 5*accum_steps == 0 :
                print(loss*accum_steps)
                test(model)
        avg_train_loss = torch.tensor([total_train_loss / len(train_loader) * accum_steps]).to(device)
        print(f"Average training loss: {avg_train_loss.mean().item()}")

except RuntimeError as err:
    print(err)


In [None]:
# import pickle as p
# p.dump(verified_negative, open('verified_negative.p','wb'))
# p.dump(dataloader, open('dataloader.p','wb'))

In [None]:
########## TRAINING #########

In [None]:
test(model=model, all=True)

In [None]:
k = random.sample(range(120000), k=1)[0]
x, m = encode(docs[k])
label = list_label[k]
pred = model(x.to(device), attention_mask = m.to(device), pred_mode='classification')
print(torch.argmax(pred[0,0,:]))
print(label)

In [None]:
227.73*512/120000

### top k = top 8 -> ~0.8 accuracy, 400 datapoints