In [None]:
!pip install git+https://github.com/cene555/ru-clip-tiny.git

In [None]:
!gdown -O ru-clip-tiny.pkl https://drive.google.com/uc?id=1-3g3J90pZmHo9jbBzsEmr7ei5zm3VXOL

In [None]:
import csv
import cv2
import faiss
import os
import torch
import transformers

import pandas as pd
import numpy as np

from PIL import Image
from sklearn.model_selection import train_test_split
from torchvision import transforms
from transformers import BertTokenizer

from rucliptiny import RuCLIPtiny
from rucliptiny.trainer import Trainer
from rucliptiny.predictor import Predictor

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
# подготовка данных для дообучения

df = pd.read_csv('ds/train.csv')
df.drop('id', axis=1, inplace=True)
df.rename(columns={'object_img': 'image_name', 'description': 'text'}, inplace=True)
df['image_name'] = df['image_name'].astype(str) + '.png'
df = df[['image_name', 'text']]
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
train_df.to_csv('train.csv', index=False)
val_df.to_csv('val.csv', index=False)

In [None]:
# дообучение модели RuClip-tiny с использованием накопления градиентов

torch.manual_seed(1576)
device = torch.device('cuda')
model = RuCLIPtiny().to(device)
model.load_state_dict(torch.load('ru-clip-tiny.pkl'))

trainer = Trainer(
    train_dataframe='train.csv',
    train_dir='ds/train/',
    val_dataframe='val.csv',
    val_dir='ds/train/',
    train_batch_size=64, grad_accum=16
)
model = trainer.train_model(model, epochs_num=100, device=device, verbose=2)
torch.save(model.state_dict(), 'model.pkl')

In [None]:
# Подготовка тестовых данных
texts = []
text_ids = []
for row in csv.DictReader(open('ds/test.csv')):
    text = row['description'].split(' ')
    text = ' '.join([item for item in text if item and '/' not in item])
    texts.append(text[:75])
    text_ids.append(row['id'])
    
image_patches = []
image_ids = []
images_path = 'ds/test/'
for image_file in os.listdir(images_path):
    image_patches.append(images_path + image_file)
    image_ids.append(image_file.rsplit('.', 1)[0])    

In [None]:
# Загрузка дообученной модели
model = RuCLIPtiny()
model.load_state_dict(torch.load('model.pkl', map_location='cuda'))
model = model.to('cuda').eval()
for x in model.parameters(): x.requires_grad = False
torch.cuda.empty_cache()

predictor = Predictor()

In [None]:
# вычисление эмбеддингов
images_embeddings = predictor.prepare_images_features(model, image_patches, device='cuda')
text_embeddings = predictor.prepare_text_features(model, texts, device='cuda')

In [None]:
# Построение индекса эмбеддингов с помощью библиотеки FAISS

items_embeddings = images_embeddings.cpu().numpy()
faiss_index = faiss.index_factory(768, 'Flat', faiss.METRIC_INNER_PRODUCT)
faiss_index.add(items_embeddings)

In [None]:
# Построение карты подобия текстов и изображений
similarity_dict = {}
for i, embedding in enumerate(text_embeddings.cpu().numpy()):
    similarities, indexes = faiss_index.search(np.expand_dims(embedding, 0), 900)
    similarity_dict[i] = list(zip(indexes[0], similarities[0]))

In [None]:
# строим словари текст-изображение и изображение-текст для отбора пар с максимальным подобием
t2i_dict = {}
for t, t_data in similarity_dict.items():
    t2i_dict[t] = {k:v for k,v in t_data}


i2t_dict = {}
for t, t_data in t2i_dict.items():
    for i, sim in t_data.items():
        if i not in i2t_dict:
            i2t_dict[i] = {}
        i2t_dict[i][t] = sim

for i, i_data in i2t_dict.items():
    i2t_dict[i] = {k: v for k, v in sorted(i_data.items(), key=lambda item: item[1], reverse=True)}

# Получаем финальный набор предсказаний циклической проверкой максимального подобия 
result = {}
while len(i2t_dict) > 0:
    for i in list(i2t_dict.keys()):
        i_data = i2t_dict[i]
        for t, sim in i_data.items():
            if t not in t2i_dict:
                continue
            remainder = [v for k, v in t2i_dict[t].items() if k in i2t_dict]
            if max(remainder) == sim:
                result[i] = t
                del t2i_dict[t]
                del i2t_dict[i]
                break

# замена порядковых индексов на ID-ы 
result = [[text_ids[v], image_ids[k]] for k, v in result.items()]

# запись решения в файл
with open('solution.csv', 'w') as f:
    f.write('id,object_img\n' + '\n'.join([f"{item[0]},{item[1]}" for item in result]) + '\n')