In [1]:
import os
import json
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel


class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        return item

    def __len__(self):
        return len(self.encodings['input_ids'])

from numpy import dot
from numpy.linalg import norm
import numpy as np
def cos_sim(A, B):
       return dot(A, B)/(norm(A)*norm(B))

In [2]:
DATASET_PATH = "../dataset/IMDb/triplet_automated_averaged_gradient_1word_augmented_1x_aclImdb/"
OUTPUT_PATH  = "../dataset/IMDb/triplet_automated_averaged_gradient_wanglike_1word_augmented_1x_aclImdb/"
BATCH_SIZE = 4

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")
model.to(device)
model.eval()

with open(os.path.join(DATASET_PATH, "train.json")) as f:
    data = json.load(f)

causal_sentences = [d['negative_text'] for d in data]
causal_encodings = tokenizer(causal_sentences, truncation=True, padding=True)

causal_dataset = IMDbDataset(causal_encodings)
causal_loader = DataLoader(causal_dataset, batch_size=BATCH_SIZE, shuffle=False)

causal_reps = []
for batch in tqdm(causal_loader):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    with torch.no_grad():
        logits = model(input_ids, attention_mask)[1]
        causal_reps.append(logits.detach().cpu())

causal_reps = torch.cat(causal_reps)
causal_reps = causal_reps.detach().cpu().numpy()

100%|██████████| 5625/5625 [02:25<00:00, 38.77it/s]


In [8]:
positive_reps = []
negative_reps = []

In [9]:
for rep, d in zip(causal_reps, data):
    if d['label'] == [1.0, 0.0]:
        negative_reps.append((d['id'], rep))
    else:
        positive_reps.append((d['id'], rep))


In [15]:
for rep, d in tqdm(zip(causal_reps, data)):
    d['triplet_sample_mask'] = False

    if d['label'] == [1.0, 0.0]:
        for pr in positive_reps:
            if d['id'] == pr[0]:
                continue
            if cos_sim(rep, pr[1]) > 0.95:
                d['triplet_sample_mask'] = True
                break

    else:
        for nr in negative_reps:
            if d['id'] == nr[0]:
                continue
            if cos_sim(rep, nr[1]) > 0.95:
                d['triplet_sample_mask'] = True
                break
    

22500it [00:10, 2063.01it/s]


In [16]:
import shutil
os.makedirs(OUTPUT_PATH)
with open(os.path.join(OUTPUT_PATH, "train.json"), 'w') as f:
    json.dump(data, f)
shutil.copy(os.path.join(DATASET_PATH, "valid.json"), os.path.join(OUTPUT_PATH, "valid.json"))
shutil.copy(os.path.join(DATASET_PATH, "test.json"), os.path.join(OUTPUT_PATH, "test.json"))
print(sum([d['triplet_sample_mask'] for d in data])/len(data))

0.9987555555555555


In [17]:
print(sum([d['triplet_sample_mask'] for d in data]))

22472


In [12]:
data[1]

{'id': 9516,
 'anchor_text': 'excellent show. instead of watching the same old sitcom type shows where it\'s the same old thing, just different " stars ", this refreshing show provided an incredibly entertaining view of office situations. we have been away from watching any television for 2 years and after coming back, of all the shows available we look forward to watching this show on w. shame on global for pulling the plug on this one. i thought this one would be a winner. let\'s be realistic about things, few canadian shows make it. everyone i talk to enjoys this show and i believe it was foolish of global to walk away. i guess they want to stick it out with the typical mind numbing shows from the states instead of pulling behind a canadian made show that had a lot of promise. don\'t get me wrong, i enjoy a lot of shows on tv, but, come on people, let\'s keep the variety. this unique show provided a very comedic view of a slightly exaggerated realistic side of office life and relati