In [5]:
import sqlite3
import os
import pandas as pd
import numpy as np
import warnings
from collections import Counter
warnings.filterwarnings('ignore')

connection = sqlite3.connect('instance/db.sqlite')
cur = connection.cursor()
PATH = 'project/static/images/'

In [109]:
data = cur.execute('''SELECT id, img_path, clothes, color FROM photo''').fetchall()
columns = ['id', 'img_path', 'clothes', 'color']
data = pd.DataFrame(columns=columns, data=data)
data['clothes'] = data.clothes.apply(lambda x: x.split(','))
for i in range(21):
    data[f'{i}_clothes'] = data.clothes.apply(lambda x: int(i in x))
colors = list(map(str.strip, open('img_proessing/colors.txt').readlines()))
for i in colors:
    data[i+'_color'] =  data.color.apply(lambda x: int(i == x))

In [111]:
from sklearn.metrics.pairwise import cosine_distances
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image, ImageShow

resnet = models.resnet152(pretrained=True)
for param in resnet.parameters():
    param.requires_grad = False
resnet = torch.nn.Sequential(*(list(resnet.children())[:-1]))


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225]),
])

def process_image(image_path):
    image = Image.open(PATH + image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0)
    return image


def get_resnet_embedding(image_path):
    image = process_image(image_path)
    with torch.no_grad():
        resnet.eval()
        output = resnet(image)
    return output.squeeze().numpy() 


In [112]:
data['embedings'] = data['img_path'].apply(get_resnet_embedding)
data.embedings = data.embedings.apply(lambda x: [x.tolist()])

In [None]:
columns = data.id.unique()
dist_df = pd.DataFrame(index=columns, columns=columns)
for i in range(len(columns)):
    for j in range(i, len(columns)):
        dist_df[columns[i]][columns[j]] = cosine_distances(emb[columns[i]], emb[columns[j]])[0][0]
        dist_df[columns[j]][columns[i]] = dist_df[columns[i]][columns[j]]
dist_df

In [230]:
dist_df.to_csv('dist.csv', sep=';', index_label='img_id')

In [7]:
user_data = cur.execute('''SELECT img_id, state, clothes, date_time FROM interaction WHERE user_id = 3 ''').fetchall()
user_likes = list(map(lambda i: i[0], filter(lambda x: x[1] == 1, user_data)))[-10:]
user_dislikes = list(map(lambda i: i[0],  filter(lambda x: x[1] == 2, user_data)))[-10:]
user_likes, user_dislikes

([1077, 788, 611, 538, 959, 1086, 818, 935, 1036, 1002],
 [617, 1003, 926, 924, 595, 858, 895, 1081, 592, 765])

In [9]:
dist_df = pd.read_csv('dist.csv', sep=';', index_col='img_id')
dist_df['532']

img_id
532     0.000000
533     0.090881
534     0.164869
535     0.127992
536     0.117409
          ...   
1148    0.183338
1149    0.147591
1150    0.124968
1151    0.116861
1152    0.136420
Name: 532, Length: 621, dtype: float64

In [None]:
reco = []
for i in user_likes:
    reco += dist_df[i].sort_values().index[1:11].to_list()
for i, j in sorted(Counter(reco).items(), key=lambda x: -x[1])[:7]:
    if i not in user_likes and i not in user_dislikes:
        im = data[data['id'] == i]['img_path'].to_list()[0]
        ImageShow.show(Image.open(PATH+im))
# Counter(reco)

In [8]:
connection.executescript('''DELETE FROM interaction''')
connection.commit()
# connection.close()

In [12]:
from PIL import Image, ImageShow
def reco1(user):
    user_data = cur.execute(f'''SELECT img_id, state, clothes, date_time 
                            FROM interaction WHERE user_id = {user} ''').fetchall()
    user_likes = list(map(lambda i: i[0], filter(lambda x: x[1] == 1, user_data)))[-10:]
    user_dislikes = list(map(lambda i: i[0],  filter(lambda x: x[1] == 2, user_data)))[-10:]
    reco = []
    for i in user_dislikes:
        reco += dist_df[str(i)].sort_values().index[-5:].to_list()
    reco_im_id = -1
    for i, j in sorted(Counter(reco).items(), key=lambda x: -x[1]):
        im_id, path, clothes = cur.execute(f'''SELECT id, img_path, clothes 
                            FROM photo WHERE id = {i} ''').fetchone()
        ImageShow.show(Image.open(PATH+path))
    
    # clothes = cur.execute(f'''SELECT clothes, clothes_id FROM clothes WHERE clothes_id IN ({clothes})''').fetchall()
    # clothes = set(clothes)
    return im_id, path, clothes
reco1(3)

(752, 'image10467.png', '20,6,15')

In [15]:
cur.execute('''SELECT photo.id, photo.img_path, interaction.user_id, interaction.date_time FROM photo 
            JOIN interaction ON photo.id = interaction.img_id ''').fetchall()

[(620, 'image10438.png', 5, '2024-05-14 17:09:38.499397'),
 (685, 'image1019.png', 5, '2024-05-14 17:09:40.878776'),
 (715, 'image10338.png', 5, '2024-05-14 17:09:46.827129'),
 (1041, 'image10420.png', 5, '2024-05-14 17:09:52.593562'),
 (593, 'image10400.png', 5, '2024-05-14 17:09:54.368615'),
 (1111, 'image10222.png', 5, '2024-05-14 17:10:01.768712'),
 (764, 'image1036.png', 5, '2024-05-14 17:10:08.789266'),
 (746, 'image10507.png', 5, '2024-05-14 17:10:19.303854'),
 (595, 'image1078.png', 5, '2024-05-14 17:10:25.941850'),
 (955, 'image10323.png', 5, '2024-05-14 17:10:31.943794'),
 (891, 'image10046.png', 5, '2024-05-14 17:10:36.384619'),
 (1048, 'image10151.png', 5, '2024-05-14 17:10:51.135970'),
 (1109, 'image10395.png', 5, '2024-05-14 17:10:53.515902'),
 (692, 'image1018.png', 5, '2024-05-14 17:10:55.912405'),
 (622, '94ea2574-image115.png', 5, '2024-05-14 17:10:57.493350'),
 (606, 'image10799.png', 5, '2024-05-14 17:11:50.289201'),
 (635, 'image10388.png', 5, '2024-05-14 17:48:25.