In [1]:
import torch
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup
import numpy as np
import gc
from joblib import Parallel, delayed
import random
from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler
from torch.utils.data import Subset
import pickle as p
from tqdm import tqdm
import os

positive_label = [0,2,3]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
vocab = tokenizer.get_vocab()
inv_vocab = {k:v for v, k in vocab.items()}

In [3]:
data_vocab = torch.load("category_vocab.pt")
label_data = torch.load("label_name_data.pt")
train_data = torch.load('train.pt')

In [6]:
corpus = open('train.txt', encoding="utf-8")
true_labels = open('train_labels.txt', encoding="utf-8")
docs_labels = [doc.strip() for doc in true_labels.readlines()]
dict_label = {0:[], 1:[], 2:[],3:[]}
list_label = [int(label) for label in docs_labels]
for i, label in enumerate(docs_labels):
    dict_label[int(label)].append(i)
docs = [doc.strip() for doc in corpus.readlines()]

In [59]:
def test(number = 512, test_batch_size = 32,docs = docs, model=model, all = False, true_label = 1):
    model.eval()
    correct_pred = 0
    if all:
        test_list = list(range(len(docs)))
    else:
        test_list = random.sample(list(range(len(docs))), k = number)
    inputs = torch.stack([encode(docs[i])[0].squeeze() for i in test_list])
    attention_mask = torch.stack([encode(docs[i])[1].squeeze() for i in test_list])
    true_labels = torch.stack([torch.tensor(int(list_label[i]==true_label)) for i in test_list])
    test_dataset = TensorDataset(inputs, attention_mask, true_labels)
    test_dataloader = DataLoader(test_dataset, batch_size = test_batch_size)
    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            inputs_test, attention_test, labels_test = batch
            logits = model(inputs_test.to(device),attention_mask=attention_test.to(device), pred_mode='classification')
            logits_cls = logits[:,0]
            prediction = torch.argmax(logits_cls, -1)
            correct_pred += (labels_test == prediction.cpu()).sum().item()
        accuracy = correct_pred / number
    print("Accuracy ", accuracy)
    model.train()
    return accuracy
        

In [7]:
category_vocab = []
for k in data_vocab.keys():
    category_vocab += list(data_vocab[k])

In [8]:
# Creer la liste de mots positifs
list_pos_keyword = []
for w in category_vocab:
    list_pos_keyword.append(inv_vocab[w])

In [11]:
from tqdm import tqdm
negative_doc=[]
negative_doc_label = []
for k, doc in tqdm(enumerate(docs)):
    tokenized_doc = tokenizer.tokenize(doc)
    new_doc = []
    wordpcs = []
    label_idx = -1 * torch.ones(512, dtype=torch.long)
    for idx, wordpc in enumerate(tokenized_doc):
        wordpcs.append(wordpc[2:] if wordpc.startswith("##") else wordpc)
        if idx >= 512 - 1: # last index will be [SEP] token
            break
        if idx == len(doc) - 1 or not doc[idx+1].startswith("##"):
            word = ''.join(wordpcs)
            if word in list_pos_keyword:
                label_idx[idx] = 0
                break
                # replace label names that are not in tokenizer's vocabulary with the [MASK] token
    #             if word not in vocab:
    #                 wordpcs = [tokenizer.mask_token]
            new_word = ''.join(wordpcs)
            if new_word != tokenizer.unk_token:
                idx += len(wordpcs)
                new_doc.append(new_word)
            wordpcs = []
    if (label_idx>=0).any():
        continue
    else:
        negative_doc_label.append(list_label[k])
        negative_doc.append(doc)
    

120000it [01:04, 1848.19it/s]


In [12]:
print("Negative pre-set", len(negative_doc))
print("Accuracy pre-set, ", len([k for k in negative_doc_label if k not in positive_label])/len(negative_doc_label))

Negative pre-set 29892
Accuracy pre-set,  0.4458718051652616


In [8]:
from transformers import BertPreTrainedModel, BertModel
from transformers.modeling_bert import BertOnlyMLMHead
from torch import nn
import sys



def encode(docs, tokenizer = tokenizer):
    encoded_dict = tokenizer.encode_plus(docs, add_special_tokens=True, max_length=200, padding='max_length',
                                                    return_attention_mask=True, truncation=True, return_tensors='pt')
    input_ids = encoded_dict['input_ids']
    attention_masks = encoded_dict['attention_mask']
    return input_ids, attention_masks


