In [1]:
import pyterrier as pt
pt.init()

PyTerrier 0.9.2 has loaded Terrier 5.7 (built by craigm on 2022-11-10 18:30) and terrier-helper 0.0.7

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


In [2]:
import os
import torch
import pandas as pd
import numpy as np
import json
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.dummy import DummyClassifier
from sklearn.cluster import KMeans
from scipy import stats
import pickle
from scipy.stats import spearmanr,kendalltau
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial import distance_matrix

In [3]:
import pyterrier_crs

  from .collection import imread_collection_wrapper


In [4]:
from pyterrier_crs.index import ResNetIndex
model_name = "resnet101"

In [5]:
from pyterrier_crs import datasets
datasets.setup_datasets()

## Load index

In [6]:
r = ResNetIndex("/nfs/from_yashon/irecsys/shoes_test", model_name)

In [7]:
pyterrier_crs.datasets.setup_datasets()

In [8]:
from pyterrier_crs.usersim import UserSim
import pyterrier_crs.models

In [9]:
from typing import List, Tuple
def parse_qid(qid : str) -> Tuple[int, List[int], int]:
    
    first, second = qid.split("-")
    turn = int( second.replace("t", "") )
    first = first.replace("u", "")
    all_ids = first.split(",")
    target = int(all_ids[0])
    alternatives = [ int(item) for item in all_ids[1 : ] ]
        
    return (target, alternatives, turn)

## Loading checkpoints of usersim & recsys

In [10]:
ege_checkpoint = "http://www.dcs.gla.ac.uk/~craigm/fcrs/model_checkpoints/ege-rl-10000.pt"
usersim_path =   "http://www.dcs.gla.ac.uk/~craigm/fcrs/model_checkpoints/caption_model_shoes"

usersim = UserSim(usersim_path, r)
#ege = pyterrier_crs.models.EGE(ege_checkpoint, r, usersim[0].vocabSize, top_K = 100)
#transformer = pyterrier_crs.models.EGE(ege_checkpoint, r, usersim.vocabSize, top_K = 100, export_image_query_rep = True)

print("Models loaded")

relative captioning is called
Models loaded


  "num_layers={}".format(dropout, num_layers))


## MetaSimTol Class

In [11]:
class MetaUserSim(pt.Transformer):
    
    def __init__(self, inner : UserSim, ranker, tolerance = 1, target_as_alt = False):
        self.inner = inner
        self.tolerance = tolerance
        self.ranker = ranker
        self.target_as_alt = target_as_alt
        #self.counter = {}
        self.counter = defaultdict(int)
    

    def transform(self, df):
        rtr = []
        for qid, qidgroup in df.groupby("qid"):
            #parse qid (target, alternatives, turn) TODO
            target, alternatives, turn = parse_qid(qid)
            if self.target_as_alt:
                alternatives.append(target)
            if turn <= self.tolerance:
                new_qid = "u"+str(target)+"-t"+str(turn)
                qidgroup = qidgroup.copy()
                qidgroup['qid'] = new_qid
                new_qid_group = self.inner.transform(qidgroup)
                new_qid_group["qid"] = qid
                rtr.append(new_qid_group)
            else:
                candidate_id = qidgroup[qidgroup["rank"] == 0].iloc[0].docid
                candidate_rep = self.ranker.feat[candidate_id]
                val = self.ranker.feat[alternatives] - candidate_rep
                val = val ** 2
                val = val.sum(1)
                v, offset = val.min(0)
                #pick the most similar item 
                nearest_docid = alternatives[offset]
                #print(nearest_docid)
                # construct new qid
                new_qid = "u"+str(nearest_docid)+"-t"+str(turn)
                if nearest_docid!= target:
                    self.counter[turn] +=1
                qidgroup = qidgroup.copy()
                qidgroup['qid'] = new_qid
                new_qid_group = self.inner.transform(qidgroup)
                new_qid_group["qid"] = qid
                rtr.append(new_qid_group)

        return pd.concat(rtr)


In [12]:
from collections import defaultdict

## Define CRS model

In [13]:
import pyterrier_crs.models
transformer = pyterrier_crs.models.EGE(ege_checkpoint, r, usersim.vocabSize, top_K = 100, export_image_query_rep = True)

In [14]:
metasim = MetaUserSim(
    usersim, transformer.ranker, tolerance = 3, target_as_alt = True
)

In [15]:
metasim.image_name = metasim.inner.image_name

## Run CRS_Experiment using Meta Class

In [16]:
input_df = pd.read_csv('input_shoes_CRS_df.csv')
input_df

Unnamed: 0,qid,docno,docid,rank
0,"u1231,2401,4511-t0",img_womens_clogs_783.jpg,3593,0
1,"u3915,1141,2482,96,2588,877,1202-t0",img_womens_sneakers_158.jpg,2273,0
2,"u2536,2119,1141-t0",img_womens_high_heels_958.jpg,173,0
3,"u2426,4098,230-t0",img_womens_sneakers_1098.jpg,2546,0
4,"u4441,2401,1021,1426,169-t0",img_womens_clogs_915.jpg,1843,0
...,...,...,...,...
185,"u4116,543,3536-t0",img_womens_pumps_220.jpg,2129,0
186,"u2361,1387,1390,2909,2361,3322-t0",img_womens_flats_1302.jpg,2930,0
187,"u1720,3039,4334,1817,3524-t0",img_womens_rain_boots_483.jpg,29,0
188,"u3098,3865,3678,4233,3865,3822,938,4577,1141,2...",img_womens_clogs_728.jpg,3331,0


In [17]:
from pyterrier.measures import *
from pyterrier_crs.display import CRS_Experiment

exp_df = CRS_Experiment(
    input_df, 
    [transformer], 
    metasim, 
    [NDCG@10, 'recip_rank', Success@1, Success@10], 
    num_turns=10, 
    test_batch_size=64, 
    names=["EGE"],
    progress=True
    #export_to_csv="./ege_dresses_alter_tol3"
)
exp_df

100%|██████████| 3/3 [01:36<00:00, 32.29s/batch]


Unnamed: 0,name,measure,turn,value
0,EGE,Success@1,1,0.047368
1,EGE,Success@1,2,0.152632
2,EGE,Success@1,3,0.236842
3,EGE,Success@1,4,0.3
4,EGE,Success@1,5,0.405263
5,EGE,Success@1,6,0.442105
6,EGE,Success@1,7,0.478947
7,EGE,Success@1,8,0.489474
8,EGE,Success@1,9,0.521053
9,EGE,Success@1,10,0.568421


In [17]:
#exp_df.to_csv('exp_df_ege_tol3.csv', index = False)

In [18]:
metasim.counter

defaultdict(int, {4: 95, 5: 109, 6: 105, 7: 115, 8: 121, 9: 126, 10: 124})