In [1]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [2]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import re

import pandas as pd
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, PegasusConfig
import matplotlib.pyplot as plt
from rouge_score import rouge_scorer


In [3]:
#globals
MAX_LENGTH = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:
# Define data directory
import os
DATA_DIR = "./data/aan_data"

#load our 10k data into a dataframe
papers = []
filenames = [] #keep a reference for later transformations

abstracts2, bodys2, cits2, files2 = [],[],[],[]

for root, dirs, files in os.walk(DATA_DIR):
    for f in files:
        fn = root+"/"+f
        if "abstract" in fn:            
            in_file = open(fn, 'r')
            file = in_file.readlines()

            new_list = [''.join(file[i*4:(i+1)*4]) for i in range(int(len(file)/4))]
            list_no_n = [item.replace('\n','').replace('- ','').replace("\'", "") for item in new_list]

            string = ''
            for item in list_no_n:
                string = string + item            
            abstracts2.append(string)
            files2.append(f.split('_')[1].split('.')[0])
            
        elif "body" in fn:
            in_file = open(fn, 'r')
            file = in_file.readlines()

            new_list = [''.join(file[i*4:(i+1)*4]) for i in range(int(len(file)/4))]
            
            list_no_n = [item.replace('\n','').replace('- ','').replace("\'", "") for item in new_list]
            #list_no_n = [item for item in new_list]
            string = ''
            for item in list_no_n:
                string = string + item            
            bodys2.append(string)
        
        elif "CS" in fn:
            in_file = open(fn, 'r')
            file = in_file.readlines()
            string = [''.join(file) for i in file]
            if len(string) > 0: 
                string = [''.join(file) for i in file][0].replace('\n','').replace('- ','').replace("\'", "")
            cits2.append(string)
            
        else:
            pass
        
        
        
        
    #    with open(fn) as jsonfile:
    #        d = json.load(jsonfile)
    #    papers.append(d)
    #    filenames.append(f)

    
    
df2 = pd.DataFrame({'paper_id':files2,'abstract':abstracts2,'body':bodys2,'citations':cits2})
df2
        

