In [3]:
import sqlite3
import os
import pandas as pd
import numpy as np
import warnings
from collections import Counter
warnings.filterwarnings('ignore')

connection = sqlite3.connect('instance/db.sqlite')
cur = connection.cursor()
PATH = 'project/static/images/'

In [109]:
data = cur.execute('''SELECT id, img_path, clothes, color FROM photo''').fetchall()
columns = ['id', 'img_path', 'clothes', 'color']
data = pd.DataFrame(columns=columns, data=data)
data['clothes'] = data.clothes.apply(lambda x: x.split(','))
for i in range(21):
    data[f'{i}_clothes'] = data.clothes.apply(lambda x: int(i in x))
colors = list(map(str.strip, open('img_proessing/colors.txt').readlines()))
for i in colors:
    data[i+'_color'] =  data.color.apply(lambda x: int(i == x))

In [111]:
from sklearn.metrics.pairwise import cosine_distances
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image, ImageShow

resnet = models.resnet152(pretrained=True)
for param in resnet.parameters():
    param.requires_grad = False
resnet = torch.nn.Sequential(*(list(resnet.children())[:-1]))


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225]),
])

def process_image(image_path):
    image = Image.open(PATH + image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0)
    return image


def get_resnet_embedding(image_path):
    image = process_image(image_path)
    with torch.no_grad():
        resnet.eval()
        output = resnet(image)
    return output.squeeze().numpy() 


In [112]:
data['embedings'] = data['img_path'].apply(get_resnet_embedding)
data.embedings = data.embedings.apply(lambda x: [x.tolist()])

In [None]:
columns = data.id.unique()
dist_df = pd.DataFrame(index=columns, columns=columns)
for i in range(len(columns)):
    for j in range(i, len(columns)):
        dist_df[columns[i]][columns[j]] = cosine_distances(emb[columns[i]], emb[columns[j]])[0][0]
        dist_df[columns[j]][columns[i]] = dist_df[columns[i]][columns[j]]
dist_df

In [230]:
dist_df.to_csv('dist.csv', sep=';', index_label='img_id')

In [7]:
user_data = cur.execute('''SELECT img_id, state, clothes, date_time FROM interaction WHERE user_id = 5 ''').fetchall()
user_likes = list(map(lambda i: i[0], filter(lambda x: x[1] == 1, user_data)))[-10:]
user_dislikes = list(map(lambda i: i[0],  filter(lambda x: x[1] == 2, user_data)))[-10:]
user_likes

[1149, 708, 912]

In [10]:
dist_df = pd.read_csv('dist.csv', sep=';', index_col='img_id')
dist_df['532']

img_id
532     0.000000
533     0.090881
534     0.164869
535     0.127992
536     0.117409
          ...   
1148    0.183338
1149    0.147591
1150    0.124968
1151    0.116861
1152    0.136420
Name: 532, Length: 621, dtype: float64

In [8]:
reco = []
for i in user_likes:
    reco += dist_df[i].sort_values().index[1:11].to_list()
for i, j in sorted(Counter(reco).items(), key=lambda x: -x[1])[:7]:
    if i not in user_likes and i not in user_dislikes:
        im = data[data['id'] == i]['img_path'].to_list()[0]
        ImageShow.show(Image.open(PATH+im))
# Counter(reco)

KeyError: 1149

In [373]:
im = data[data['id'] == 995]
im

Unnamed: 0,id,img_path,clothes,color,0_clothes,1_clothes,2_clothes,3_clothes,4_clothes,5_clothes,...,red_color,orange_color,peach_color,yellow_color,green_color,blue_color,brown_color,violet_color,dark_color,embedings
463,995,image10345.png,"[0, 3, 9, 13, 15]",blue,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,"[[0.241148442029953, 0.5479977130889893, 0.108..."


In [277]:
cur.execute('''SELECT id, email, login FROM user ''').fetchall()

[(1, 'lari.x@mail.ru', 'lol'),
 (2, 'new_acc@mail.ru', 'lol'),
 (3, 'lol@mail.ru', 'lol'),
 (4, 'typalollol@gamil.cpm', 'gamilalka'),
 (5, 'new@mail.ru', 'new5')]

