In [1]:
import torch
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup
import numpy as np
import gc
from joblib import Parallel, delayed
import random
from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler
from torch.utils.data import Subset
import pickle as p
from tqdm import tqdm
import os
from transformers import BertPreTrainedModel, BertModel
from transformers.modeling_bert import BertOnlyMLMHead
from torch import nn
import sys


from collections import Counter
from nltk.util import ngrams 
from itertools import chain
from nltk.corpus import stopwords

positive_label = [3]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
class LOTClassModel(BertPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
        # MLM head is not trained
        for param in self.cls.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, pred_mode, attention_mask=None, token_type_ids=None, 
                position_ids=None, head_mask=None, inputs_embeds=None):
        bert_outputs = self.bert(input_ids,
                                 attention_mask=attention_mask,
                                 token_type_ids=token_type_ids,
                                 position_ids=position_ids,
                                 head_mask=head_mask,
                                 inputs_embeds=inputs_embeds)
        last_hidden_states = bert_outputs[0]
        if pred_mode == "classification":
            trans_states = self.dense(last_hidden_states)
            trans_states = self.activation(trans_states)
            trans_states = self.dropout(trans_states)
            logits = self.classifier(trans_states)
        elif pred_mode == "mlm":
            logits = self.cls(last_hidden_states)
        else:
            sys.exit("Wrong pred_mode!")
        return logits

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
vocab = tokenizer.get_vocab()
inv_vocab = {k:v for v, k in vocab.items()}

In [4]:
data_vocab = torch.load("category_vocab.pt")
label_data = torch.load("label_name_data.pt")
train_data = torch.load('train.pt')

In [5]:
corpus = open('train.txt', encoding="utf-8")
true_labels = open('train_labels.txt', encoding="utf-8")
docs_labels = [doc.strip() for doc in true_labels.readlines()]
dict_label = {0:[], 1:[], 2:[],3:[]}
list_label = [int(label) for label in docs_labels]
for i, label in enumerate(docs_labels):
    if label in list(dict_label.keys()):
        
        dict_label[int(label)].append(i)
    else:
        dict_label[int(label)] = [i]
docs = [doc.strip() for doc in corpus.readlines()]

In [72]:
def test(model, number = 512, test_batch_size = 32,docs = docs, all = False, true_label = positive_label):
    model.eval()
    true_negative = 0
    true_positive = 0
    false_positive = 0
    false_negative = 0
    correct_pred = 0
    negative = 0
    divider = number
    if all:
        test_list = list(range(len(docs)))
        divider = len(docs)
    else:
        test_list = random.sample(list(range(len(docs))), k = number)
    inputs = torch.stack([encode(docs[i])[0].squeeze() for i in test_list])
    attention_mask = torch.stack([encode(docs[i])[1].squeeze() for i in test_list])
    true_labels = torch.stack([torch.tensor(int(list_label[i] in true_label)) for i in test_list])
    test_dataset = TensorDataset(inputs, attention_mask, true_labels)
    test_dataloader = DataLoader(test_dataset, batch_size = test_batch_size)
    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            inputs_test, attention_test, labels_test = batch
            logits = model(inputs_test.to(device),attention_mask=attention_test.to(device), pred_mode='classification')
            logits_cls = logits[:,0]
            prediction = torch.argmax(logits_cls, -1)
            
            true_positive += ((labels_test == prediction.cpu())*torch.from_numpy(np.array([np.where(i in positive_label,1,0) for i in labels_test]))).sum().item()
            true_negative += ((labels_test == prediction.cpu())*torch.from_numpy(np.array([np.where(i in positive_label,0,1) for i in labels_test]))).sum().item()
            false_positive += ((labels_test != prediction.cpu())*torch.from_numpy(np.array([np.where(i in positive_label,1,0) for i in prediction.cpu()]))).sum().item()
            false_negative += ((labels_test != prediction.cpu())*torch.from_numpy(np.array([np.where(i in positive_label,0,1) for i in prediction.cpu()]))).sum().item()
            
            
            correct_pred += (labels_test == prediction.cpu()).sum().item()
            assert (correct_pred == (true_positive + true_negative))
        assert(true_positive+true_negative+false_positive+false_negative == divider)
        accuracy = correct_pred / divider
    if (true_positive+false_positive) > 0:
        precision = true_positive / (true_positive+false_positive)
        print('Precision', precision)
    else : 
        precision = None
        print("Precision Undefined")
    if (true_positive+false_negative) > 0 :
        recall = true_positive/(true_positive+false_negative)
        print('Recall', recall)
    else :
        recall = None
        print("Recall Undefined")
    if recall is not None and precision is not None:
        f1_score = 2*(recall*precison)/(recall+precision)
        print("F1_score", f1_score)
    else:
        print("F1_score Undefined")
    print("Accuracy ", accuracy)
    model.train()
    return accuracy
        
    
    
