In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.metrics import auc as calc_auc, precision_recall_curve, average_precision_score
import random
import glob
import ngsci
import hydra
import h5py
sys.path.append("../../../03_training/transformer_on_embeddings_bag/")
from transformer_model_cls import ClsTokenTransformerClassifier

from omegaconf import OmegaConf
from omegaconf import DictConfig

## LOAD CORRESPONDING SAVED CONFIG FILE!

In [None]:
conf_preproc = OmegaConf.load("/home/ngsci/project/tuberculosis_detection/conf/preproc.yaml")
conf_train = OmegaConf.load("/home/ngsci/project/tuberculosis_detection/03_training/transformer_on_embeddings_bag/uni_224_224_patches_cls/runs/run_4_sqrt_sampling/conf_train.yaml")
conf_preproc = conf_preproc["transformer_on_embeddings_bag"]["uni_224_224_patches"]

In [None]:
if torch.cuda.is_available():
    DEVICE = 'cuda:0' 
else:
    DEVICE = 'cpu'
print("Device:", DEVICE)

## Load models from CV folds

In [None]:
results_dir = conf_train.results_dir
results_dir

In [None]:
checkpoints_dir_cv_0 = f'{results_dir}cv_0/'
checkpoints_dir_cv_1 = f'{results_dir}cv_1/'
checkpoints_dir_cv_2 = f'{results_dir}cv_2/'
checkpoints_dir_cv_3 = f'{results_dir}cv_3/'
checkpoints_dir_cv_4 = f'{results_dir}cv_4/'
checkpoints_dir_cv_5 = f'{results_dir}cv_5/'
checkpoints_dir_cv_6 = f'{results_dir}cv_6/'
checkpoints_dir_cv_7 = f'{results_dir}cv_7/'
checkpoints_dir_cv_8 = f'{results_dir}cv_8/'
checkpoints_dir_cv_9 = f'{results_dir}cv_9/'

In [None]:
file_names_all_cv = np.array([ np.array( sorted(  glob.glob( os.path.join(eval(f"checkpoints_dir_cv_{i}"), "*.pt"))   )) for i in range(10) ], dtype=object)
file_names_all_cv.shape

In [None]:
file_names_all_cv[0][0]

## Load input data

In [None]:
embeddings_bag_input_path = conf_preproc.emb_dir

tb_df_local_test = pd.read_csv(conf_preproc["cv_split_dir"] +'test_split_stratified.csv')
tb_df_local_test.sort_values('image', inplace=True)
embeddings_bag_input_files_local_test = np.array( sorted([ embeddings_bag_input_path + os.path.basename(i).replace(".jpg", ".h5") for i in tb_df_local_test.file_path.values ]) )

embeddings_bag_input_files_local_test.shape

## Look for best models based on AUC or VAL LOSS

#### AUC

In [None]:
best_models_on_val_auc = []

for i in range(file_names_all_cv.shape[0]):
    select = 10
    max_auc_sort_index = np.argsort([float(os.path.basename(item).split('_')[5]) for item in file_names_all_cv[i]])[::-1]
    
    for m in range(select):
        max_auc_model = file_names_all_cv[i][max_auc_sort_index[m]]
        best_models_on_val_auc.append(max_auc_model)

best_models_on_val_auc = np.array(best_models_on_val_auc)

In [None]:
best_models_on_val_auc.shape

In [None]:
best_models_on_val_auc;

In [None]:
print("ROC AUC of selected models: ", np.mean([float(e.split('_auc_')[1].split('_')[0]) for e in best_models_on_val_auc]))
print("PR AUC of selected models: ", np.mean([float(e.split('_prauc_')[1].split('_')[0]) for e in best_models_on_val_auc]))

#### PR AUC

In [None]:
best_models_on_prauc = []

for i in range(file_names_all_cv.shape[0]):
    select = 10
    max_prauc_sort_index = np.argsort([float(os.path.basename(item).split('_')[7]) for item in file_names_all_cv[i]])[::-1]
    
    for m in range(select):
        max_prauc_model = file_names_all_cv[i][max_prauc_sort_index[m]]
        best_models_on_prauc.append(max_prauc_model)

best_models_on_prauc = np.array(best_models_on_prauc)

In [None]:
best_models_on_prauc.shape

In [None]:
best_models_on_prauc;

In [None]:
print("ROC AUC of selected models: ", np.mean([float(e.split('_auc_')[1].split('_')[0]) for e in best_models_on_prauc]))
print("PR AUC of selected models: ", np.mean([float(e.split('_prauc_')[1].split('_')[0]) for e in best_models_on_prauc]))

## Predict with model ensemble

