In [4]:
import random
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torch.nn as nn

from transformers import T5Tokenizer, T5ForConditionalGeneration, get_scheduler

import regex
import json
import string
from tqdm.auto import tqdm
import numpy as np
import collections

from torch.utils.tensorboard import SummaryWriter
%load_ext tensorboard

In [32]:
# MODEL_NAME = 't5-base'
# MODEL_NAME = "NlpHUST/t5-en-vi-small"
MODEL_NAME = "VietAI/vit5-base"
SOURCE_MAX_LEN = 256
TARGET_MAX_LEN = 64
BATCH_SIZE = 32


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [39]:
checkpoint_name = 'ViT5-UIT-ViQuAD-1'
checkpoint_path = f"checkpoints/{checkpoint_name}"
os.makedirs(checkpoint_path, exist_ok=True)

In [None]:
writer = SummaryWriter(f'runs/{checkpoint_name}')

# Dataset

In [7]:
class ODQADataset(Dataset):
    def __init__(self, data_path, n_context=1):
        self.data = []
        with open(data_path, 'r') as f:
            data = json.load(f)
            for qa in data:
                if len(qa['contexts'][:n_context]) != 0:
                    self.data.append(qa)
        self.n_context = n_context
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
       
        source = [f"question: {sample['question']} context: {ctx.replace('_', ' ')}" for ctx in sample['contexts'][:self.n_context]]
        target = random.choice(sample['answers'])['text']
        
        return {'index': idx,
                'source': source,
                'target': target}

In [8]:
class Collator:
    def __init__(self, tokenizer, source_max_length, target_max_length):
        self.tokenizer = tokenizer
        self.source_max_length = source_max_length
        self.target_max_length = target_max_length
    
    def __call__(self, batch):
        indices = torch.tensor([sample['index'] for sample in batch])
        
        sources = [sample['source'] for sample in batch]
        sources_encoding = self.__encode_sources_batch__(sources)
        sources_ids = sources_encoding['input_ids']
        sources_masks = sources_encoding['attention_mask'].bool()
        
        targets = [sample['target'] for sample in batch]
        targets_encoding = self.tokenizer.batch_encode_plus(
            targets, 
            max_length=self.target_max_length,
            add_special_tokens=True, 
            padding='max_length',
            truncation=True,
            return_tensors='pt'
            )
        
        targets_ids = targets_encoding['input_ids']
        targets_masks = targets_encoding['attention_mask'].bool()
        targets_ids = targets_ids.masked_fill(~targets_masks, -100)
        
        return indices, sources_ids, sources_masks, targets_ids, targets_masks
    
    def __encode_sources_batch__(self, sources_batch):
        input_ids, attention_mask = [], []
        for sources in sources_batch:
            encoding = self.tokenizer.batch_encode_plus(
                sources, 
                max_length=self.source_max_length,
                add_special_tokens=True, 
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            input_ids.append(encoding['input_ids'][None])
            attention_mask.append(encoding['attention_mask'][None])
        
        input_ids = torch.cat(input_ids, dim=0)
        attention_mask = torch.cat(attention_mask, dim=0)
        
        return {'input_ids': input_ids, 'attention_mask': attention_mask}

In [9]:
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
collate_fn = Collator(tokenizer, SOURCE_MAX_LEN, TARGET_MAX_LEN)

# Model

In [10]:
class CheckpointWrapper(torch.nn.Module):
    """
    Wrapper replacing None outputs by empty tensors, which allows the use of
    checkpointing.
    """
    def __init__(self, module, use_checkpoint=False):
        super().__init__()
        self.module = module
        self.use_checkpoint = use_checkpoint

    def forward(self, hidden_states, attention_mask, position_bias, **kwargs):
        if self.use_checkpoint and self.training:
            kwargs = {k: v for k, v in kwargs.items() if v is not None}
            def custom_forward(*inputs):
                output = self.module(*inputs, **kwargs)
                empty = torch.tensor(
                    [],
                    dtype=torch.float,
                    device=output[0].device,
                    requires_grad=True)
                output = tuple(x if x is not None else empty for x in output)
                return output

            output = torch.utils.checkpoint.checkpoint(
                custom_forward,
                hidden_states,
                attention_mask,
                position_bias
            )
            output = tuple(x if x.size() != 0 else None for x in output)
        else:
            output = self.module(hidden_states, attention_mask, position_bias, **kwargs)
        return output
    
def apply_checkpoint_wrapper(t5stack, use_checkpoint):
    """
    Wrap each block of the encoder to enable checkpointing.
    """
    block = []
    for mod in t5stack.block:
        wrapped_mod = CheckpointWrapper(mod, use_checkpoint)
        block.append(wrapped_mod)
    block = nn.ModuleList(block)
    t5stack.block = block

In [11]:
class MultipleInputsT5(T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)        
        # self.wrap_encoder()

    # We need to resize as B x (N * L) instead of (B * N) x L here
    # because the T5 forward method uses the input tensors to infer
    # dimensions used in the decoder.
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        if input_ids != None:
            # inputs might have already be resized in the generate method
            if input_ids.dim() == 3:
                self.encoder.n_passages = input_ids.size(1)
            input_ids = input_ids.view(input_ids.size(0), -1)
        if attention_mask != None:
            attention_mask = attention_mask.view(attention_mask.size(0), -1)
        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
    def generate(self, input_ids, attention_mask, max_length):
        self.encoder.n_passages = input_ids.size(1)
        return super().generate(
            input_ids=input_ids.view(input_ids.size(0), -1),
            attention_mask=attention_mask.view(attention_mask.size(0), -1),
            max_length=max_length
        )
        
            
    def wrap_encoder(self, use_checkpoint=False):
        """
        Wrap T5 encoder to obtain a Fusion-in-Decoder model.
        """
        self.encoder = WrappedEncoder(self.encoder, use_checkpoint=use_checkpoint)
    
    def unwrap_encoder(self):
        """
        Unwrap Fusion-in-Decoder encoder, useful to load T5 weights.
        """
        self.encoder = self.encoder.encoder
        block = []
        for mod in self.encoder.block:
            block.append(mod.module)
        block = nn.ModuleList(block)
        self.encoder.block = block
        
    def load_t5(self, state_dict):
        self.unwrap_encoder()
        self.load_state_dict(state_dict)
        self.wrap_encoder()
        
class WrappedEncoder(torch.nn.Module):
    """
    Encoder Wrapper for T5 Wrapper to obtain a Fusion-in-Decoder model.
    """
    def __init__(self, encoder, use_checkpoint=False):
        super().__init__()

        self.encoder = encoder
        self.main_input_name = encoder.main_input_name
        apply_checkpoint_wrapper(self.encoder, use_checkpoint)

    def forward(self, input_ids=None, attention_mask=None, **kwargs,):
        # total_length = n_passages * passage_length
        batch_sz, total_length = input_ids.shape
        passage_length = total_length // self.n_passages
        input_ids = input_ids.view(batch_sz * self.n_passages, passage_length)
        attention_mask = attention_mask.view(batch_sz * self.n_passages, passage_length)
        outputs = self.encoder(input_ids, attention_mask, **kwargs)
        outputs = (outputs[0].view(batch_sz, self.n_passages * passage_length, -1), ) + outputs[1:]
        return outputs


# Fine-tuning

In [8]:
train_set = ODQADataset('data/ViQuAD/train_viquad.json', n_context=10)

In [12]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)

