# Подготовка данных

In [None]:
!pip install transformers rouge-score nltk sklearn seaborn pandas torch torchvision pytelegrambotapi 

In [None]:
import json
import math
import re
from functools import reduce
import numpy as np
import telebot
import random

from telebot import apihelper

from nltk.translate.bleu_score import sentence_bleu as bleu
from nltk.translate.bleu_score import SmoothingFunction
import nltk
from rouge_score import rouge_scorer
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
import pickle
from random import shuffle

from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelWithLMHead, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")

model = AutoModelWithLMHead.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=608.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1713123.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1270925.0, style=ProgressStyle(descript…




Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=551290714.0, style=ProgressStyle(descri…




In [None]:
model = model.to(device)

In [None]:
model.resize_token_embeddings(len(tokenizer.get_vocab()))

Embedding(50258, 768)

In [None]:
BOS, EOS, PAD, SP1, SP2, HISTORY, INPUT, REPLY = '<bos>', '<eos>', '<pad>', '<speaker1>', '<speaker2>', '<history>', '<input>', '<reply>'
ATTR_TO_SPECIAL_TOKEN = {
    'bos_token': BOS, 
    'eos_token': EOS, 
    'pad_token': PAD,
    'additional_special_tokens': [SP1, SP2, HISTORY, INPUT, REPLY]}

In [None]:
orig_num_tokens = model.get_input_embeddings().num_embeddings
num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN)
if num_added_tokens > 0:
    model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens)

In [None]:
model.load_state_dict(torch.load('flibusta_model'))

<All keys matched successfully>

In [None]:
with open('tokenized_flibusta_small', 'rb') as f:
    tokenized_data = pickle.load(f)

In [None]:
def normalize(lines):
    data = []
    for line in lines:
        new_line = []
        for item in line:
            item = item.lower()
            item = re.sub('[,.!?();:\_\\/\[\]]', '', item)
            new_line.append(item)
        data.append(new_line)
    return data

def tokenize(lines):
    data = []
    for line in lines:
        data.append([tokenizer.encode_plus(item)['input_ids'] for item in line])
    return data

def delete_bad_sized_strings(lines):
    data = []
    for line in lines:
        new_line = []
        for item in line:
            if len(item) < 3 or len(item) > 15:
                if len(new_line) > 1:
                    data.append(new_line)
                new_line = []
                continue
            new_line.append(item)
        if len(new_line) > 1:
           data.append(new_line)
    return data

def build_data(lines, history_len=3):
    data = []
    for line in lines:
        if len(line) < 2:
            continue
        history = []
        for i in range(len(line)-1):
            input = line[i]
            reply = line[i+1]
            item = {
                'history': history.copy(),
                'input': input,
                'reply': reply
            }
            data.append(item)
            if len(history) == history_len:
                history = history[1:]
            history.append(input)
    return data

In [None]:
def process_data(data, history_len=3):
    data = normalize(data)
    data = tokenize(data)
    data = delete_bad_sized_strings(data)
    return build_data(data)

In [None]:
def preprocess_data_for_model(src, dest):
    with open(src, 'rb') as f:
        lines = pickle.load(f)
    no_doubled_lines = []
    seen = set()
    for line in lines:
        start = line[0]
        if start in seen:
            continue
        seen.add(start)
        no_doubled_lines.append(lines)
    tokenized_data = process_data(no_doubled_lines)
    cleared_data = []
    take = True
    for item in tokenized_data:
        if item['input'] == item['reply']:
            take = False
            continue
        if len(item['history']) == 0:
            take = True
        if not take:
            continue
        cleared_data.append(item)
    tokenized_data = cleared_data
    with open(dest, 'wb') as f:
        pickle.dump(tokenized_data, f)
    return tokenized_data

In [None]:
tokenized_data = preprocess_data_for_model('test', 'tokenized_flibusta_small')

# Dataset и Dataloader

In [None]:
batch_size = 64
batches = []

for i_batch in range(math.ceil(len(tokenized_data) / batch_size)):
    batches.append(tokenized_data[i_batch*batch_size:(i_batch+1)*batch_size])

