In [4]:
import os
from glob import glob
import re
import pandas as pd

from tqdm import tqdm
tqdm.pandas()

In [5]:
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate

from transformers import T5Model, T5Tokenizer

In [9]:
tokenizer = T5Tokenizer.from_pretrained('sberbank-ai/ruT5-base')
t5 = T5Model.from_pretrained('sberbank-ai/ruT5-base')

Some weights of the model checkpoint at sberbank-ai/ruT5-base were not used when initializing T5Model: ['lm_head.weight']
- This IS expected if you are initializing T5Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing T5Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Dataset class

In [6]:
class JokeDataset(Dataset):

    def __init__(self, df, tokenizer, sentence_length=150):
        super().__init__()
        self.dataset = df \
          .sort_values(['setup', 'punch']) \
          .reset_index()
        self.tokenizer = tokenizer
        self.sentence_length = sentence_length

    def __len__(self):
        return len(self.dataset)

    def tokenize_input(self, input_tests):
        encode = self.tokenizer(
            input_tests, 
            add_special_tokens=True,
            return_attention_mask=False,
            padding='max_length',
            truncation=True,
            max_length=self.sentence_length,
            return_special_tokens_mask=True,
            return_tensors='pt'
        )

        word_ids = encode.input_ids[0]
        masks = (encode.special_tokens_mask[0] == 0).to(torch.int8)

        return word_ids, masks


    def __getitem__(self, idx):
        setup_text = self.dataset.setup[idx]
        punch_text = self.dataset.punch[idx]

        setup_encode_ids, setup_encode_mask = self.tokenize_input(setup_text)
        punch_encode_ids, punch_encode_mask = self.tokenize_input(punch_text)

        if 'mark' in self.dataset.columns:
            target = self.dataset.mark[idx]
            return (setup_encode_ids, 
                  setup_encode_mask, 
                  punch_encode_ids, 
                  punch_encode_mask,
                  target)
        else:
            return (setup_encode_ids, 
                  setup_encode_mask, 
                  punch_encode_ids, 
                  punch_encode_mask)

### Model class

In [7]:
class Markuper(nn.Module):
    
    def __init__(self, model, sentence_length=150):
        super().__init__()
        self.model = model

        self.embedding_size = 768
        self.sentence_length = sentence_length

        self.fс1 = nn.Linear(self.embedding_size, 1)
        self.fс2 = nn.Linear(self.sentence_length, 1)

        self.activation = nn.ReLU()
        
    def forward(self, 
                encoder_word_ids, 
                encoder_mask, 
                decoder_word_ids, 
                decoder_mask):
                
        # T5 output
        transformer_output = self.model.forward(
            input_ids=encoder_word_ids, 
            attention_mask=encoder_mask,
            decoder_input_ids=decoder_word_ids,
            decoder_attention_mask=decoder_mask,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        ).last_hidden_state
        
        # FC layer
        result = self.activation(self.fс1(transformer_output)).squeeze()
        result = self.activation(self.fс2(result)).squeeze()

        return result

### Train functions

In [8]:
import time

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 50
    start_time = time.time()

    for idx, (enc_ids, enc_mask, dec_ids, dec_mask, mark) in enumerate(dataloader):
        mark = mark.to(torch.float32)
        optimizer.zero_grad()
        predicted_mark = model(enc_ids, enc_mask, dec_ids, dec_mask)
        loss = criterion(predicted_mark, mark)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_mark - mark).abs().sum().item()
        total_count += mark.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| train MAPE loss {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (enc_ids, enc_mask, dec_ids, dec_mask, mark) in enumerate(dataloader):
            mark = mark.to(torch.float32)
            predicted_mark = model(enc_ids, enc_mask, dec_ids, dec_mask)
            loss = criterion(predicted_mark, mark)
            total_acc += (predicted_mark - mark).abs().sum().item()
            total_count += mark.size(0)
    return total_acc/total_count

### Train process

In [12]:
SENTENCE_LENGTH = 50
BATCH_SIZE = 8 
EPOCHS = 10
LR = 5
# device = torch.device('cuda:0') # gpu
device = torch.device('cpu') # cpu



df = pd.read_csv('comedy-news-tg-dataset/marked/full_dataset.tsv', sep=',')
df['is_valid'] = df.setup == 'Balenciaga выпустил кроссовки на каблуках почти за 100 тысяч рублей'

train_dataset = JokeDataset(df.loc[~df.is_valid], tokenizer, sentence_length=SENTENCE_LENGTH)
valid_dataset = JokeDataset(df.loc[df.is_valid], tokenizer, sentence_length=SENTENCE_LENGTH)

train_dataloader = DataLoader(train_dataset,
                              batch_size=BATCH_SIZE, 
                              shuffle=True,  
                              num_workers=0, 
                              collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

valid_dataloader = DataLoader(valid_dataset,
                              batch_size=BATCH_SIZE, 
                              shuffle=True,  
                              num_workers=0,
                              collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))

model = Markuper(model=t5,
                 sentence_length=SENTENCE_LENGTH).to(device)

In [13]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

In [14]:
# Test
enc_ids, enc_mask, dec_ids, dec_mask, mark = iter(train_dataloader).__next__()
predict = model(enc_ids, enc_mask, dec_ids, dec_mask)
criterion(predict, mark)

tensor(1.7162, grad_fn=<MseLossBackward0>)

In [41]:
init_accu_val = evaluate(valid_dataloader)
print('Init eval MAPE loss: {:8.3f} '.format(init_accu_val))

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid MAPE loss {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

Init eval MAPE loss:    2.983 
| epoch   1 |    50/  432 batches | accuracy    2.686
| epoch   1 |   100/  432 batches | accuracy    2.645
| epoch   1 |   150/  432 batches | accuracy    2.942
| epoch   1 |   200/  432 batches | accuracy    2.420
| epoch   1 |   250/  432 batches | accuracy    2.750
| epoch   1 |   300/  432 batches | accuracy    2.430
| epoch   1 |   350/  432 batches | accuracy    2.690
| epoch   1 |   400/  432 batches | accuracy    2.788
-----------------------------------------------------------
| end of epoch   1 | time: 86.99s | valid MAPE loss    2.983 
-----------------------------------------------------------
| epoch   2 |    50/  432 batches | accuracy    3.115
| epoch   2 |   100/  432 batches | accuracy    2.263
| epoch   2 |   150/  432 batches | accuracy    2.803
| epoch   2 |   200/  432 batches | accuracy    2.415
| epoch   2 |   250/  432 batches | accuracy    2.493
| epoch   2 |   300/  432 batches | accuracy    2.397
| epoch   2 |   350/  432 batch