In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.modeling_bart import (
    PretrainedBartModel,  
    LayerNorm, 
    EncoderLayer, 
    DecoderLayer, 
    LearnedPositionalEmbedding,
    _prepare_bart_decoder_inputs,
    _make_linear_from_emb
)

import os, argparse, pickle, h5py
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

from utils import Timer, make_path, deleaf
from pprint import pprint
from tqdm import tqdm
from transformers import BartTokenizer, BartConfig, BartModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from tqdm import tqdm
from utils import *

def prepare_dataset(para_data, tokenizer, num, synt_vocab):
    max_sent_len = 40
    max_synt_len = 160
    sents1 = list(para_data['sents1'][:num])
    synts1 = list(para_data['synts1'][:num])
    sents2 = list(para_data['sents2'][:num])
    synts2 = list(para_data['synts2'][:num])

    sent1_token_ids = torch.ones((num, max_sent_len+2), dtype=torch.long) 
    sent2_token_ids = torch.ones((num, max_sent_len+2), dtype=torch.long)    		
    synt1_token_ids = torch.ones((num, max_synt_len+2), dtype=torch.long) 
    synt2_token_ids = torch.ones((num, max_synt_len+2), dtype=torch.long)
    synt1_bow = torch.ones((num, 74))
    synt2_bow = torch.ones((num, 74))
        
    bsz = 64
    
    for i in tqdm(range(0, num, bsz)):
        sent1_inputs = tokenizer([s.decode() for s in sents1[i:i+bsz]], padding='max_length', truncation=True, max_length=max_sent_len+2, return_tensors="pt")
        sent2_inputs = tokenizer([s.decode() for s in sents2[i:i+bsz]], padding='max_length', truncation=True, max_length=max_sent_len+2, return_tensors="pt")
        sent1_token_ids[i:i+bsz] = sent1_inputs['input_ids']
        sent2_token_ids[i:i+bsz] = sent2_inputs['input_ids']

    for i in tqdm(range(num)):
        synt1 = ['<s>'] + deleaf(synts1[i].decode()) + ['</s>']
        synt1_token_ids[i, :len(synt1)] = torch.tensor([synt_vocab[tag] for tag in synt1])[:max_synt_len+2]
        synt2 = ['<s>'] + deleaf(synts2[i].decode()) + ['</s>']
        synt2_token_ids[i, :len(synt2)] = torch.tensor([synt_vocab[tag] for tag in synt2])[:max_synt_len+2]
        
        for tag in synt1:
            if tag != '<s>' and tag != '</s>':
                synt1_bow[i][synt_vocab[tag]-3] += 1
        for tag in synt2:
            if tag != '<s>' and tag != '</s>':
                synt2_bow[i][synt_vocab[tag]-3] += 1

    synt1_bow /= synt1_bow.sum(1, keepdim=True)
    synt2_bow /= synt2_bow.sum(1, keepdim=True)
    
    sum = 0
    for i in range(num):
        if torch.equal(synt1_bow[i], synt2_bow[i]):
            sum += 1

    return {'sent1':sent1_token_ids, 'sent2':sent2_token_ids, 'synt1': synt1_token_ids, 'synt2': synt2_token_ids,
            'synt1bow': synt1_bow, 'synt2bow': synt2_bow}

In [3]:
import h5py, os
print("==== loading data ====")
mrpc_set = h5py.File(os.path.join('./test_data/test_data_mrpc.h5'), 'r')
mrpc_set.keys()

==== loading data ====


<KeysViewHDF5 ['sents1', 'sents2', 'synts1', 'synts2']>

In [4]:
from transformers import BartTokenizer, BartConfig, BartModel
from utils import Timer, make_path, deleaf
import pickle

print("==== preparing data ====")
make_path('./bart-base/')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir='./bart-base/')

with open('synt_vocab.pkl', 'rb') as f:
    synt_vocab = pickle.load(f)

==== preparing data ====


In [5]:
# Get the vocabulary from the tokenizer
vocab = tokenizer.get_vocab()

