# Check what the Neural Network method mask

In [1]:
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForTokenClassification
from torch.utils.data import DataLoader
from IPython.display import clear_output
import torch.nn as nn
import copy
import sys, os
sys.path.append('..')

os.environ['TRANSFORMERS_CACHE'] = './cache/'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

from src.dataset import *
from src.utils   import *
from src.traineval  import *

SEED = 42
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = get_freer_gpu()
print('device', device)

device cuda:0


In [None]:
train_dataset = pickle.load(open(f'../data/domain/domainchunk-R10-test-ds.pkl', "rb"))

for sample in train_dataset:
    print(sample['input_ids'])
    print(sample['word_ids'])
    print(sample['attention_mask'])
    print(sample['orig_text'])
    print(sample['labels'])
    break

In [3]:
batch_size = 32

# can shuffle now because we use the model to do inference on any sample
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)

In [4]:
num_labels  = 2 # CE
checkpoint = 'bert-base-uncased'
model       = AutoModelForTokenClassification.from_pretrained(checkpoint, num_labels = num_labels).to(device)
model.classifier.dropout = nn.Dropout(p = 0.1, inplace = False)

# load model from trained model
path = '../save/NN-02-classitoken-round2/best-model-4200.tar'
print("Load model from : ", path)

loaded_checkpoint = torch.load(path)
model.load_state_dict(loaded_checkpoint['model_state_dict'])
print(model.load_state_dict(loaded_checkpoint['model_state_dict'])) # <All keys matched successfully>

model.eval()

tokenizer     = BertTokenizerFast.from_pretrained(checkpoint)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForTokenClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-u

Load model from :  ../save/NN-02-classitoken-round2/best-model-4200.tar
<All keys matched successfully>


In [7]:
class NNMLMDataset(Dataset):
    
    def __init__(self, data_loader, classitoken_model, tokenizer):
        
        self.tokenizer = tokenizer
        self.model_mask_id = self.tokenizer(self.tokenizer.mask_token)['input_ids'][1] 
        self.model_cls_id     = self.tokenizer(self.tokenizer.cls_token)['input_ids'][1]
        self.model_sep_id     = self.tokenizer(self.tokenizer.sep_token)['input_ids'][1]
        self.make_NN_MLM_ds(data_loader, classitoken_model)
        
        
        
        del classitoken_model, data_loader
        
    def __len__(self):
        return len(self.list_input_ids)

    def __getitem__(self, idx):       
        sample = {  'input_ids'      : self.list_input_ids[idx],
                    'word_ids'       : self.list_word_ids[idx],
                    'attention_mask' : self.list_attention_mask[idx],
                    'orig_text'      : self.list_orig_text[idx],
                    'masked_text'    : self.list_masked_text[idx],
                    'labels'         : self.list_labels[idx]}
        return sample
    
    # Use the trained model to do inference on domain dataset to creats masked domain ds
    def make_NN_MLM_ds(self, data_loader, model):
        
        self.list_input_ids      = []
        self.list_word_ids       = []
        self.list_attention_mask = []
        self.list_orig_text      = []
        self.list_masked_text    = []
        self.list_labels         = []
        
        self.list_important_input_ids = []
        # self.list_important_word_ids = []
        # self.list_important_idx_seq = []
        # self.list_important_idx_pos = []

        for idx, batch in enumerate(data_loader):  
            
            sys.stdout.write(str(idx))
            
            input_ids = batch['input_ids'].clone().to(device)
            att_mask  = batch['attention_mask'].clone().to(device)
            
            word_ids     = batch['word_ids']
            orig_text    = batch['orig_text']
            labels       = batch['labels']
            
            model.eval()
            with torch.no_grad():
                outputs = model(input_ids = input_ids, attention_mask = att_mask)
            
            pred = torch.argmax(torch.softmax(outputs.logits.detach(), dim = 2), dim = 2) # bs, seq_len
            # print(pred.shape)
            
            # get index of important tokens
            important_idx_seq = (pred == 1).nonzero(as_tuple=True)[0]
            important_idx_pos = (pred == 1).nonzero(as_tuple=True)[1]
            
            # print(important_idx_seq[0:5], important_idx_pos[0:5])
            
            important_input_ids = input_ids.clone().detach()[important_idx_seq, important_idx_pos]
