In [6]:
import os
import sys
src_path = os.path.split(os.getcwd())[0]
sys.path.insert(0, src_path)

import json
import random
import glob

import pandas as pd
import numpy as np
import torch
import torchvision.datasets as datasets
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import accuracy_score, balanced_accuracy_score

from PIL import Image

from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop, InterpolationMode

import matplotlib.pyplot as plt

import clip_.clip as clip
from clip_.clip import _transform
from clip_.model import CLIPGeneral

import training.zeroshot_data_crx as zeroshot_data
import shutil

# Set the GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

In [7]:
device_count = torch.cuda.device_count()
print("Number of available GPUs:", device_count)
device_ids = list(range(device_count))
print("Device IDs:", device_ids)

Number of available GPUs: 4
Device IDs: [0, 1, 2, 3]


In [8]:
checkpoint_paths = [
    '/system/user/publicdata/CLOOB/paper/checkpoints_icml22/yfcc/clip_rn50_yfcc_epoch_28.pt',
    '/system/user/publicdata/CLOOB/paper/checkpoints_icml22/yfcc/clip_rn50x4_yfcc_epoch_28.pt',
    '/system/user/publicdata/CLOOB/paper/checkpoints_icml22/yfcc/clip_rn101_yfcc_epoch_28.pt',
    '/system/user/publicdata/CLOOB/paper/checkpoints_icml22/yfcc/cloob_rn50_yfcc_epoch_28.pt',
    '/system/user/publicdata/CLOOB/paper/checkpoints_icml22/yfcc/cloob_rn50x4_yfcc_epoch_28.pt',
    '/system/user/publicdata/CLOOB/paper/checkpoints_icml22/yfcc/cloob_rn101_yfcc_epoch_28.pt',
]

In [9]:
def load_checkpoint(checkpoint_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path)
    model_config_file = os.path.join('/system/user/hagenede/MIMIC_CXR_bachelor/Cloob_zero/cloob/src/training/model_configs', checkpoint['model_config_file'])

    print("Device is", device)

    # Load model config
    assert os.path.exists(model_config_file)
    with open(model_config_file, 'r') as f:
        model_info = json.load(f)
    model = CLIPGeneral(**model_info)
    preprocess = _transform(model.visual.input_resolution, is_train=False)

    # Load model state dictionary
    sd = checkpoint["state_dict"]
    sd = {k[len('module.'):]: v for k, v in sd.items()}  # Remove 'module.' prefix from keys for DataParallel
    if 'logit_scale_hopfield' in sd:
        sd.pop('logit_scale_hopfield', None)
    model.load_state_dict(sd)

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model, device_ids=device_ids)
        print("Device IDs:", device_ids)
        print('model')
        #next(model.parameters()).device print
        print(model)
        print('model.parameters().device')
        print(next(model.parameters()).device)
        print('model.module.visual.conv1.parameters().device')
        print(next(model.module.visual.conv1.parameters()).device)





    model.to(device)
    model.eval()

    return model, preprocess, device

In [10]:
def load_checkpoint(checkpoint_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path)
    model_config_file = os.path.join('/system/user/hagenede/MIMIC_CXR_bachelor/Cloob_zero/cloob/src/training/model_configs', checkpoint['model_config_file'])

    print("Device is", device)

    # Load model config
    assert os.path.exists(model_config_file)
    with open(model_config_file, 'r') as f:
        model_info = json.load(f)
    model = CLIPGeneral(**model_info)
    preprocess = _transform(model.visual.input_resolution, is_train=False)

    # Load model state dictionary
    sd = checkpoint["state_dict"]
    sd = {k[len('module.'):]: v for k, v in sd.items()}  # Remove 'module.' prefix from keys for DataParallel
    if 'logit_scale_hopfield' in sd:
        sd.pop('logit_scale_hopfield', None)
    model.load_state_dict(sd)

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
        model.to(device)
        print("Device IDs:", list(range(torch.cuda.device_count())))
    else:
        model.to(device)

    model.eval()

    return model, preprocess, device


In [11]:
load_checkpoint(checkpoint_paths[0])


Device is cuda
Using 4 GPUs!
Device IDs: [0, 1, 2, 3]


