In [None]:
import json
import torch
import torch.nn.functional as F
from telegram import Update, ReplyKeyboardMarkup
from telegram.ext import Updater, CommandHandler, MessageHandler, Filters, CallbackContext

checkpoint = torch.load("/content/model_weights.pth", map_location=torch.device('cpu'))
print(checkpoint.keys())  # Проверим, какие веса сохранены

import torch
import torch.nn as nn
from torch.nn import Transformer

# Определяем ту же модель
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, nhead, hidden_dim, num_layers, dropout):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.pos_encoder = PositionalEncoding(embed_dim, dropout)
        self.transformer = Transformer(
            d_model=embed_dim,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True
        )
        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, src, tgt):
        src = self.embedding(src)
        tgt = self.embedding(tgt)
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        output = self.transformer(src, tgt)
        output = self.fc_out(output)
        return output

# Определяем PositionalEncoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Гиперпараметры (должны совпадать с обученной моделью)
vocab_size = 762
embed_dim = 128
nhead = 4
hidden_dim = 256
num_layers = 2
dropout = 0.1

# Инициализируем модель
model = TransformerModel(vocab_size, embed_dim, nhead, hidden_dim, num_layers, dropout)

# Загружаем сохранённые веса
model.load_state_dict(torch.load("/content/model_weights.pth", map_location=torch.device('cpu')))

# Переводим в режим оценки
model.eval()


with open("vocab.json", "r") as f:
    vocab = json.load(f)


vocab = {word: idx for idx, word in enumerate(vocab)}  
vocab["<unk>"] = 0  # Добавляем токен для неизвестных слов

# Обратный словарь
reverse_vocab = {idx: word for word, idx in vocab.items()}

# Токенизатор (преобразует строку в индексы)
def tokenizer(text):
    return [vocab.get(word, vocab["<unk>"]) for word in text.split()]

# Детокенизатор (преобразует индексы обратно в строку)
def detokenizer(indices):
    return " ".join([reverse_vocab.get(idx, "<unk>") for idx in indices])


# Загрузка модели
model = torch.load("model.pth", map_location=torch.device("cpu"))
model.eval()

# Загрузка словаря
with open("vocab.json", "r") as f:
    vocab = json.load(f)
vocab = {word: idx for idx, word in enumerate(vocab)}  # Индексация
vocab["<unk>"] = 0  # Добавляем токен для неизвестных слов
reverse_vocab = {idx: word for word, idx in vocab.items()}

# Функции токенизатора
def tokenizer(text):
    return [vocab.get(word, vocab["<unk>"]) for word in text.split()]

def detokenizer(indices):
    return " ".join([reverse_vocab.get(idx, "<unk>") for idx in indices])

# Функция генерации текста
def generate_text(start_sequence, max_len=50, temperature=1.0, top_p=0.9):
    tokens = tokenizer(start_sequence)
    input_seq = torch.tensor(tokens).unsqueeze(0)
    generated = tokens.copy()
    
    for _ in range(max_len):
        with torch.no_grad():
            output = model(input_seq, input_seq)
            next_token_logits = output[0, -1, :]
        next_token_logits = next_token_logits / temperature
        sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        cutoff = torch.where(cumulative_probs > top_p)[0][0]
        filtered_logits = sorted_logits[:cutoff + 1]
        filtered_indices = sorted_indices[:cutoff + 1]
        probs = F.softmax(filtered_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).item()
        generated.append(filtered_indices[next_token].item())
        input_seq = torch.tensor(generated).unsqueeze(0)
        if filtered_indices[next_token].item() == vocab.get('<eos>', -1):
            break
    return detokenizer(generated)

# Функции Телеграм-бота
def start(update: Update, context: CallbackContext) -> None:
    keyboard = [["Сгенерировать текст"]]
    reply_markup = ReplyKeyboardMarkup(keyboard, one_time_keyboard=True, resize_keyboard=True)
    update.message.reply_text("Привет! Нажми кнопку, чтобы сгенерировать текст.", reply_markup=reply_markup)

def handle_message(update: Update, context: CallbackContext) -> None:
    text = update.message.text
    if text == "Сгенерировать текст":
        generated_text = generate_text("adventurer", max_len=30, temperature=0.7, top_p=0.9)
        update.message.reply_text(f"Сгенерированный текст: {generated_text}")
    else:
        update.message.reply_text("Нажмите кнопку, чтобы сгенерировать текст.")

# Запуск бота
def main():
    updater = Updater("TOKEN")
    dp = updater.dispatcher
    dp.add_handler(CommandHandler("start", start))
    dp.add_handler(MessageHandler(Filters.text & ~Filters.command, handle_message))
    updater.start_polling()
    updater.idle()

if __name__ == "__main__":
    main()
