In [2]:
import csv
import logging
import os
import random
import sys

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
#from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear

from bertviz.bertviz import attention, visualization
from bertviz.bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
logger = logging.getLogger(__name__)
bert_classifier_model_dir = "../models/BERT/" ## Path of BERT classifier model path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("device: {}, n_gpu {}".format(device, n_gpu))

In [4]:
# file paths
data_dir = "/home/jack/Desktop/NN/clean/datasets/"
dataset = "yelp" # amazon / yelp / imagecaption
train_0 = os.path.join(data_dir ,"{}/sentiment.train.0".format(dataset))
train_1 = os.path.join(data_dir,"{}/sentiment.train.1".format(dataset))
test_0 = os.path.join(data_dir,"{}/sentiment.test.0".format(dataset))
test_1 = os.path.join(data_dir,"{}/sentiment.test.1".format(dataset))
dev_0 = os.path.join(data_dir,"{}/sentiment.dev.0".format(dataset))
dev_1 = os.path.join(data_dir,"{}/sentiment.dev.1".format(dataset))
reference_0 = os.path.join(data_dir,"{}/reference.0".format(dataset))
reference_1 = os.path.join(data_dir,"{}/reference.1".format(dataset))

reference_0_org = os.path.join(data_dir,"{}/reference_0_org.txt".format(dataset))
reference_1_org = os.path.join(data_dir,"{}/reference_1_org.txt".format(dataset))
reference_0_expected = os.path.join(data_dir,"{}/reference_0_expected.txt".format(dataset))
reference_1_expected = os.path.join(data_dir,"{}/reference_1_expected.txt".format(dataset))

In [5]:
def read_file(file_path):
    with open(file_path) as fp:
        data = fp.read().splitlines()
    return data

In [16]:
def split_Ref(ref_file_path, target_file_path, target_split):
    open_File = read_file(ref_file_path)
    #target_file = read_file(target_file_path)
    with open(target_file_path, "w") as target_file:
        for ref_line in open_File:
            print(ref_line.split("	")[target_split])
            target_file.write(ref_line.split("	")[target_split]+"\n")
            
        
split_Ref(reference_0,reference_0_org, 0)
split_Ref(reference_0,reference_0_expected, 1)
split_Ref(reference_1,reference_1_org, 0)
split_Ref(reference_1,reference_1_expected, 1)

ever since joes has changed hands it 's just gotten worse and worse .
there is definitely not enough room in that part of the venue .
so basically tasted watered down .
she said she 'd be back and disappeared for a few minutes .
i ca n't believe how inconsiderate this pharmacy is .
just left and took it off the bill .
it is n't terrible , but it is n't very good either .
definitely disappointed that i could not use my birthday gift !
new owner , i heard - but i do n't know the details .
but it probably sucks too !
we sit down and we got some really slow and lazy service .
the charge did include miso soup and a small salad .
there was no i 'm sorry or how did everything come out .
said we could n't sit at the table if we were n't ordering dinner .
the cash register area was empty and no one was watching the store front .
there chips are ok , but their salsa is really bland .
the wine was very average and the food was even less .
staffed primarily by teenagers that do n't understand cust

In [6]:
# file paths
data_dir = "/home/jack/Desktop/NN/clean/datasets/"
dataset = "yelp" # amazon / yelp / imagecaption
train_0_out = os.path.join(data_dir ,"{}/processed_files_with_bert_with_best_head/sentiment_train_0.txt".format(dataset))
train_1_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_train_1.txt".format(dataset))
test_0_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_test_0.txt".format(dataset))
test_1_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_test_1.txt".format(dataset))
dev_0_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_dev_0.txt".format(dataset))
dev_1_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_dev_1.txt".format(dataset))
reference_0_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/reference_0.txt".format(dataset))
reference_1_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/reference_1.txt".format(dataset))

