In [1]:
import razdel
import re
import torch
import pandas as pd
from collections import defaultdict
import telebot
from telebot import types
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image 
import torchvision
import torchvision.transforms as transforms 
import requests


class Vectorizer:
    pad = "<PAD>"
    unk = "<UNK>"
    sos = "<SOS>"
    eos = "<EOS>"

    def __init__(self, annotations):

        words_with_dot_list = annotations.apply(
            lambda x: self.tokenize(x))
        words_with_dot = words_with_dot_list.explode()
        words = words_with_dot.apply(lambda x: re.sub(r'[^\w\s]', '', x))
        self.counts = words.value_counts()
        words = list(self.counts[self.counts > 2].index)
        self.vocabulary = [Vectorizer.pad, Vectorizer.unk,
                           Vectorizer.sos, Vectorizer.eos, *words]

        text2seq = {word: i for i, word in enumerate(self.vocabulary)}
        self.padding_idx = text2seq[Vectorizer.pad]
        self.unknown_idx = text2seq[Vectorizer.unk]
        self.start_of_sentance_idx = text2seq[Vectorizer.sos]
        self.end_of_sentance_idx = text2seq[Vectorizer.eos]
        self.text2seq = defaultdict(lambda: self.unknown_idx,  text2seq)
        self.seq2text = {i: word for i, word in enumerate(self.vocabulary)}
        max_len = max(words_with_dot_list.apply(lambda x: len(x)))
        self.max_len = max_len + 2

    def __len__(self):
        return len(self.vocabulary)

    def tokenize(self, text):
        text_sub = re.sub(r'[^\w\s]', '', text)
        text_list = [_.text for _ in razdel.tokenize(text_sub.lower())]
        return text_list

    def encode(self, text):
        no_pad = [self.start_of_sentance_idx] + list(map(lambda x: self.text2seq.get(
            x, self.unknown_idx), self.tokenize(text))) + [self.end_of_sentance_idx]
        len_pad = self.max_len - len(no_pad)
        return torch.tensor(no_pad + [self.text2seq['<PAD>']]*len_pad)

    def decode(self, encode_text):
        with_pad = list(map(self.seq2text.get, encode_text.tolist(
        ) if not isinstance(encode_text, list) else encode_text))
        return ' '.join(list(filter(lambda x: x != '<PAD>', with_pad)))

all_captions_path = 'C:/Users/ivanb/Downloads/Telegram Desktop/all_captions.csv'
vectorizer = Vectorizer(pd.read_csv(all_captions_path)['translations'])