def encode(docs, tokenizer = tokenizer, max_length = 200):
    encoded_dict = tokenizer.encode_plus(docs, add_special_tokens=True, max_length=max_length, padding='max_length',
                                                    return_attention_mask=True, truncation=True, return_tensors='pt')
    input_ids = encoded_dict['input_ids']
    attention_masks = encoded_dict['attention_mask']
    return input_ids, attention_masks

In [73]:
test(model)

  6%|▋         | 1/16 [00:00<00:05,  2.61it/s]

0
25


 12%|█▎        | 2/16 [00:00<00:05,  2.77it/s]

0
45


 19%|█▉        | 3/16 [00:01<00:04,  2.88it/s]

0
65


 25%|██▌       | 4/16 [00:01<00:04,  2.88it/s]

0
91


 31%|███▏      | 5/16 [00:01<00:03,  2.96it/s]

0
119


 38%|███▊      | 6/16 [00:01<00:03,  3.01it/s]

0
137


 44%|████▍     | 7/16 [00:02<00:03,  2.98it/s]

0
157


 50%|█████     | 8/16 [00:02<00:02,  3.02it/s]

0
180


 56%|█████▋    | 9/16 [00:02<00:02,  3.04it/s]

0
197


 62%|██████▎   | 10/16 [00:03<00:01,  3.01it/s]

0
224


 69%|██████▉   | 11/16 [00:03<00:01,  3.03it/s]

0
244


 75%|███████▌  | 12/16 [00:03<00:01,  3.05it/s]

0
267


 81%|████████▏ | 13/16 [00:04<00:00,  3.02it/s]

0
287


 88%|████████▊ | 14/16 [00:04<00:00,  3.03it/s]

0
310


 94%|█████████▍| 15/16 [00:04<00:00,  3.03it/s]

0
337


100%|██████████| 16/16 [00:05<00:00,  3.02it/s]

0
360
Precision Undefined
Recall 0.0
F1_score Undefined
Accuracy  0.703125





0.703125

In [25]:
a = np.array([1,0,2,3,5,8,4,1])
[torch.where(i.item() in positive_label, 1,0) for i in torch.from_numpy(a)]

TypeError: where() received an invalid combination of arguments - got (bool, int, int), but expected one of:
 * (Tensor condition)
 * (Tensor condition, Tensor input, Tensor other)
      didn't match because some of the arguments have invalid types: ([31;1mbool[0m, [31;1mint[0m, [31;1mint[0m)
 * (Tensor condition, Number self, Tensor other)
      didn't match because some of the arguments have invalid types: ([31;1mbool[0m, [31;1mint[0m, [31;1mint[0m)
 * (Tensor condition, Tensor input, Number other)
      didn't match because some of the arguments have invalid types: ([31;1mbool[0m, [31;1mint[0m, [31;1mint[0m)
 * (Tensor condition, Number self, Number other)
      didn't match because some of the arguments have invalid types: ([31;1mbool[0m, [31;1mint[0m, [31;1mint[0m)


In [9]:
category_vocab = []
for k in data_vocab.keys():
    category_vocab += list(data_vocab[k])

In [10]:
# Creer la liste de mots positifs
list_pos_keyword = []
for w in category_vocab:
    list_pos_keyword.append(inv_vocab[w])

In [15]:
test(model)

  0%|          | 0/16 [00:00<?, ?it/s]


AttributeError: 'generator' object has no attribute 'sum'