In [360]:
# connection.executescript('''DELETE FROM interaction''')
# connection.commit()
# connection.close()

In [375]:
def reco1(user):
    user_data = cur.execute(f'''SELECT img_id, state, clothes, date_time 
                            FROM interaction WHERE user_id = {user} ''').fetchall()
    user_likes = list(map(lambda i: i[0], filter(lambda x: x[1] == 1, user_data)))[-10:]
    user_dislikes = list(map(lambda i: i[0],  filter(lambda x: x[1] == 2, user_data)))[-10:]
    reco = []
    for i in user_likes:
        reco += dist_df[i].sort_values().index[1:7].to_list()
    reco_im_id = -1
    for i, j in sorted(Counter(reco).items(), key=lambda x: -x[1]):
        if i not in user_likes and i not in user_dislikes:
            reco_im_id = i
            break
    im_id, path, clothes = cur.execute(f'''SELECT id, img_path, clothes 
                            FROM photo WHERE id = {reco_im_id} ''').fetchone()
    clothes = cur.execute(f'''SELECT clothes, clothes_id FROM clothes WHERE clothes_id IN ({clothes})''').fetchall()
    clothes = set(clothes)
    return im_id, path, clothes


760 image10102.png {('туфли', 15), ('аксесуары', 0), ('платье', 10), ('сумка', 13)}


In [378]:
pd.read_csv('dist.csv', sep=';', index_col='img_id')

Unnamed: 0_level_0,532,533,534,535,536,537,538,539,540,541,...,1143,1144,1145,1146,1147,1148,1149,1150,1151,1152
img_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
532,0.000000,0.090881,0.164869,0.127992,0.117409,0.166112,0.110936,0.189296,0.152014,0.104161,...,0.167234,0.193522,0.173963,0.117299,0.137659,0.183338,0.147591,0.124968,0.116861,0.136420
533,0.090881,0.000000,0.134439,0.092319,0.114742,0.143979,0.101152,0.146070,0.117258,0.100623,...,0.123271,0.150577,0.131973,0.078333,0.082995,0.147360,0.117034,0.106618,0.080257,0.105393
534,0.164869,0.134439,0.000000,0.099290,0.183802,0.148955,0.133981,0.180000,0.140382,0.104938,...,0.098025,0.167492,0.118869,0.128477,0.105301,0.166930,0.191475,0.072936,0.166619,0.138721
535,0.127992,0.092319,0.099290,0.000000,0.140338,0.138159,0.086395,0.146765,0.133230,0.102183,...,0.114693,0.149887,0.101645,0.086789,0.105816,0.135424,0.141160,0.087851,0.106675,0.108290
536,0.117409,0.114742,0.183802,0.140338,0.000000,0.206305,0.083982,0.162764,0.148409,0.142742,...,0.174547,0.145667,0.124297,0.113354,0.127810,0.125042,0.132809,0.129681,0.075381,0.118373
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1148,0.183338,0.147360,0.166930,0.135424,0.125042,0.195683,0.129044,0.130707,0.148501,0.134127,...,0.186252,0.137775,0.126324,0.121037,0.115543,0.000000,0.160170,0.129809,0.140880,0.123167
1149,0.147591,0.117034,0.191475,0.141160,0.132809,0.219953,0.146317,0.172678,0.161601,0.149662,...,0.197041,0.210591,0.161311,0.135515,0.143837,0.160170,0.000000,0.160192,0.135383,0.177170
1150,0.124968,0.106618,0.072936,0.087851,0.129681,0.149962,0.129570,0.150356,0.135754,0.123838,...,0.089299,0.158433,0.078261,0.114586,0.075295,0.129809,0.160192,0.000000,0.123825,0.091584
1151,0.116861,0.080257,0.166619,0.106675,0.075381,0.164806,0.099749,0.145738,0.125891,0.122169,...,0.123948,0.144821,0.119636,0.107877,0.129634,0.140880,0.135383,0.123825,0.000000,0.124019