In [None]:
class SequenceBucketingData(Dataset):
    def __init__(self, data, tokenizer):
        super().__init__()
        
        self.data = data
        self.tokenizer = tokenizer
        self.bos = BOS
        self.eos = EOS
        self.pad = 0
        self.sp1 = SP1
        self.sp2 = SP2
        self.bos_emb = tokenizer.convert_tokens_to_ids(self.bos)
        self.eos_emb = tokenizer.convert_tokens_to_ids(self.eos)
        self.sp1_emb = tokenizer.convert_tokens_to_ids(self.sp1)
        self.sp2_emb = tokenizer.convert_tokens_to_ids(self.sp2)
        self.history = tokenizer.convert_tokens_to_ids(HISTORY)
        self.input = tokenizer.convert_tokens_to_ids(INPUT)
        self.reply = tokenizer.convert_tokens_to_ids(REPLY)
        
        self.shuffle = shuffle
        
    def __len__(self):
        return len(self.data)
    
    def padding(self, text, length):
        text += [self.pad] * (length - len(text))
        return text

    def __getitem__(self, index):
        batch = self.data[index]
        
        xx = []
        yy = []
        pi = []
        am = []
        tti = []

        for i in range(len(batch)):
            item = batch[i]
            history = item['history']
            input = item['input']
            reply = item['reply']

            position_ids = []
            token_type_ids = []

            batch_item = []
            
            curr_speaker = self.sp1_emb
            for item in history:
                for i in range(len(item)+1):
                    position_ids.append(i+1)
                    token_type_ids.append(curr_speaker)
                item = [self.history] + item
                batch_item.append(item)
                if curr_speaker == self.sp1_emb:
                    curr_speaker = self.sp2_emb
                else:
                    curr_speaker = self.sp1_emb
            batch_item.append([self.input] + input)
            for i in range(len(input)+1):
                position_ids.append(i+1)
                token_type_ids.append(self.sp1_emb)
            batch_item.append([self.reply] + reply)
            for i in range(len(reply)+1): 
                position_ids.append(i+1)
                token_type_ids.append(self.sp2_emb)
            
            batch_item_flat = reduce(lambda x,y: x+y, batch_item)

            x = batch_item_flat
            y = batch_item_flat[1:] + [self.eos_emb]

            attn_mask = [el != 0 for el in x]

            xx.append(x)
            yy.append(y)
            pi.append(position_ids)          
            am.append(attn_mask)
            tti.append(token_type_ids)
        
        max_length = max([len(item) for item in xx])
        for i in range(len(xx)):
            xx[i] = self.padding(xx[i], max_length)
            yy[i] = self.padding(yy[i], max_length)
            pi[i] = self.padding(pi[i], max_length)
            am[i] = self.padding(am[i], max_length)
            tti[i] = self.padding(tti[i], max_length)
        
        xx = torch.tensor(xx).to(device)
        yy = torch.tensor(yy).to(device)
        pi = torch.tensor(pi).to(device)
        am = torch.tensor(am).to(device)
        tti = torch.tensor(tti).to(device)
        return xx, yy, pi, am, tti

In [None]:
validation_start_index = int(len(batches) * 0.1)

In [None]:
train_dataset_seq = SequenceBucketingData(batches[:-validation_start_index],
                                          tokenizer
                                        )
valid_dataset_seq = SequenceBucketingData(batches[-validation_start_index:],
                                          tokenizer
                                        )

In [None]:
train_loader = DataLoader(train_dataset_seq, batch_size=1, shuffle=True)
valid_loader = DataLoader(valid_dataset_seq, batch_size=1)

# Обучение модели

In [None]:
def ids_to_string(ids, tokenizer):
    return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(ids,skip_special_tokens=True))

In [None]:
def get_nlp_metrics(preds, yy, tokenizer):
    smoothie = SmoothingFunction().method4
    rouge_f1 = []
    bleus = []
    scorer = rouge_scorer.RougeScorer(['rougeL'])
    for i in range(len(preds)):
        pred = preds[i]
        y = yy[i]
        pred = ids_to_string(pred, tokenizer)
        y = ids_to_string(y, tokenizer)
        scores = scorer.score(pred, y)
        rouge_f1.append(scores['rougeL'].fmeasure)
        bleus.append(bleu([pred], y, smoothing_function=smoothie))
    return rouge_f1, bleus

In [None]:
def train(lm_model, opt, cri, x, y, pi, am, tti, batch_size, curr_step, accumulation_step, clip_norm, tok):
    lm_model.train()
    output = lm_model.forward(input_ids=x, attention_mask=am, token_type_ids=tti)
    lm_pred = output.logits
    loss = cri(lm_pred.view(-1, lm_pred.size(-1)), y.view(-1))
    loss.backward()
    
    lm_pred = F.softmax(lm_pred, dim=2)
    lm_pred = torch.argmax(lm_pred, dim=-1)
    rouge_f1, bleus = get_nlp_metrics(lm_pred[0], y[0], tok)
    rouge_f1_score = np.mean(rouge_f1)
    bleu_score = np.mean(bleus)

    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
    if batch_size < 16:
        if curr_step % accumulation_step == 0:
            opt.step()
            opt.zero_grad()
    else:
        opt.step()
        opt.zero_grad()
         
    return loss, rouge_f1_score, bleu_score