#             important_word_ids = word_ids.clone().detach()[important_idx_seq, important_idx_pos]
            
#             # print(important_input_ids)
            
#             # put [MASK] token at the position of the important tokens
#             masked_input_ids = input_ids.detach().clone()
#             masked_input_ids[important_idx_seq, important_idx_pos] = self.model_mask_id
#             # ensure that the first and last tokens are not masked
#             masked_input_ids[:, 0]   = self.model_cls_id
#             masked_input_ids[:, 511] = self.model_sep_id
            
#             # print(masked_input_ids.shape)
            
#             labels    = torch.ones_like(att_mask).to(device) * -100 # init all labels with -100
#             # put original token input_ids at the position of important tokens
#             masked_labels  = labels.index_put(indices = (important_idx_seq, important_idx_pos) , values = important_input_ids)
#             # ensure that the model do not predict the fist and last tokens
#             masked_labels[:, 0]   = -100
#             masked_labels[:, 511] = -100
            
            self.list_important_input_ids.append(important_input_ids)
            # self.list_important_word_ids.append(important_word_ids)
            # self.list_important_idx_seq.append(important_idx_seq)
            # self.list_important_idx_pos.append(important_idx_pos)
            
            # if idx == 10 :
            #     break
            
#             for i in range(input_ids.shape[0]):
            
#                 self.list_input_ids.append(masked_input_ids[i].clone())
#                 self.list_word_ids.append(word_ids[i].clone())
#                 self.list_attention_mask.append(torch.ones_like(masked_input_ids[i]))
#                 self.list_orig_text.append(orig_text[i])
#                 self.list_masked_text.append(self.tokenizer.decode(masked_input_ids[i].clone()))
#                 self.list_labels.append(masked_labels[i].clone())
                
#                 # print(masked_input_ids[10].clone())
#                 # print(masked_labels[10].clone())
                
#         assert len(self.list_input_ids) == len(self.list_word_ids) == len(self.list_attention_mask) == len(self.list_orig_text) == len(self.list_masked_text) == len(self.list_labels)

In [8]:
my_train_dataset   = NNMLMDataset(train_loader, model, tokenizer)

012345678910

In [12]:
masked_token_freq = {}

for input_ids in my_train_dataset.list_important_input_ids:
    for each_id in input_ids:
        # print(each_id)
        token = tokenizer.decode(each_id)
        # print(token)
        if token not in masked_token_freq.keys():
            # print("new")
            masked_token_freq[token] = 1
        else : 
            # print("old")
            masked_token_freq[token] += 1

In [13]:
keyword = sorted(masked_token_freq.items(), key=lambda x: x[1], reverse=True)
# print(keyword)

masked_words = [kw for kw, freq in keyword]
# print(masked_words)

[('you', 45), ('i', 41), ('that', 29), ('is', 29), ('and', 27), ('it', 26), ('to', 23), ('in', 20), ('a', 19), ('of', 18), ('for', 17), ('the', 16), ('this', 16), ('if', 16), (',', 16), ('u', 16), ("'", 15), ('are', 15), ('have', 15), ('t', 15), ('can', 14), ('there', 13), ('me', 13), ('with', 13), ('not', 13), ('ru', 13), ('##u', 12), ('nu', 12), ('up', 11), ('##s', 10), ('try', 10), ('know', 10), ('on', 9), ('your', 9), ('time', 9), ('all', 9), ('but', 9), ('use', 9), ('just', 9), ('thanks', 9), ('was', 8), ('would', 8), ('level', 8), ('?', 8), ('some', 7), ('want', 7), ('also', 7), ('then', 7), ('##y', 7), ('like', 7), ('be', 7), ('so', 7), ('at', 7), ('##l', 7), ('##13', 7), ('will', 6), ('different', 6), ('from', 6), ('no', 6), ('my', 6), ('good', 6), ('people', 6), ('do', 6), ('##kk', 6), ('don', 6), ('when', 6), ('who', 6), ('thing', 6), ('ice', 6), ('let', 6), ('##31', 6), ('has', 5), ('think', 5), ('##i', 5), ('about', 5), ('or', 5), ('see', 5), ('one', 5), ('app', 5), ('make'

In [11]:
with open(f"./nn_masked_words.txt", "w") as f:
    for word in masked_words:
        f.write(word + "\n")