In [1]:
import sys
sys.path.insert(0, "..")

import os, glob
import yaml
from pathlib import Path

import torch
from torch.utils.data import DataLoader

from cutmix.cutmix import CutMix
from cutmix.utils import CutMixCrossEntropyLoss

from utils.models import get_model
from utils.data import CustomImageDataset, CustomImageDatasetV2
from utils.log import TextDocument

In [2]:
NUM_CLASSES = 4 
RESUME = False
epochs = 40
IMG_SIZE = 640
BATCH_SIZE = 4
ACCUM_STEPS = 16
WEIGHTS_DIR = "../weights"
CUTMIX = True
#model_names = ["resnet50", "cspresnet50", "efficientnet_b1", "dpn68"]

model_names = ["resnet50", "cspresnet50", "efficientnet_b1", "dpn68"]
#model_names = ["efficientnet_b1", "dpn68"]


TRAIN_DATASET = "../../../Dataset/Covid19/train_test_classification_quarter_size/train"
VALID_DATASET = "../../../Dataset/Covid19/train_test_classification_quarter_size/valid"
FILENAME_SUFFIX = "_lr1e-5"

TRAIN_DATASET += "/*/*.jpg"
VALID_DATASET += "/*/*.jpg"
Path(WEIGHTS_DIR).mkdir(exist_ok=True, parents=True)

In [3]:
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
import cv2
import numpy as np

from torchvision.transforms import Compose, Resize, Normalize, ToTensor, Lambda
from torchvision.transforms import ColorJitter, RandomAffine, RandomPerspective, RandomRotation, RandomErasing, RandomCrop, Grayscale
from torchvision.transforms import RandomChoice, RandomApply

from albumentations.core.composition import Compose as ComposeV2
import albumentations.augmentations as A 
from albumentations.pytorch.transforms import ToTensorV2

def get_train_grayscale_transforms_V2(img_size: int) -> Compose:
    """Returns data transformations/augmentations for train dataset.
    
    Args:
        img_size: The resolution of the input image (img_size x img_size)
    """
    return ComposeV2([
        A.geometric.resize.LongestMaxSize(img_size),
        A.geometric.rotate.Rotate(limit=30, border_mode=cv2.BORDER_CONSTANT, value=0, p=0.5),
        A.crops.transforms.RandomResizedCrop(img_size, img_size, scale=(0.7, 1.0), ratio=(0.95, 1.05)),
        A.transforms.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT, value=0),
        A.transforms.Normalize(
            mean=[0.5203580774185134],
            std=[0.24102417452995067]),
        ToTensorV2()
    ])
def get_test_grayscale_transforms(img_size: int) -> Compose:
    """Returns data transformations/augmentations for train dataset.
    
    Args:
        img_size: The resolution of the input image (img_size x img_size)
    """
    return Compose([
        Resize([img_size, img_size], interpolation=3),
        ToTensor(),
        Normalize(
            mean=[0.5203580774185134],
            std=[0.24102417452995067])
    ])

def get_test_grayscale_transforms_V2(img_size: int) -> Compose:
    """Returns data transformations/augmentations for train dataset.
    
    Args:
        img_size: The resolution of the input image (img_size x img_size)
    """
    return ComposeV2([
        A.geometric.resize.LongestMaxSize(img_size),
        A.transforms.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT, value=0),
        A.transforms.Normalize(
            mean=[0.5203580774185134],
            std=[0.24102417452995067]),
        ToTensorV2()
    ])

def show_confusion_matrix(matrix: List[List], labels: List[str]):
    """Display a nice confusion matrix given
    the confusion matrix in a 2D list + list of labels (decoder)
    
    Args:
        matrix: 2D array containing the values to display (confusion matrix)
        labels: Array containing the labels (indexed by corresponding label idx)
    """
    fig, ax = plt.subplots()
    fig.set_figheight(15)
    fig.set_figwidth(15)

    min_val, max_val = 0, len(labels)

    for i in range(max_val):
        for j in range(max_val):
            c = matrix[i][j]
            ax.text(i, j, str(int(c)), va='center', ha='center')

    ax.matshow(matrix, cmap=plt.cm.Blues)

    # Set number of ticks for x-axis001
    ax.set_xticks(np.arange(max_val))
    # Set ticks labels for x-axis
    ax.set_xticklabels(labels, rotation='vertical', fontsize=16)

    # Set number of ticks for x-axis
    ax.set_yticks(np.arange(max_val))
    # Set ticks labels for x-axis
    ax.set_yticklabels(labels, rotation='horizontal', fontsize=16)
                    
    #ax.set_xlim(min_val, max_val)
    ax.set_ylim(max_val - 0.5, min_val - 0.5)
    plt.show()
    
