In [40]:
from typing import List, Dict, Tuple
from transformers import T5Tokenizer, T5ForConditionalGeneration
import pandas as pd

In [39]:
class T5GenerationModel:
    inspiration_prefix = 'Сгенерировать вдохновение: '
    mark_prefix = 'Сгенерировать оценку: '
    punch_prefix = 'Сгенерировать шутку: '

    def __init__(self):
        self.model = None
        self.tokenizer = None

    def load_model_from_file(self, model_dir):
        self.model = T5ForConditionalGeneration.from_pretrained(model_dir)
        self.tokenizer = T5Tokenizer.from_pretrained(model_dir)

    def load_model_from_hub(self,
                      model_name,
                      model_type,
                      force_download,
                      use_auth_token,
                      revision):

        self.tokenizer = T5Tokenizer.from_pretrained(model_name,
                                                     from_flax=model_type == "flax",
                                                     force_download=force_download,
                                                     use_auth_token=use_auth_token,
                                                     revision=revision)

        self.model = T5ForConditionalGeneration.from_pretrained(model_name,
                                                                from_flax=model_type == "flax",
                                                                force_download=force_download,
                                                                use_auth_token=use_auth_token,
                                                                revision=revision)
        
    def save_weights(self, path):
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)
        

    def generate_inspirations(self, setup: str,
                              num_return_sequences: int = 5, temperature: float = 1) -> List[str]:
        # Generate inspirations
        setup_ids = self.tokenizer(self.inspiration_prefix + setup, return_tensors="pt").input_ids
        predict_inspiration_ids = self.model.generate(setup_ids,
                                                      top_k=20,
                                                      do_sample=True,
                                                      max_length=50,
                                                      no_repeat_ngram_size=2,
                                                      temperature=temperature,
                                                      num_return_sequences=num_return_sequences).tolist()
        predict_inspirations = [self.tokenizer.decode(p, skip_special_tokens=True) for p in predict_inspiration_ids]
        return predict_inspirations

    def generate_punches(self, setup: str, inspiration: str,
                         num_return_sequences: int = 5, temperature: float = 1) -> List[str]:

        input_ids = self.tokenizer(self.punch_prefix + inspiration + '|' + setup, return_tensors="pt").input_ids
        predict_punches_ids = self.model.generate(input_ids,
                                                  do_sample=True,
                                                  top_k=20,
                                                  max_length=50,
                                                  no_repeat_ngram_size=2,
                                                  temperature=temperature,
                                                  num_return_sequences=num_return_sequences).tolist()
        predict_punches = [self.tokenizer.decode(p, skip_special_tokens=True) for p in predict_punches_ids]
        return predict_punches

    def generate_mark(self, joke: str) -> str:
        input_ids = self.tokenizer(self.mark_prefix + joke, return_tensors="pt").input_ids
        predict_mark_ids = self.model.generate(input_ids).tolist()
        predict_mark = self.tokenizer.decode(predict_mark_ids[0], skip_special_tokens=True)
        return predict_mark

    def inference(self, setup: str, inspirations: List = None,
                  num_return_sequences: int = 5, temperature: float = 1) -> List[Tuple[str, str, str, str]]:
        result_list = list()
        if not inspirations:
            # Generate inspirations
            inspirations = self.generate_inspirations(setup,
                                                      num_return_sequences=num_return_sequences,
                                                      temperature=temperature, )
        for inspiration in inspirations:
            # Generate punches
            punches = self.generate_punches(setup,
                                            inspiration,
                                            num_return_sequences=num_return_sequences,
                                            temperature=temperature)
            for punch in punches:
                joke = setup + punch
                mark = self.generate_mark(joke)
                result_list.append((setup, inspiration, punch, mark))

        sorted_result_list = sorted(result_list, key=lambda tup: tup[3], reverse=True)[:num_return_sequences]
        return sorted_result_list


In [44]:
# Models

# Default sberbank model used for finetune
# model_name, model_type = "sberbank-ai/ruT5-large", "pytorch"

# Model finetuned on span masks task
# model_name, model_type = "naltukhov/joke-generator-t5-rus", "flax"

# Model finetuned on span masks task and generation task
model_name, model_type = "naltukhov/joke-generator-t5-rus-finetune-gen", "flax"

In [42]:
# Gererator model versions

# Last
revision = None


# Model with inspirations on small dataset (first success workflow)
# For naltukhov/joke-generator-t5-rus-finetune model
# revision = "159b2223b230be99faa5f9d661996c757f58d66a"  # 26
# revision = "9aaf18df0cb1dc5de2480222e52f504837306048"  # 27

# Model with inspirations on big dataset (the first run, overfitting?)
# revision = "c2a18676b1e782c4deef22e1d85a261f134f3a85" # 36
# revision = "a001d2b3c44d193f489f2e3704ca13776a57a43b" # 34
# revision = "5514501d62fab937ba8cd77ae17ed062fbb3bf74" # 28