In [11]:
from tqdm import tqdm
negative_doc=[]
negative_doc_label = []
for k, doc in tqdm(enumerate(docs)):
    tokenized_doc = tokenizer.tokenize(doc)
    new_doc = []
    wordpcs = []
    label_idx = -1 * torch.ones(512, dtype=torch.long)
    for idx, wordpc in enumerate(tokenized_doc):
        wordpcs.append(wordpc[2:] if wordpc.startswith("##") else wordpc)
        if idx >= 512 - 1: # last index will be [SEP] token
            break
        if idx == len(doc) - 1 or not doc[idx+1].startswith("##"):
            word = ''.join(wordpcs)
            if word in list_pos_keyword:
                label_idx[idx] = 0
                break
            new_word = ''.join(wordpcs)
            if new_word != tokenizer.unk_token:
                idx += len(wordpcs)
                new_doc.append(new_word)
            wordpcs = []
    if (label_idx>=0).any():
        continue
    else:
        negative_doc_label.append(list_label[k])
        negative_doc.append(doc)
    

560000it [06:00, 1554.48it/s]


In [12]:
print("Negative pre-set", len(negative_doc))
print("Precision pre-set, ", len([k for k in negative_doc_label if k not in positive_label])/len(negative_doc_label))

Negative pre-set 277759
Precision pre-set,  0.9844721503173615


In [12]:
model = LOTClassModel.from_pretrained('bert-base-uncased',
                                           output_attentions=False,
                                           output_hidden_states=False,
                                           num_labels=2).to('cuda')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing LOTClassModel: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing LOTClassModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LOTClassModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LOTClassModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias', 'dense.weight', 'dense.bias', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predict

In [120]:
inputs_list = []
masks_list = []
for doc in tqdm(negative_doc):
    input_ids, input_mask = encode(doc)
    inputs_list.append(input_ids)
    masks_list.append(input_mask)

100%|██████████| 66370/66370 [00:38<00:00, 1706.85it/s]


In [121]:
input_tensor = torch.stack(inputs_list).squeeze()
mask_tensor = torch.stack(masks_list).squeeze()
label_tensor = torch.stack([torch.tensor(i).unsqueeze(0) for i in negative_doc_label])
dataset = torch.utils.data.TensorDataset(input_tensor,mask_tensor, label_tensor)
dataloader = torch.utils.data.DataLoader(dataset, shuffle = False, batch_size = 8)

In [122]:
def intersect_tensor(t1, t2, device = 'cuda', mask = None):    
    indices = torch.zeros_like(t1, dtype = torch.uint8, device = device)
    for elem in t2:
        indices = indices | (t1 == elem) 
        indices = indices.to(bool)
        
    if mask is not None:
        indices = indices * mask 
    intersection = t1[indices]  
    return intersection, indices


In [123]:
def count_similar_words(batch, category_vocab = category_vocab):
    prediction= batch[0]
    input_mask = batch[1]
    masked_pred = prediction[:input_mask.sum().item(),:]
    _ , words = torch.topk(masked_pred, 8, -1)
    counter = 0
    for word in words.squeeze():
        counter += int(len(np.intersect1d(word.numpy(), category_vocab))>0)
        intersect_time = time() - intersect_time_start
        if counter > 0:
            print('break')
            return False
            break
    return True
            
def occurences(word, vocab = category_vocab):
    return len(np.intersect1d(word.cpu().numpy(), vocab))

In [124]:
#### Get negative set
from time import time
verified_negative = []
correct_label = 0
verbose = False
topk = 10
vocab = torch.tensor(category_vocab).to(device)
min_similar_words = 0
max_category_word = 0
num_cpus = 8
with torch.no_grad():
    for k, batch in tqdm(enumerate(dataloader)):

        input_ids, input_mask, label_id = batch
        predictions = model(input_ids.to(device),
                        pred_mode="mlm",
                        token_type_ids=None, 
                        attention_mask=input_mask.to(device))

        
        
    ########### GPU    
#         intersection, indices = intersect_tensor(torch.topk(predictions,topk,-1)[1],vocab, 
#                                                 mask = input_mask.unsqueeze(2).repeat(1,1,topk).to(device))
#         counts_similar_word_per_word = indices.sum(-1) 
        
#         counts = (counts_similar_word_per_word>min_similar_words).sum(-1)
#         end_intersection_time = time()
        
#         indices_count = torch.where(counts<=0)[0]
#         for j in indices_count:
#             i = j.item()

    ################## CPU ####################""
#         counts = Parallel(n_jobs=num_cpus)(delayed(count_similar_words)(batch) for batch in zip(predictions.cpu(), input_mask))
    
    
    
        for i, doc in enumerate(predictions.cpu()):
            masked_pred = doc[:input_mask[i].sum().item(),:]
            _ , words = torch.topk(masked_pred, topk, -1)
            counter = 0
            