def display_missclassified(class_to_idx: Dict[str,int], 
                           targets: List[int], 
                           predictions: List[int], 
                           images: List[np.ndarray], 
                           gridsize: Tuple[int] = (4,4)):
    """Display a grid with missclassified samples from test set.
    
    Args:
        class_to_idx: Class to idx encoder
        targets:      List containing all ground truths
        predictions:  List containing all predictions
        images:       List containing image arrays
        gridsize:     Tuple describing the final image grid
    """
    fig = plt.figure()
    plot_counter = 1
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    idx_to_class = {i:label for i, label in enumerate(class_to_idx)}
    for i in range(len(targets)):
        if plot_counter > gridsize[0]*gridsize[1]:
            break
        
        image = images[i].transpose(1, 2, 0)
        image = ((image * std) + mean) * 255
        image = image.astype("uint8")
    
        image = cv2.resize(image, (128, 128))
        image = cv2.putText(image, idx_to_class[predictions[i]], (0,20), 3, 0.4, (0,0,255), 1)
        if predictions[i] == targets[i]:
            pass
        else:
            ax = fig.add_subplot(gridsize[0], gridsize[1], plot_counter)
            ax.imshow(image)
            plot_counter += 1
    plt.show()

In [4]:
def train_one_epoch(model, train_dataloader, device, accumulate_steps=1):
    model.train()
    results = {
        "running_loss": 0
    }
    t = tqdm(train_dataloader)
    for i, (X, y) in enumerate(t):

        X = X.to(device)
        y = y.to(device)

        preds = model(X)
        loss = criterion(preds, y)
        
        results["running_loss"] += loss.cpu().detach()
        loss = loss/accumulate_steps
        loss.backward()
        
        if ((i+1) % accumulate_steps) == 0:
            optimizer.step()
            optimizer.zero_grad()
            
        t.set_description(f"{epoch+1}/{epochs} Train: {round(float(results['running_loss'])/(i+1), 4)}")
    optimizer.step()
    optimizer.zero_grad()
    
    return results

def evaluate_model(model, valid_dataloader, device, save_images=False):
    results = {
        "running_loss": 0,
        "targets": list(),
        "predictions": list()
    }    

    model.eval()
    with torch.no_grad():

        if save_images:
            results["images"] = list() 
            
        t = tqdm(valid_dataloader)
        for i, (X, y) in enumerate(t):
            X = X.to(device)
            y = y.to(device)

            preds = model(X)
            results["predictions"] += list(preds.argmax(axis=1).cpu().detach().numpy())
            results["targets"] += list(np.array(y.cpu()))
            if save_images:
                results["images"] += list(np.array(X.cpu()))
        
            loss = criterion(preds, y)

            results["running_loss"] += loss.cpu().detach()
            t.set_description(f"Test: {round(float(results['running_loss']/(i+1)), 4)}")
            
    return results

def calculate_metrics():
    pass

class TrainingResults():
    
    def __init__(self, metrics):
        self.best_results = {metric: [1e99, -1e99] for metric in metrics}
    
    def isHighest(self, metric, value):
        if self.best_results[metric][1] < value:
            self.best_results[metric][1] = value
            return True
        return False
    
    def isLowest(self, metric, value):
        if self.best_results[metric][0] > value:
            self.best_results[metric][0] = value
            return True
        return False
    
    def loadCheckpoint(self, checkpoint):
        self.best_results = checkpoint

In [5]:
train_imgs = glob.glob(TRAIN_DATASET)
valid_imgs = glob.glob(VALID_DATASET)

train_labels = set([os.path.basename(os.path.dirname(img_path)) for img_path in train_imgs])
valid_labels = set([os.path.basename(os.path.dirname(img_path)) for img_path in valid_imgs])
class_to_idx = {label: idx for idx, label in enumerate(train_labels)}

train_dataset = CustomImageDatasetV2(train_imgs, get_train_grayscale_transforms_V2(IMG_SIZE), train_labels)
if CUTMIX:
    train_dataset = CutMix(train_dataset, num_class=NUM_CLASSES, beta=1.0, prob=0.5, num_mix=2)    # this is paper's original setting for cifar.