Unnamed: 0,paper_id,abstract,body,citations
0,A00-1004,A major obstacle to the construction ofa pro...,Parallel texts have been used in a number of...,A compilation of parallel texts offered in a s...
1,A00-1006,This paper proposes a way to improve the tran...,"Recently, various dialogue translation syste...",We plan to improve the accuracy obtained so fa...
2,A00-1008,This paper describes an application of APE (t...,The purpose of the Atlas project is to enlarg...,Computer dialogue is now used at production st...
3,A00-1009,In this paper we describe an implemented fram...,In this paper we present a linguistically mot...,"From this viewpoint, research on paraphrasing ..."
4,A00-1011,"This paper reports on a large-scale, end-toen...",One major goal of information extraction (IE)...,"For example, Aone and Ramos-Santacruz (2000) p..."
...,...,...,...,...
4268,W99-0904,We present in this paper an unsupervised meth...,Development of electronic morphological resou...,Next along the spectrum of orthographic simila...
4269,W99-0905,This paper presents an unsupervised method fo...,Choosing the correct translation of a content...,"For comparable corpora, previous bilingual sen..."
4270,W99-0909,In this paper we report on an unsupervised a...,In this paper we discuss a potential solutio...,Watkinson and Manandhar (1999) present an unsu...
4271,X93-1018,This paper presents results from a study comp...,In evaluating the state of technology for ext...,"It may be noted that ""correctly"" is a problema..."


In [6]:
file1 = open("cited_spans_full1.txt","r+") 
  
text = file1.read()
text = text.split('\n')

paper_ids, cts = [], []
for i in text:
    if i != '':
        p = i.split('@@')
        paper_ids.append(p[0])
        cts.append(p[1])

dfcts = pd.DataFrame({'paper_id':paper_ids,
             'cited text spans':cts})

df = df2.merge(dfcts, on='paper_id')
df

Unnamed: 0,paper_id,abstract,body,citations,cited text spans
0,A00-1004,A major obstacle to the construction ofa pro...,Parallel texts have been used in a number of...,A compilation of parallel texts offered in a s...,Parallel texts have been used in a number of...
1,A00-1006,This paper proposes a way to improve the tran...,"Recently, various dialogue translation syste...",We plan to improve the accuracy obtained so fa...,If we want to make a conversation proceed smoo...
2,A00-1008,This paper describes an application of APE (t...,The purpose of the Atlas project is to enlarg...,Computer dialogue is now used at production st...,The purpose of the Atlas project is to enlarg...
3,A00-1009,In this paper we describe an implemented fram...,In this paper we present a linguistically mot...,"From this viewpoint, research on paraphrasing ...",In this paper we present a linguistically mot...
4,A00-1011,"This paper reports on a large-scale, end-toen...",One major goal of information extraction (IE)...,"For example, Aone and Ramos-Santacruz (2000) p...",One major goal of information extraction (IE)...
...,...,...,...,...,...
4203,W99-0904,We present in this paper an unsupervised meth...,Development of electronic morphological resou...,Next along the spectrum of orthographic simila...,Development of electronic morphological resou...
4204,W99-0905,This paper presents an unsupervised method fo...,Choosing the correct translation of a content...,"For comparable corpora, previous bilingual sen...",Choosing the correct translation of a content...
4205,W99-0909,In this paper we report on an unsupervised a...,In this paper we discuss a potential solutio...,Watkinson and Manandhar (1999) present an unsu...,In this paper we discuss a potential solutio...
4206,X93-1018,This paper presents results from a study comp...,In evaluating the state of technology for ext...,"It may be noted that ""correctly"" is a problema...",In evaluating the state of technology for ext...


In [7]:
#dataset and tokenizer building
#load our 10k data into a dataframe
#limit = 10
#papers = []
#for root, dirs, files in os.walk("./data/mini_10k"):
#    for f in files:
#        fn = root+"/"+f
#        with open(fn) as jsonfile:
#            d = json.load(jsonfile)
#        papers.append(d)
#        
#        if len(papers) >= limit:
#            break
#    if len(papers) >= limit:
#        break
#df = pd.DataFrame(papers)

In [8]:
#load our rouge scorer
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2'], use_stemmer=True)

In [9]:
#load our pretrained model
model_name = 'google/pegasus-large'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
config = PegasusConfig.from_pretrained(model_name, output_hidden_states=True, output_attentions=True)  
pt_model = PegasusForConditionalGeneration.from_pretrained(model_name, config=config).to(device)

In [11]:
#example batch (size 1)
batch = tokenizer(df.body[3], truncation=True, padding='longest', return_tensors="pt").to(device)
print(batch.keys())

dict_keys(['input_ids', 'attention_mask'])


In [12]:
#example pretrained generation with keys
out = pt_model.generate(return_dict_in_generate=True, **batch)
print(out.keys())
print(len(out["encoder_hidden_states"]))

odict_keys(['sequences', 'encoder_attentions', 'encoder_hidden_states', 'decoder_attentions', 'cross_attentions', 'decoder_hidden_states'])
17


In [20]:
#example batch (size 1)
batch2 = tokenizer(df.citations[3], truncation=True, padding='longest', return_tensors="pt").to(device)
print(batch2.keys())
out2 = pt_model.generate(return_dict_in_generate=True, **batch2)

#example batch (size 1)
batch3 = tokenizer(df['cited text spans'][3], truncation=True, padding='longest', return_tensors="pt").to(device)
print(batch3.keys())
out3 = pt_model.generate(return_dict_in_generate=True, **batch3)

dict_keys(['input_ids', 'attention_mask'])
dict_keys(['input_ids', 'attention_mask'])


In [33]:
class AttentionAttention(nn.Module):
    def __init__(self,  vocab_size=32000, input_size=1024, target_size=256):
        super(AttentionAttention, self).__init__()
        
        #edit this to manipulate the network:
        #attn head1
        self.ah1_1 = nn.Linear(input_size, target_size)
        self.ah1_2 = nn.Linear(target_size, target_size//2)
        self.ah1_3 = nn.Linear(target_size//2, target_size//4)
        #self.ah1_4 = nn.Linear(target_size//4, target_size//8)
        
        #attn head1
        self.ah2_1 = nn.Linear(input_size, target_size)
        self.ah2_2 = nn.Linear(target_size, target_size//2)
        self.ah2_3 = nn.Linear(target_size//2, target_size//4)
        
        #embedding head
        self.ah3_1 = nn.Linear(input_size, target_size)
        self.ah3_2 = nn.Linear(target_size, target_size//2)
        self.ah3_3 = nn.Linear(target_size//2, target_size//4)
        
        #compression head 
        
        
        #output head
        self.fc_out = nn.Linear(target_size//4, vocab_size)
        self.sm = nn.Softmax(dim=1)
        
    def forward(self, out, citations, cts):
        
        #initialize a random tensor as our 'shallow' attn
        shallow_attn = torch.rand((1024,1024), requires_grad=True).to(device)
        shallow_attn2 = torch.rand((1024,1024), requires_grad=True).to(device)
        shallow_attn3 = torch.rand((1024,1024), requires_grad=True).to(device)
        
        #we're going to focus on the first N^2 attn layers
        num_layers = 2
        num_attentions = 2
        
        for i,attn in enumerate(out["encoder_attentions"]):
            if i >= num_layers:
                break
            
            for j,block in enumerate(attn[0]):
                attn = torch.tensor(block).to(device)
                
                #add our attention to the noise mask
                shallow_attn = shallow_attn.add(attn)
                
                #edit this to manipulate the attention
                #manipulate attention
                #shallow_attn = torch.einsum("ab,cd->bc", shallow_attn, attn)
                #shallow_attn = torch.einsum("ab,cd->ad", shallow_attn, attn)
                
                if j >= num_attentions:
                    break
                    
        for i,attn2 in enumerate(citations["encoder_attentions"]):
            if i >= num_layers:
                break
            
            for j,block in enumerate(attn[0]):
                attn2 = torch.tensor(block).to(device)
                
                #add our attention to the noise mask
                shallow_attn2 = shallow_attn2.add(attn2)
                
                #edit this to manipulate the attention
                #manipulate attention
                #shallow_attn = torch.einsum("ab,cd->bc", shallow_attn, attn)
                #shallow_attn = torch.einsum("ab,cd->ad", shallow_attn, attn)
                
                if j >= num_attentions:
                    break
                    
    
        for i,attn3 in enumerate(cts["encoder_attentions"]):
            if i >= num_layers:
                break
            
            for j,block in enumerate(attn[0]):
                attn3 = torch.tensor(block).to(device)
                
                #add our attention to the noise mask
                shallow_attn3 = shallow_attn3.add(attn3)
                
                #edit this to manipulate the attention
                #manipulate attention
                #shallow_attn = torch.einsum("ab,cd->bc", shallow_attn, attn)
                #shallow_attn = torch.einsum("ab,cd->ad", shallow_attn, attn)
                
                if j >= num_attentions:
                    break
                    
        
        #values,indices = torch.sort(global_attn)
        
        #learn from shallow_attn
        
        x1 = F.relu(self.ah1_1(shallow_attn))
        x1 = F.relu(self.ah1_2(x1))
        x1 = F.relu(self.ah1_3(x1))
        #x1 = F.relu(self.ah1_4(x1))
        
        x2 = F.relu(self.ah2_1(shallow_attn2))
        x2 = F.relu(self.ah2_2(x2))
        x2 = F.relu(self.ah2_3(x2))
        
        #x1 = F.relu(self.ah1_4(x1))
        x3 = F.relu(self.ah3_1(shallow_attn3))
        x3 = F.relu(self.ah3_2(x3))
        x3 = F.relu(self.ah3_3(x3))
        #x1 = F.relu(self.ah1_4(x1))
        #cited head
        
        #learn from raw attn
#         raw_attn = out["encoder_attentions"][-1][0][-1]
#         x2 = F.relu(self.ah2_1(raw_attn))
#         x2 = F.relu(self.ah2_2(x2))
#         x2 = F.relu(self.ah2_3(x2))
        
#         #learn from embeds
#         last_embed = out["encoder_hidden_states"][-1][0]
#         x3 = F.relu(self.eh1_1(last_embed))
#         x3 = F.relu(self.eh1_2(x3))
#         x3 = F.relu(self.eh1_3(x3))
                
#         #concatenate all heads
#         #print(x1.shape, x2.shape, x3.shape)
        x_concat = torch.cat((x1,x2,x3), 1)
    
        
        #print(x_concat.shape)

        #x_input = h1+h2
        #print(x_input.shape)
        x = F.relu(self.fc_out(x1))
        
        x = self.sm(x)

        return x

In [34]:
aa = AttentionAttention(vocab_size=tokenizer.vocab_size).to(device)

In [35]:
o3 = aa.forward(out, out2, out3)
o3.shape

  attn = torch.tensor(block).to(device)
  attn2 = torch.tensor(block).to(device)
  attn3 = torch.tensor(block).to(device)


torch.Size([1024, 96103])

In [36]:

def summary2tensor(summary, batch_size=1, vocab_size=32000):
    z = torch.zeros(batch_size,vocab_size).to(device)
    for i,wid in enumerate(summary):   
        z[i][wid] = 1.0 
    return z


def pred2tensor(pred):
    ids = []
    for r in pred:
        idx = torch.argmax(r)
        ids.append(idx)
    return torch.tensor(ids)

In [37]:
#loss_fn = nn.CrossEntropyLoss()
loss_fn = torch.nn.MSELoss()

for param in aa.parameters():
    param.requires_grad = True

lr = 5.0 # learning rate
optimizer = torch.optim.SGD(aa.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
ntokens = tokenizer.vocab_size

for i in range(len(papers)):
    
    batch = tokenizer(df.fulltext[i], truncation=True, padding='longest', return_tensors="pt").to(device)
    
    out = pt_model.generate(return_dict_in_generate=True, **batch)
    
    cts_batch = tokenizer(df.cts[i], truncation=True, padding='longest', return_tensors="pt").to(device)
    
    citations_batch = tokenizer(df.citations[i], truncation=True, padding='longest', return_tensors="pt").to(device)

    
    
    try:
        pred = aa.forward(out,  cts_batch['input_ids'][0], citations_batch['input_ids'][0])
    except Exception as e:
        print(e)
        continue
    
    y = tokenizer(df.summary[i], truncation=True, padding='longest', return_tensors="pt").to(device)
    y = y["input_ids"]
    y = summary2tensor(y, batch_size=1024, vocab_size=tokenizer.vocab_size)
    #print(pred.shape, y.shape)
    #print(pred)
    loss = loss_fn(pred, y)
    print("{} {}".format(i, loss.item()))
    loss.backward()
    optimizer.step()
    
    

In [13]:
#tgt_text = tokenizer.batch_decode(ids, skip_special_tokens=True)
#print(" ".join(tgt_text))

In [38]:
#save state
torch.save(aa.state_dict(), "data/aa_derek_model.state")

#save params
torch.save(aa, "data/aa_derek_model.param")

In [39]:
def decode_output(pred):
    ids = []
    for x in pred:
        pred_id = torch.argmax(x)
        ids.append(pred_id)
    return ids

In [46]:
#example generate summary from fulltext
summary, fulltext, cits, cts1 = df.abstract[4], df.body[4], df.citations[4], df['cited text spans'][4]
batch = tokenizer(fulltext, truncation=True, padding='longest', return_tensors="pt").to(device)
out = pt_model.generate(return_dict_in_generate=True, **batch)


batch2 = tokenizer(cits, truncation=True, padding='longest', return_tensors="pt").to(device)
out2 = pt_model.generate(return_dict_in_generate=True, **batch2)
batch3 = tokenizer(cts1, truncation=True, padding='longest', return_tensors="pt").to(device)
out3 = pt_model.generate(return_dict_in_generate=True, **batch3)



pred = aa.forward(out, out2, out3)
ids = decode_output(pred)
tgt_text = tokenizer.batch_decode(ids, skip_special_tokens=True)
tgt_text = " ".join(tgt_text)
print(tgt_text)

  attn = torch.tensor(block).to(device)
  attn2 = torch.tensor(block).to(device)
  attn3 = torch.tensor(block).to(device)


Suffolk Suffolk $30.00 sold Suffolk Suffolk sai upbeat sold Suffolk Suffolk aisles Suffolk SEN Suffolk Suffolk Suffolk sold Suffolk Darn sold Suffolk sold sold upbeat Suffolk Suffolk sold trifle sold sold sold sold sold sold Suffolk sai sold Suffolk Suffolk Suffolk Suffolk Suffolk Darn Suffolk sai sold foal foal upbeat sold prorated Suffolk Suffolk sold sai Suffolk Darn Suffolk sold Suffolk trifle Suffolk Suffolk Converse sold upbeat sold Suffolk Suffolk sold Darn sold Darn Mix Darn Suffolk Suffolk Suffolk sold sold Converse Suffolk Suffolk Darn Suffolk sold sold Darn Suffolk Darn aisles sold sold sold Suffolk sold sold sai sold Suffolk sai Darn Suffolk Suffolk Suffolk Suffolk Suffolk constrained foal sold Suffolk Darn Suffolk Suffolk Suffolk Darn sold sold upbeat sold brooke Suffolk Suffolk SEN Suffolk foal sold sold upbeat sold Suffolk Suffolk Suffolk Suffolk Suffolk Suffolk Suffolk upbeat sold Suffolk sold foal Suffolk Suffolk Suffolk sold Suffolk Suffolk Suffolk sold Suffolk Darn Z

In [47]:
tgt_text = " ".join(tgt_text)

score = scorer.score(tgt_text, summary)
score

{'rouge1': Score(precision=0.07462686567164178, recall=0.0008947745168217609, fmeasure=0.0017683465959328027),
 'rouge2': Score(precision=0.0, recall=0.0, fmeasure=0.0)}