#             counts = Parallel(n_jobs=num_cpus)(delayed(occurences)(word) 
#                                                    for word in words.squeeze())
#             counter += len(np.where(np.array(counts)>min_similar_words)[0])
            for word in words.squeeze():
#                 counter += int(len(intersect_tensor(word, vocab))>0)
                counter += int(len(np.intersect1d(word.cpu().numpy(), category_vocab))>min_similar_words)
                if counter > max_category_word:
                    break

#             j = i.item()
#         for i in np.where(np.array(counts))[0]:

            if counter <= max_category_word:             
                verified_negative.append(k*4+i)
                if label_id[i] not in positive_label:
                    correct_label += 1 
             
        if k%100 == 0:
            if len(verified_negative)>0:
                print('accuracy :', correct_label/len(verified_negative))
                print('number of elements retrieved', len(verified_negative))
#         if verbose:
#             print('Prediction time', end_prediction_time-start_time) 
# #             print('bascule cpu', start_loop-end_prediction_time)
#             print('Intersection time', end_intersection_time-end_prediction_time)
# #             print('topk time', topk_time-start_loop)
#             print('counting time', end_counting_time-end_intersection_time)
#         del predictions
#         gc.collect()
#         torch.cuda.empty_cache()
        
        
        
    

102it [00:17,  6.00it/s]

accuracy : 0.8870967741935484
number of elements retrieved 62


202it [00:34,  5.88it/s]

accuracy : 0.8888888888888888
number of elements retrieved 117


302it [00:51,  5.88it/s]

accuracy : 0.8881987577639752
number of elements retrieved 161


402it [01:08,  5.87it/s]

accuracy : 0.8979591836734694
number of elements retrieved 196


502it [01:25,  5.86it/s]

accuracy : 0.9051383399209486
number of elements retrieved 253


602it [01:42,  5.95it/s]

accuracy : 0.9042904290429042
number of elements retrieved 303


702it [01:59,  5.87it/s]

accuracy : 0.8976608187134503
number of elements retrieved 342


802it [02:16,  5.92it/s]

accuracy : 0.8981233243967829
number of elements retrieved 373


902it [02:33,  5.89it/s]

accuracy : 0.894484412470024
number of elements retrieved 417


1002it [02:50,  5.82it/s]

accuracy : 0.886021505376344
number of elements retrieved 465


1102it [03:07,  5.86it/s]

accuracy : 0.8843813387423936
number of elements retrieved 493


1202it [03:24,  5.91it/s]

accuracy : 0.8886756238003839
number of elements retrieved 521


1302it [03:41,  5.92it/s]

accuracy : 0.8914185639229422
number of elements retrieved 571


1402it [03:58,  5.77it/s]

accuracy : 0.8903436988543372
number of elements retrieved 611


1502it [04:15,  5.57it/s]

accuracy : 0.8893939393939394
number of elements retrieved 660


1602it [04:33,  5.76it/s]

accuracy : 0.8866571018651362
number of elements retrieved 697


1702it [04:50,  5.68it/s]

accuracy : 0.8763297872340425
number of elements retrieved 752


1802it [05:08,  5.77it/s]

accuracy : 0.8735919899874843
number of elements retrieved 799


1902it [05:25,  5.78it/s]

accuracy : 0.8656361474435196
number of elements retrieved 841


2002it [05:43,  5.92it/s]

accuracy : 0.8693693693693694
number of elements retrieved 888


2102it [06:00,  5.57it/s]

accuracy : 0.8751334044823906
number of elements retrieved 937


2202it [06:18,  5.82it/s]

accuracy : 0.8752556237218814
number of elements retrieved 978


2302it [06:35,  5.91it/s]

accuracy : 0.8725490196078431
number of elements retrieved 1020


2402it [06:52,  5.54it/s]

accuracy : 0.8726591760299626
number of elements retrieved 1068


2502it [07:10,  5.59it/s]

accuracy : 0.8710550045085663
number of elements retrieved 1109


2602it [07:27,  5.69it/s]

accuracy : 0.8729184925503944
number of elements retrieved 1141


2702it [07:44,  5.41it/s]