In [None]:
def validate(lm_model, cri, x, y, pi, am, tti, tok):
    lm_model.eval()
    with torch.no_grad():
        output = lm_model.forward(input_ids=x, attention_mask=am, token_type_ids=tti)
        
    lm_pred = output.logits     
    loss = cri(lm_pred.view(-1, lm_pred.size(-1)), y.view(-1))

    lm_pred = F.softmax(lm_pred, dim=2)
    lm_pred = torch.argmax(lm_pred, dim=-1)

    rouge_f1, bleus = get_nlp_metrics(lm_pred[0], y[0], tok)
    rouge_f1_score = np.mean(rouge_f1)
    bleu_score = np.mean(bleus)
    return loss, rouge_f1_score, bleu_score

In [None]:
def iterate(lm_model, t_loader, v_loader, epochs, opt, cri, accumulation_step, clip_norm, tok, last_n_perpl=500):
    perpl = []
    rouges = []
    bleus = []
    
    prev_val_perplexity = 100000000
    bad_iters = 0
    for n_epoch in range(epochs):
        if bad_iters > 3:
            return lm_model, perpl, rouges, bleus
        progress_bar = tqdm(total=len(t_loader), desc=f'Epoch {n_epoch + 1} of {epochs}')
        curr_step = 1
        for x, y, pi, am, tti in t_loader:
            if bad_iters > 3:
                return lm_model, perpl, rouges, bleus
            if curr_step % 2000 == 0:
                epoch_val_perpl = []
                epoch_val_rouge = []
                epoch_val_bleu = []

                for x, y, pi, am, tti in v_loader:
                    batch_size = len(x)
                    loss, rouge_f1_score, bleu_score = validate(lm_model, cri, x, y, pi, am, tti, tok)
                    perplexity = np.exp(loss.item())
                    epoch_val_perpl.append(perplexity)
                    epoch_val_rouge.append(rouge_f1_score)
                    epoch_val_bleu.append(bleu_score)

                curr_perpl = np.mean(epoch_val_perpl)
                curr_rouge = np.mean(epoch_val_rouge)
                curr_bleu = np.mean(epoch_val_bleu)
                print("\nCurr val perpl: ", curr_perpl)
                
                if curr_perpl < prev_val_perplexity:
                    bad_iters = 0
                    torch.save(lm_model.state_dict(), f'dicts/model')
                    torch.save(opt.state_dict(), f'dicts/optimizer')
                    prev_val_perplexity = curr_perpl
                else:
                    bad_iters += 1
                    
            batch_size = len(x)
            loss, rouge_f1_score, bleu_score = train(lm_model, opt, cri, x, y, pi, am, tti, batch_size, curr_step, accumulation_step, clip_norm, tok)
            rouges.append(rouge_f1_score)
            bleus.append(bleu_score)
            perplexity = np.exp(loss.item())
            perpl.append(perplexity)
            curr_step += 1
            progress_bar.set_postfix(perplexity=np.mean(perpl[-last_n_perpl:]))
            progress_bar.update()
        progress_bar.close()

        for x, y, pi, am, tti in v_loader:
            batch_size = len(x)
            loss, rouge_f1_score, bleu_score = validate(lm_model, cri, x, y, pi, am, tti, tok)
            perplexity = np.exp(loss.item())
            epoch_val_perpl.append(perplexity)
            epoch_val_rouge.append(rouge_f1_score)
            epoch_val_bleu.append(bleu_score)

        curr_perpl = np.mean(epoch_val_perpl)
        curr_rouge = np.mean(epoch_val_rouge)
        curr_bleu = np.mean(epoch_val_bleu)
        print("\nCurr val perpl: ", curr_perpl)

        if curr_perpl < prev_val_perplexity:
            bad_iters = 0
            torch.save(lm_model.state_dict(), f'dicts/model')
            torch.save(opt.state_dict(), f'dicts/optimizer')
            prev_val_perplexity = curr_perpl
        else:
            bad_iters += 1
        
    return lm_model, perpl, rouges, bleus

In [None]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
optim = AdamW(params=model.parameters(), lr=5e-5, correct_bias=True)

In [None]:
lm_model, perpls, rouges, bleus = iterate(model, train_loader, valid_loader, 3, optim, criterion, 1, 3, tokenizer)

In [None]:
for instance in list(tqdm._instances):
    tqdm._decr_instances(instance)

# Генерация

