In [None]:
!pip install transformers telebot accelerate peft sentencepiece

In [None]:
!pip install --upgrade gradio

In [1]:
import os
import sys
import base64
import shutil
import subprocess
from typing import Union
from dataclasses import dataclass, field

In [2]:
import gradio as gr

In [3]:
from peft import PeftModel

In [4]:
import torch

In [5]:
import warnings
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig

In [6]:
import telebot

In [7]:
def xor_decrypt_from_str(encrypted_str, key):
    encrypted_bytes = base64.b64decode(encrypted_str.encode('utf-8'))
    key_bytes = key.encode('utf-8')
    full_key_bytes = (key_bytes * (len(encrypted_bytes) // len(key_bytes) + 1))[:len(encrypted_bytes)]
    decrypted_bytes = bytes([encrypted_byte ^ key_byte for encrypted_byte, key_byte in zip(encrypted_bytes, full_key_bytes)])
    decrypted_str = decrypted_bytes.decode('utf-8')

    return decrypted_str

encrypted_telegram_api_key = "BggEAgMHCQUAAgp2cXgEc19oVwdnVgBWZmBbY0ZAYgYCVQFjQXp8XQFcV2AOZQ=="
telegram_api_key = xor_decrypt_from_str(encrypted_telegram_api_key, "007")

In [8]:
@dataclass
class Params:
    repo_url = "https://github.com/dmitryilyn/MIPT.git"
    temp_path = "/tmp/MIPT_repo"
    git_weights_dir = "NLPgen/HW2/weights"
    git_utils_dir = "NLPgen/HW2/utils"
    weights_path = "weights"
    utils_path = "utils"

    base_model = "nickypro/tinyllama-15M"
    tokenizer_name = "hf-internal-testing/llama-tokenizer"

    device = "cuda"

    load_in_8bit = False
    torch_dtype =  torch.float32
    device_map: str = field(init=False)
    low_cpu_mem_usage = True

    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1

    def __post_init__(self):
        if self.ddp:
            self.device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
        else:
            self.device_map = "auto"


params = Params()

# Загружаем модель

In [None]:
subprocess.run(["git", "clone", params.repo_url, params.temp_path])
shutil.copytree(os.path.join(params.temp_path, params.git_weights_dir), params.weights_path)
shutil.copytree(os.path.join(params.temp_path, params.git_utils_dir), params.utils_path)
shutil.rmtree(params.temp_path)

In [9]:
from utils.callbacks import Iteratorize, Stream
from utils.prompter import Prompter

In [10]:
prompter = Prompter()

tokenizer = LlamaTokenizer.from_pretrained(params.tokenizer_name)

joey_model = LlamaForCausalLM.from_pretrained(
    pretrained_model_name_or_path=params.base_model,
    load_in_8bit=params.load_in_8bit,
    torch_dtype=params.torch_dtype,
    device_map=params.device_map,
    low_cpu_mem_usage=params.low_cpu_mem_usage,
)
joey_model = PeftModel.from_pretrained(
    joey_model,
    params.weights_path,
    torch_dtype=params.torch_dtype,
    device_map={'': 0},
)
joey_model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
    joey_model = torch.compile(joey_model)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
  return self.fget.__get__(instance, owner)()


In [11]:
def generate_reply_async(
    text,
    context=None,
    temperature=0.1,
    top_p=0.75,
    top_k=40,
    num_beams=4,
    max_new_tokens=512,
    **kwargs,
):
    prompt = prompter.generate_prompt(text, context)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(params.device)
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        **kwargs,
    )

    generate_params = {
        "input_ids": input_ids,
        "generation_config": generation_config,
        "return_dict_in_generate": True,
        "output_scores": True,
        "max_new_tokens": max_new_tokens,
    }

    # Асинхронный вывод
    def generate_with_callback(callback=None, **kwargs):
        kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList())
        kwargs["stopping_criteria"].append(Stream(callback_func=callback))
        with torch.no_grad():
            joey_model.generate(**kwargs)

    def generate_with_streaming(**kwargs):
        return Iteratorize(generate_with_callback, kwargs, callback=None)

    with generate_with_streaming(**generate_params) as generator:
        for output in generator:
            decoded_output = tokenizer.decode(output)

            if output[-1] in [tokenizer.eos_token_id]:
                break

            yield prompter.get_response(decoded_output)

    return

In [14]:
def generate_reply(
    text,
    context=None,
    temperature=0.1,
    top_p=0.75,
    top_k=40,
    num_beams=4,
    max_new_tokens=512,
    **kwargs,
):
    prompt = prompter.generate_prompt(text, context)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(params.device)
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        **kwargs,
    )

    generate_params = {
        "input_ids": input_ids,
        "generation_config": generation_config,
        "return_dict_in_generate": True,
        "output_scores": True,
        "max_new_tokens": max_new_tokens,
    }

    # Без стриминга инференса
    with torch.no_grad():
        generation_output = joey_model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s, skip_special_tokens=True).strip()

    return prompter.get_response(output)

# Тест на Gradio с асинхронным выводом

In [14]:
warnings.filterwarnings("ignore")

In [17]:
iface = gr.Interface(
   fn=generate_reply_async,
   inputs=[
        gr.components.Textbox(
            lines=1,
            label="Ваше сообщение",
            placeholder="Hi!",
        ),
        gr.components.Textbox(
            lines=1,
            label="Контекст"
        ),
        gr.components.Slider(
            minimum=0, maximum=1, value=0.1, label="Температура"
        ),
        gr.components.Slider(
            minimum=0, maximum=1, value=0.75, label="Top p"
        ),
        gr.components.Slider(
            minimum=0, maximum=100, step=1, value=40, label="Top k"
        ),
        gr.components.Slider(
            minimum=1, maximum=4, step=1, value=4, label="Beams"
        ),
        gr.components.Slider(
            minimum=1, maximum=1024, step=1, value=512, label="Максимальное число токенов"
        )
    ],
    outputs=[
        gr.components.Textbox(
            lines=5,
            label="Ответ",
        )
    ]
)
iface.queue()
iface.launch(debug=True)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://a96d1846cf728e0500.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://a96d1846cf728e0500.gradio.live




In [16]:
iface.close()

Closing server running on port: 7860


# 4. Запускаем бота

In [13]:
warnings.filterwarnings("ignore")

In [16]:
bot = telebot.TeleBot(telegram_api_key)

context = ""

@bot.message_handler(commands=["start"])
def start(m, res=False):
    global context
    contex = "Hi! I'm Joe!"
    bot.send_message(m.chat.id, context)


@bot.message_handler(content_types=["text"])
def process_message(message):
    global context
    reply = generate_reply(
        message.text,
        context
    )
    context = reply
    bot.send_message(message.chat.id, reply)

In [17]:
print("https://t.me/JoeyGeneratorBot")
bot.polling(none_stop=False)

https://t.me/JoeyGeneratorBot