In [6]:
num = 1920
dataset = prepare_dataset(mrpc_set, tokenizer, num, synt_vocab)
dataset.keys()

100%|██████████| 30/30 [00:00<00:00, 40.75it/s]
100%|██████████| 1920/1920 [00:03<00:00, 507.78it/s]


dict_keys(['sent1', 'sent2', 'synt1', 'synt2', 'synt1bow', 'synt2bow'])

In [7]:
import random
test_idxs = random.sample(range(0, 1920), 1920)

In [8]:
test_loader = DataLoader(test_idxs, batch_size=16, shuffle=False)

## 3. Model

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.modeling_bart import (
    PretrainedBartModel,  
    LayerNorm, 
    EncoderLayer, 
    DecoderLayer, 
    LearnedPositionalEmbedding,
    _prepare_bart_decoder_inputs,
    _make_linear_from_emb
)

class ParaBart(PretrainedBartModel):
    def __init__(self, config):
        super().__init__(config)
        
        self.shared = nn.Embedding(config.vocab_size, config.d_model, config.pad_token_id)

        self.encoder = ParaBartEncoder(config, self.shared)
        self.decoder = ParaBartDecoder(config, self.shared)
                
        self.linear = nn.Linear(config.d_model, config.vocab_size)
        
        self.adversary = Discriminator(config)
        
        self.init_weights()

    def forward(
        self,
        input_ids,      
        decoder_input_ids,
        attention_mask=None,
        decoder_padding_mask=None,
        encoder_outputs=None,
        return_encoder_outputs=False,
    ):
        if attention_mask is None:
            attention_mask = input_ids == self.config.pad_token_id
        
        if encoder_outputs is None:
            encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)
            
        if return_encoder_outputs:
            return encoder_outputs
        
        assert encoder_outputs is not None
        assert decoder_input_ids is not None

        decoder_input_ids = decoder_input_ids[:, :-1]
                
        _, decoder_padding_mask, decoder_causal_mask = _prepare_bart_decoder_inputs(
            self.config,
            input_ids=None,
            decoder_input_ids=decoder_input_ids,
            decoder_padding_mask=decoder_padding_mask,
            causal_mask_dtype=self.shared.weight.dtype,
        )    

        attention_mask2 = torch.cat((torch.zeros(input_ids.shape[0], 1).bool().cuda(), attention_mask[:, self.config.max_sent_len+2:]), dim=1)
           
        # decoder
        decoder_outputs = self.decoder(
            decoder_input_ids,
            torch.cat((encoder_outputs[1], encoder_outputs[0][:, self.config.max_sent_len+2:]), dim=1),           
            decoder_padding_mask=decoder_padding_mask,
            decoder_causal_mask=decoder_causal_mask,
            encoder_attention_mask=attention_mask2,
        )[0]
        
       
        batch_size = decoder_outputs.shape[0]
        outputs = self.linear(decoder_outputs.contiguous().view(-1, self.config.d_model))
        outputs = outputs.view(batch_size, -1, self.config.vocab_size)
        
        # discriminator
        for p in self.adversary.parameters():
            p.required_grad=False
        adv_outputs = self.adversary(encoder_outputs[1])        
        
        return outputs, adv_outputs
    
    def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
        assert past is not None, "past has to be defined for encoder_outputs"

        encoder_outputs = past[0]
        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "decoder_input_ids": torch.cat((decoder_input_ids, torch.zeros((decoder_input_ids.shape[0], 1), dtype=torch.long).cuda()), 1),
            "attention_mask": attention_mask,
        }

    def get_encoder(self):
        return self.encoder

    def get_output_embeddings(self):
        return _make_linear_from_emb(self.shared)
    
    def get_input_embeddings(self):
        return self.shared
    
    @staticmethod
    def _reorder_cache(past, beam_idx):
        enc_out = past[0][0]

        new_enc_out = enc_out.index_select(0, beam_idx)

        past = ((new_enc_out, ), )
        return past

    def forward_adv(
        self,
        input_token_ids,      
        attention_mask=None,
        decoder_padding_mask=None
    ):
        for p in self.adversary.parameters():
            p.required_grad=True
        sent_embeds = self.encoder.embed(input_token_ids, attention_mask=attention_mask).detach()
        adv_outputs = self.adversary(sent_embeds)

        return adv_outputs

    def generate(self, input_ids, decoder_input_ids, attention_mask=None,decoder_padding_mask=None,
                 encoder_outputs=None,return_encoder_outputs=False, 
                 max_len = 40, sample=True, temp=0.5):
        
        max_targ_len = decoder_input_ids.size(1) - 2
        batch_size   = decoder_input_ids.size(0)
        # output index starts with <sos>
        idxs = torch.zeros((batch_size, max_targ_len +2), dtype=torch.long).cuda()
        idxs[:, 0] = 0

        if attention_mask is None:
            attention_mask = input_ids == self.config.pad_token_id
        
        if encoder_outputs is None:
            encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)
            
        if return_encoder_outputs:
            return encoder_outputs
        
        assert encoder_outputs is not None
        assert decoder_input_ids is not None

        # decoder_input_ids = decoder_input_ids[:, :1]
                
        # _, decoder_padding_mask, decoder_causal_mask = _prepare_bart_decoder_inputs(
        #     self.config,
        #     input_ids=None,
        #     decoder_input_ids=decoder_input_ids,
        #     decoder_padding_mask=decoder_padding_mask,
        #     causal_mask_dtype=self.shared.weight.dtype,
        # )    

        attention_mask2 = torch.cat((torch.zeros(input_ids.shape[0], 1).bool().cuda(), attention_mask[:, self.config.max_sent_len+2:]), dim=1)
           
        # decoder
        decoder_outputs = self.decoder(
            idxs[:, :1],
            torch.cat((encoder_outputs[1], encoder_outputs[0][:, self.config.max_sent_len+2:]), dim=1),           
            decoder_padding_mask = decoder_padding_mask,
            decoder_causal_mask = None,
            encoder_attention_mask = attention_mask2,
        )[0].transpose(0,1)
        # print('decoder_outputs',decoder_outputs.shape)
        
        # output index starts with <sos>
        idxs = torch.zeros((batch_size, max_targ_len+2), dtype=torch.long).to(self.device)
        idxs[:, 0] = 1

        # auto-regressively generate output
        for i in range(1, max_targ_len+2):
            batch_size = decoder_outputs.shape[0]
            outputs = self.linear(decoder_outputs[-1].contiguous().view(-1, self.config.d_model))
            # outputs = outputs.view(batch_size, -1, self.config.vocab_size)
            # print('outputs', outputs.shape)
            # get argmax index or sample index
            if not sample:
                values, idx = torch.max(outputs, 1)
            else:
                probs = F.softmax(outputs/temp, dim=1)
                idx = torch.multinomial(probs, 1).squeeze(1)
                # print('idx',idx.shape)
            # save to output index
            idxs[:, i] = idx   
      
            attention_mask2 = torch.cat((torch.zeros(input_ids.shape[0], 1).bool().cuda(), attention_mask[:, self.config.max_sent_len+2:]), dim=1)
           
            # decoder
            decoder_outputs = self.decoder(
                idxs[:, :i+1],
                torch.cat((encoder_outputs[1], encoder_outputs[0][:, self.config.max_sent_len+2:]), dim=1),           
                decoder_padding_mask = decoder_padding_mask,
                decoder_causal_mask =  None,
                encoder_attention_mask = attention_mask2,
            )[0].transpose(0,1)
            # print('decoder_outputs',decoder_outputs.shape)
        
        return idxs[:, 1:]

