In [4]:
import os
from glob import glob
import re
import pandas as pd

from tqdm import tqdm
tqdm.pandas()

import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate

from transformers import T5Model, T5Tokenizer

from textrl import TextRLEnv, TextRLActor
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM  

In [5]:
tokenizer = AutoTokenizer.from_pretrained('sberbank-ai/ruT5-base')  
model = AutoModelForSeq2SeqLM.from_pretrained('sberbank-ai/ruT5-base')

In [None]:
sentence_sim_model = SentenceTransformer('models/DeepPavlov_rubert-base-cased-sentence/')

Check:
https://github.com/voidful/TextRL/blob/main/README.md

In [None]:
class MyRLEnv(TextRLEnv):
    def get_reward(self, input_text, predicted_list, finish): # predicted will be the list of predicted token
        if "[UNK]" in predicted_list:
            reward = -1
        else:
            reward = 1
        return reward

In [9]:
tokenizer = T5Tokenizer.from_pretrained('sberbank-ai/ruT5-base')
t5 = T5Model.from_pretrained('sberbank-ai/ruT5-base')

Some weights of the model checkpoint at sberbank-ai/ruT5-base were not used when initializing T5Model: ['lm_head.weight']
- This IS expected if you are initializing T5Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing T5Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Dataset class

In [6]:
class JokeDataset(Dataset):

    def __init__(self, df, tokenizer, sentence_length=150):
        super().__init__()
        self.dataset = df \
          .sort_values(['setup', 'punch']) \
          .reset_index()
        self.tokenizer = tokenizer
        self.sentence_length = sentence_length

    def __len__(self):
        return len(self.dataset)

    def tokenize_input(self, input_tests):
        encode = self.tokenizer(
            input_tests, 
            add_special_tokens=True,
            return_attention_mask=False,
            padding='max_length',
            truncation=True,
            max_length=self.sentence_length,
            return_special_tokens_mask=True,
            return_tensors='pt'
        )

        word_ids = encode.input_ids[0]
        masks = (encode.special_tokens_mask[0] == 0).to(torch.int8)

        return word_ids, masks


    def __getitem__(self, idx):
        setup_text = self.dataset.setup[idx]
        punch_text = self.dataset.punch[idx]

        setup_encode_ids, setup_encode_mask = self.tokenize_input(setup_text)
        punch_encode_ids, punch_encode_mask = self.tokenize_input(punch_text)

        if 'mark' in self.dataset.columns:
            target = self.dataset.mark[idx]
            return (setup_encode_ids, 
                  setup_encode_mask, 
                  punch_encode_ids, 
                  punch_encode_mask,
                  target)
        else:
            return (setup_encode_ids, 
                  setup_encode_mask, 
                  punch_encode_ids, 
                  punch_encode_mask)