# Model with inspirations on big dataset (the second run)
revision = "af089323b89ca9c1968f7c0f34c1d77be2d3d6d4" # 40 -> production

In [46]:
model = T5GenerationModel()
model.load_model_from_hub(
    model_name=model_name,
    model_type=model_type,
    revision=revision,
    force_download=True,
    use_auth_token=True)

Downloading:   0%|          | 0.00/980k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.39k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.75G [00:00<?, ?B/s]

All Flax model weights were used when initializing T5ForConditionalGeneration.

Some weights of T5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['lm_head.weight', 'decoder.embed_tokens.weight', 'encoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [47]:
setup = 'Российские войска ведут штурм и наступают на двух направлениях'
# setup = 'Медведев приехал на церемонию прощания с Горбачевым'
# setup = 'Я плохой рассказчик'
# setup = 'Медведь шёл по лесу'

predicts = model.inference(setup, 
                           num_return_sequences=5, 
                           temperature=0.9)

In [48]:
predicts_df = pd.DataFrame(predicts, columns=['setup', 'inspiration', 'punch', 'mark'])
predicts_df.to_dict(orient='records')

[{'setup': 'Российские войска ведут штурм и наступают на двух направлениях',
  'inspiration': 'сирия',
  'punch': 'А у нас есть Сирии. А сирия - в Саратове',
  'mark': '1'},
 {'setup': 'Российские войска ведут штурм и наступают на двух направлениях',
  'inspiration': 'сирия',
  'punch': 'Сирю на две сири в завязке на двух, на четыре...',
  'mark': '1'},
 {'setup': 'Российские войска ведут штурм и наступают на двух направлениях',
  'inspiration': 'ща',
  'punch': 'А я как ща, в рюрик в отруби!',
  'mark': '1'},
 {'setup': 'Российские войска ведут штурм и наступают на двух направлениях',
  'inspiration': 'ща',
  'punch': 'Е не надо, я в щах!',
  'mark': '1'},
 {'setup': 'Российские войска ведут штурм и наступают на двух направлениях',
  'inspiration': 'ща',
  'punch': 'Ну, и ща им в завязке, ели у нас уже',
  'mark': '1'}]

In [None]:
# Punch generation
df = pd.read_csv('data/agg-generation-dataset/agg-generation-dataset-test.csv')
df = df.loc[df.input.str.startswith(punch_prefix)]
for i in range(10):
    input, output = df.sample().values[0]
    print(f'\tExample {i + 1}:')

    predicts = inference(input)
    print(*predicts)


## Save prod version

In [51]:
import os
import shutil
from huggingface_hub import Repository

dir_name = 't5_jokes_generator_v3'
model_name = "naltukhov/joke-generator-rus-t5"


if os.path.exists(dir_name):
    shutil.rmtree(dir_name)

# Create local repo
repo = Repository(dir_name, clone_from=model_name)
    
# Save as pytorch model
model.save_weights(dir_name)
print('Saved model weights locally')

# Push
print('Start pushing model weiths to repo')
repo.push_to_hub(commit_message=f"Add new model version", blocking=False)


Cloning https://huggingface.co/naltukhov/joke-generator-rus-t5 into local empty directory.


('https://huggingface.co/naltukhov/joke-generator-rus-t5/commit/829409111cfd78e24077fe4034edb4b526c1e661',
 [push command, status code: running, in progress. PID: 42940])

In [62]:
# Test prod model
model = T5GenerationModel()
model.load_model_from_hub(model_name="naltukhov/joke-generator-rus-t5",
                          model_type="pytorch",
                          use_auth_token=False,
                          force_download=True,
                          revision=None)

setup = 'Российские войска ведут штурм и наступают на двух направлениях'
predicts = model.inference(setup, 
                           num_return_sequences=5, 
                           temperature=0.9)

predicts

Downloading:   0%|          | 0.00/980k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.75G [00:00<?, ?B/s]

[('Российские войска ведут штурм и наступают на двух направлениях',
  'вражеский атака',
  'Вражеская атака',
  '1'),
 ('Российские войска ведут штурм и наступают на двух направлениях',
  'вражеский атака',
  'Вражеские атаки и ступаются в отпуске, в вражейский',
  '1'),
 ('Российские войска ведут штурм и наступают на двух направлениях',
  'вражеский атака',
  'Вражеский атаку',
  '1'),
 ('Российские войска ведут штурм и наступают на двух направлениях',
  'вражеский атака',
  'Вражеские атаки и так посылают',
  '1'),
 ('Российские войска ведут штурм и наступают на двух направлениях',
  'сбить',
  'Чтобы не сбили',
  '1')]