class LOTClassModel(BertPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
        # MLM head is not trained
        for param in self.cls.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, pred_mode, attention_mask=None, token_type_ids=None, 
                position_ids=None, head_mask=None, inputs_embeds=None):
        bert_outputs = self.bert(input_ids,
                                 attention_mask=attention_mask,
                                 token_type_ids=token_type_ids,
                                 position_ids=position_ids,
                                 head_mask=head_mask,
                                 inputs_embeds=inputs_embeds)
        last_hidden_states = bert_outputs[0]
        if pred_mode == "classification":
            trans_states = self.dense(last_hidden_states)
            trans_states = self.activation(trans_states)
            trans_states = self.dropout(trans_states)
            logits = self.classifier(trans_states)
        elif pred_mode == "mlm":
            logits = self.cls(last_hidden_states)
        else:
            sys.exit("Wrong pred_mode!")
        return logits
    
    
model = LOTClassModel.from_pretrained('bert-base-uncased',
                                           output_attentions=False,
                                           output_hidden_states=False,
                                           num_labels=2).to('cuda')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing LOTClassModel: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing LOTClassModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LOTClassModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LOTClassModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias', 'dense.weight', 'dense.bias', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predict

In [14]:
inputs_list = []
masks_list = []
for doc in tqdm(negative_doc):
    input_ids, input_mask = encode(doc)
    inputs_list.append(input_ids)
    masks_list.append(input_mask)

100%|██████████| 29892/29892 [00:16<00:00, 1858.49it/s]


In [15]:
input_tensor = torch.stack(inputs_list).squeeze()
mask_tensor = torch.stack(masks_list).squeeze()
label_tensor = torch.stack([torch.tensor(i).unsqueeze(0) for i in negative_doc_label])
dataset = torch.utils.data.TensorDataset(input_tensor,mask_tensor, label_tensor)
dataloader = torch.utils.data.DataLoader(dataset, shuffle = False, batch_size = 8)

In [16]:
def intersect_tensor(t1, t2, device = 'cuda', mask = None):    
    indices = torch.zeros_like(t1, dtype = torch.uint8, device = device)
    for elem in t2:
        indices = indices | (t1 == elem) 
        indices = indices.to(bool)
        
    if mask is not None:
        indices = indices * mask 
    intersection = t1[indices]  
    return intersection, indices


In [17]:
def count_similar_words(batch, category_vocab = category_vocab):
    prediction= batch[0]
    input_mask = batch[1]
    masked_pred = prediction[:input_mask.sum().item(),:]
    _ , words = torch.topk(masked_pred, 8, -1)
    counter = 0
    for word in words.squeeze():
        counter += int(len(np.intersect1d(word.numpy(), category_vocab))>0)
        intersect_time = time() - intersect_time_start
        if counter > 0:
            print('break')
            return False
            break
    return True
            
def occurences(word, vocab = category_vocab):
    return len(np.intersect1d(word.cpu().numpy(), vocab))

In [27]:
# #### Get negative set

# from time import time
# verified_negative = []
# correct_label = 0
# verbose = False
# topk = 25
# vocab = torch.tensor(category_vocab).to(device)
# min_similar_words = 1
# max_category_word = 0
# num_cpus = 8
# with torch.no_grad():
#     for k, batch in tqdm(enumerate(dataloader)):
#         start_time = time()
#         input_ids, input_mask, label_id = batch
#         predictions = model(input_ids.to(device),
#                         pred_mode="mlm",
#                         token_type_ids=None, 
#                         attention_mask=input_mask.to(device))
#         end_prediction_time = time()
        
        
#     ########### GPU    
# #         intersection, indices = intersect_tensor(torch.topk(predictions,topk,-1)[1],vocab, 
# #                                                 mask = input_mask.unsqueeze(2).repeat(1,1,topk).to(device))
# #         counts_similar_word_per_word = indices.sum(-1) 
        
# #         counts = (counts_similar_word_per_word>min_similar_words).sum(-1)
# #         end_intersection_time = time()
        
# #         indices_count = torch.where(counts<=0)[0]
# #         for j in indices_count:
# #             i = j.item()

#     ################## CPU ####################""
# #         counts = Parallel(n_jobs=num_cpus)(delayed(count_similar_words)(batch) for batch in zip(predictions.cpu(), input_mask))
    
    
    
#         for i, doc in enumerate(predictions.cpu()):
#             start_loop = time()
#             masked_pred = doc[:input_mask[i].sum().item(),:]
#             _ , words = torch.topk(masked_pred, topk, -1)
#             counter = 0
#             topk_time = time()
            
