In [1]:
from tqdm import tqdm
import numpy as np
from glob import glob
import torch
import tensorflow as tf
import faiss

import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

2023-09-19 15:02:13.183391: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [1]:
import faiss 
import numpy as np

In [2]:
embeddings = np.arange(10).reshape(-1 ,1).astype('float32')
embeddings

array([[0.],
       [1.],
       [2.],
       [3.],
       [4.],
       [5.],
       [6.],
       [7.],
       [8.],
       [9.]], dtype=float32)

In [3]:
index = faiss.IndexFlatL2(1)
index.add(embeddings)
index.ntotal

10

In [4]:
D, I = index.search(embeddings, 3)

In [5]:
D.shape, I.shape

((10, 3), (10, 3))

In [10]:
for num, (d, i) in enumerate(zip(D, I)):
    print(f'query: {num}')
    for val, dist in zip(embeddings[i], d):
        print(f'{val}: {dist}')
    print()

query: 0
[0.]: 0.0
[1.]: 1.0
[2.]: 4.0

query: 1
[1.]: 0.0
[0.]: 1.0
[2.]: 1.0

query: 2
[2.]: 0.0
[1.]: 1.0
[3.]: 1.0

query: 3
[3.]: 0.0
[2.]: 1.0
[4.]: 1.0

query: 4
[4.]: 0.0
[3.]: 1.0
[5.]: 1.0

query: 5
[5.]: 0.0
[4.]: 1.0
[6.]: 1.0

query: 6
[6.]: 0.0
[5.]: 1.0
[7.]: 1.0

query: 7
[7.]: 0.0
[6.]: 1.0
[8.]: 1.0

query: 8
[8.]: 0.0
[7.]: 1.0
[9.]: 1.0

query: 9
[9.]: 0.0
[8.]: 1.0
[7.]: 4.0



In [2]:
def get_metrics(query_label, labels):
    
    # initialise mean hit ratio, mean reciprocal rank, and mean average precision
    MHR, MRR, MAP = [], [], []
    
    # position, rank, and flag
    pos, mrr_flag = 0, False
    
    # iterate over the neighbors
    for rank, label in enumerate(labels):

        # its a hit
        if (query_label == label).all():
            pos += 1
            MHR.append(1)
            MAP.append(pos/(rank+1))

            # its the first hit
            if not mrr_flag:
                MRR.append(pos/(rank+1))
                mrr_flag = True
        
        # its a miss
        else:
            MHR.append(0)
            MAP.append(0)
    
    MRR = MRR[0] if len(MRR) else 0
    
    return sum(MAP)/len(MAP), sum(MHR)/len(MHR), MRR

In [3]:
file_path = '/ssd_scratch/cvit/arihanth/physionet.org/files/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/files'
all_files = sorted(glob(f'{file_path}/**/*.tfrecord', recursive=True))
print(len(all_files))

243324


In [4]:
raw_dataset = tf.data.TFRecordDataset(all_files)

2023-09-19 15:02:37.921431: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 2454 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:02:00.0, compute capability: 7.5
2023-09-19 15:02:37.922171: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 2568 MB memory:  -> device: 1, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:03:00.0, compute capability: 7.5
2023-09-19 15:02:37.922927: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 2568 MB memory:  -> device: 2, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:82:00.0, compute capability: 7.5
2023-09-19 15:02:37.923486: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 2568 MB memory:  -> device: 3, name: NVIDIA GeForce RTX

In [5]:
d = 1376

In [6]:
index = faiss.IndexFlatL2(d)   # build the index
print(index.is_trained)

True


In [7]:
all_emb = []

for i, test in enumerate(tqdm(raw_dataset.take(-1), total=len(all_files))):
    example = tf.train.Example()
    example.ParseFromString(test.numpy())
    f_name = example.features.feature['image/id'].bytes_list.value[0].decode('utf-8').split('/')
    assert f_name[6:9] == all_files[i].split('/')[8:11], f'{f_name[6:9]} != {all_files[i].split("/")[8:11]}'
    emb = np.array(example.features.feature['embedding'].float_list.value).astype(np.float32).reshape(1, -1)
    all_emb.append(emb)
    index.add(emb) 

100%|██████████| 243324/243324 [01:23<00:00, 2905.49it/s]


In [8]:
print(index.ntotal)

243324


