<a href="https://colab.research.google.com/github/dTenebrae/nlp_course/blob/main/nlp_course.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Телеграм-бот на основе кастомной ruDalle

In [None]:
# Installation
!nvidia-smi -L
!pip install rudalle==0.0.1rc7 > /dev/null

GPU 0: Tesla T4 (UUID: GPU-0645c37d-9806-215a-30e2-0ca557a33068)
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.[0m


In [None]:
!pip install pyTelegramBotAPI > /dev/null

In [None]:
# Imports
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
from rudalle import utils
from rudalle.utils import seed_everything
from rudalle.dalle.utils import divide, split_tensor_along_last_dim

import torch
import torchvision
import torch.nn.functional as F
import transformers

import math
import re
import numpy as np

import inspect
from functools import partial

import more_itertools
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import telebot

import warnings
warnings.simplefilter("ignore")

In [None]:
device = 'cuda'
dalle = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device)
try:
    realesrgan, tokenizer, ruclip, ruclip_processor
except NameError:
    realesrgan = get_realesrgan('x4', device=device)
    tokenizer = get_tokenizer()
    vae = get_vae().to(device)
    ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5')
    ruclip = ruclip.to(device)

In [None]:
def show(pil_images, nrow=4):
    imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
    if not isinstance(imgs, list):
        imgs = [imgs.cpu()]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = torchvision.transforms.functional.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    fig.show()
    plt.show()

## Кастомизация модели

In [None]:
def generate_images(text, 
                    tokenizer, 
                    dalle, 
                    vae, 
                    top_k, 
                    top_p, 
                    images_num, 
                    temperature=1.0,
                    bs=8, 
                    seed=None,
                    use_cache=True):
    if seed is not None:
        utils.seed_everything(seed)

    vocab_size = dalle.get_param('vocab_size')
    text_seq_length = dalle.get_param('text_seq_length')
    image_seq_length = dalle.get_param('image_seq_length')
    total_seq_length = dalle.get_param('total_seq_length')
    device = dalle.get_param('device')

    text = text.lower().strip()

    input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length)

    pil_images, scores = [], []
    for chunk in more_itertools.chunked(range(images_num), bs):
        chunk_bs = len(chunk)
        # инференс
        with torch.no_grad():
            attention_mask = torch.tril(
                torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device))
            out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device)
            has_cache = False
            sample_scores = []
            for i in tqdm(range(len(input_ids), total_seq_length)):
                logits, has_cache = dalle(out[:, :i], attention_mask,
                                          has_cache=has_cache, use_cache=use_cache, return_loss=False)
                logits = logits[:, -1, vocab_size:]
                logits /= temperature
                filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
                probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
                sample = torch.multinomial(probs, 1)
                sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
                out = torch.cat((out, sample), dim=-1)
            codebooks = out[:, -image_seq_length:]
            images = vae.decode(codebooks)
            pil_images += utils.torch_tensors_to_pil_list(images)
            scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist()
    return pil_images, scores

In [None]:
@torch.jit.script
def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))

def gelu(x):
    return gelu_impl(x)

def dalle_layer_forward(self, hidden_states, ltor_mask, has_cache, use_cache):
    # hidden_states: [b, s, h]
    # ltor_mask: [1, 1, s, s]

    # Layer norm at the begining of the transformer layer.
    layernorm_output = self.input_layernorm(hidden_states)

    # Self attention.
    attention_output, att_has_cache = self.attention(
        layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache)  # if False else layernorm_output, True

    if self.cogview_sandwich_layernorm:
        attention_output = self.before_first_addition_layernorm(
            attention_output, has_cache=has_cache, use_cache=use_cache)

    # Residual connection.
    layernorm_input = hidden_states + attention_output

    # Layer norm post the self attention.
    layernorm_output = self.post_attention_layernorm(
        layernorm_input, has_cache=has_cache, use_cache=use_cache)

    # MLP.
    mlp_output, mlp_has_cache = self.mlp(
        layernorm_output, has_cache=has_cache, use_cache=use_cache
        )  # if False else layernorm_output, True

    if self.cogview_sandwich_layernorm:
        mlp_output = self.before_second_addition_layernorm(
            mlp_output, has_cache=has_cache, use_cache=use_cache)

    # Second residual connection.
    output = layernorm_input + mlp_output

    return output, att_has_cache and mlp_has_cache