accuracy : 0.8761506276150628
number of elements retrieved 1195


2802it [08:02,  5.74it/s]

accuracy : 0.8769230769230769
number of elements retrieved 1235


2902it [08:19,  5.69it/s]

accuracy : 0.8778565799842396
number of elements retrieved 1269


3002it [08:36,  5.86it/s]

accuracy : 0.878234398782344
number of elements retrieved 1314


3102it [08:54,  5.49it/s]

accuracy : 0.8780308596620132
number of elements retrieved 1361


3202it [09:11,  5.73it/s]

accuracy : 0.8767908309455588
number of elements retrieved 1396


3302it [09:29,  5.90it/s]

accuracy : 0.8782365290412876
number of elements retrieved 1429


3402it [09:46,  5.94it/s]

accuracy : 0.8778371161548731
number of elements retrieved 1498


3502it [10:03,  5.92it/s]

accuracy : 0.8764705882352941
number of elements retrieved 1530


3602it [10:20,  5.93it/s]

accuracy : 0.8779402415766052
number of elements retrieved 1573


3702it [10:37,  5.89it/s]

accuracy : 0.8788819875776398
number of elements retrieved 1610


3802it [10:54,  5.87it/s]

accuracy : 0.876357056694813
number of elements retrieved 1658


3902it [11:11,  5.93it/s]

accuracy : 0.8763250883392226
number of elements retrieved 1698


4002it [11:28,  5.88it/s]

accuracy : 0.8752146536920435
number of elements retrieved 1747


4102it [11:45,  5.85it/s]

accuracy : 0.8748603351955307
number of elements retrieved 1790


4202it [12:02,  5.94it/s]

accuracy : 0.8767567567567568
number of elements retrieved 1850


4302it [12:19,  5.83it/s]

accuracy : 0.877906976744186
number of elements retrieved 1892


4402it [12:36,  5.91it/s]

accuracy : 0.8767967145790554
number of elements retrieved 1948


4502it [12:53,  5.85it/s]

accuracy : 0.8761952692501258
number of elements retrieved 1987


4602it [13:10,  5.98it/s]

accuracy : 0.8743196437407225
number of elements retrieved 2021


4702it [13:27,  5.80it/s]

accuracy : 0.8748187530207829
number of elements retrieved 2069


4802it [13:44,  5.77it/s]

accuracy : 0.8751182592242195
number of elements retrieved 2114


4902it [14:01,  5.91it/s]

accuracy : 0.8730964467005076
number of elements retrieved 2167


5002it [14:18,  5.91it/s]

accuracy : 0.8712291760468257
number of elements retrieved 2221


5102it [14:35,  5.92it/s]

accuracy : 0.870509977827051
number of elements retrieved 2255


5202it [14:52,  5.93it/s]

accuracy : 0.8709677419354839
number of elements retrieved 2294


5302it [15:09,  5.87it/s]

accuracy : 0.8725573491928632
number of elements retrieved 2354


5402it [15:26,  5.90it/s]

accuracy : 0.8736447039199333
number of elements retrieved 2398


5502it [15:43,  5.91it/s]

accuracy : 0.8765281173594132
number of elements retrieved 2454


5602it [16:00,  5.85it/s]

accuracy : 0.8777064955894146
number of elements retrieved 2494


5702it [16:17,  5.89it/s]

accuracy : 0.8800157356412274
number of elements retrieved 2542


5802it [16:34,  5.82it/s]

accuracy : 0.8803716608594657
number of elements retrieved 2583


5902it [16:51,  5.85it/s]

accuracy : 0.8806883365200765
number of elements retrieved 2615


6002it [17:08,  5.89it/s]

accuracy : 0.8807097017742544
number of elements retrieved 2649


6102it [17:25,  5.97it/s]

accuracy : 0.8823311061618412
number of elements retrieved 2694


6202it [17:42,  5.88it/s]

accuracy : 0.8808309037900874
number of elements retrieved 2744


6302it [17:59,  5.89it/s]

accuracy : 0.8813803019410497
number of elements retrieved 2782


6402it [18:16,  5.88it/s]

accuracy : 0.8813257305773343
number of elements retrieved 2806


6502it [18:33,  5.94it/s]

accuracy : 0.8823322795925536
number of elements retrieved 2847


6602it [18:50,  5.92it/s]

accuracy : 0.8826177285318559
number of elements retrieved 2888