In [9]:
idx, k = 13245, 30
for _, test in enumerate(tqdm(raw_dataset.skip(idx).take(1))):
    example = tf.train.Example()
    example.ParseFromString(test.numpy())
    f_name = example.features.feature['image/id'].bytes_list.value[0].decode('utf-8').split('/')
    assert f_name[6:9] == all_files[idx].split('/')[8:11], f'{f_name[6:9]} != {all_files[idx].split("/")[8:11]}'
    emb = np.array(example.features.feature['embedding'].float_list.value).astype(np.float32).reshape(1, -1)
    D, I = index.search(emb, k) # sanity check

1it [00:01,  1.00s/it]


In [10]:
D.shape, I.shape, D[0], I[0]

((1, 30),
 (1, 30),
 array([  0.     , 105.08956, 221.67871, 228.2427 , 240.05519, 243.17981,
        244.0354 , 249.12021, 250.36957, 250.50313, 252.62714, 254.26682,
        255.83916, 256.03058, 263.09354, 264.16855, 267.72662, 268.4331 ,
        268.6006 , 270.77734, 271.50287, 273.61182, 274.33374, 274.4439 ,
        274.94873, 275.5473 , 275.92383, 275.99344, 276.56293, 276.8836 ],
       dtype=float32),
 array([ 13245,  13246, 153752,  81398,   6727, 201689, 123850,  30607,
        108081,  45245,  63107,  66697, 123681, 219862,   1296,  95856,
         51507, 190085, 158009, 132451,  27778, 137348, 118796,  71700,
         59210,  44520, 110218, 204313,  98682, 171858]))

In [11]:
from dataloader.mimic_cxr_emb import CustomDataset

In [21]:
my_dataset = CustomDataset('/ssd_scratch/cvit/arihanth/physionet.org/files/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/files', 'all', -1)
my_loader  = torch.utils.data.DataLoader(my_dataset, batch_size=512, shuffle=False)

In [13]:
sample_emb, sample_label = next(iter(my_loader))
print(sample_emb, sample_label)

tensor([[ 0.1257, -1.8030,  1.2843,  ..., -0.7077,  1.0860,  0.0256]]) tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])


In [14]:
D[0], I[0]

(array([  0.     , 105.08956, 221.67871, 228.2427 , 240.05519, 243.17981,
        244.0354 , 249.12021, 250.36957, 250.50313, 252.62714, 254.26682,
        255.83916, 256.03058, 263.09354, 264.16855, 267.72662, 268.4331 ,
        268.6006 , 270.77734, 271.50287, 273.61182, 274.33374, 274.4439 ,
        274.94873, 275.5473 , 275.92383, 275.99344, 276.56293, 276.8836 ],
       dtype=float32),
 array([ 13245,  13246, 153752,  81398,   6727, 201689, 123850,  30607,
        108081,  45245,  63107,  66697, 123681, 219862,   1296,  95856,
         51507, 190085, 158009, 132451,  27778, 137348, 118796,  71700,
         59210,  44520, 110218, 204313,  98682, 171858]))

In [15]:
_, query_label = my_dataset.__getitem__(I[0][0])
labels = [my_dataset.__getitem__(i)[1] for i in I[0][1:]]

query_label, labels

(tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
 [tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,

In [16]:
get_metrics(query_label, labels)

(0.7860952051160411, 0.896551724137931, 1.0)

In [22]:
mAP, mHR, mRR = [], [], []

with tqdm(my_loader) as pbar:
    for emb, query_labels in pbar:
        D, I = index.search(emb, 5)

        labels = [[my_dataset.__getitem__(i)[1] for i in I[j][1:]] for j in range(I.shape[0])]

        for query_label, target_label in zip(query_labels, labels):
            MAP, MHR, MRR = get_metrics(query_label, target_label)
            mAP.append(MAP)
            mHR.append(MHR)
            if MRR:
                mRR.append(MRR)
        
        pbar.set_postfix({'mAP': sum(mAP)/len(mAP), 'mHR': sum(mHR)/len(mHR), 'mRR': sum(mRR)/len(mRR)})

100%|██████████| 476/476 [21:44<00:00,  2.74s/it, mAP=0.221, mHR=0.258, mRR=0.812]


In [None]:
# not batched, takes too long

mAP, mHR, mRR = [], [], []

with tqdm(all_emb) as pbar:
    for emb in pbar:
        D, I = index.search(emb, 5)

        _, query_label = my_dataset.__getitem__(I[0][0])
        labels = [my_dataset.__getitem__(i)[1] for i in I[0][1:]]

        MAP, MHR, MRR = get_metrics(query_label, labels)
        mAP.append(MAP)
        mHR.append(MHR)
        if MRR:
            mRR.append(MRR)
        
        pbar.set_postfix({'mAP': sum(mAP)/len(mAP), 'mHR': sum(mHR)/len(mHR), 'mRR': sum(mRR)/len(mRR)})