## Model

In [22]:
model = MultipleInputsT5.from_pretrained(MODEL_NAME).to(device)




In [23]:
num_epochs = 2
num_training_steps = num_epochs * len(train_loader)
optimizer = AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler(name='linear', optimizer=optimizer, num_warmup_steps=1000, num_training_steps=num_training_steps)

In [24]:
progress_bar = tqdm(range(num_training_steps))

verbose_step = 50
checkpoint_step = 5000
model.train()
step = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch}:')
    for iter, batch in enumerate(train_loader):
        _, sources_ids, sources_masks, targets_ids, targets_masks = batch
        
        outputs = model(
            input_ids=sources_ids.to(device),
            attention_mask=sources_masks.to(device),
            labels=targets_ids.to(device)
        )
        loss = outputs.loss
        loss.backward()
        lr = lr_scheduler.get_lr()[0]
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        
        step += 1
        if step % verbose_step == 0:
            print(f'### Iter: {step} - Learning rate: {round(lr, 10)} - Loss: {round(loss.item(), 10)}')
            writer.add_scalar('Learning rate', lr, step)
            writer.add_scalar('Loss', loss.item(), step)
            
        if step % checkpoint_step == 0:
            torch.save({
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'scheduler': lr_scheduler.state_dict(),
                'model': model.state_dict(),
            }, f'checkpoints/{checkpoint_name}/checkpoint_iter{step}.pth')
            model.save_pretrained(f'checkpoints/{checkpoint_name}/model_iter{step}')

torch.save({
    'epoch': epoch,
    'optimizer': optimizer.state_dict(),
    'scheduler': lr_scheduler.state_dict(),
    'model': model.state_dict(),
}, f'checkpoints/{checkpoint_name}/checkpoint_final.pth')
model.save_pretrained(f'checkpoints/{checkpoint_name}/model_final')

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=37154.0), HTML(value='')))