6702it [19:07,  5.92it/s]

accuracy : 0.883617747440273
number of elements retrieved 2930


6802it [19:23,  5.90it/s]

accuracy : 0.8836032388663968
number of elements retrieved 2964


6902it [19:40,  5.90it/s]

accuracy : 0.8844370860927152
number of elements retrieved 3020


7002it [19:57,  5.88it/s]

accuracy : 0.8846028113762667
number of elements retrieved 3059


7102it [20:14,  5.96it/s]

accuracy : 0.8848758465011287
number of elements retrieved 3101


7202it [20:31,  5.89it/s]

accuracy : 0.8842607313195548
number of elements retrieved 3145


7302it [20:48,  5.91it/s]

accuracy : 0.8842897460018815
number of elements retrieved 3189


7402it [21:05,  5.88it/s]

accuracy : 0.8841028819336846
number of elements retrieved 3227


7502it [21:22,  5.93it/s]

accuracy : 0.8850398528510116
number of elements retrieved 3262


7602it [21:39,  5.88it/s]

accuracy : 0.8848374354299605
number of elements retrieved 3291


7702it [21:56,  5.92it/s]

accuracy : 0.8852606351108449
number of elements retrieved 3338


7802it [22:13,  5.95it/s]

accuracy : 0.8854939187184812
number of elements retrieved 3371


7902it [22:30,  5.87it/s]

accuracy : 0.8845029239766082
number of elements retrieved 3420


8002it [22:47,  5.89it/s]

accuracy : 0.8849608809040858
number of elements retrieved 3451


8102it [23:04,  5.35it/s]

accuracy : 0.8844827586206897
number of elements retrieved 3480


8202it [23:21,  5.93it/s]

accuracy : 0.8844950213371267
number of elements retrieved 3515


8297it [23:37,  5.85it/s]


In [125]:
import pickle as p

p.dump(verified_negative, open('verified_negative_politics.p','wb'))
p.dump(dataloader, open('dataloader_politics.p','wb'))

In [27]:
######################### TRAINING PART ###############################################

In [None]:
#### DATASET CONSTRUCTION ####

In [41]:
def decode(ids, tokenizer=tokenizer):
    strings = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return strings


In [126]:
new_verified_negative = p.load(open('verified_negative_politics.p','rb'))
new_dataloader = p.load(open('dataloader_politics.p','rb'))

In [127]:
mcp_data = torch.load('mcp_train.pt')
label = torch.LongTensor([1]).repeat(len(mcp_data['labels']))
mcp_data['labels'] = label

In [128]:
#### Statistiques des deux Sets ####
negative_dataset = Subset(new_dataloader.dataset, new_verified_negative)
positive_dataset = torch.utils.data.TensorDataset(mcp_data['input_ids'], mcp_data['attention_masks'], mcp_data['labels'])

## TO DO ##


In [93]:
### STATISTICS ON CORPUS
corpus = open('train.txt', encoding="utf-8")
stopwords_vocab = stopwords.words('english')
lines = corpus.readlines()
words = chain.from_iterable(line.lower().split() for line in lines)
count = Counter(word for word in words if word not in stopwords_vocab)
count.most_common(55)

In [30]:
negative_dataset = Subset(new_dataloader.dataset, new_verified_negative)
positive_dataset = torch.utils.data.TensorDataset(mcp_data['input_ids'], mcp_data['attention_masks'], mcp_data['labels'])


In [129]:
#### Start of the training

In [130]:
target = np.hstack((np.zeros(int(len(negative_dataset)), dtype=np.int32),
                    np.ones(int(len(positive_dataset)), dtype=np.int32)))

