In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import torch

In [3]:
test_sample=pd.read_pickle('/Users/lison/code/hackathon_bio_ai/example_data/conch_embedding/embeddings.pkl')

In [5]:
test_sample.keys()

dict_keys(['NCBI591', 'SPA101', 'NCBI293', 'MISC19', 'NCBI672'])

In [39]:
from tqdm.auto import tqdm
from itertools import islice


def retrieve_score_dict_for_each_dataset(test_image_embedding, patches_embeddings, dataset_name ):
    """
    Iter on each dataset patches and calculate the similarity score with the embeddings

    return a dictionary containing one column with the dataset_name+patch_id : similarity_score
    """
    all_candidates_embeddings=torch.stack(patches_embeddings['features']).squeeze(1)
    sim_scores = compute_scores(test_image_embedding, all_candidates_embeddings)
    candidate_ids=[]
    for id in tqdm(range(len(patches_embeddings['features']))):
        # Create a unique indentifier.
        entry = str(dataset_name) + "_" + str(id)

        candidate_ids.append(entry)

    similarity_mapping = dict(zip(candidate_ids, sim_scores))

    return similarity_mapping


def compute_scores(emb_one, emb_two):
    """Computes cosine similarity between two vectors"""
    scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
    return scores.numpy().tolist()



def fetch_similar(image_embedding, candidate_images, top_k=5):
    """Fetches the `top_k` similar images with `image` as the query."""

    all_scores={}
    for dataset_name, patches_embeddings in candidate_images.items():

        output_simscores= retrieve_score_dict_for_each_dataset(image_embedding, patches_embeddings, dataset_name )

        all_scores.update(output_simscores)


    # Sort the mapping dictionary and return top_k candidates.
    similarity_mapping_sorted = dict(
        sorted(all_scores.items(), key=lambda x: x[1], reverse=True)
    )
    id_entries = list(similarity_mapping_sorted.keys())[:top_k]

    selected_candidate_datasets = list(map(lambda x: x.split("_")[0], id_entries))
    patches_idx = list(map(lambda x: int(x.split("_")[-1]), id_entries))
    similarity_score=list(similarity_mapping_sorted.values())[:top_k]

    results_dict={}
    for i,names in enumerate(id_entries):
        results_dict[names]={'dataset_name':selected_candidate_datasets[i],
                             'patch_id':patches_idx[i],
                             'similarity_score':similarity_score[i]}

    #if we want to output lists then use this
    # return selected_candidate_datasets, patches_idx, similarity_score

    #if prefer to retrieve a dictionnary then output this
    return results_dict


In [40]:
test_set=test_sample['NCBI672']['features'][10][0]

In [41]:
# candidates, patches, similarity_score=fetch_similar(test_set, test_sample, top_k=100)
results_dict=fetch_similar(test_set, test_sample, top_k=10)

100%|██████████| 644/644 [00:00<00:00, 252820.27it/s]
100%|██████████| 691/691 [00:00<00:00, 738787.68it/s]


100%|██████████| 157/157 [00:00<00:00, 204441.39it/s]
100%|██████████| 2422/2422 [00:00<00:00, 499611.68it/s]
100%|██████████| 2408/2408 [00:00<00:00, 454581.15it/s]


In [42]:
results_dict

{'NCBI672_10': {'dataset_name': 'NCBI672',
  'patch_id': 10,
  'similarity_score': 1.0},
 'NCBI672_536': {'dataset_name': 'NCBI672',
  'patch_id': 536,
  'similarity_score': 0.9361788034439087},
 'NCBI672_2127': {'dataset_name': 'NCBI672',
  'patch_id': 2127,
  'similarity_score': 0.9288486242294312},
 'NCBI672_1666': {'dataset_name': 'NCBI672',
  'patch_id': 1666,
  'similarity_score': 0.9286186695098877},
 'NCBI672_912': {'dataset_name': 'NCBI672',
  'patch_id': 912,
  'similarity_score': 0.92724609375},
 'NCBI672_163': {'dataset_name': 'NCBI672',
  'patch_id': 163,
  'similarity_score': 0.9261435270309448},
 'NCBI672_441': {'dataset_name': 'NCBI672',
  'patch_id': 441,
  'similarity_score': 0.9201226830482483},
 'NCBI672_1244': {'dataset_name': 'NCBI672',
  'patch_id': 1244,
  'similarity_score': 0.9145873188972473},
 'NCBI672_519': {'dataset_name': 'NCBI672',
  'patch_id': 519,
  'similarity_score': 0.9132484197616577},
 'NCBI672_768': {'dataset_name': 'NCBI672',
  'patch_id': 768,

In [15]:
candidates

['NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NCBI672',
 'NC

In [16]:
patches

[100,
 2152,
 2140,
 660,
 2304,
 1813,
 97,
 339,
 748,
 1647,
 2161,
 1103,
 91,
 391,
 12,
 1386,
 1624,
 978,
 2350,
 1312,
 1547,
 2136,
 607,
 1403,
 1332,
 2069,
 213,
 1944,
 2230,
 2131,
 366,
 69,
 1026,
 2176,
 794,
 1512,
 770,
 1586,
 398,
 1152,
 1966,
 666,
 463,
 501,
 1990,
 2360,
 735,
 917,
 1158,
 1107,
 840,
 306,
 2338,
 150,
 2263,
 473,
 2026,
 706,
 1500,
 881,
 1711,
 2061,
 1391,
 858,
 2033,
 1196,
 2392,
 886,
 2355,
 219,
 41,
 643,
 1996,
 426,
 1199,
 2085,
 2399,
 124,
 837,
 2352,
 899,
 1219,
 1027,
 2047,
 636,
 2278,
 491,
 1733,
 915,
 1628,
 767,
 693,
 496,
 234,
 2323,
 1802,
 1957,
 38,
 2125,
 1240]

In [17]:
similarity_score

[1.0,
 0.9290846586227417,
 0.9033427238464355,
 0.888172447681427,
 0.8758412599563599,
 0.8710325360298157,
 0.8646895885467529,
 0.8638283014297485,
 0.8623823523521423,
 0.861164927482605,
 0.8598006963729858,
 0.8592799305915833,
 0.8589239120483398,
 0.8566622138023376,
 0.8566621541976929,
 0.8564107418060303,
 0.8554733991622925,
 0.8545252084732056,
 0.8542725443840027,
 0.8538774251937866,
 0.8537330031394958,
 0.852072536945343,
 0.8519091606140137,
 0.8518979549407959,
 0.8516289591789246,
 0.850304126739502,
 0.849414587020874,
 0.8488729000091553,
 0.8485807776451111,
 0.8482584953308105,
 0.8475426435470581,
 0.8474446535110474,
 0.8472878336906433,
 0.8471999764442444,
 0.8446747064590454,
 0.8444631695747375,
 0.8439488410949707,
 0.8433664441108704,
 0.8430297374725342,
 0.8425158262252808,
 0.8417850732803345,
 0.8417149782180786,
 0.8416786193847656,
 0.841031014919281,
 0.8406308889389038,
 0.8404603600502014,
 0.8400219678878784,
 0.8390124440193176,
 0.8388375043

In [None]:
import matplotlib.pyplot as plt

#Go get the patches corresponding to the results dict: results_dict
#tok_k: number of most similar patches to retreive
#candidate_subset: sdata containing the patches

fig = plt.figure(figsize= (10, 10))
for i in range (0,12):

  ax = fig.add_subplot(top_k/4, 4, i+1)
  # go get the the patches for the retrieved top k most similar patches
  ax.imshow(candidate_subset["image"][i])