In [None]:
from tqdm import tqdm
import json
import os
import numpy as np
import nltk
import random
from nltk import pos_tag
from sklearn.model_selection import train_test_split
from transformers import BertModel, BertTokenizerFast
import torch
import datasets

In [None]:
random_seed = 1234
torch.manual_seed(random_seed)
random.seed(random_seed)
data_dir = 'data'
max_len = 128

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
sp_token = '[MASK]'


with open(os.path.join(data_dir, 'target_keywords.json'), 'r') as file:
    target_keywords = json.load(file)
with open(os.path.join(data_dir, 'other_sentences.json'), 'r') as file:
    other_sentences = json.load(file)
with open(os.path.join(data_dir, 'target_sentences.json'), 'r') as file:
    target_sentences = json.load(file)


ns = [5] ## R_stms
print(f"R_STMS is {ns[0]}")

### Making STMS data

In [None]:
masked_target_sen, replaced_words, masked_target_hard_neg, replaced_words_hard_neg  = [], [], [], []
for sen in tqdm(target_sentences):
    words = nltk.WordPunctTokenizer().tokenize(sen)
    noun_inds = []
    target_inds = []
    for ind, data in enumerate(nltk.pos_tag(words)):
        if (data[1].startswith('N')) and (data[0].lower() not in target_keywords and data[0].isalnum()): noun_inds.append(ind) 
        elif data[0].lower() in target_keywords: target_inds.append(ind)

    for target_ind in target_inds:
        words_ = words.copy()
        replaced_words.append(words_[target_ind]) ## sotre replaced words 
        words_[target_ind] = sp_token ##Substitute target keyword with SP token
        masked_target_sen.append(" ".join(words_))

    for n in ns:
        if len(noun_inds) > n-1:
            change_inds = random.sample(noun_inds, n)
        else:
            change_inds = noun_inds
            
        for change_ind in change_inds:
            words_ = words.copy()
            replaced_words_hard_neg.append(words[change_ind])
            words_[change_ind] = sp_token ##Substitute target keyword with SP token
            masked_target_hard_neg.append(" ".join(words_))
len(masked_target_sen), len(replaced_words), len(masked_target_hard_neg), len(replaced_words_hard_neg)

### Making non STMS data

In [None]:
masked_sentences_other, replaced_words_other = [], []
n=1
for sen in tqdm(other_sentences):
    words = nltk.WordPunctTokenizer().tokenize(sen)
    noun_inds = []
    for ind, data in enumerate(nltk.pos_tag(words)):
        if (data[1].startswith('N')) and (data[0].isalnum()): noun_inds.append(ind) 
    if len(noun_inds) > n-1:
        change_inds = random.sample(noun_inds, n)
    else:
        change_inds = noun_inds
    for change_ind in change_inds:
        words_ = words.copy()
        replaced_words_other.append(words[change_ind])
        words_[change_ind] = sp_token ##Substitute target keyword with SP token
        masked_sentences_other.append(" ".join(words_))
len(masked_sentences_other), len(replaced_words_other)

### Make Dataset

In [None]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length= max_len)

def get_target_word_embeddings(sentences_with_mask, target_words):
    # Replace [MASK] with target words
    sentences = [sentence.replace("[MASK]", target_word) for sentence, target_word in zip(sentences_with_mask, target_words)]

    # Tokenize sentences with [MASK]
    inputs_with_mask = tokenizer(sentences_with_mask, return_tensors='pt', padding=True, truncation=True)

    # Identify positions of [MASK] tokens
    mask_positions = [torch.nonzero(input_ids == tokenizer.mask_token_id, as_tuple=False)[0].item() for input_ids in inputs_with_mask['input_ids']]
    mask_positions = torch.tensor(mask_positions, dtype=torch.long, device=device)

    # Run model on sentences with target words
    with torch.no_grad():
        tokenized_inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
        outputs = model(**tokenized_inputs)
        embeddings = outputs.last_hidden_state.detach()

    # Extract embeddings at the [MASK] positions
    target_embeddings = embeddings[torch.arange(embeddings.size(0)), mask_positions].cpu()

    # Explicitly delete large objects
    del outputs
    del embeddings
    del tokenized_inputs
    torch.cuda.empty_cache()

    return target_embeddings

In [None]:
model = BertModel.from_pretrained(os.path.join(data_dir,'BERT_pretrained_reddit'))
model.eval()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

model = model.to(device)

In [None]:
ratios = [2] ## R_nonstms
print(f"R_nonSTMS is {ratios[0]}")

In [None]:
for ratio in ratios:
    masked_sentences_other_sampled_words = random.sample(list(zip(masked_sentences_other,replaced_words_other)), ratio*len(masked_target_sen))
    masked_sentences_other_sampled = []
    replaced_word_other_sampled = []
    for d in masked_sentences_other_sampled_words:
        masked_sentences_other_sampled.append(d[0])
        replaced_word_other_sampled.append(d[1])
    for maxn in ns:
        print(ratio, maxn)
        print(len(masked_target_sen), len(masked_target_hard_neg), len(masked_sentences_other_sampled))

        labels = [1]*len(masked_target_sen)+[0]*len(masked_target_hard_neg)+[0]*len(masked_sentences_other_sampled)
        words = replaced_words+replaced_words_hard_neg+replaced_word_other_sampled

        train_text, valid_text, train_label, valid_label, train_words, valid_words = train_test_split(masked_target_sen+masked_target_hard_neg+masked_sentences_other_sampled, labels, words, train_size=0.8, shuffle=True)

        train_embs, valid_embs = [], []

        print('Extracting embs for Train')
        for data in tqdm(np.array_split(list(zip(train_text, train_words)), 1000)): ##Splitting for GPU memory
            train_embs.extend(get_target_word_embeddings(list(data[:,0]), list(data[:,1])))

        print('Extracting embs for Valid')
        for data in tqdm(np.array_split(list(zip(valid_text, valid_words)), 1000)): ##Splitting for GPU memory
            valid_embs.extend(get_target_word_embeddings(list(data[:,0]), list(data[:,1])))


        train = datasets.Dataset.from_dict({'text': train_text, 'label': train_label, 'word': train_words, 'emb': train_embs}) ## datasets automatically tensor or numpy to list, so later, change its type to tensor
        valid = datasets.Dataset.from_dict({'text': valid_text, 'label': valid_label, 'word': valid_words, 'emb': valid_embs})
        dataset = datasets.DatasetDict({'train':train, 'valid':valid})    


        tokenized_dataset = dataset.map(preprocess_function, batched=True)

        print(f'Emb_Dataset_ratio{ratio}_max{maxn}_{random_seed}')
        
        tokenized_dataset.save_to_disk(os.path.join(data_dir, f'Emb_Dataset_ratio{ratio}_max{maxn}_{random_seed}'))