class_sample_count = np.array(
    [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
target = torch.from_numpy(target).long()

In [131]:
data = torch.stack([negative_data[0][:200] for negative_data in negative_dataset] + 
            [positive_data[0][:200] for positive_data in positive_dataset])

In [132]:
mask = torch.stack([negative_data[1][:200] for negative_data in negative_dataset] + 
            [positive_data[1][:200] for positive_data in positive_dataset])

In [133]:
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

In [134]:
train_dataset = torch.utils.data.TensorDataset(data,mask, target)

In [135]:
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size = batch_size, sampler=sampler)

In [136]:
### TRAINING LOOPS

In [137]:

model = LOTClassModel.from_pretrained('bert-base-uncased',
                                           output_attentions=False,
                                           output_hidden_states=False,
                                           num_labels=2).to(device)
accum_steps = 8
model.train()
epochs = 1
train_loss = nn.CrossEntropyLoss()
total_steps = len(train_loader) * epochs / accum_steps
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)
losses_track = []
try:
    for i in range(epochs):
        model.train()
        total_train_loss = 0
        model.zero_grad()
        for j, batch in enumerate(train_loader):
            input_ids = batch[0].to(device)
            input_mask = batch[1].to(device)
            labels = batch[2].to(device)


            ### RANDOM MASKING
            random_masking = random.choices(list(range(199)),k=batch_size)
            for i, mask_pos in enumerate(random_masking):
                input_ids[i,mask_pos+1] = tokenizer.get_vocab()[tokenizer.mask_token]
            
            ### PREDICTION
            logits = model(input_ids, 
                           pred_mode="classification",
                           token_type_ids=None, 
                           attention_mask=input_mask)
            ### LOSS
            logits_cls = logits[:,0]
            loss = train_loss(logits_cls.view(-1, 2), labels.view(-1)) / accum_steps            
            total_train_loss += loss.item()
            loss.backward()
            if (j+1) % accum_steps == 0:
                # Clip the norm of the gradients to 1.0.
                print(loss*accum_steps)
                losses_track.append(loss*accum_steps)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
            if j % 10 == 0 :
                test(model)
        avg_train_loss = torch.tensor([total_train_loss / len(train_loader) * accum_steps]).to(device)
        print(f"Average training loss: {avg_train_loss.mean().item()}")

except RuntimeError as err:
    print(err)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing LOTClassModel: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing LOTClassModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LOTClassModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LOTClassModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias', 'dense.weight', 'dense.bias', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predict

Accuracy  0.640625
tensor(0.7087, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.03it/s]


Accuracy  0.703125
tensor(0.6882, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.04it/s]


Accuracy  0.619140625
tensor(0.6369, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.13it/s]


Accuracy  0.6484375
tensor(0.6561, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.6796, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.02it/s]


Accuracy  0.6328125
tensor(0.6825, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.658203125
tensor(0.6421, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.12it/s]


Accuracy  0.669921875
tensor(0.6525, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.07it/s]


Accuracy  0.666015625
tensor(0.6039, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.6513, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.744140625
tensor(0.5866, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.01it/s]


Accuracy  0.75390625
tensor(0.5831, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.10it/s]


Accuracy  0.7890625
tensor(0.5620, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.779296875
tensor(0.5717, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.5439, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.80859375
tensor(0.5009, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.841796875
tensor(0.5708, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.09it/s]


Accuracy  0.822265625
tensor(0.5623, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.09it/s]


Accuracy  0.8359375
tensor(0.5491, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.5320, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.10it/s]


Accuracy  0.798828125
tensor(0.5086, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.10it/s]


Accuracy  0.81640625
tensor(0.5521, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.09it/s]


Accuracy  0.828125
tensor(0.4803, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.10it/s]


Accuracy  0.83984375
tensor(0.4790, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.4496, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  2.99it/s]


Accuracy  0.830078125
tensor(0.4735, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.05it/s]


Accuracy  0.841796875
tensor(0.4460, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.81640625
tensor(0.4971, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.861328125
tensor(0.4522, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.4925, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.12it/s]


Accuracy  0.859375


IndexError: index 15 is out of bounds for dimension 0 with size 15

In [None]:
# import pickle as p
# p.dump(verified_negative, open('verified_negative.p','wb'))
# p.dump(dataloader, open('dataloader.p','wb'))

In [None]:
########## TRAINING #########

In [46]:
test(model=model, all=True)

100%|██████████| 3750/3750 [20:06<00:00,  3.11it/s]

Accuracy  0.8578916666666667





0.8578916666666667

In [58]:
k = random.sample(range(120000), k=1)[0]
x, m = encode(docs[k])
label = list_label[k]
pred = model(x.to(device), attention_mask = m.to(device), pred_mode='classification')
print(torch.argmax(pred[0,0,:]))
print(label)

tensor(0, device='cuda:0')
3


In [61]:
227.73*512/120000

0.971648

tensor(0, device='cuda:0')
2


### top k = top 8 -> ~0.8 accuracy, 400 datapoints