In [None]:
def plot_roc(y_true, y_pred):
    if y_pred.shape != y_true.shape:
        y_true = F.one_hot(torch.from_numpy(y_true).to(torch.int64), 2)

    plt.figure(figsize=(6, 6))
    auc_all = []
    for class_ind in range(y_pred.shape[-1]):
        fpr, tpr, _ = roc_curve(y_true[:, class_ind], y_pred[:, class_ind])
        auc = roc_auc_score(y_true[:, class_ind], y_pred[:, class_ind])
        auc_all.append(auc)
        plt.plot(fpr, tpr, '-', label='AUC : %.3f, label : %d' % (auc, class_ind))
    plt.legend()
    plt.show()
    return auc_all

In [None]:
def load_merged_h5_file(filename):
    with h5py.File(filename, "r") as f:
        coords = f['coords'][()]
        features = f['features'][()]
        tb_positive = f['tb_positive'][()]
        
        return coords, features, tb_positive

In [None]:
class h5_Dataset(Dataset):
    def __init__(self, emb_file_in_memory, cv_df, transform=None):
        self.transform = transform
        self.emb_file_in_memory = emb_file_in_memory
        self.cv_df = cv_df
        
        self.cv_samples_index = np.array([ int(os.path.basename(f).replace(".jpg","").replace("tb",""))-1 for f in self.cv_df["file_path"] ])
        self.cv_idx_to_all_idx = dict(zip(np.arange(self.cv_samples_index.shape[0]), self.cv_samples_index))                                 
        
    def __len__(self):
        return len(self.cv_samples_index)

    def __getitem__(self, idx):
        
        all_idx = self.cv_idx_to_all_idx[idx]
        
        image_data = self.emb_file_in_memory[all_idx]

        if self.transform:
            image_data = self.transform(image_data)
        
        return image_data

In [None]:
def pred_with_one_model(model, data_loader):
    
    preds_all = []
    labels_all = []
    
    for data in data_loader:

        with torch.no_grad():
            
            data = data.to(DEVICE, dtype=torch.float32, non_blocking=True)
                                       
            _, preds, label, _, _ = model(data)

        preds_all.append(preds.cpu().detach().numpy()[:,:2])
        labels_all.append(label.cpu().numpy())
        
    preds_all = np.concatenate(preds_all)
    labels_all = np.concatenate(labels_all)
    
    return preds_all, labels_all

### Select best models

In [None]:
best_models_on_selected_metric = best_models_on_prauc

### Local test set

In [None]:
# LOAD MERGED H5 IN MEMORY
print("Loading merged h5 local test file into memory...")
coords_h5, features_h5, tb_positive_h5 = load_merged_h5_file(conf_preproc["emb_h5"])

print("\nDone!")
print(features_h5.shape)

In [None]:
nr_models = best_models_on_selected_metric.shape[0]

preds_ensemble_local_test = np.zeros((best_models_on_selected_metric.shape[0], embeddings_bag_input_files_local_test.shape[0], 2))
labels_ensmble_local_test = np.zeros((best_models_on_selected_metric.shape[0], embeddings_bag_input_files_local_test.shape[0], 1))

# DEFINE DATALOADER
test_df = pd.read_csv(f'{conf_preproc["cv_split_dir_10fold"]}test_split_stratified.csv')
test_dataset = h5_Dataset(features_h5, test_df)
test_dataset_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, num_workers=0, shuffle=False)

for m in tqdm(range(nr_models)):
    
    model = ClsTokenTransformerClassifier(conf_train.emb_dim, 
                                      conf_train.num_heads, 
                                      conf_train.num_encoder_layers, 
                                      conf_train.dim_feedforward,
                                      conf_train.dropout,
                                      conf_train.num_classes).to(DEVICE)
    
    best_model_path = best_models_on_selected_metric[m] # path of the m th best model
    model_state_dict = torch.load(best_model_path, map_location=torch.device(DEVICE)) # load

    model.load_state_dict(model_state_dict) # load weights
    model.eval()
    model.to(DEVICE)
    
    preds_all, labels_all = pred_with_one_model(model, data_loader=test_dataset_loader)
    
    preds_ensemble_local_test[m] = preds_all
    labels_ensmble_local_test[m] = labels_all

In [None]:
preds_ensemble_local_test.shape, labels_ensmble_local_test.shape

In [None]:
np.save(f'preds_100ensemble_transformer_cls_on_embeddings_bag_uni_run_4_sqrt_sampling_10fold_local_test.npy', preds_ensemble_local_test)


#### Simple mean

In [None]:
final_pred_ensemble_local_test = np.mean(preds_ensemble_local_test, axis=0)
final_pred_ensemble_local_test.shape

In [None]:
final_pred_ensemble_local_test[:,1]

