In [192]:
import sqlite3
import os
import pandas as pd
import numpy as np
import warnings
from collections import Counter
warnings.filterwarnings('ignore')

connection = sqlite3.connect('instance/db.sqlite')
cur = connection.cursor()
PATH = 'project/static/images/'

In [109]:
data = cur.execute('''SELECT id, img_path, clothes, color FROM photo''').fetchall()
columns = ['id', 'img_path', 'clothes', 'color']
data = pd.DataFrame(columns=columns, data=data)
data['clothes'] = data.clothes.apply(lambda x: x.split(','))
for i in range(21):
    data[f'{i}_clothes'] = data.clothes.apply(lambda x: int(i in x))
colors = list(map(str.strip, open('img_proessing/colors.txt').readlines()))
for i in colors:
    data[i+'_color'] =  data.color.apply(lambda x: int(i == x))

In [111]:
from sklearn.metrics.pairwise import cosine_distances
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image, ImageShow

resnet = models.resnet152(pretrained=True)
for param in resnet.parameters():
    param.requires_grad = False
resnet = torch.nn.Sequential(*(list(resnet.children())[:-1]))


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225]),
])

def process_image(image_path):
    image = Image.open(PATH + image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0)
    return image


def get_resnet_embedding(image_path):
    image = process_image(image_path)
    with torch.no_grad():
        resnet.eval()
        output = resnet(image)
    return output.squeeze().numpy() 


In [112]:
data['embedings'] = data['img_path'].apply(get_resnet_embedding)
data.embedings = data.embedings.apply(lambda x: [x.tolist()])

In [None]:
columns = data.id.unique()
dist_df = pd.DataFrame(index=columns, columns=columns)
for i in range(len(columns)):
    for j in range(i, len(columns)):
        dist_df[columns[i]][columns[j]] = cosine_distances(emb[columns[i]], emb[columns[j]])[0][0]
        dist_df[columns[j]][columns[i]] = dist_df[columns[i]][columns[j]]
dist_df

In [230]:
dist_df.to_csv('dist.csv', sep=';', index_label='img_id')

In [316]:
user_data = cur.execute('''SELECT img_id, state, clothes, date_time FROM interaction WHERE user_id = 3 ''').fetchall()
user_likes = list(map(lambda i: i[0], filter(lambda x: x[1] == 1, user_data)))[-10:]
user_dislikes = list(map(lambda i: i[0],  filter(lambda x: x[1] == 2, user_data)))[-10:]
user_likes

[1020, 615, 663, 887, 835, 1113, 1138, 764, 1106]

In [317]:
reco = []
for i in user_likes:
    reco += dist_df[i].sort_values().index[1:11].to_list()
for i, j in sorted(Counter(reco).items(), key=lambda x: -x[1]):
    if i not in user_likes and i not in user_dislikes:
        new_ans = i
        break
new_ans

1048

In [318]:
data[data['id'] == new_ans]

Unnamed: 0,id,img_path,clothes,color,0_clothes,1_clothes,2_clothes,3_clothes,4_clothes,5_clothes,...,red_color,orange_color,peach_color,yellow_color,green_color,blue_color,brown_color,violet_color,dark_color,embedings
516,1048,image10151.png,"[0, 3, 7, 8, 17]",blue,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,"[[0.21100880205631256, 0.3498663604259491, 0.1..."


In [286]:
user_likes

[918, 820, 1079, 714, 928, 1070, 780]

In [277]:
cur.execute('''SELECT id, email, login FROM user ''').fetchall()


[(1, 'lari.x@mail.ru', 'lol'),
 (2, 'new_acc@mail.ru', 'lol'),
 (3, 'lol@mail.ru', 'lol'),
 (4, 'typalollol@gamil.cpm', 'gamilalka'),
 (5, 'new@mail.ru', 'new5')]