Epoch 0:
### Iter: 50 - Learning rate: 2.45e-06 - Loss: 5.5834293365
### Iter: 100 - Learning rate: 4.95e-06 - Loss: 16.8477725983
### Iter: 150 - Learning rate: 7.45e-06 - Loss: 1.9185131788
### Iter: 200 - Learning rate: 9.95e-06 - Loss: 3.182046175
### Iter: 250 - Learning rate: 1.245e-05 - Loss: 3.7116234303
### Iter: 300 - Learning rate: 1.495e-05 - Loss: 1.5477397442
### Iter: 350 - Learning rate: 1.745e-05 - Loss: 0.7607566118
### Iter: 400 - Learning rate: 1.995e-05 - Loss: 3.6523406506
### Iter: 450 - Learning rate: 2.245e-05 - Loss: 2.2885518074
### Iter: 500 - Learning rate: 2.495e-05 - Loss: 4.9177060127
### Iter: 550 - Learning rate: 2.745e-05 - Loss: 1.3733707666
### Iter: 600 - Learning rate: 2.995e-05 - Loss: 6.1418886185
### Iter: 650 - Learning rate: 3.245e-05 - Loss: 0.7553626299
### Iter: 700 - Learning rate: 3.495e-05 - Loss: 1.6045964956
### Iter: 750 - Learning rate: 3.745e-05 - Loss: 0.783010006
### Iter: 800 - Learning rate: 3.995e-05 - Loss: 5.4374928474
### I

# Evaluation

## Utilities

In [12]:
def normalize_text(s):
    def remove_articles(text):
        return regex.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    text = white_space_fix(remove_articles(remove_punc(lower(s))))
    return text

def exact_match_score(prediction, ground_truth):
    return normalize_text(prediction) == normalize_text(ground_truth)

def ems(prediction, ground_truths):
    return max([exact_match_score(prediction, gt) for gt in ground_truths])

def compute_f1(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    
    common_tokens = collections.Counter(pred_tokens) & collections.Counter(truth_tokens)
    
    n_common_tokens = sum(common_tokens.values())
    
    if n_common_tokens == 0:
        return 0
    
    prec = n_common_tokens / len(pred_tokens)
    rec = n_common_tokens / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec)

def compute_f1_all(prediction, ground_truths):
    return max([compute_f1(prediction, gt) for gt in ground_truths])
    

In [40]:
model = MultipleInputsT5.from_pretrained(f'{checkpoint_path}/model_final').to(device)

In [41]:
test_set = ODQADataset('data/ViQuAD/test_ViQuAD_converted.json', 1)

In [42]:
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [43]:
model.eval()

exact_match = []
f1 = []
total = 0
progress_bar = tqdm(range(len(test_loader)))
result = []
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        idx, context_ids, context_masks, _, _, = batch

        outputs = model.generate(
            input_ids=context_ids.to(device),
            attention_mask=context_masks.to(device),
            max_length=50,
        )

        for k, o in enumerate(outputs):
            predict = tokenizer.decode(o, skip_special_tokens=True)
            sample = test_set.data[idx[k]]
            ground_truths = [ans['text'] for ans in sample['answers']]
            em_score = ems(predict, ground_truths)
            f1_score = compute_f1_all(predict, ground_truths)
            result.append({
                'question': sample['question'],
                'predict': predict,
                'ground_truths': ground_truths,
                'em_score': em_score,
                'f1_score': f1_score,
                })
            exact_match.append(em_score)
            f1.append(f1_score)
            total += 1
        
        progress_bar.update(1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=70.0), HTML(value='')))

In [44]:
em_score = np.mean(exact_match)
print(em_score)

0.5235294117647059


In [45]:
f1_score = np.mean(f1)
f1_score

0.747181659068327

Ground truth only: 0.42036199095022625 0.6388585316330179

top10: em: 0.36334841628959275, f1: 0.5881111813897126

In [20]:
result

[{'question': 'Cộng hòa Weimar chính thức thay thế đế quốc Đức kể từ sau sự kiện nào?',
  'predict': 'Chiến tranh thế giới thứ nhất và Cách mạng Đức 1918-1919',
  'ground_truths': ['Chiến tranh thế giới thứ nhất và Cách mạng Đức 1918-1919',
   'Chiến tranh thế giới thứ nhất và Cách mạng Đức 1918-1919',
   'Chiến tranh thế giới thứ nhất và Cách mạng Đức 1918-1919',
   'Chiến tranh thế giới thứ nhất và Cách mạng Đức 1918-1919'],
  'em_score': True,
  'f1_score': 1.0},
 {'question': 'Nước Đức hiện nay sự hợp thành của hai nước nào trong thời kỳ Đồng minh chiếm đóng Đức?',
  'predict': 'Cộng hòa Liên bang Đức và Cộng hòa Dân chủ Đức',
  'ground_truths': ['Cộng hòa Liên bang Đức và Cộng hòa Dân chủ Đức',
   'Cộng hòa Liên bang Đức và Cộng hòa Dân chủ Đức',
   'Cộng hòa Liên bang Đức và Cộng hòa Dân chủ Đức',
   'Cộng hòa Liên bang Đức và Cộng hòa Dân chủ Đức'],
  'em_score': True,
  'f1_score': 1.0},
 {'question': 'Hậu quả nào đã xảy ra khi tại Đức chế độ độc tài quốc xã được hình thành?',