In [7]:
## Model for performing Classification
model_cls = BertForSequenceClassification.from_pretrained(bert_classifier_model_dir, num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model_cls.to(device)
model_cls.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
    

In [8]:
## Model to get the attention weights of all the heads
model = BertModel.from_pretrained(bert_classifier_model_dir)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model.to(device)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features

In [9]:
max_seq_len=70 # Maximum sequence length 
sm = torch.nn.Softmax(dim=-1) ## Softmax over the batch

In [10]:
common_words=['is','are','was','were','has','have','had','a','an','the','this','that','these','those','there','how','i','we',
             'he','she','it','they','them','their','his','him','her','us','our', 'and','in','my','your','you', 'will', 'shall']
common_words_tokens = tokenizer.convert_tokens_to_ids(common_words)
not_to_remove_ids = tokenizer.convert_tokens_to_ids(["[CLS]","[SEP]", ".", "?", "!"])
not_to_remove_ids += common_words_tokens

In [11]:
def read_file(file_path):
    with open(file_path) as fp:
        data = fp.read().splitlines()
    return data

In [12]:
def create_output_file(original_sentences,processed_sentences, output_file, sentiment="<POS>"):
    with open(output_file,"w") as fp:
        for sen1,sen2 in zip(original_sentences,processed_sentences):
            if sen1 != None and sen2 != None:
                str1 = sentiment + " <CON_START> " + sen2 + " <START> " + sen1 + " <END>\n"
                fp.write(str1)

In [13]:
def create_ref_output_file(processed_sentences,original_sentences, output_file, sentiment="<POS>"):
    with open(output_file,"w") as fp:
        for sen in tqdm(processed_sentences):
            if sen != None:
                str1 = sentiment + " <CON_START> " + sen + " <START>\n"
                fp.write(str1)

In [15]:
def concate_files(inp_files, out_files):
    with open(out_files,"w") as fp:
        for file in inp_files:
            with open(file) as f:
                for line in f:
                    fp.write(line)

In [19]:
def run_attn_examples(input_sentences, layer, head, bs=128):
    """
    Returns Attention weights for selected Layer and Head along with ids and tokens
    of the input_sentence
    """
    ids = []
    ids_to_decode = [None for k in range(len(input_sentences))]
    tokens_to_decode = [None for k in range(len(input_sentences))]
    segment_ids = []
    input_masks = []
    attention_weights = [None for z in input_sentences]
    ## BERT pre-processing
    for j,sen in enumerate(tqdm(input_sentences)):
        
        text_tokens = tokenizer.tokenize(sen)
        if len(text_tokens) >= max_seq_len-2:
            text_tokens = text_tokens[:max_seq_len-4]
        tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
        tokens_to_decode[j] = tokens
        temp_ids = tokenizer.convert_tokens_to_ids(tokens)
        ids_to_decode[j] = temp_ids
        input_mask = [1] * len(temp_ids)
        segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))
        
        
        temp_ids += padding
        input_mask += padding
        segment_id += padding
        
        ids.append(temp_ids)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
    
    # Convert Ids to Torch Tensors
    ids = torch.tensor(ids) 
    segment_ids = torch.tensor(segment_ids)
    input_masks = torch.tensor(input_masks)
    
    steps = len(ids) // bs
    
    for i in trange(steps+1):
        if i == steps:
            temp_ids = ids[i * bs : len(ids)]
            temp_segment_ids = segment_ids[i * bs: len(ids)]
            temp_input_masks = input_masks[i * bs: len(ids)]
        else:
            temp_ids = ids[i * bs : i * bs + bs]
            temp_segment_ids = segment_ids[i * bs: i * bs + bs]
            temp_input_masks = input_masks[i * bs: i * bs + bs]
        
        temp_ids = temp_ids.to(device)
        temp_segment_ids = temp_segment_ids.to(device)
        temp_input_masks = temp_input_masks.to(device)
        with torch.no_grad():
             _, _, attn = model(temp_ids, temp_segment_ids, temp_input_masks)
        # Concate Attention weights
        for j in range(len(attn[layer]['attn_probs'])):
            attention_weights[i * bs + j] = (attn[layer]['attn_probs'][j][head][0]).to('cpu')
    
    return attention_weights, ids_to_decode, tokens_to_decode

