In [1]:
import os
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
import faiss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision.models import densenet121

In [2]:
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

from dataloader.mimic_cxr_jpg import CustomDataset

In [3]:
os.environ['TORCH_HOME'] = '/ssd_scratch/cvit'

In [4]:
def get_metrics(query_label, labels):
    
    # initialise mean hit ratio, mean reciprocal rank, and mean average precision
    MHR, MRR, MAP = [], [], []
    
    # position, rank, and flag
    pos, mrr_flag = 0, False
    
    # iterate over the neighbors
    for rank, label in enumerate(labels):

        # its a hit
        if (query_label == label).all():
            pos += 1
            MHR.append(1)
            MAP.append(pos/(rank+1))

            # its the first hit
            if not mrr_flag:
                MRR.append(pos/(rank+1))
                mrr_flag = True
        
        # its a miss
        else:
            MHR.append(0)
            MAP.append(0)
    
    MRR = MRR[0] if len(MRR) else 0
    
    return sum(MAP)/len(MAP), sum(MHR)/len(MHR), MRR

In [5]:
config = {
    'batch_size': 32,
    'num_workers': 4,
    'data_dir': '/ssd_scratch/cvit/arihanth/physionet.org/files/mimic-cxr-jpg/2.0.0/files/',
    'device': 'cuda:1',
    'hidden_dim': 1000,
}

train_dataset = CustomDataset(config, None, 'train')
train_loader  = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'])

val_dataset = CustomDataset(config, None, 'validate')
val_loader  = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

test_dataset = CustomDataset(config, None, 'test')
test_loader  = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

In [6]:
len(train_dataset), len(val_dataset), len(test_dataset)

(310756, 2539, 4240)

In [7]:
model = densenet121(pretrained=True)
model = model.to(config['device'])



In [8]:
# build datastore
datastore = []
for i, (img, label) in tqdm(enumerate(val_loader), total=len(val_loader)):
    out = model(img.to(config['device']))
    datastore.append(out.cpu().detach().numpy())

datastore = np.concatenate(datastore, axis=0)

100%|██████████| 80/80 [03:01<00:00,  2.27s/it]


In [9]:
index = faiss.IndexFlatL2(config['hidden_dim'])   # build the index
print(index.is_trained)

True


In [10]:
index.add(datastore)
index.ntotal

2539

In [1]:
mAP, mHR, mRR = [], [], []

with tqdm(val_loader) as pbar:
    for imgs, query_labels in pbar:
        emb = model(imgs.to(config['device'])).cpu().detach().numpy()
        D, I = index.search(emb, 5)

        labels = [[val_dataset.__getitem__(i)[1] for i in I[j][1:]] for j in range(I.shape[0])]

        for query_label, target_label in zip(query_labels, labels):
            MAP, MHR, MRR = get_metrics(query_label, target_label)
            mAP.append(MAP)
            mHR.append(MHR)
            if MRR:
                mRR.append(MRR)
        
        pbar.set_postfix({'mAP': sum(mAP)/len(mAP), 'mHR': sum(mHR)/len(mHR), 'mRR': sum(mRR)/len(mRR)})

NameError: name 'tqdm' is not defined