In [41]:
from typing import List, Dict, Tuple
from transformers import T5Tokenizer, T5ForConditionalGeneration
import pandas as pd

def generate_inspirations(setup, num_return_sequences=5, temperature=1):
    
    # Generate inspirations
    setup_ids = tokenizer(inspiration_prefix + setup, return_tensors="pt").input_ids
    predict_inspiration_ids = model.generate(setup_ids,
                                             top_k=20,
                                             do_sample=True,
                                             max_length=50,
                                             no_repeat_ngram_size=2,
                                             temperature=temperature,
                                             num_return_sequences=num_return_sequences).tolist()
    predict_inspirations = [tokenizer.decode(p, skip_special_tokens=True) for p in predict_inspiration_ids]
    return predict_inspirations

def generate_punches(setup, inspiration, num_return_sequences=5, temperature=1):
    input_ids = tokenizer(punch_prefix + inspiration + '|' + setup, return_tensors="pt").input_ids
    predict_punches_ids = model.generate(input_ids,
                                         do_sample=True,
                                         top_k=20,
                                         max_length=50,
                                         no_repeat_ngram_size=2,
                                         temperature=temperature,
                                         num_return_sequences=num_return_sequences).tolist()
    predict_punches = [tokenizer.decode(p, skip_special_tokens=True) for p in predict_punches_ids]
    return predict_punches

def generate_mark(joke):
    input_ids = tokenizer(mark_prefix + joke, return_tensors="pt").input_ids
    predict_mark_ids = model.generate(input_ids).tolist()
    predict_mark = tokenizer.decode(predict_mark_ids[0], skip_special_tokens=True)
    return predict_mark

def inference(setup, inspirations: List=None, 
              num_return_sequences: int=5, 
              temperature: int=1) -> List[Tuple[str, str, str, str]]:
    result_list = list()
    if not inspirations:
        # Generate inspirations
        inspirations = generate_inspirations(setup, 
                                             num_return_sequences=num_return_sequences,
                                             temperature=temperature,)
    for inspiration in inspirations:
        # Generate punches
        punches = generate_punches(setup, 
                                   inspiration, 
                                   num_return_sequences=num_return_sequences,
                                   temperature=temperature)
        for punch in punches:
            joke = setup + punch
            mark = generate_mark(joke)
            result_list.append((setup, inspiration, punch, mark))
    return result_list

In [3]:
# Models

# Default sberbank model used for finetune
# model_name, model_type = "sberbank-ai/ruT5-large", "pytorch"

# Model finetuned on span masks task
# model_name, model_type = "naltukhov/joke-generator-t5-rus", "flax"

# Model finetuned on span masks task and generation task
model_name, model_type = "naltukhov/joke-generator-t5-rus-finetune", "flax"

In [71]:
# Versions

# Last
revision = None

# Model with inspirations on small dataset (first success workflow)
# For naltukhov/joke-generator-t5-rus-finetune model
# revision = "159b2223b230be99faa5f9d661996c757f58d66a"  # 26
# revision = "9aaf18df0cb1dc5de2480222e52f504837306048"  # 27

# Model with inspirations on big dataset (first run, overfitting?)
revision = "c2a18676b1e782c4deef22e1d85a261f134f3a85" # 36
revision = "a001d2b3c44d193f489f2e3704ca13776a57a43b" # 34
# revision = "5514501d62fab937ba8cd77ae17ed062fbb3bf74" # 28



In [72]:
tokenizer = T5Tokenizer.from_pretrained(model_name, from_flax=model_type == "flax",
                                        force_download=True, use_auth_token=True, revision=revision)
model = T5ForConditionalGeneration.from_pretrained(model_name, from_flax=model_type == "flax",
                                                   force_download=True, use_auth_token=True, revision=revision)
print(f'Loaded model {model_name}')

inspiration_prefix = 'Сгенерировать вдохновение: '
mark_prefix = 'Сгенерировать оценку: '
punch_prefix = 'Сгенерировать шутку: '

Downloading: 100%|██████████| 980k/980k [00:00<00:00, 30.3MB/s]
Downloading: 100%|██████████| 1.74k/1.74k [00:00<00:00, 813kB/s]
Downloading: 100%|██████████| 1.90k/1.90k [00:00<00:00, 831kB/s]
Downloading: 100%|██████████| 1.39k/1.39k [00:00<00:00, 639kB/s]
Downloading: 100%|██████████| 2.75G/2.75G [01:50<00:00, 26.8MB/s] 
All Flax model weights were used when initializing T5ForConditionalGeneration.

Some weights of T5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['encoder.embed_tokens.weight', 'lm_head.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded model naltukhov/joke-generator-t5-rus-finetune


In [73]:
setup = 'Российские войска ведут штурм и наступают на двух направлениях'
# setup = 'Медведев приехал на церемонию прощания с Горбачевым'
# setup = 'Я плохой рассказчик'
setup = 'Медведь шёл по лесу'

predicts = inference(setup, 
                     num_return_sequences=3, 
                     temperature=0.6)

In [74]:
predicts_df = pd.DataFrame(predicts, columns=['setup', 'inspiration', 'punch', 'mark'])
predicts_df.to_dict(orient='records')

[{'setup': 'Медведь шёл по лесу',
  'inspiration': 'медведь',
  'punch': 'Медведь не шёл',
  'mark': '0'},
 {'setup': 'Медведь шёл по лесу',
  'inspiration': 'медведь',
  'punch': 'В медведей',
  'mark': '1'},
 {'setup': 'Медведь шёл по лесу',
  'inspiration': 'медведь',
  'punch': 'Медведь не ходил',
  'mark': '0'},
 {'setup': 'Медведь шёл по лесу',
  'inspiration': 'медведь',
  'punch': 'Медведь и медведев',
  'mark': '0'},
 {'setup': 'Медведь шёл по лесу',
  'inspiration': 'медведь',
  'punch': 'Медведь не нашёл себе на нём наполеон',
  'mark': '0'},
 {'setup': 'Медведь шёл по лесу',
  'inspiration': 'медведь',
  'punch': 'Медведь не нашёл',
  'mark': '0'},
 {'setup': 'Медведь шёл по лесу',
  'inspiration': 'медведь',
  'punch': 'Медведев не пришёл',
  'mark': '0'},
 {'setup': 'Медведь шёл по лесу',
  'inspiration': 'медведь',
  'punch': 'Медведь не видел',
  'mark': '0'},
 {'setup': 'Медведь шёл по лесу',
  'inspiration': 'медведь',
  'punch': 'Медведев уже не идёт',
  'mark': '0'}

In [None]:
# Punch generation
df = pd.read_csv('data/agg-generation-dataset/agg-generation-dataset-test.csv')
df = df.loc[df.input.str.startswith(punch_prefix)]
for i in range(10):
    input, output = df.sample().values[0]
    print(f'\tExample {i + 1}:')

    predicts = inference(input)
    print(*predicts)