In [18]:
def prepare_data(aw, ids_to_decode, tokens_to_decode):
    out_sen = [None for i in range(len(aw))]
    for i in trange(len(aw)):
        #topv, topi = aw[i].topk(len(inps_tokens[i]))
        topv, topi = aw[i].topk(ids_to_decode[i].index(0))
        topi = topi.tolist()
        topv = topv.tolist()
        #print(i,train_0[i])
        #print(tokens_to_decode[i])
        #print("Original Top Indexes = {}".format(topi))
        topi = [topi[j] for j in range(len(topi)) if ids_to_decode[i][topi[j]] not in not_to_remove_ids] # remove noun and common words
        #print("After removing Nouns = {}".format(topi))
        topi = [topi[j] for j in range(len(topi)) if "##" not in tokens_to_decode[i][topi[j]]] # Remove half words
        #print("After removing Half-words = {}".format(topi))

        if (len(topi) < 4 and len(topi) > 0):
            topi = [topi[0]]
        elif(len(topi) < 8):
            topi = topi[:2]
        else:
            topi = topi[:3]

        #print("Final Topi = {}".format(topi))
        final_indexes = []
        count = 0
        count1 = 0
        #print(ids_to_decode[i], tokens_to_decode[i])
        while ids_to_decode[i][count] != 0:
            if count in topi:
                while ids_to_decode[i][count + count1 + 1] != 0:
                    if "##" in tokens_to_decode[i][count + count1 + 1]:
                        count1 += 1
                    else:
                        break
                count += count1
                count1 = 0
            else:
                final_indexes.append(ids_to_decode[i][count])
            count += 1

        #print(final_indexes)
        temp_out_sen = tokenizer.convert_ids_to_tokens(final_indexes)
        temp_out_sen = " ".join(temp_out_sen).replace(" ##", "").replace("[CLS]","").replace("[SEP]","")
        #print(temp_out_sen, "\n\n")
        out_sen[i] = temp_out_sen.strip()
    
    return out_sen

In [17]:
train_0_data = read_file(train_0)
train_1_data = read_file(train_1)
dev_0_data = read_file(dev_0)
dev_1_data = read_file(dev_1)
test_0_data = read_file(test_0)
test_1_data = read_file(test_1)
ref_0_data = read_file(reference_0)
ref_1_data = read_file(reference_1)

In [16]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(train_0_data, layer=4, head=2, bs=128)
train_0_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(train_0_data, train_0_out_sen, train_0_out, sentiment="<NEG>")

100%|██████████| 177218/177218 [00:16<00:00, 10895.78it/s]
100%|██████████| 1385/1385 [03:42<00:00,  6.22it/s]
100%|██████████| 177218/177218 [00:02<00:00, 65505.84it/s]


In [17]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(train_1_data, layer=4, head=2, bs=128)
train_1_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(train_1_data, train_1_out_sen, train_1_out, sentiment="<POS>")

100%|██████████| 266041/266041 [00:23<00:00, 11485.38it/s]
100%|██████████| 2079/2079 [05:33<00:00,  6.23it/s]
100%|██████████| 266041/266041 [00:03<00:00, 76176.93it/s]


In [18]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(dev_0_data, layer=4, head=2, bs=128)
dev_0_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(dev_0_data, dev_0_out_sen, dev_0_out, sentiment="<NEG>")

100%|██████████| 2000/2000 [00:00<00:00, 10967.68it/s]
100%|██████████| 16/16 [00:02<00:00,  7.27it/s]
100%|██████████| 2000/2000 [00:00<00:00, 67851.43it/s]


In [19]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(dev_1_data, layer=4, head=2, bs=128)
dev_1_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(dev_1_data, dev_1_out_sen, dev_1_out, sentiment="<POS>")

100%|██████████| 2000/2000 [00:00<00:00, 12829.84it/s]
100%|██████████| 16/16 [00:02<00:00,  7.14it/s]
100%|██████████| 2000/2000 [00:00<00:00, 76325.30it/s]


In [20]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(test_1_data, layer=4, head=2, bs=128)
test_1_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(test_1_data, test_1_out_sen, test_1_out, sentiment="<POS>")

100%|██████████| 500/500 [00:00<00:00, 7652.28it/s]
100%|██████████| 4/4 [00:00<00:00,  7.27it/s]
100%|██████████| 500/500 [00:00<00:00, 67281.10it/s]


In [28]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(test_0_data, layer=4, head=2, bs=128)
test_0_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
print(test_0_out_sen[:4])
print(test_0_data[:4])

create_output_file(test_0_data, test_0_out_sen, test_0_out, sentiment="<NEG>")

100%|██████████| 500/500 [00:00<00:00, 10439.20it/s]
100%|██████████| 4/4 [00:00<00:00,  6.90it/s]
100%|██████████| 500/500 [00:00<00:00, 66147.87it/s]

["ever since joes has changed hands it ' just gotten and .", 'there is not enough room in that part of the .', 'so basically down .', "she she ' d be back and for a few ."]
["ever since joes has changed hands it 's just gotten worse and worse .", 'there is definitely not enough room in that part of the venue .', 'so basically tasted watered down .', "she said she 'd be back and disappeared for a few minutes ."]





In [20]:
original_content_data_1 = [x.split("	")[0] for x in ref_1_data]

