In [47]:
import json
import os

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from transformers import T5Tokenizer, MT5ForConditionalGeneration, AdamW, get_linear_schedule_with_warmup

# Google's Official Preprocess Codes
# https://github.com/google-research/language/blob/master/language/totto/baseline_preprocessing/preprocess_utils.py
from preprocess_utils import get_highlighted_subtable, linearize_subtable

In [48]:
# Train Config
device=torch.device('cpu')
lr=3e-1
batch_size=8 # 4(max 6) for 't5-large' and make 'accumulation_steps' larger
accumulation_steps=1
epochs=20

# Prompt Config
prompt_len=100
hidden_dim=768

In [49]:
# Pre-Trained T5 Tokenizer
tokenizer=T5Tokenizer.from_pretrained('google/mt5-base')
# Add Special Tokens: Table Tags
tokenizer.add_special_tokens({
    'additional_special_tokens': [
        '|',
        ':',
    ]
})
# Pre-Trained T5 Model
pretrained = MT5ForConditionalGeneration.from_pretrained('google/mt5-base').to(device)
# Resize PLM's Embedding Layer
pretrained.resize_token_embeddings(len(tokenizer))
# Freeze LM
for param in pretrained.parameters():
    param.requires_grad=False

In [65]:
class WebNLGDataset(Dataset):
    
    def __init__(self, tokenizer, raw_path='../webnlg_data/release_v3.0/ru', data_path='../webnlg_data/preprocessed', split='train'):
        
        if not os.path.exists(f'{data_path}/{split}.json'):
            b = Benchmark()
            files = select_files(raw_path)
            b.fill_benchmark(files)
            b.b2json(data_path, f'{split}.json')
        
        with open(f'{data_path}/{split}.json', 'r') as f:
            dataset = json.load(f)
            entries = dataset['entries']

        full_lan_list = []
        full_rela_lst = []
        full_src_lst = []
        full_tgt_lst = []
        for i, entry in enumerate(entries):
            sents = entry[str(i + 1)]['lexicalisations']
            triples = entry[str(i + 1)]['modifiedtripleset']
            
            rela_lst = []
            temp_triples = ''
            for j, tripleset in enumerate(triples):
                subj, rela, obj = tripleset['subject'], tripleset['property'], tripleset['object']
                rela_lst.append(rela)
                temp_triples += ' | '
                temp_triples += '{} : {} : {}'.format(subj, rela, obj)

            for sent in sents:
                full_tgt_lst.append(sent["lex"])
                full_src_lst.append(temp_triples)
                full_rela_lst.append(rela_lst)
                full_lan_list.append(sent["lang"])
                if split == 'dev':
                    break
            
        assert len(full_rela_lst) == len(full_src_lst)
        assert len(full_rela_lst) == len(full_tgt_lst)

        self.examples = []
        self.targets = []
        self.languages = []
        
        language_transform = {'en': 'english', 'ru': 'russian'}
        
        for src, tgt, lan in zip(full_src_lst, full_tgt_lst, full_lan_list):
            src = tokenizer.encode(src)
            if len(src)>512:
                # Truncate
                encoded = src[:511] + [tokenizer.eos_token_id]
            self.examples.append(src)
    
            tgt = tokenizer.encode(tgt)
            self.targets.append(tgt)
            
            self.languages.append(language_transform[lan])

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx], self.targets[idx], self.languages[idx]


In [105]:
def collate_fn(batch):
    """
    Same Sequence Length on Same Batch
    """
    max_len_data=0
    max_len_label=0
    for data, label, _ in batch:
        if len(data)>max_len_data: max_len_data=len(data)
        if len(label)>max_len_label: max_len_label=len(label)
            
    datas=[]
    attn_masks=[]
    labels=[]
    controls = []
    for data, label, control in batch:
        data.extend([tokenizer.pad_token_id]*(max_len_data-len(data)))
        datas.append(data)
        
        attn_mask=[int(e!=tokenizer.pad_token_id) for e in data]
        attn_masks.append(attn_mask)
        
        label.extend([-100]*(max_len_label-len(label)))
        labels.append(label)
        
        controls.append([control])
        
    return torch.tensor(datas), torch.tensor(attn_masks), torch.tensor(labels), controls