In [None]:
(np.argmax(final_pred_ensemble_local_test,1) > 0).sum() / final_pred_ensemble_local_test.shape[0]

In [None]:
auc = plot_roc( tb_df_local_test.tb_positive.values, final_pred_ensemble_local_test  )
print( np.mean(auc) )

In [None]:
def plot_pr(y_true, y_pred):
    if y_pred.shape != y_true.shape:
        # try to one-hot encode y_true
        y_true = F.one_hot(torch.from_numpy(y_true).to(torch.int64), 2)

    plt.figure(figsize=(6, 6))
    auc_all = []
 
    fpr, tpr, _ = precision_recall_curve(y_true[:, 1], y_pred[:, 1])
    auc = average_precision_score(y_true[:, 1], y_pred[:, 1])
    auc_all.append(auc)
    plt.plot(fpr, tpr, '-', label='AUC : %.3f, label : %d' % (auc, 1))
    plt.legend()
    plt.show()
    return auc_all

In [None]:
pr_auc = plot_pr( tb_df_local_test.tb_positive.values, final_pred_ensemble_local_test  )
print( np.mean(pr_auc) )

## Holdout set 

In [None]:
class h5_Dataset_holdout(Dataset):
    def __init__(self, emb_file_in_memory, transform=None):
        self.transform = transform
        self.emb_file_in_memory = emb_file_in_memory
          
    def __len__(self):
        return self.emb_file_in_memory.shape[0]

    def __getitem__(self, idx):
        
        image_data = self.emb_file_in_memory[idx]

        if self.transform:
            image_data = self.transform(image_data)
        
        return image_data

In [None]:
print("Loading merged h5 holdout file into memory...")

holdout_h5_file = "/home/ngsci/project/tuberculosis_detection/02_patch_embeddings/uni_224_224_patches/patch_embeddings_uni_224_holdout.h5"
coords_h5_holdout, features_h5_holdout, tb_positive_h5_holdout = load_merged_h5_file(holdout_h5_file)

print("\nDone!")

In [None]:
embeddings_bag_input_path_holdout = conf_preproc["emb_dir_holdout"]

tb_df_holdout = pd.read_csv(conf_preproc["tb_labels_csv_holdout"])
embeddings_bag_input_files_holdout = np.array( [ embeddings_bag_input_path_holdout + os.path.basename(i).replace(".jpg", ".h5") for i in tb_df_holdout.file_path.values ] )

embeddings_bag_input_files_holdout.shape

In [None]:
tb_df_holdout["file_path"][0]

In [None]:
embeddings_bag_input_files_holdout[:2]

In [None]:
nr_models = best_models_on_selected_metric.shape[0]

preds_ensemble_local_test = np.zeros((best_models_on_selected_metric.shape[0], embeddings_bag_input_files_holdout.shape[0], 2))
labels_ensmble_local_test = np.zeros((best_models_on_selected_metric.shape[0], embeddings_bag_input_files_holdout.shape[0], 1))


# DEFINE DATALOADER
holdout_dataset = h5_Dataset_holdout(features_h5_holdout)
holdout_dataset_loader = torch.utils.data.DataLoader(holdout_dataset, batch_size=1024, num_workers=0, shuffle=False)

for m in tqdm(range(nr_models)):
    
    model = ClsTokenTransformerClassifier(conf_train.emb_dim, 
                                      conf_train.num_heads, 
                                      conf_train.num_encoder_layers, 
                                      conf_train.dim_feedforward,
                                      conf_train.dropout,
                                      conf_train.num_classes
                                                 ).to(DEVICE)
    
    
    best_model_path = best_models_on_selected_metric[m] # path of the m th best model
    model_state_dict = torch.load(best_model_path, map_location=torch.device(DEVICE)) # load

    model.load_state_dict(model_state_dict) # load weights
    model.eval()
    model.to(DEVICE)
    
    preds_all, labels_all = pred_with_one_model(model, data_loader=holdout_dataset_loader)
    
    preds_ensemble_local_test[m] = preds_all
    labels_ensmble_local_test[m] = labels_all

In [None]:
np.save(f'preds_100ensemble_transformer_cls_on_embeddings_bag_uni_run_4_sqrt_sampling_10fold_holdout.npy', preds_ensemble_local_test)

In [None]:
final_pred_ensemble_local_test = np.mean(preds_ensemble_local_test, axis=0)

In [None]:
prediction_df = tb_df_holdout[["image_id"]].copy()
prediction_df["prob"] = final_pred_ensemble_local_test[:,1]

In [None]:
np.argmax(final_pred_ensemble_local_test, axis=1).sum()

In [None]:
np.argmax(final_pred_ensemble_local_test, axis=1).sum() / final_pred_ensemble_local_test.shape[0]