aw, ids_to_decode, tokens_to_decode = run_attn_examples(original_content_data_1, layer=4, head=2, bs=128)
ref_1_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
print(ref_1_out_sen[:4])
print(ref_1_data[:4])

create_ref_output_file(ref_1_out_sen, original_content_data_1 , reference_1_out, sentiment="<NEG>")

100%|██████████| 500/500 [00:00<00:00, 8660.98it/s]
100%|██████████| 4/4 [00:01<00:00,  2.55it/s]
100%|██████████| 500/500 [00:00<00:00, 39971.64it/s]


["it ' s yet they you feel at home .", 'i will be going back and this place !', 'the drinks were and a pour .', 'my husband got a ruben , he it .']
["it 's small yet they make you feel right at home .\tit's small yet they make you feel like a stranger.", "i will be going back and enjoying this great place !\ti won't be going back and suffering at this terrible place !", 'the drinks were affordable and a good pour .\tthe drinks were expensive and half full.', 'my husband got a ruben sandwich , he loved it .\tmy husband got a reuben sandwich, he hated it.']


100%|██████████| 500/500 [00:00<00:00, 956729.93it/s]


In [21]:
original_content_data_0 = [x.split("	")[0] for x in ref_0_data]


aw, ids_to_decode, tokens_to_decode = run_attn_examples(original_content_data_0, layer=4, head=2, bs=128)
ref_0_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)

create_ref_output_file(ref_0_out_sen, original_content_data_0, reference_0_out, sentiment="<POS>")

100%|██████████| 500/500 [00:00<00:00, 9356.27it/s]
100%|██████████| 4/4 [00:00<00:00,  6.59it/s]
100%|██████████| 500/500 [00:00<00:00, 47344.05it/s]
100%|██████████| 500/500 [00:00<00:00, 1428577.66it/s]


In [15]:
# Hacky merge here - Do NOT execute without need
content_dir = data_dir + "yelp/processed_files_with_bert_with_best_head/delete_retrieve_edit_model/" 
att_dir = data_dir + "yelp/processed_files_with_bert_with_best_head/delete_retrieve_edit_model/tfidf/backup/original/"
output_dir = data_dir + "yelp/processed_files_with_bert_with_best_head/delete_retrieve_edit_model/tfidf/" 

ref_0_content = read_file(content_dir + "reference_0.txt")
ref_1_content = read_file(content_dir + "reference_1.txt")

ref_0_att = read_file(att_dir + "reference_0.txt")
ref_1_att = read_file(att_dir + "reference_1.txt")

print(ref_0_content[:4])
print(ref_0_att[:4])

content_0 = [x.split("<CON_START>")[1] for x in ref_0_content]
content_1 = [x.split("<CON_START>")[1] for x in ref_1_content]

att_0 = [x.split("<CON_START>")[0] for x in ref_0_att]
att_1 = [x.split("<CON_START>")[0] for x in ref_1_att]

print(content_0[:4])
print(att_0[:4])



with open(output_dir+"reference_0", 'w') as outfile:
    for index, content in enumerate(content_0):
            outfile.write(att_0[index] + "<CON_START>" + content +"\n")

with open(output_dir+"reference_1", 'w') as outfile:
    for index, content in enumerate(content_1):
            outfile.write(att_1[index] + "<CON_START>" + content +"\n")

["<ATTR_WORDS> worse worse <CON_START> ever since joes has changed hands it ' just gotten <REPLACE> and <REPLACE> . <START> ", '<ATTR_WORDS> definitely venue <CON_START> there is <REPLACE> not enough room in that part of the <REPLACE> . <START> ', '<ATTR_WORDS> tasted watered <CON_START> so basically <REPLACE> <REPLACE> down . <START> ', "<ATTR_WORDS> said disappeared <CON_START> she she ' d be back and <REPLACE> for a few <REPLACE> . <START> "]
["<ATTR_WORDS> prompt friendly <CON_START> ever since joes has changed hands it ' just gotten and . <START>", '<ATTR_WORDS> does does well <CON_START> there is not enough room in that part of the . <START>', '<ATTR_WORDS> work patrick <CON_START> so basically down . <START>', "<ATTR_WORDS> honestly talent talent <CON_START> she she ' d be back and for a few . <START>"]
[" ever since joes has changed hands it ' just gotten <REPLACE> and <REPLACE> . <START> ", ' there is <REPLACE> not enough room in that part of the <REPLACE> . <START> ', ' so ba