# #             counts = Parallel(n_jobs=num_cpus)(delayed(occurences)(word) 
# #                                                    for word in words.squeeze())
# #             counter += len(np.where(np.array(counts)>min_similar_words)[0])
#             for word in words.squeeze():
# #                 counter += int(len(intersect_tensor(word, vocab))>0)
#                 counter += int(len(np.intersect1d(word.cpu().numpy(), category_vocab))>min_similar_words)
#                 intersect_time = time() - intersect_time_start
#                 if counter > max_category_word:
#                     break

# #             j = i.item()
# #         for i in np.where(np.array(counts))[0]:

#             if counter <= max_category_word:             
#                 verified_negative.append(k*4+i)
#                 if label_id[i] not in positive_label:
#                     correct_label += 1 
            
#         end_counting_time = time()    
#         if k%100 == 0:
#             if len(verified_negative)>0:
#                 print('accuracy :', correct_label/len(verified_negative))
#                 print('number of elements retrieved', len(verified_negative))
# #         if verbose:
# #             print('Prediction time', end_prediction_time-start_time) 
# # #             print('bascule cpu', start_loop-end_prediction_time)
# #             print('Intersection time', end_intersection_time-end_prediction_time)
# # #             print('topk time', topk_time-start_loop)
# #             print('counting time', end_counting_time-end_intersection_time)
# #         del predictions
# #         gc.collect()
# #         torch.cuda.empty_cache()
        
        
        
    

101it [01:25,  1.19it/s]

accuracy : 0.8571428571428571
number of elements retrieved 7


155it [02:12,  1.17it/s]


KeyboardInterrupt: 

In [11]:
# import pickle as p

# p.dump(verified_negative, open('verified_negative_sports.p','wb'))
# p.dump(dataloader, open('dataloader_sports.p','wb'))

In [None]:
######################### TRAINING PART ###############################################

In [None]:
#### DATASET CONSTRUCTION ####

In [10]:
new_verified_negative = p.load(open('verified_negative.p','rb'))
new_dataloader = p.load(open('dataloader.p','rb'))

In [11]:
mcp_data = torch.load('mcp_train.pt')
label = torch.LongTensor([1]).repeat(len(mcp_data['labels']))
mcp_data['labels'] = label

In [12]:
negative_dataset = Subset(new_dataloader.dataset, new_verified_negative)
positive_dataset = torch.utils.data.TensorDataset(mcp_data['input_ids'], mcp_data['attention_masks'], mcp_data['labels'])


In [13]:
target = np.hstack((np.zeros(int(len(negative_dataset)), dtype=np.int32),
                    np.ones(int(len(positive_dataset)), dtype=np.int32)))

class_sample_count = np.array(
    [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
target = torch.from_numpy(target).long()

In [14]:
data = torch.stack([negative_data[0][:200] for negative_data in negative_dataset] + 
            [positive_data[0][:200] for positive_data in positive_dataset])

In [15]:
mask = torch.stack([negative_data[1][:200] for negative_data in negative_dataset] + 
            [positive_data[1][:200] for positive_data in positive_dataset])

In [16]:
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

In [17]:
train_dataset = torch.utils.data.TensorDataset(data,mask, target)

In [18]:
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size = batch_size, sampler=sampler)

In [None]:
### TRAINING LOOPS

In [19]:

model = LOTClassModel.from_pretrained('bert-base-uncased',
                                           output_attentions=False,
                                           output_hidden_states=False,
                                           num_labels=2).to(device)
accum_steps = 8
model.train()
epochs = 1
train_loss = nn.CrossEntropyLoss()
total_steps = len(train_loader) * epochs / accum_steps
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)
losses_track = []
try:
    for i in range(epochs):
        model.train()
        total_train_loss = 0
        model.zero_grad()
        for j, batch in enumerate(train_loader):
            input_ids = batch[0].to(device)
            input_mask = batch[1].to(device)
            labels = batch[2].to(device)


            ### RANDOM MASKING
            random_masking = random.choices(list(range(199)),k=batch_size)
            for i, mask_pos in enumerate(random_masking):
                input_ids[i,mask_pos+1] = tokenizer.get_vocab()[tokenizer.mask_token]
            
            ### PREDICTION
            logits = model(input_ids, 
                           pred_mode="classification",
                           token_type_ids=None, 
                           attention_mask=input_mask)
            ### LOSS
            logits_cls = logits[:,0]
            loss = train_loss(logits_cls.view(-1, 2), labels.view(-1)) / accum_steps            
            total_train_loss += loss.item()
            loss.backward()
            if (j+1) % accum_steps == 0:
                # Clip the norm of the gradients to 1.0.
                print(loss*accum_steps)
                losses_track.append(loss*accum_steps)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
            if j % 10 == 0 :
                test()
        avg_train_loss = torch.tensor([total_train_loss / len(train_loader) * accum_steps]).to(device)
        print(f"Average training loss: {avg_train_loss.mean().item()}")

