In [1]:
import json
import os

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import T5Tokenizer, MT5ForConditionalGeneration

from preprocess_utils import get_highlighted_subtable, linearize_subtable

In [3]:
device=torch.device('cpu')
batch_size=12 # 10 for 't5-large'

In [4]:
# Pre-Trained T5 Tokenizer
tokenizer=T5Tokenizer.from_pretrained('google/mt5-base')
# Add Special Tokens: Table Tags
tokenizer.add_special_tokens({
    'additional_special_tokens': [
        '|',
        ':'
    ]
})
# Pre-Trained T5 Model
pretrained = MT5ForConditionalGeneration.from_pretrained('google/mt5-base').to(device)
# Resize PLM's Embedding Layer
pretrained.resize_token_embeddings(len(tokenizer))
# Freeze LM
for param in pretrained.parameters():
    param.requires_grad=False

In [23]:
class WebNLGDataset(Dataset):
    
    def __init__(self, tokenizer, raw_path='../webnlg_data/release_v3.0/ru', language='en', data_path='../webnlg_data/preprocessed', split='train'):
        
        if not os.path.exists(f'{data_path}/{split}.json'):
            b = Benchmark()
            files = select_files(raw_path)
            b.fill_benchmark(files)
            b.b2json(data_path, f'{split}.json')
        
        with open(f'{data_path}/{split}.json', 'r') as f:
            dataset = json.load(f)
            entries = dataset['entries']

        full_rela_lst = []
        full_src_lst = []
        full_tgt_lst = []
        for i, entry in enumerate(entries):
            sents = entry[str(i + 1)]['lexicalisations']
            triples = entry[str(i + 1)]['modifiedtripleset']
            
            rela_lst = []
            temp_triples = ''
            for j, tripleset in enumerate(triples):
                subj, rela, obj = tripleset['subject'], tripleset['property'], tripleset['object']
                rela_lst.append(rela)
                temp_triples += ' | '
                temp_triples += '{} : {} : {}'.format(subj, rela, obj)

            for sent in sents:
                if sent["lang"] == language:
                    full_tgt_lst.append(sent["lex"])
                    full_src_lst.append(temp_triples)
                    full_rela_lst.append(rela_lst)
                    if split == 'dev':
                        break
            
        assert len(full_rela_lst) == len(full_src_lst)
        assert len(full_rela_lst) == len(full_tgt_lst)

        self.examples = []
        self.targets = []
        for src, tgt in zip(full_src_lst, full_tgt_lst):
            src = tokenizer.encode(src)
            if len(src)>512:
                # Truncate
                encoded = src[:511] + [tokenizer.eos_token_id]
            self.examples.append(src)
    
            tgt = tokenizer.encode(tgt)
            self.targets.append(tgt)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


In [28]:
def collate_fn(batch):
    """
    Same Sequence Length on Same Batch
    """
    max_len_data=0
    for data in batch:
        if len(data)>max_len_data: max_len_data=len(data)
            
    datas=[]
    attn_masks=[]
    for data in batch:
        data.extend([tokenizer.pad_token_id]*(max_len_data-len(data)))
        datas.append(data)
        
        attn_mask=[int(e!=tokenizer.pad_token_id) for e in data]
        attn_masks.append(attn_mask)
        
    return torch.tensor(datas), torch.tensor(attn_masks)

In [29]:
dataset_dev = WebNLGDataset(tokenizer=tokenizer, split='dev')
dataloader_dev = DataLoader(dataset_dev, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [30]:
class PromptTuning(nn.Module):
    """
    """
    def __init__(self, pretrained_config, prompt_len=20, hidden_dim=256):
        super().__init__()
        
        # Config of Pre-Trained LM
        self.pretrained_config=pretrained_config
        
        # torch.tensor([0, 1, 2, .. , prompt_len-1])
        self.pre_prompt=torch.arange(prompt_len)
        # Embedding
        self.embd=nn.Embedding(num_embeddings=prompt_len, embedding_dim=pretrained_config.d_model)
        # Reparameterization
        self.reparam=nn.Sequential(
            nn.Linear(pretrained_config.d_model, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, pretrained_config.d_model)
        )
        
    def forward(self, batch_size, device):
        # Shape: batch_size, prompt_len
        prompt=self.pre_prompt.unsqueeze(0).expand(batch_size, -1).to(device)
        # Shape: batch_size, prompt_len, d_model
        prompt=self.embd(prompt)
        # Shape: batch_size, prompt_len, d_model
        prompt=self.reparam(prompt)
        
        return prompt

In [None]:
# PLM (Eval Mode)
pretrained.eval()

# Trained Model
model=torch.load('../models/MT5-base_Prompt-Tuning_prompt-len100_hidden-dim768_lr0.3_batch8_epoch1of20.pt')
model=model.to(device)
model.eval()

# Generation
if os.path.exists('../totto_data/generation_dev_prompt_tuning.txt'):
    os.remove('../totto_data/generation_dev_prompt_tuning.txt')
f=open('../totto_data/generation_dev_prompt_tuning.txt', 'a')

with torch.no_grad():
    for idx, (data, attn_mask) in enumerate(dataloader_dev):
        if (idx+1)%100==0: print(batch_size*(idx+1), 'generated')
            
        data=data.to(device)
        attn_mask=attn_mask.to(device)
        
        # Get Prompt
        prompt=model(batch_size=data.shape[0], device=device)
        
        # Beam Search
        outputs=pretrained.generate(
            data,
            max_length=300,
            num_beams=5,
            early_stopping=True,
            prompt=prompt
        )
        
        for generation in tokenizer.batch_decode(outputs, skip_special_tokens=True):
            f.write(generation+'\n')
            
f.close()

In [None]:
# Evaluation
!cd ../language_repo/ && bash language/totto/totto_eval.sh --prediction_path ../totto_data/generation_dev_prompt_tuning.txt --target_path ../totto_data/totto_dev_data.jsonl