In [2]:
import torch
import torch.nn as nn
import torchvision.models as models

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.encoder.fc = nn.Linear(self.encoder.fc.in_features, 512)

        self.decoder = torch.nn.Sequential(
            nn.Linear(256*8, 128*14*14),
            nn.Unflatten(1, (128, 14, 14)),

            nn.Upsample(scale_factor=2.0, mode='nearest'),

            nn.Conv2d(128, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Upsample(scale_factor=2.0, mode='nearest'),

            nn.Conv2d(64, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Upsample(scale_factor=2.0, mode='nearest'),

            nn.Conv2d(32, 16, kernel_size=5, stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Upsample(scale_factor=2.0, mode='nearest'),

            nn.ConvTranspose2d(16, 8, 3, stride=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),

            nn.ConvTranspose2d(8, 3, 3, stride=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x) * 255
        return x

class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()

        self.attention_dim = attention_dim

        self.W = nn.Linear(decoder_dim, attention_dim)
        self.U = nn.Linear(encoder_dim, attention_dim)

        self.A = nn.Linear(attention_dim, 1)

    def forward(self, features, hidden_state):

        u_hs = self.U(features)
        w_ah = self.W(hidden_state)

        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1))

        attention_scores = self.A(combined_states)
        attention_scores = attention_scores.squeeze(2)

        alpha = F.softmax(attention_scores, dim=1)

        attention_weights = features * alpha.unsqueeze(2)
        attention_weights = attention_weights.sum(dim=1)

        return alpha, attention_weights


class AttentionDecoder(nn.Module):
    def __init__(self, embed_size: int, hidden_size: int, encoder_size: int, attention_dim: int, vocab: Vectorizer):
        super(AttentionDecoder, self).__init__()

        self.encoder = nn.Sequential(
            *list(Autoencoder().encoder.children())[:-2])

        for param in list(self.encoder.parameters())[:-1]:
            param.requires_grad = False

        self.vocab = vocab
        self.embedding = nn.Embedding(
            len(self.vocab), embed_size, padding_idx=self.vocab.padding_idx)
        self.pos_embeddings = nn.Embedding(self.vocab.max_len, embed_size)
        self.attention = Attention(encoder_size, hidden_size, attention_dim)

        self.init_h = nn.Linear(encoder_size, hidden_size)
        self.init_c = nn.Linear(encoder_size, hidden_size)

        self.lstm_cell = nn.LSTMCell(
            embed_size+encoder_size, hidden_size, bias=True)
        self.f_beta = nn.Linear(hidden_size, encoder_size)

        self.fcn = nn.Linear(hidden_size, len(self.vocab))
        self.drop = nn.Dropout(0.3)

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward_step(self, decoder_input, encoder_outputs, last_hidden, last_cell, position, device='cuda:0'):

        embeds = self.embedding(
            decoder_input) + self.pos_embeddings(torch.tensor(position).long().to(device))
        alpha, context = self.attention(encoder_outputs, last_hidden)

        lstm_input = torch.cat((embeds, context), dim=1)

        hidden, cell = self.lstm_cell(lstm_input, (last_hidden, last_cell))

        output = self.fcn(self.drop(hidden))

        return output, alpha, hidden, cell

    def forward(self, imgs, decoder_input, device='cuda:0'):

        encoder_outputs = self.encoder(imgs)
        encoder_outputs = encoder_outputs.permute(0, 2, 3, 1)
        encoder_outputs = encoder_outputs.view(
            encoder_outputs.size(0), -1, encoder_outputs.size(3))

        hidden, cell = self.init_hidden_state(encoder_outputs)

        seq_length = len(decoder_input[0])
        batch_size = decoder_input.size(0)
        num_features = encoder_outputs.size(1)

        outputs = torch.zeros(batch_size, seq_length,
                              len(self.vocab)).to(device)
        alphas = torch.zeros(batch_size, seq_length, num_features).to(device)

        for s in range(seq_length):
            output, alpha, hidden, cell = self.forward_step(
                decoder_input[:, s].to(device), encoder_outputs, hidden, cell, s)
            outputs[:, s] = output
            alphas[:, s] = alpha

        return outputs, alphas

    def greedy_decode(self, imgs, device='cpu'):

        encoder_outputs = self.encoder.to(device)(imgs)
        encoder_outputs = encoder_outputs.permute(0, 2, 3, 1)
        encoder_outputs = encoder_outputs.view(
            encoder_outputs.size(0), -1, encoder_outputs.size(3))

        batch_size = encoder_outputs.size(0)
        hidden, cell = self.init_hidden_state(encoder_outputs)
        decoder_input = torch.tensor(self.vocab.text2seq['<SOS>']).to(device)
        decoder_input = torch.LongTensor([decoder_input]).to(device)

        decoded_batch = [self.vocab.text2seq['<SOS>']]
        for i in range(self.vocab.max_len):
            decoder_output, alpha, hidden, cell = self.forward_step(
                decoder_input, encoder_outputs, hidden, cell, i, 'cpu')

            decoder_output = decoder_output.view(batch_size, -1)
            predicted_word_idx = decoder_output.argmax(dim=1)
            decoded_batch.append(predicted_word_idx.item())
            if self.vocab.seq2text[predicted_word_idx.item()] == "<EOS>":
                break
            decoder_input = predicted_word_idx
        return decoded_batch

In [3]:
transform = transforms.Compose([transforms.Resize((224, 224))]) 
CAPTION_MODEL_weights = torch.load("./Caption Image/models/captioning_flickr_attention_pos_embedding_weights.pt", map_location='cpu')


embed_size=512
attention_dim=512
encoder_size=512
hidden_size=512

CAPTION_MODEL = AttentionDecoder(embed_size=embed_size, hidden_size=hidden_size, encoder_size=encoder_size, attention_dim=attention_dim, vocab=vectorizer)
CAPTION_MODEL.load_state_dict(CAPTION_MODEL_weights)
CAPTION_MODEL.encoder.eval()
CAPTION_MODEL.eval()


AttentionDecoder(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [5]:
bot = telebot.TeleBot('7090775661:AAHIZsPF22Dn2_9_5ls_iIFUnNO0jq0gc28')

@bot.message_handler(commands=['start'])
def start(message):
    bot.send_message(message.chat.id, 'Привет! Я - модель глубокого обучения, созданная для описания изображений')
    markup_inline = types.InlineKeyboardMarkup()
    caption = types.InlineKeyboardButton(text = 'Описание фотографии', callback_data = 'caption')
    llm = types.InlineKeyboardButton(text = 'Спросить у Yandex-GPT', callback_data = 'llm')
    markup_inline.add(caption, llm)
    bot.send_message(message.chat.id, 'Что Вы хотите сделать?', reply_markup=markup_inline)

@bot.callback_query_handler(func=lambda call: 'caption' in call.data)   
def call_caption(call):
    action = 'описание'
    msg = bot.send_message(call.from_user.id, 'Пришлите фото')
    tensor = bot.register_next_step_handler(msg, partial(get_photo, call, action))
    
@bot.callback_query_handler(func=lambda call: 'llm' in call.data)   
def call_caption(call):
    action = 'вопрос'
    msg = bot.send_message(call.from_user.id, 'Пришлите фото')
    bot.register_next_step_handler(msg, partial(get_photo, call, action))
    
def get_photo(call, action, message):
    fileID = message.photo[-1].file_id   
    file_info = bot.get_file(fileID)
    downloaded_file = bot.download_file(file_info.file_path)
    with open("image.jpg", 'wb') as new_file:
        new_file.write(downloaded_file)
    image = torchvision.io.read_image('image.jpg') 
    img_tensor = transform(image)
    if action == 'описание':
        make_caption(call, img_tensor)
    elif action == 'вопрос':
        msg = bot.send_message(call.from_user.id, 'Какой вопрос вы хотите задать?')
        bot.register_next_step_handler(msg, partial(ask_llm, call, img_tensor))

def make_caption(call, img_tensor):
    caption_ready = CAPTION_MODEL.greedy_decode(img_tensor.float().unsqueeze(0))
    bot.send_message(call.from_user.id, vectorizer.decode(caption_ready[1:-1]))

def ask_llm(call, img_tensor, message):
    question = message.text
    caption_ready = CAPTION_MODEL.greedy_decode(img_tensor.float().unsqueeze(0))
    ask_to_llm(call, vectorizer.decode(caption_ready[1:-1]), question)
    
def ask_to_llm(call, caption, question):
    # api_key = 'AQVNw8Q_Vhgtj0fvm1g5qfS9vhGQ96UvlsVADQrb'
    api_key = 'AQVN1wuefTqX_WAQa6YAvx4mduQtzaX1OAXOcso2'

    url = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion"

    headers = {
            "Content-Type": "application/json",
            "Authorization": "Api-Key {0}".format(api_key)
    }

    json_body = {
      "modelUri": "gpt://b1gnc3j9s2j025i5f63o/yandexgpt/latest",
      "completionOptions": {
        "stream": False,
        "temperature": 0.15,
        "maxTokens": 300
      },
      "messages": [
        {
          "role": "system",
          "text": "Тебе на вход поступит вопрос и описание картинки, полученное с помощью другой нейросети Image Captioning. Постарайся максимально точно ответить на вопрос о картинке исходя из её описания. Не обязательно использовать все распознанные объекты на изображении для формирования ответа."
        },
        {
          "role": "user",
          "text": f"{question}\nОписание картинки: {caption}"
        }
      ]
    }   
    responce = requests.post(url, headers=headers, json=json_body)
    print(responce.json())
    bot.send_message(call.from_user.id, responce.json()['result']['alternatives'][0]['message']['text'])    
    
bot.polling(none_stop=True, interval=0)

{'result': {'alternatives': [{'message': {'role': 'assistant', 'text': 'Исходя из описания, можно предположить, что поезд едет по железнодорожным путям.\n\nОпределить конкретное направление движения поезда по описанию затруднительно.\n\nОбъект, обозначенный как <UNK>, не позволяет сделать более точный вывод о том, куда может ехать этот поезд.'}, 'status': 'ALTERNATIVE_STATUS_FINAL'}], 'usage': {'inputTextTokens': '80', 'completionTokens': '54', 'totalTokens': '134'}, 'modelVersion': '07.03.2024'}}
{'result': {'alternatives': [{'message': {'role': 'assistant', 'text': 'Загадочный мужчина в стильном образе держит табличку с интригующей надписью. Кто знает, что там написано? Может быть, это приглашение на секретное мероприятие или просто шутка? Снимок создаёт атмосферу тайны и вызывает желание разгадать загадку. #загадочныймужчина #секрет #интрига'}, 'status': 'ALTERNATIVE_STATUS_FINAL'}], 'usage': {'inputTextTokens': '87', 'completionTokens': '60', 'totalTokens': '147'}, 'modelVersion': 

ReadTimeout: HTTPSConnectionPool(host='api.telegram.org', port=443): Read timed out. (read timeout=25)