In [None]:
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
            top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
                whose total probability mass is greater than or equal to the threshold top_p.
                In practice, we select the highest probability tokens whose cumulative probability mass exceeds
                the threshold top_p.
            threshold: a minimal threshold to keep logits
    """
    assert logits.dim() == 1  # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Back to unsorted indices and set them to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    indices_to_remove = logits < threshold
    logits[indices_to_remove] = filter_value

    return logits

In [None]:
def generate(model, input, attention_mask, tti, next_speaker, max_len=50):
    model.eval()
    generated = []
    with torch.no_grad():
        for i in range(max_len):
            input_batch = input.unsqueeze(0).unsqueeze(0)
            am_batch = torch.tensor(attention_mask).to(device).unsqueeze(0).unsqueeze(0)
            tti_batch = torch.tensor(tti).to(device).unsqueeze(0).unsqueeze(0)
            output = model.forward(input_ids=input, 
                                   attention_mask=am_batch,
                                   token_type_ids=tti_batch
                                   )
            lm_pred = output.logits
            logits = lm_pred[len(lm_pred)-1, :]
            k = 10
            p = 0.9
            filtered_logits = top_filtering(logits, top_p=p, top_k=k)
            pred = F.softmax(filtered_logits, dim=0)
            new_token = torch.multinomial(pred, 1)
            
            input = torch.cat((input, new_token.view(1)))
            generated.append(new_token)
            if new_token == tokenizer.convert_tokens_to_ids(EOS):
                break
            attention_mask.append(1)
            tti.append(next_speaker)
    return generated

In [None]:
replics = [
    {"tag":HISTORY, "text": 'совсем разложились надо же бульдозерами давить'},
    {"tag":HISTORY, "text": 'или срок годности переклеить'},
    {"tag":INPUT, "text": 'а потом бульдозером раздавить и по федеральным каналам показать'},
    {"tag":REPLY, "text": ''},
]
sp1 = tokenizer.convert_tokens_to_ids(SP1)
sp2 = tokenizer.convert_tokens_to_ids(SP2)
curr_sp_1 = True
inp = []
am = []
tti = []
for replic in replics:
    replic = [tokenizer.convert_tokens_to_ids(replic["tag"])] + tokenizer.encode_plus(replic['text'])['input_ids']
    if curr_sp_1:
        use = sp1
    else:
        use = sp2
    am += [1] * len(replic)
    tti += [use] * len(replic)
    inp += replic
    curr_sp_1 = !curr_sp_1
next_speaker = sp1 if curr_sp_1 else sp2
generated = generate(model, torch.tensor(inp).to(device), am, tti, next_speaker)

In [None]:
tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(generated, skip_special_tokens=True))

# Бот

In [None]:
secret = "" # вставить код

In [None]:
bot = telebot.TeleBot(secret)

history = {}

def generate_reply(input, user_history):
    seed = random.choice(seeds)
    replics = [{"tag":HISTORY, "text": sentence} for sentence in user_history]
    replics.append({"tag":INPUT, "text": input})
    replics.append({"tag":REPLY, "text": ''})

    sp1 = tokenizer.convert_tokens_to_ids(SP1)
    sp2 = tokenizer.convert_tokens_to_ids(SP2)
    curr_sp_1 = True
    inp = []
    am = []
    tti = []
    for replic in replics:
        replic = [tokenizer.convert_tokens_to_ids(replic["tag"])] + tokenizer.encode_plus(replic['text'])['input_ids']
        if curr_sp_1:
            use = sp1
        else:
            use = sp2
        am += [1] * len(replic)
        tti += [use] * len(replic)
        inp += replic
        curr_sp_1 = !curr_sp_1
    next_speaker = sp1 if curr_sp_1 else sp2
    generated = generate(model, torch.tensor(inp).to(device), am, tti, next_speaker)
    return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(generated, skip_special_tokens=True)) 

@bot.message_handler(commands=['history'])
def handle_history(message):
    curr_user = message.from_user.id
    user_history = history.get(curr_user)
    if user_history is None or len(user_history) == 0:
        bot.send_message(message.chat.id, 'Истории нет')
        return
    msgs = ['{0}. {1}'.format(i, user_history[i]) for i in range(len(user_history))]
    bot.send_message(message.chat.id, '\n'.join(msgs))
    
@bot.message_handler(commands=['ping'])
def handle_ping(message):
    bot.send_message(message.chat.id, 'pong')

@bot.message_handler(commands=['restart'])
def handle_start(message):
    curr_user = message.from_user.id
    history[curr_user] = []
    bot.send_message(message.chat.id, 'История была сброшена')

def update_history(old_history, input, reply):
    new_history = old_history.copy()
    new_history.append(input)
    new_history.append(reply)
    if len(new_history) > 3:
        return new_history[-3:]
    return new_history

@bot.message_handler(content_types=['text'])
def send_text(message):
    curr_user = message.from_user.id
    user_history = history.get(curr_user)
    if user_history is None:
        history[curr_user] = []
        user_history = history.get(curr_user)
    user_input = normalize([[message.text]])[0][0]
    reply = generate_reply(user_input, user_history)

    new_history = update_history(user_history, user_input, reply)
    history[curr_user] = new_history
    bot.send_message(message.chat.id, reply)

bot.polling()