(DataParallel(
   (module): CLIPGeneral(
     (visual): ModifiedResNet(
       (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
       (relu): ReLU(inplace=True)
       (layer1): Sequential(
         (0): Bottleneck(
           (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
           (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
           (conv2): Conv2d(64, 6

In [32]:
def zero_shot_classifier(model, classnames, templates, device):
    with torch.no_grad():
        zeroshot_weights = []
        tokenizer = clip.tokenize  # Move tokenizer definition outside the loop
        for classname in tqdm(classnames):
            texts = [template(classname) for template in templates]  # format with class
            texts = tokenizer(texts).to(device)  # tokenize and move to CUDA device
            class_embeddings = model.module.encode_text(texts)  # Access the encode_text method from model.module
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
    return zeroshot_weights


In [13]:
def run(model, classifier, dataloader, device, accuracy_metric):
    with torch.no_grad():
        all_logits = []
        all_targets = []
        for images, target in tqdm(dataloader):
            images = images.to(device)
            target = target.to(device)

            # predict
            image_features = model.module.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            logits = image_features @ classifier

            all_logits.append(logits.cpu())
            all_targets.append(target.cpu())

        all_logits = torch.cat(all_logits).numpy()
        all_targets = torch.cat(all_targets).numpy()

        acc = accuracy_metric(all_targets, all_logits.argmax(axis=1)) * 100.0
        return acc

In [14]:
classnames = zeroshot_data.classes_CRX_one
prompt_templates = zeroshot_data.illness_templates_one

In [17]:
from torch.utils.data import Dataset

In [18]:
class CustomImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        image_path = row['jpg_path']
        image = Image.open(image_path).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        label = row['Finding']
        return image, label


In [19]:
df_jpg = pd.read_csv('/system/user/publicdata/MIMIC_CXR/hageneder/jpg_path_fingings.csv')


In [20]:
df_jpg

Unnamed: 0,study_id,jpg_path,Finding
0,50414267,/system/user/publicdata/MIMIC_CXR/hageneder/JP...,0
1,50414267,/system/user/publicdata/MIMIC_CXR/hageneder/JP...,0
2,53189527,/system/user/publicdata/MIMIC_CXR/hageneder/JP...,0
3,53189527,/system/user/publicdata/MIMIC_CXR/hageneder/JP...,0
4,53911762,/system/user/publicdata/MIMIC_CXR/hageneder/JP...,0
...,...,...,...
377090,57132437,/system/user/publicdata/MIMIC_CXR/hageneder/JP...,0
377091,57132437,/system/user/publicdata/MIMIC_CXR/hageneder/JP...,0
377092,55368167,/system/user/publicdata/MIMIC_CXR/hageneder/JP...,1
377093,58621812,/system/user/publicdata/MIMIC_CXR/hageneder/JP...,1


In [33]:
df_accuracy = pd.DataFrame(columns=['model', 'accuracy'])

for checkpoint_path in checkpoint_paths:
    print("Loading checkpoint:", checkpoint_path)
    model, preprocess, device = load_checkpoint(checkpoint_path)

    #data_path = '/system/user/publicdata/MIMIC_CXR/hageneder/embeddings/embedding_0_1'
    dataset = CustomImageDataset(df_test, transform=preprocess)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, num_workers=4)

    model.eval()
    accuracy_metric = accuracy_score

    print("Calculating the text embeddings for all classes of the dataset", flush=True)
    classifier = zero_shot_classifier(model, classnames, prompt_templates, device)

    print("Calculating the image embeddings for all images of the dataset", flush=True)
    accuracy = run(model, classifier, dataloader, device, accuracy_score)
    print('Zeroshot accuracy:', accuracy.round(2))
    #add the accuracy in a pandas dataframe and a column for the checkpoint path but only the last part of the path (after the last /)
    new_row = {'model': checkpoint_path.split('/')[-1], 'accuracy': round(accuracy, 2)}
    df_accuracy = pd.concat([df_accuracy, pd.DataFrame(new_row, index=[0])], ignore_index=True)
    torch.cuda.empty_cache()

Loading checkpoint: /system/user/publicdata/CLOOB/paper/checkpoints_icml22/yfcc/clip_rn50_yfcc_epoch_28.pt
Device is cuda
Using 4 GPUs!
Device IDs: [0, 1, 2, 3]
Calculating the text embeddings for all classes of the dataset


100%|██████████| 2/2 [00:00<00:00,  7.15it/s]

Calculating the image embeddings for all images of the dataset



  0%|          | 0/1474 [00:00<?, ?it/s]