except RuntimeError as err:
    print(err)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing LOTClassModel: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing LOTClassModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LOTClassModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LOTClassModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias', 'dense.weight', 'dense.bias', 'classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predict

Accuracy  0.41015625
tensor(0.6924, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  2.96it/s]


Accuracy  0.390625
tensor(0.7166, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.04it/s]


Accuracy  0.3828125
tensor(0.6924, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.00it/s]


Accuracy  0.39453125
tensor(0.6705, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.6785, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  2.99it/s]


Accuracy  0.37109375
tensor(0.6506, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  2.96it/s]


Accuracy  0.349609375
tensor(0.6073, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.06it/s]


Accuracy  0.416015625
tensor(0.6605, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.00it/s]


Accuracy  0.4765625
tensor(0.5990, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.5902, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  2.95it/s]


Accuracy  0.654296875
tensor(0.5558, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.05it/s]


Accuracy  0.736328125
tensor(0.5541, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.04it/s]


Accuracy  0.7890625
tensor(0.5121, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.04it/s]


Accuracy  0.8515625
tensor(0.4840, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.4897, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.07it/s]


Accuracy  0.92578125
tensor(0.4963, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.09it/s]


Accuracy  0.91796875
tensor(0.3918, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.08it/s]


Accuracy  0.953125
tensor(0.4688, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.09it/s]


Accuracy  0.9609375
tensor(0.3814, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.3585, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.10it/s]


Accuracy  0.953125
tensor(0.2797, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.09it/s]


Accuracy  0.966796875
tensor(0.3243, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.06it/s]


Accuracy  0.95703125
tensor(0.3700, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  2.93it/s]


Accuracy  0.962890625
tensor(0.3145, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.3599, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  2.97it/s]


Accuracy  0.953125
tensor(0.3279, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.07it/s]


Accuracy  0.955078125
tensor(0.2930, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.10it/s]


Accuracy  0.970703125
tensor(0.1995, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.970703125
tensor(0.2089, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.1774, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.12it/s]


Accuracy  0.95703125
tensor(0.1930, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.07it/s]


Accuracy  0.97265625
tensor(0.2282, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.08it/s]


Accuracy  0.958984375
tensor(0.1713, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.10it/s]


Accuracy  0.958984375
tensor(0.1886, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.1567, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.9609375
tensor(0.1194, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.12it/s]


Accuracy  0.978515625
tensor(0.2339, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.12it/s]


Accuracy  0.966796875
tensor(0.1345, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.11it/s]


Accuracy  0.95703125
tensor(0.1123, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.1248, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.09it/s]


Accuracy  0.982421875
tensor(0.2491, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.07it/s]


Accuracy  0.986328125
tensor(0.1834, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.06it/s]


Accuracy  0.9609375
tensor(0.1234, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.06it/s]


Accuracy  0.958984375
tensor(0.2068, device='cuda:0', grad_fn=<MulBackward0>)
tensor(0.1264, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  3.05it/s]


Accuracy  0.984375
tensor(0.1303, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 16/16 [00:05<00:00,  2.99it/s]


Accuracy  0.970703125
tensor(0.1651, device='cuda:0', grad_fn=<MulBackward0>)


 56%|█████▋    | 9/16 [00:03<00:02,  2.72it/s]


KeyboardInterrupt: 

In [None]:
# import pickle as p
# p.dump(verified_negative, open('verified_negative.p','wb'))
# p.dump(dataloader, open('dataloader.p','wb'))

In [None]:
########## TRAINING #########

In [60]:
test(all=True)

100%|██████████| 3750/3750 [20:03<00:00,  3.12it/s]

Accuracy  227.732421875





227.732421875

In [58]:
k = random.sample(range(120000), k=1)[0]
x, m = encode(docs[k])
label = list_label[k]
pred = model(x.to(device), attention_mask = m.to(device), pred_mode='classification')
print(torch.argmax(pred[0,0,:]))
print(label)

tensor(0, device='cuda:0')
3


In [61]:
227.73*512/120000

0.971648

tensor(0, device='cuda:0')
2


### top k = top 8 -> ~0.8 accuracy, 400 datapoints