train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)

valid_dataset = CustomImageDatasetV2(valid_imgs, get_test_grayscale_transforms_V2(IMG_SIZE), valid_labels)
valid_dataloader = DataLoader(valid_dataset, batch_size = BATCH_SIZE, shuffle=False)
print(class_to_idx)

{'typical': 0, 'negative': 1, 'indeterminate': 2, 'atypical': 3}
{'typical': 0, 'negative': 1, 'indeterminate': 2, 'atypical': 3}
{'typical': 0, 'negative': 1, 'indeterminate': 2, 'atypical': 3}


In [6]:
# Download all model once
for model_name in model_names:
    model = get_model(model_name, NUM_CLASSES, 1)

In [None]:
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix, balanced_accuracy_score

# Using gpu or not
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print(f"Using CPU")
    
for model_name in model_names:
    model = get_model(model_name, NUM_CLASSES, 1)
    model.to(device)
    print(list(class_to_idx.keys()))

    results_document = TextDocument(f"{model_name}{FILENAME_SUFFIX}_results.txt")
    results_document.add_line(f"acc balanced_acc f1 recall precision valid_loss train_loss") 
    metrics = ["balanced_acc"]
    training_results = TrainingResults(metrics)
    
    if RESUME:
        pass
        #start_epoch = state_dict["epoch"]
        #optimizer_state_dict = state_dict["optimizer_state_dict"]
        
        #model.load_state_dict(state_dict["model_state_dict"])
        #optimizer.load_state_dict(state_dict["optimizer_state_dict"])
        training_results.loadCheckpoint(state_dict["training_results"])
    else:
        
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
        if CUTMIX:
            criterion = CutMixCrossEntropyLoss(True)
        else:
            criterion = torch.nn.CrossEntropyLoss()
        start_epoch = 0

    for epoch in range(start_epoch, epochs):
        
        # Train one epoch    
        results = train_one_epoch(model, train_dataloader, device, accumulate_steps=ACCUM_STEPS)
        train_loss = float(results["running_loss"])


        if epoch+1 == epochs:
            results = evaluate_model(model, valid_dataloader, device, save_images=True)
            images = results["images"]
        else:
            results = evaluate_model(model, valid_dataloader, device, save_images=False)

        valid_loss = float(results["running_loss"])
        acc = accuracy_score(results["targets"], results["predictions"])
        f1 = f1_score(results["targets"], results["predictions"], average="macro", labels=np.unique(results["predictions"]))
        recall = recall_score(results["targets"], results["predictions"], average="macro", labels=np.unique(results["predictions"]))
        precision = precision_score(results["targets"], results["predictions"], average="macro", labels=np.unique(results["predictions"]))
        balanced_acc = balanced_accuracy_score(results["targets"], results["predictions"])
        
        
        results_document.add_line(f"{float(acc)} {float(balanced_acc)} {float(f1)} {float(recall)} {float(precision)} {valid_loss} {train_loss}")
        
        if training_results.isHighest('balanced_acc',  balanced_acc):
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'training_results': training_results.best_results,
                'train_dataloader': train_dataloader,
                'test_dataloader': valid_dataloader,
                'class_to_idx': class_to_idx
            }, os.path.join(WEIGHTS_DIR, f"{model_name}{FILENAME_SUFFIX}_best_balanced_acc.pt"))
        """
        if training_results.isHighest('f1',  f1):
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'training_results': training_results.best_results,
                'train_dataloader': train_dataloader,
                'test_dataloader': valid_dataloader,
                'class_to_idx': class_to_idx
            }, os.path.join(WEIGHTS_DIR, f"{model_name}{FILENAME_SUFFIX}_best_f1.pt"))
        if training_results.isHighest('acc', acc):
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'training_results': training_results.best_results,
                'train_dataloader': train_dataloader,
                'test_dataloader': valid_dataloader,
                'class_to_idx': class_to_idx
            }, os.path.join(WEIGHTS_DIR, f"{model_name}{FILENAME_SUFFIX}_best_acc.pt"))
        if training_results.isLowest('valid_loss', valid_loss):
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'training_results': training_results.best_results,
                'train_dataloader': train_dataloader,
                'test_dataloader': valid_dataloader,
                'class_to_idx': class_to_idx
            }, os.path.join(WEIGHTS_DIR, f"{model_name}{FILENAME_SUFFIX}_best_valid_loss.pt"))
        """
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'training_results': training_results.best_results,
            'train_dataloader': train_dataloader,
            'test_dataloader': valid_dataloader,
            'class_to_idx': class_to_idx
        }, os.path.join(WEIGHTS_DIR, f"{model_name}{FILENAME_SUFFIX}_last.pt"))
        print(f"Final balanced accuracy: {balanced_acc}")
        
    #display_missclassified(class_to_idx, targets, predictions, images, gridsize=(4,4))
    #show_confusion_matrix(confusion_matrix(targets, predictions), list(class_to_idx.keys()))