In [None]:
# About 1.3x speedup. Query/key/value cat is surprisingly fast.
def dalle_sa_forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
    # hidden_states: [b, s, h]
    # ltor_mask: [1, 1, s, s]
    # Attention heads. [b, s, hp]
    
    def calculate_attention_scores(query_layer, key_layer, ltor_mask):
        key_t = key_layer.transpose(-1, -2)
        if self.cogview_pb_relax:
            attention_scores = torch.matmul(
                query_layer / math.sqrt(self.hidden_size_per_attention_head),
                key_t
            )
        else:
            attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head)
        ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:]
        attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask)
        if self.cogview_pb_relax:
            # normalize attention scores. Should not affect resulting softmax value
            alpha = 32
            attention_scores_scaled = attention_scores / alpha
            attention_scores_scaled_maxes, _ = attention_scores_scaled.detach().view(
                [attention_scores.size(0), attention_scores.size(1), -1]
            ).max(dim=-1)  # max per head per sample
            attention_scores_scaled_maxes = attention_scores_scaled_maxes.unsqueeze(-1).unsqueeze(-1).expand(
                [-1, -1, attention_scores.size(2), attention_scores.size(3)]
            )  # expand to [b, np, s, s]
            attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha
        return attention_scores
    
    t = hidden_states.shape[-2]
    if has_cache and use_cache:
        mixed_x_layer = self.query_key_value(hidden_states[:, self.past_output.shape[-2]:, :])
    else:
        mixed_x_layer = self.query_key_value(hidden_states)

    (mixed_query_layer,
        mixed_key_layer,
        mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

    query_layer = self._transpose_for_scores(mixed_query_layer)
    key_layer = self._transpose_for_scores(mixed_key_layer)
    value_layer = self._transpose_for_scores(mixed_value_layer)

    if use_cache and has_cache:
        value_layer = torch.cat((self.past_value, value_layer), dim=-2)
        key_layer = torch.cat((self.past_key, key_layer), dim=-2)
    attention_scores = calculate_attention_scores(
        query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
    )

    if use_cache:
        # self.past_query = query_layer
        self.past_key = key_layer
        self.past_value = value_layer
    else:
        has_cache = False

    if use_cache and has_cache:
        attention_scores = attention_scores[..., -1:, :]
        # value_layer = value_layer[..., -1:, :]
    
    # Attention probabilities. [b, np, s, s]
    attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.attention_dropout(attention_probs)

    # Context layer.
    # [b, np, s, hn]
    context_layer = torch.matmul(attention_probs, value_layer)

    # [b, s, np, hn]
    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

    new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
    # [b, s, hp]
    context_layer = context_layer.view(*new_context_layer_shape)

    # Output. [b, s, h]
    output = self.dense(context_layer)

    if use_cache:
        if has_cache:
            output = torch.cat((self.past_output, output), dim=-2)
            self.past_output = output
        else:
            self.past_output = output
        has_cache = True 
    output = self.output_dropout(output)
    return output, has_cache

In [None]:
def dalle_mlp_forward(self, hidden_states, has_cache=False, use_cache=False):
    if has_cache and use_cache:
        hidden_states = hidden_states[:, self.past_x.shape[1]:]

    # [b, s, 4hp]
    x = self.dense_h_to_4h(hidden_states)
    x = gelu(x)
    # [b, s, h]
    x = self.dense_4h_to_h(x)
    if use_cache:
        if has_cache:
            x = torch.cat((self.past_x, x), dim=-2)
            self.past_x = x
        else:
            self.past_x = x
        has_cache = True
    else:
        has_cache = False
    output = self.dropout(x)
    return output, has_cache

In [None]:
# Speeds up like 6 seconds.
def ln_forward(self, input, has_cache=False, use_cache=False):
    if has_cache and use_cache:
        input = input[:, self.past_output.shape[1]:]
    
    output = F.layer_norm(
        input, self.normalized_shape, self.weight, self.bias, self.eps)
    
    if use_cache:
        if has_cache:
            output = torch.cat((self.past_output, output), dim=1)
            self.past_output = output
        else:
            self.past_output = output
        has_cache = True
    else:
        has_cache = False
    return output

### Собираем сеть

In [None]:
for layer in dalle.module.transformer.layers:
    layer.forward = partial(dalle_layer_forward, layer)
    layer.mlp.forward = partial(dalle_mlp_forward, layer.mlp)
    layer.attention.past_attentions = None
    layer.attention.past_query = None
    layer.attention.forward = partial(dalle_sa_forward, layer.attention)
    for ln in [layer.input_layernorm,
               layer.before_first_addition_layernorm,
               layer.post_attention_layernorm,
               layer.before_second_addition_layernorm]:
        ln.forward = partial(ln_forward, ln)

## Бот

In [None]:
#@markdown Введите ID вашего бота.
BOT_ID = "Ваш ID"#@param {type:"string"}
bot = telebot.TeleBot(BOT_ID, parse_mode=None) 

def if_russian(check_str: str) -> bool:
    """
    проверяем, полностью ли русское сообщение
    """
    initial_len = len(check_str)
    result_str = " ".join(re.findall(r"[а-яА-Я]+", check_str))
    return not len(result_str) < initial_len

# ответ на команду start
@bot.message_handler(commands=['start'])
def send_welcome(message):
	bot.reply_to(message, "Привет. Это бот для генерации изображений с помощью ruDalle")

# ответ на команду help
@bot.message_handler(commands=['help'])
def send_welcome(message):
    bot.reply_to(message, "Чтобы сгенерировать изображение, отправьте текст на русском языке. Генерация изображения занимает до 2-х минут.")

@bot.message_handler(func=lambda message: if_russian(message.text))
def echo_all(message):
    # сообщение пользователя
    text = message.text
    pil_images = []
    # генерируем батч изображений (в данном случае одно)
    for top_k, top_p, images_num in [(2048, 0.995, 1)]:
        _pil_images, _ = generate_images(text, 
                                         tokenizer, 
                                         dalle, 
                                         vae, 
                                         top_k=top_k, 
                                         images_num=images_num, 
                                         bs=8, 
                                         top_p=top_p)
        pil_images += _pil_images
    # апскейл картинки
    sr_images = super_resolution(pil_images, realesrgan)
    # посылаем ответ
    bot.send_photo(message.chat.id, sr_images[0])

bot.infinity_polling()

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

2022-02-11 15:44:06,777 (__init__.py:621 MainThread) ERROR - TeleBot: "Infinity polling: polling exited"
2022-02-11 15:44:06,780 (__init__.py:623 MainThread) ERROR - TeleBot: "Break infinity polling"