class ParaBartEncoder(nn.Module):
    def __init__(self, config, embed_tokens):
        super().__init__()
        self.config = config

        self.dropout = config.dropout
        self.embed_tokens = embed_tokens
                
        self.embed_synt = nn.Embedding(77, config.d_model, config.pad_token_id)       
        self.embed_synt.weight.data.normal_(mean=0.0, std=config.init_std)
        self.embed_synt.weight.data[config.pad_token_id].zero_()

        self.embed_positions = LearnedPositionalEmbedding(
            config.max_position_embeddings, config.d_model, config.pad_token_id, config.extra_pos_embeddings
        )
        
        self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
        self.synt_layers = nn.ModuleList([EncoderLayer(config) for _ in range(1)])

        self.layernorm_embedding = LayerNorm(config.d_model) 

        self.synt_layernorm_embedding = LayerNorm(config.d_model)
        
        self.pooling = MeanPooling(config)
        

    def forward(self, input_ids, attention_mask): 
        
        input_token_ids, input_synt_ids = torch.split(input_ids, [self.config.max_sent_len+2, self.config.max_synt_len+2], dim=1)
        input_token_mask, input_synt_mask = torch.split(attention_mask, [self.config.max_sent_len+2, self.config.max_synt_len+2], dim=1)
        
        x = self.forward_token(input_token_ids, input_token_mask)
        y = self.forward_synt(input_synt_ids, input_synt_mask)
                
        encoder_outputs = torch.cat((x,y), dim=1)

        sent_embeds = self.pooling(x, input_token_ids)

        return encoder_outputs, sent_embeds
    
    def forward_token(self, input_token_ids, attention_mask):
        if self.training:
            drop_mask = torch.bernoulli(self.config.word_dropout*torch.ones(input_token_ids.shape)).bool().cuda()
            input_token_ids = input_token_ids.masked_fill(drop_mask, 50264)
               
        input_token_embeds = self.embed_tokens(input_token_ids) + self.embed_positions(input_token_ids)
        x = self.layernorm_embedding(input_token_embeds)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        x = x.transpose(0, 1)
        
        for encoder_layer in self.layers:
            x, _ = encoder_layer(x, encoder_padding_mask=attention_mask)
            
        x = x.transpose(0, 1)
        return x
        
    def forward_synt(self, input_synt_ids, attention_mask):
        input_synt_embeds = self.embed_synt(input_synt_ids) + self.embed_positions(input_synt_ids)        
        y = self.synt_layernorm_embedding(input_synt_embeds)        
        y = F.dropout(y, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        y = y.transpose(0, 1)
            
        for encoder_synt_layer in self.synt_layers:
            y, _ = encoder_synt_layer(y, encoder_padding_mask=attention_mask)

        # T x B x C -> B x T x C
        y = y.transpose(0, 1)
        return y
        

    def embed(self, input_token_ids, attention_mask=None, pool='mean'):
        if attention_mask is None:
            attention_mask = input_token_ids == self.config.pad_token_id
            
        x = self.forward_token(input_token_ids, attention_mask)
        
        sent_embeds = self.pooling(x, input_token_ids)
        return sent_embeds
            
class MeanPooling(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
    def forward(self, x, input_token_ids):
        mask = input_token_ids != self.config.pad_token_id
        mean_mask = mask.float()/mask.float().sum(1, keepdim=True)
        x = (x*mean_mask.unsqueeze(2)).sum(1, keepdim=True)
        return x


class ParaBartDecoder(nn.Module):
    def __init__(self, config, embed_tokens):
        super().__init__()
        
        self.dropout = config.dropout
        
        self.embed_tokens = embed_tokens
        
        self.embed_positions = LearnedPositionalEmbedding(
            config.max_position_embeddings, config.d_model, config.pad_token_id, config.extra_pos_embeddings
        )
        
        self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(1)]) 
        self.layernorm_embedding = LayerNorm(config.d_model)

    def forward(
        self, 
        decoder_input_ids, 
        encoder_hidden_states,  
        decoder_padding_mask, 
        decoder_causal_mask,  
        encoder_attention_mask
    ):        
		
        x = self.embed_tokens(decoder_input_ids) + self.embed_positions(decoder_input_ids)
        x = self.layernorm_embedding(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = x.transpose(0, 1)
        encoder_hidden_states = encoder_hidden_states.transpose(0, 1)

        for idx, decoder_layer in enumerate(self.layers):
            x, _, _ = decoder_layer(
                x, 
                encoder_hidden_states,
                encoder_attn_mask=encoder_attention_mask,
                decoder_padding_mask=decoder_padding_mask,
                causal_mask=decoder_causal_mask)

        x = x.transpose(0, 1)
        return x,
    
class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.sent_layernorm_embedding = LayerNorm(config.d_model, elementwise_affine=False)
        self.adv = nn.Linear(config.d_model, 74)
        
    def forward(self, sent_embeds):
        x = self.sent_layernorm_embedding(sent_embeds).squeeze(1)
        x = self.adv(x)
        return x
    

In [10]:
print("==== loading model ====")
config = BartConfig.from_pretrained('facebook/bart-base', cache_dir='./bart-base/')
config.word_dropout = 0.2
config.max_sent_len = 40
config.max_synt_len = 160

bart = BartModel.from_pretrained('facebook/bart-base', cache_dir='./bart-base/')
model = ParaBart(config)
# model.load_state_dict(bart.state_dict(), strict=False)
model.zero_grad()
del bart

==== loading model ====


In [11]:
idx2word = {v:k for k,v in vocab.items()}
span_vocab = vocab.copy()
span_vocab.pop('<s>')
span_vocab.pop('<pad>')
span_vocab.pop('</s>')
span_vocab.pop('<unk>')
span_vocab.pop('<mask>')

def reverse_bpe(sent):
    x = []
    cache = ''

    for w in sent:
        if w.startswith('Ġ'):
            cache += w.replace('Ġ', '')
            # cache = cache.strip()
        elif cache != '':
            x.append(cache + w)
            cache = ''
        else:
            x.append(w)

    return ' '.join(x)

def sent2str(sent, vocab):
    return " ".join([idx2word[i] for i in sent if i != vocab["<pad>"]])

def synt2str(synt, vocab):
    eos_pos = np.where(synt==vocab["</s>"])[0]
    eos_pos = eos_pos[0] if len(eos_pos) > 0 else len(synt)
    return " ".join([idx2word[i][1:-1] if i in span_vocab.values() else idx2word[i] for i in synt[:eos_pos]])

In [12]:
print("==== loading model ====")
config = BartConfig.from_pretrained('facebook/bart-base', cache_dir="./bart-base/")
embed_model = ParaBart(config)
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base', cache_dir="./bart-base/")
model.load_state_dict(torch.load(os.path.join("./model/model.pt"), map_location='cpu'))
model = model.cuda()

==== loading model ====


In [13]:
def sent2str(sent,something=None):
    return " ".join([idx2word[i].replace('Ġ','') for i in sent if i != vocab["<pad>"]]).split('</s>')[0]


def generate(model, loader, vocab_transform):
    #turn off dropout (and batch norm if used)
    model.eval()
    epoch_loss = 0
    with open("./eval_via2/target_sents_parabart.txt", "w") as target_sent, \
         open("./eval_via2/target_synts_parabart.txt", "w") as syntax_keep, \
         open("./eval_via2/outputs_parabart.txt", "w") as output_sentence,\
         open("./eval_via2/inputs_parabart.txt", "w") as input_sentences:
        with torch.no_grad():
            for idxs in tqdm(loader):
                
                sent1_token_ids = dataset['sent1'][idxs].cuda()
                synt1_token_ids = dataset['synt1'][idxs].cuda()
                sent2_token_ids = dataset['sent2'][idxs].cuda()
                synt2_token_ids = dataset['synt2'][idxs].cuda()
                synt1_bow = dataset['synt1bow'][idxs].cuda()
                synt2_bow = dataset['synt2bow'][idxs].cuda()

                # generate
                idxs = model.generate(torch.cat((sent1_token_ids, synt2_token_ids),1), sent1_token_ids, temp=0.5)
                
                for sent, idx,sent2, synt,synt2 in zip(sent1_token_ids.cpu().numpy(), idxs.cpu().numpy(),sent2_token_ids.cpu().numpy(), synt1_token_ids.cpu().numpy(), synt2_token_ids.cpu().numpy()):
                    
                    convert_idx_out =sent2str(idx, None)
                    targetSent = sent2str(sent2[1:-1],None) 
                    inputSente = sent2str(sent[1:-1],None) 
                    input_sentences.write(inputSente+'\n')
                    target_sent.write(targetSent+'\n') 
                    output_sentence.write(convert_idx_out+'\n')

                    # fp1.write(convert_sent)
                    # fp2.write(convert_synt)
                    # fp3.write(convert_idx)
                    
                    # fp1.write(sent2str(sent, vocab_transform) +'\n')
                    # fp2.write(synt2str(synt[1:], vocab_transform)+'\n')
                    # fp3.write(reverse_bpe(synt2str(idx, vocab_transform).replace("<pad>", "")) +'\n')

In [14]:
generate(model, test_loader, vocab)

100%|██████████| 120/120 [00:21<00:00,  5.63it/s]


In [15]:
import numpy as np
from nltk.translate.bleu_score import sentence_bleu

def cal_bleu(hypothesis, reference, n):
    hypothesis = hypothesis.strip().split(' ')
    reference = reference.strip().split(' ')

    if n == 0:
        return sentence_bleu([reference], hypothesis)
    elif n == 1:
        weights = (1, 0, 0, 0)
    elif n == 2:
        weights = (0, 1, 0, 0)
    elif n == 3:
        weights = (0, 0, 1, 0)
    elif n == 4:
        weights = (0, 0, 0, 1)

    return sentence_bleu([reference], hypothesis, weights=weights)  

In [16]:
from tqdm import tqdm

with open('./eval_via2/target_sents_parabart.txt') as fp:
    targs = fp.readlines()
with open('./eval_via2/outputs_parabart.txt') as fp:
    preds = fp.readlines()
with open('./eval_via2/inputs_parabart.txt') as fp:
    inps = fp.readlines()

print(f"number of examples: {len(preds)} , {len(targs)}")

number of examples: 1920 , 1920


In [17]:
preds[4], targs[4]

('the terrorist activity will be awareness awareness awareness . \n',
 'the new name will be terrorism information awareness . \n')

In [24]:
scores0 = [cal_bleu(pred, targ, 1) for pred, targ in tqdm(zip(preds, targs))]
print(f"BLEU: {np.mean(scores0)*100.0}")

1920it [00:00, 7774.23it/s]

BLEU: 45.794625299054395





In [19]:
import numpy as np
from nltk.translate import meteor

def cal_meteor(hypothesis, reference):
    hypothesis = hypothesis.strip().split(' ')
    reference = reference.strip().split(' ')

    return meteor([reference], hypothesis)   

scoresm = [cal_meteor(pred, targ) for pred, targ in tqdm(zip(preds, targs))]
print(f"METHEO: {np.mean(scoresm)*100.0}")

1920it [00:02, 696.31it/s] 

METHEO: 40.84170222329394





In [20]:
from rouge import Rouge
scorer = Rouge()
scoresR = [scorer.get_scores(pred,refs= targ)[0]['rouge-1']['r'] for pred, targ in tqdm(zip(preds, targs))]
scoresP = [scorer.get_scores(pred,refs= targ)[0]['rouge-1']['p'] for pred, targ in tqdm(zip(preds, targs))]
scoresF = [scorer.get_scores(pred,refs= targ)[0]['rouge-1']['f'] for pred, targ in tqdm(zip(preds, targs))]

# # for ind,k in enumerate(scoresโร้ก):
print(f"Rouge-r: {np.mean(scoresR)*100.0}") 
print(f"Rouge-p: {np.mean(scoresP)*100.0}") 
print(f"Rouge-f: {np.mean(scoresF)*100.0}") 

1920it [00:00, 2442.85it/s]
1920it [00:00, 2505.09it/s]
1920it [00:00, 2494.01it/s]

Rouge-r: 45.352399615722454
Rouge-p: 49.067626041670145
Rouge-f: 46.99901940638212