Using GPU: NVIDIA GeForce RTX 3070 Laptop GPU


  0%|                                                                                         | 0/1334 [00:00<?, ?it/s]

['typical', 'negative', 'indeterminate', 'atypical']


1/40 Train: 1.226: 100%|███████████████████████████████████████████████████████████| 1334/1334 [05:34<00:00,  3.99it/s]
Test: 1.179: 100%|███████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.24it/s]
2/40 Train: 0.978:   0%|                                                                      | 0/1334 [00:00<?, ?it/s]

Final balanced accuracy: 0.38061837415376165


2/40 Train: 1.1352: 100%|██████████████████████████████████████████████████████████| 1334/1334 [05:32<00:00,  4.02it/s]
Test: 1.0543: 100%|██████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.23it/s]
  0%|                                                                                         | 0/1334 [00:00<?, ?it/s]

Final balanced accuracy: 0.3908648231962815


3/40 Train: 1.1015: 100%|██████████████████████████████████████████████████████████| 1334/1334 [05:36<00:00,  3.97it/s]
Test: 1.0177: 100%|██████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.26it/s]
  0%|                                                                                         | 0/1334 [00:00<?, ?it/s]

Final balanced accuracy: 0.3856786316697324


4/40 Train: 1.0903: 100%|██████████████████████████████████████████████████████████| 1334/1334 [05:35<00:00,  3.98it/s]
Test: 0.9991: 100%|██████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.25it/s]
  0%|                                                                                         | 0/1334 [00:00<?, ?it/s]

Final balanced accuracy: 0.3955816243197655


5/40 Train: 1.079: 100%|███████████████████████████████████████████████████████████| 1334/1334 [05:33<00:00,  4.00it/s]
Test: 0.9775: 100%|██████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.17it/s]
6/40 Train: 1.1615:   0%|                                                                     | 0/1334 [00:00<?, ?it/s]

Final balanced accuracy: 0.403704368002076


6/40 Train: 1.065: 100%|███████████████████████████████████████████████████████████| 1334/1334 [05:33<00:00,  4.00it/s]
Test: 0.9708: 100%|██████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.25it/s]
  0%|                                                                                         | 0/1334 [00:00<?, ?it/s]

Final balanced accuracy: 0.400403370452065


7/40 Train: 1.0647: 100%|██████████████████████████████████████████████████████████| 1334/1334 [06:00<00:00,  3.70it/s]
Test: 0.9765: 100%|██████████████████████████████████████████████████████████████████| 250/250 [00:41<00:00,  6.08it/s]
  0%|                                                                                         | 0/1334 [00:00<?, ?it/s]

Final balanced accuracy: 0.4103063631020981


8/40 Train: 1.0565: 100%|██████████████████████████████████████████████████████████| 1334/1334 [05:45<00:00,  3.87it/s]
Test: 0.9662: 100%|██████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00,  7.24it/s]
  0%|                                                                                         | 0/1334 [00:00<?, ?it/s]

Final balanced accuracy: 0.4125102082872211


9/40 Train: 1.0589: 100%|██████████████████████████████████████████████████████████| 1334/1334 [05:33<00:00,  3.99it/s]
Test: 1.2058:  61%|████████████████████████████████████████▏                         | 152/250 [00:20<00:14,  6.84it/s]

In [None]:
display_missclassified(class_to_idx, targets, predictions, images, gridsize=(4,4))
show_confusion_matrix(confusion_matrix(targets, predictions), list(class_to_idx.keys()))