In [106]:
dataset_train = WebNLGDataset(tokenizer=tokenizer)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [109]:
class ControlPrefixes(nn.Module):
    """
    """
    def __init__(self, pretrained_config, control_config, prompt_len=20, hidden_dim=256):
        super().__init__()

        # Config of Base (Pre-Trained) LM
        self.pretrained_config = pretrained_config
        # Config of Control-Codes
        control_config['pad'] = 0
        self.control_config = control_config

        # Input: 0, 1, 2 ... prompt_len (General-Prefix)
        self.preseq=torch.arange(prompt_len)
        # Embedding for General-Prefix
        self.embd_general=nn.Embedding(prompt_len, pretrained_config.d_model)
        # Embedding for Control-Prefix
        self.embd_control=nn.Embedding(len(control_config), pretrained_config.d_model)
        # Reparam (Shared Between General-Prefix & Control-Prefix)
        self.reparam=nn.Sequential(
            nn.Linear(pretrained_config.d_model,hidden_dim),
            nn.Tanh(),
            nn.Tanh(),
            nn.Linear(hidden_dim, pretrained_config.d_model)
        )

    def forward(self, batch_size, control, device):
        # General-Prefix: batch_size, prompt_len
        preseq=self.preseq.unsqueeze(0).expand(batch_size,-1).to(device)
        # General-Prefix: batch_size, prompt_len, n_embd
        preseq=self.embd_general(preseq)

        # Control-Prefix
        preseq_control=[[self.control_config[code] for code in codes] for codes in control]
        print(preseq_control)
        # Control-Prefix: batch_size, len(control_codes)
        preseq_control=torch.tensor(preseq_control).to(device)
        # Control-Prefix: batch_size, len(control_codes), n_embd
        preseq_control=self.embd_control(preseq_control)

        # Merge: [Control-Prefix, General-Prefix]
        # batch_size, len(control_codes)+preseqlen, n_embd
        preseq=torch.cat((preseq_control,preseq), dim=1)
        # batch_size, len(control_codes)+preseqlen, 2*n_layer*n_embd
        prompt = preseq=self.reparam(preseq)

        return prompt

In [110]:
control_config = {'english': 1, 'russian': 2}

# Model: Prompt Tuning
model = ControlPrefixes(pretrained_config=pretrained.config, control_config=control_config, prompt_len=prompt_len, hidden_dim=hidden_dim)

# Optim, Scheduler
optimizer=AdamW(model.parameters(), lr=lr)
# NO Warm-Up
scheduler=get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=int(epochs*len(dataset_train)/(accumulation_steps*batch_size))
)

# TensorBoard: Logging
writer=SummaryWriter()
step_global=0

for epoch in range(epochs):
    # Train Phase
    model.train()
    model.to(device)
    
    loss_train=0
    optimizer.zero_grad()
    
    for step, (data, attn_mask, label, control) in enumerate(dataloader_train):
        data=data.to(device)
        attn_mask=attn_mask.to(device)
        label=label.to(device)
        
        prompt=model(batch_size=data.shape[0], control=control, device=device)
        outputs=pretrained(input_ids=data, attention_mask=attn_mask, labels=label, prompt=prompt)
        
        loss=outputs[0]/accumulation_steps
        loss.backward()
        
        loss_train+=loss.item()
        
        if (step+1)%accumulation_steps==0:
            step_global+=1
            
            # TensorBoard
            writer.add_scalar(
                f'loss_train/MT5-base_Prompt-Tuning_prompt-len{prompt_len}_hidden-dim{hidden_dim}_lr{lr}_batch{int(accumulation_steps*batch_size)}_epoch{epochs}',
                loss_train,
                step_global
            )
            # Console
            if step_global%1000==0:
                print(f'epoch {epoch+1} step {step_global} loss_train {loss_train:.4f}')
            # Set Loss to 0
            loss_train=0
            
            optimizer.step()
            scheduler.step()
            
            optimizer.zero_grad()
            
    # Save Model
    model.to(torch.device('cpu'))
    torch.save(model, f'../model/MT5-base_Control-Prefixes_prompt-len{prompt_len}_hidden-dim{hidden_dim}_lr{lr}_batch{int(accumulation_steps*batch_size)}_epoch{epoch+1}of{epochs}.pt')

[['russian'], ['english'], ['english'], ['russian'], ['english'], ['english'], ['russian'], ['russian']]
[[2], [1], [1], [2], [1], [1], [2], [2]]
[['russian'], ['russian'], ['english'], ['russian'], ['english'], ['russian'], ['russian'], ['russian']]
[[2], [2], [1], [2], [1], [2], [2], [2]]


KeyboardInterrupt: 