In [None]:
def read_results(txt_path):
    with open(txt_path) as f:
        data = f.readlines()
    columns = data[0].replace('\n', '').split()
    rows = [line.replace('\n', '').split() for line in data[1:]]
    for idx, row in enumerate(rows):
        rows[idx] = [float(i) for i in row]
    results_dict = {column: [row[i] for row in rows] for i, column in enumerate(columns)}
    return results_dict

model_names = ["resnet50", "cspresnet50", "efficientnet_b1", "dpn68"]
fig, axs = plt.subplots(1,3, figsize=(30,10))
for model_name in model_names:
    txt_path = model_name + "_results.txt"
    data = read_results(txt_path)

    axs[0].plot(data['f1'])
    axs[1].plot(data['acc'])
    axs[2].plot(data['valid_loss'])
axs[0].set_title("F1")
axs[1].set_title("Accuracy")
axs[2].set_title("valid_loss")
axs[0].legend(model_names)
axs[1].legend(model_names)
axs[2].legend(model_names)

In [None]:
print(data.keys())

# Inference

In [None]:
import sys
sys.path.insert(0, "../")

import glob

from tqdm import tqdm
import cv2
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
import albumentations.augmentations as A 
from albumentations.pytorch.transforms import ToTensorV2
from albumentations.core.composition import Compose as ComposeV2

from utils.models import get_model
from utils.data import InferenceImageDatasetV2

In [None]:
VALID_DATASET = "../../../Dataset/Covid19/train_test_classification_quarter_size/valid"
model_names = ["resnet50", "cspresnet50", "efficientnet_b1", "dpn68"]
NUM_CLASSES = 4

#{'negative': 0, 'indeterminate': 1, 'atypical': 2, 'typical': 3}
label_names = ['indeterminate', 'negative', 'atypical', 'typical']
IMG_SIZE = 640

pt_paths = glob.glob("../weights/**/*.pt", recursive=True) 
model_pt_dict = dict()
for pt_path in pt_paths:
    for model_name in model_names:
        if model_name in pt_path:
            #print(model_name, pt_path)
            model_pt_dict[pt_path] = model_name
            
    

In [None]:

def get_test_grayscale_transforms_V2(img_size: int) -> Compose:
    """Returns data transformations/augmentations for train dataset.
    
    Args:
        img_size: The resolution of the input image (img_size x img_size)
    """
    return ComposeV2([
        A.geometric.resize.LongestMaxSize(img_size),
        A.transforms.PadIfNeeded(min_height=img_size, min_width=img_size, border_mode=cv2.BORDER_CONSTANT, value=0),
        A.transforms.Normalize(
            mean=[0.5203580774185134],
            std=[0.24102417452995067]),
        ToTensorV2()
    ])

def inference(model, img_paths, device="cpu"):
    
    dataset = InferenceImageDatasetV2(img_paths, get_test_grayscale_transforms_V2(IMG_SIZE), label_names)
    dataloader = DataLoader(dataset, batch_size=32)
    print(dataloader)
    model.to(device)
    model.eval()
    predictions = list() 
    image_paths = list()
    with torch.no_grad():
        for X, img_paths in tqdm(dataloader):

            X = X.to(device)
            preds = model(X)
            predictions += list(preds.cpu().detach().numpy())
            image_paths += list(img_paths)
    model.cpu()

    return predictions, image_paths

In [None]:
# Using gpu or not
CUDA = "cuda" if torch.cuda.is_available() else "cpu"
if CUDA == "cuda":
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")

ensemble_models = list()
ensemble_predictions = list()
ensemble_img_paths = list()
valid_imgs = glob.glob(VALID_DATASET + "/*/*.jpg")
for pt_path in pt_paths:
    if model_pt_dict.get(pt_path) is not None:
        model_name = model_pt_dict[pt_path]
        model = get_model(model_name, input_channels=1, num_classes=NUM_CLASSES)
        state_dict = torch.load(pt_path)
        model.load_state_dict(state_dict["model_state_dict"])  
        print(f"'{pt_path}' loaded into '{model_name}'")
        #print(valid_imgs)
        predictions, image_paths = inference(model, valid_imgs, device="cuda")
        ensemble_img_paths.append(image_paths)
        ensemble_predictions.append(predictions)
        ensemble_models.append((model_name, pt_path))
    else:
        print(f"'{pt_path}' cannot be linked to model")


In [None]:
import os
from pathlib import Path
"""
ensemble_img_paths # List of paths
ensemble_predictions # List of predictions
ensemble_models #list of model and pt names 
"""
results = dict()
for img_path in set([path for paths in ensemble_img_paths for path in paths]):
    results[img_path] = list()

    
for img_paths, predictions, (model_name, pt_path) in tqdm(zip(ensemble_img_paths, ensemble_predictions, ensemble_models)):
    for img_path, prediction in zip(img_paths, predictions):
        filename = f"{os.path.basename(img_path)}#{model_name}#{os.path.basename(pt_path)}"
        save_path = f"../results/predictions/{os.path.basename(img_path)}"
        Path(save_path).mkdir(exist_ok=True, parents=True)
        torch.save(prediction, os.path.join(save_path, filename))
    

In [None]:
ensemble_img_paths

In [None]:
prediction

In [None]:
img_path

In [None]:
import pandas as pd

label_names = ['indeterminate', 'negative', 'atypical', 'typical']

In [None]:
temp_img_path = list()
temp_0 = list()
temp_1 = list()  
temp_2 = list()
temp_3 = list()
temp_model_name = list()
temp_model_pt_path = list()

for img_paths, predictions, (model_name, pt_path) in tqdm(zip(ensemble_img_paths, ensemble_predictions, ensemble_models)):
    for img_path, prediction in zip(img_paths, predictions):
        temp_img_path.append(img_path)
        temp_0.append(prediction[0])
        temp_1.append(prediction[1])    
        temp_2.append(prediction[2])
        temp_3.append(prediction[3])
        temp_model_name.append(model_name)
        temp_model_pt_path.append(pt_path)
        
        
df = pd.DataFrame({
    "image_path":       temp_img_path,
    "00_indeterminate": temp_0,
    "01_negative":      temp_1,
    "02_atypical":      temp_2,
    "03_typical":       temp_3,
    "model_name":       temp_model_name,
    "model_pt_path":    temp_model_pt_path,
})
df.to_csv("model_logits.csv", index=False)

In [None]:
df.head()

# Ensemble calibration

In [None]:
import pandas as pd
from scipy.special import softmax

In [None]:
df = pd.read_csv("model_logits.csv")
df

In [None]:
logit_cols = ["00_indeterminate", "01_negative", "02_atypical", "03_typical"]  
#logit_cols = ["0_prob", "1_prob", "2_prob", "3_prob"]
temp = df[logit_cols].apply(softmax, axis=1)

In [None]:
for col in logit_cols:
    temp.rename(columns={
        col: col+"_prob",
    }, inplace=True)
df = df.join(temp)

In [None]:
ground_truth = df["image_path"].str.split("\\").apply(lambda x: x[-2])
ground_truth.head()

In [None]:
class_names = ["indeterminate", "negative", "atypical", "typical"]  
class2idx = {col:idx for idx, col in enumerate(class_names)}
ground_truth_indices = ground_truth.apply(lambda x: class2idx[x])
ground_truth_indices.head()

In [None]:
df["ground_truth"] = ground_truth
df["ground_truth_indices"] = ground_truth_indices

In [None]:
df["prediction"] = df.apply(lambda x: x[["00_indeterminate_prob", "01_negative_prob", "02_atypical_prob", "03_typical_prob"]].to_numpy().argmax(), axis=1)

In [None]:
models = df["model_pt_path"].unique()
image = df["image_path"].unique()

def get_sample(image_name, model_name):
    row = df[df["image_path"] == image_name]
    X = row[row["model_pt_path"] == model_name][["00_indeterminate_prob", "01_negative_prob", "02_atypical_prob", "03_typical_prob"]]
    y = row[row["model_pt_path"] == model_name]["ground_truth"]
    #print(row[row["model_pt_path"] == model_name]["prediction"])
    return X, y
get_sample(image[0], models[0])
    

In [None]:
from sklearn.metrics import accuracy_score , classification_report, balanced_accuracy_score

In [None]:
for model_pt_path in models:
    temp_df = df[df["model_pt_path"] == model_pt_path]
    y_pred = temp_df["prediction"]
    y_true = temp_df["ground_truth_indices"]
    
    acc = balanced_accuracy_score(y_true, y_pred)
    print(acc, model_pt_path)
    #print(classification_report(y_true, y_pred, zero_division=1))

In [None]:
weights = torch.rand([num_models])

final_prediction = None
for weight, prediction in zip(weights, predictions):
    if final_prediction is None:
        final_prediction = weight * prediction
    else:
        final_prediction += weight * prediction