# Dataset

In [42]:
from torch.utils.data import Dataset
from PIL import Image
import torch
import numpy as np
from torchvision import transforms
from torchvision.transforms import v2
from tqdm import tqdm

class CustomImageDataset(Dataset):
    def __init__(self, dataframe, images_path, transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.transform = transform
        self.images_path = images_path

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = self.dataframe.loc[idx, 'isic_id']
        image = Image.open(f'{self.images_path}/{image_path}.jpg').convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, int(self.dataframe.loc[self.dataframe.index[idx], 'target'])

# Data module

In [43]:
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torch
from torch.utils.data import WeightedRandomSampler
from PIL import Image
import numpy as np

from torchvision import transforms

class ISICDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dataframe = pd.read_csv(config.annotations_path)
        self.dataframe['path'] = config.path
        self.dataframe['target_name'] = config.target_name
        self.transform = {
            'train': transforms.Compose(
                    [
                    transforms.RandomResizedCrop(config.img_size, scale=config.crop_scale, antialias=True),
                    transforms.Resize((self.config.img_size, self.config.img_size), interpolation=Image.BILINEAR),
                    transforms.ToTensor(),
                    #transforms.Lambda(lambda x: x / 255.0),
                    #transforms.Normalize(mean, std)
                    ]),
            'test': transforms.Compose(
                    [
                    #transforms.RandomResizedCrop(config.img_size ,scale=config.crop_scale_tta, antialias=True),
                    transforms.Resize((self.config.img_size, self.config.img_size), interpolation=Image.BILINEAR),
                    transforms.ToTensor(),
                    #transforms.Lambda(lambda x: x / 255.0),
                    
                    #transforms.Normalize(mean, std)
                    ]),
            'test_tta':transforms.Compose(
                    [
                    transforms.RandomResizedCrop(config.img_size, scale=config.crop_scale_tta, antialias=True),
                    transforms.Resize((self.config.img_size, self.config.img_size), interpolation=Image.BILINEAR),
                    transforms.ToTensor(),
                    #transforms.Lambda(lambda x: x / 255.0),
                    #transforms.Normalize(mean, std)
                    ])
                }  # Define your transforms here

    def add_label_noise(self, labels, noise_level):
        """
        Adds noise to the labels based on the noise_level.
        noise_level: Probability of flipping the label.
        """
        noisy_labels = labels.copy()
        for i in range(len(noisy_labels)):
            if np.random.rand() < noise_level:
                noisy_labels[i] = 1 - noisy_labels[i]  # Flip the label
        return noisy_labels

    def setup(self, stage=None):
        if self.config.bagging == True:
            self.dataframe = self.dataframe.sample(n=self.config.bagging_size, replace=True, random_state=self.config.seed, ignore_index=True).reset_index(drop=True)
            
        # Apply label noise to the training dataset
        if self.config.noise > 0:
            self.dataframe['target'] = self.add_label_noise(self.dataframe['target'].values, self.config.noise)
            
        train_val_df, test_df = train_test_split(self.dataframe, test_size=self.config.test_size, random_state=self.config.seed, stratify=self.dataframe['target'])
        train_df, val_df = train_test_split(train_val_df, test_size=self.config.val_size / (self.config.train_size + self.config.val_size), random_state=self.config.seed, stratify=train_val_df['target'])
        self.train_dataset = CustomImageDataset(train_df, self.config.path, self.transform['train'])
        self.val_dataset = CustomImageDataset(val_df, self.config.path, self.transform['test'])
        
        
        # Fixed test set
        if self.config.fixed == True:
            test_df = pd.read_csv('/repo/uncertainty_skin/data/isic_balanced/test.csv').reset_index(drop=True)
            test_df['path'] = self.config.path
            test_df['target_name'] = self.config.target_name        
        
        self.test_dataset = CustomImageDataset(test_df, self.config.path, self.transform['test'])
        self.test_tta_dataset = CustomImageDataset(test_df, self.config.path, self.transform['test_tta'])
        
        print(self.config.bagging_size)
        print(self.dataframe.shape)
        
        print(train_df['target'].value_counts())
        print(val_df['target'].value_counts())
        print(test_df['target'].value_counts())
        
        print(test_df.index)
        
        #print(f'Train idx: {train_df.reset_index(drop=True).index}')
        #print(f'Dataframe shape: {train_df.shape}')
        #print(f'Dataframe head: {train_df.head()}')
        #print(f'Dataframe head: {train_df.reset_index(drop=True).head()}')
        
        labels = [int(label) for _, label in zip(train_df.index, train_df.target)]
        class_sample_count = [labels.count(i) for i in [0,1]]
        weight = 1. / torch.tensor(class_sample_count, dtype=torch.float)
        samples_weight = torch.tensor([weight[t] for t in labels])
        self.train_sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
        
        labels = [int(label) for _, label in zip(val_df.index, val_df.target)]
        class_sample_count = [labels.count(i) for i in [0,1]]
        weight = 1. / torch.tensor(class_sample_count, dtype=torch.float)
        samples_weight = torch.tensor([weight[t] for t in labels])
        self.val_sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
        
        labels = [int(label) for _, label in zip(test_df.index, test_df.target)]
        class_sample_count = [labels.count(i) for i in [0,1]]
        weight = 1. / torch.tensor(class_sample_count, dtype=torch.float)
        samples_weight = torch.tensor([weight[t] for t in labels])
        self.test_sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
        
        self.g = torch.Generator()
        self.g.manual_seed(self.config.seed)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.config.batch_size,
                          num_workers=self.config.num_workers,
                          sampler=self.train_sampler,
                          pin_memory=True,
                          worker_init_fn=self.config.seed % 2**32,
                          generator=self.g)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.config.batch_size,
                          num_workers=self.config.num_workers,
                          sampler=self.val_sampler,
                          pin_memory=True,
                          worker_init_fn=self.config.seed % 2**32,
                          generator=self.g)

    def test_dataloader(self):
        return DataLoader(self.test_tta_dataset,
                            batch_size=self.config.batch_size,
                            num_workers=self.config.num_workers,
                            sampler=self.test_sampler,
                            pin_memory=True,
                            worker_init_fn=self.config.seed % 2**32,
                            generator=self.g)
    def tta_dataloader(self):
        return DataLoader(self.test_dataset,
                            batch_size=self.config.batch_size,
                            num_workers=self.config.num_workers,
                            sampler=self.test_sampler,
                            pin_memory=True,
                            worker_init_fn=self.config.seed % 2**32,
                            generator=self.g)
     

# Model

In [44]:
from pytorch_metric_learning import distances, losses, miners, reducers, testers, regularizers

from src.functional.criterion import UANLLloss

import logging
import sys

from typing import List, Tuple, Dict, Any, Optional
from torch.utils.data import DataLoader

import numpy as np

from src.utils.metrics import calculate_ece, calculate_accuracy, calculate_f1_score_binary, certain_predictions, accuracy_tta, test_vis_tta, ttac, ttaWeightedPred, hist

from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt


logging.basicConfig(
    stream=sys.stderr, 
    level=logging.DEBUG, 
    format="%(asctime)s %(levelname)s: %(message)s"
)

class BaseModel(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.build_model()
        self.loss_fun, self.loss_module_1 = self.build_loss()
        self.mining_func = miners.TripletMarginMiner(margin=0.2, distance=distances.CosineSimilarity(), type_of_triplets="semihard")

        self.epoch_val_loss = []
        self.epoch_train_loss = []

    def build_model(self):
        raise NotImplementedError

    def build_loss(self):
        if self.config.model.loss_fun == 'TM+UANLL':
            distance = distances.CosineSimilarity()
            reducer = reducers.ThresholdReducer(low=0)
            eRegularizer = regularizers.LpRegularizer()
            loss_fun = losses.TripletMarginLoss(margin=self.config.model.margin, distance=distance, reducer=reducer, embedding_regularizer=eRegularizer)
            loss_module_1 = UANLLloss(smoothing=self.config.model.label_smoothing)
            print('UANLL loss is an additional loss term (module 1)')
        elif self.config.model.loss_fun == 'TM+CE':
            distance = distances.CosineSimilarity()
            reducer = reducers.ThresholdReducer(low=0)
            eRegularizer = regularizers.LpRegularizer()
            loss_fun = losses.TripletMarginLoss(margin=self.config.model.margin, distance=distance, reducer=reducer, embedding_regularizer=eRegularizer)
            loss_module_1 = nn.CrossEntropyLoss(label_smoothing=self.config.model.label_smoothing)
            print('CE loss is an additional loss term (module 1)')
        elif self.config.model.loss_fun == 'CE':
            loss_fun = None
            loss_module_1 = nn.CrossEntropyLoss(label_smoothing=self.config.model.label_smoothing)
            print('Single CE loss')
        elif self.config.model.loss_fun == 'UANLL':
            loss_fun = None
            loss_module_1 = UANLLloss(smoothing=self.config.model.label_smoothing)
            print('Single UANLL loss')
        else:
            raise ValueError(f"Unknown loss function: {self.config.model.loss_fun}")
        return loss_fun, loss_module_1

    def forward(self, x):
        return self.model(x)
    
    def lr_lambda(self, epoch, n=150, delay=30, stop_lr=0):
        n = self.config.trainer.max_epochs
        start_lr = self.config.model.lr
        learning_rate = start_lr if epoch < delay else start_lr - (epoch - delay) * (start_lr - stop_lr) / (n - 1 - delay)
        return learning_rate / start_lr

    def configure_optimizers(self):
        if self.config.model.optimizer_name == "Adam":
            optimizer = optim.AdamW(self.parameters(), **self.config.model.optimizer_hparams)
        elif self.config.model.optimizer_name == "SGD":
            optimizer = optim.SGD(self.parameters(), **self.config.model.optimizer_hparams)
        else:
            assert False, f'Unknown optimizer: "{self.config.model.optimizer_name}"'

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        embeddings = self(x)
        if self.loss_fun is not None:
            indices_tuple = self.mining_func(embeddings[:, :self.config.model.num_classes], y)
            loss = self.loss_fun(embeddings[:, :self.config.model.num_classes], y, indices_tuple)
        else:
            loss = 0
        loss += self.loss_module_1(embeddings, y)
        if self.config.model.loss_fun in ['TM+CE', 'CE']:
            preds = nn.functional.softmax(embeddings, 1)
            acc = (preds.argmax(dim=-1) == y).float().mean()
        elif self.config.model.loss_fun in ['TM+UANLL', 'UANLL']:
            preds = nn.functional.softmax(embeddings[:, :self.config.model.num_classes], 1)
            acc = (preds.argmax(dim=-1) == y).float().mean()
        self.log("train_loss", loss.float(), on_epoch=True)
        self.log("train_acc", acc, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        embeddings = self(x)
        if self.loss_fun is not None:
            indices_tuple = self.mining_func(embeddings[:, :self.config.model.num_classes], y)
            loss = self.loss_fun(embeddings[:, :self.config.model.num_classes], y, indices_tuple)
        else:
            loss = 0
        loss += self.loss_module_1(embeddings, y)
        if self.config.model.loss_fun in ['TM+CE', 'CE']:
            preds = nn.functional.softmax(embeddings[:, :self.config.model.num_classes], 1)
            acc = (preds.argmax(dim=-1) == y).float().mean()
        else:
            acc = (embeddings[:, :self.config.model.num_classes].argmax(dim=-1) == y).float().mean()
        self.log("val_loss", loss.float(), on_epoch=True)
        self.log("val_acc", acc, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        embeddings = self(x)
        if self.config.model.loss_fun in ['TM+CE', 'CE']:
            preds = nn.functional.softmax(embeddings[:, :self.config.model.num_classes], 1)
            acc = (preds.argmax(dim=-1) == y).float().mean()
        else:
            acc = (embeddings[:, :self.config.model.num_classes].argmax(dim=-1) == y).float().mean()
        print("Test set accuracy without TTA (Precision@1) = {}".format(acc))
        self.log("test_acc", acc, on_epoch=True)

    def on_test_epoch_end(self) -> None:
        """
        Perform predictions, TTA, ensembling, and weighted predictions at the end of the test epoch.
        """
        # Collect predictions without TTA
        test_predictions_no_tta, test_labels_no_tta = self.collect_predictions_no_tta()
        self.evaluate_no_tta(test_predictions_no_tta, test_labels_no_tta)

        # Perform TTA and ensembling
        test_predictions_tta, test_labels_tta, test_confidences_tta, test_uncertainties_tta = self.perform_tta_and_ensembling()
        self.evaluate_with_tta(test_predictions_tta, test_labels_tta)

        # Save predictions to CSV
        self.save_predictions_to_csv(test_predictions_no_tta, test_labels_no_tta, test_predictions_tta, test_labels_tta, test_confidences_tta, test_uncertainties_tta)

        # Handle weighted predictions based on confidence and certainty
        self.handle_weighted_predictions(test_predictions_tta, test_labels_tta, test_confidences_tta, test_uncertainties_tta)

        # Compare metrics for different approaches
        self.compare_metrics(test_predictions_no_tta, test_labels_no_tta, test_predictions_tta, test_labels_tta, test_confidences_tta, test_uncertainties_tta)

    def compare_metrics(self, test_predictions_no_tta: torch.Tensor, test_labels_no_tta: torch.Tensor, test_predictions_tta: torch.Tensor, test_labels_tta: torch.Tensor, test_confidences_tta: torch.Tensor, test_uncertainties_tta: torch.Tensor) -> None:
        """
        Compare metrics for different approaches:
        - Without TTA
        - TTAM (Mode-based TTA)
        - TTAWCo-S (Weighted with Confidences)
        - TTAWCe-S (Weighted with Certainties)
        - Ensembling (Simple)
        - Ensembling with Confidences
        - Ensembling with Certainties
        - Ensembling with TTA

        Args:
            test_predictions_no_tta (torch.Tensor): Predictions without TTA.
            test_labels_no_tta (torch.Tensor): True labels without TTA.
            test_predictions_tta (torch.Tensor): Predictions with TTA.
            test_labels_tta (torch.Tensor): True labels with TTA.
            test_confidences_tta (torch.Tensor): Confidences with TTA.
            test_uncertainties_tta (torch.Tensor): Uncertainties with TTA.
        """
        #mode_predictions_tta = self.mode_based_tta(test_predictions_tta)
        
        # Convert tensors to numpy for sklearn metrics
        test_labels_no_tta = test_labels_no_tta.cpu().numpy()
        test_labels_tta = test_labels_tta.cpu().numpy()
        test_predictions_no_tta = test_predictions_no_tta.argmax(dim=1).cpu().numpy()
        test_predictions_tta = test_predictions_tta.argmax(dim=1).cpu().numpy()
        test_confidences_tta = test_confidences_tta.cpu().numpy()
        test_uncertainties_tta = test_uncertainties_tta.cpu().numpy()

        # 1. Without TTA
        accuracy_no_tta = accuracy_score(test_labels_no_tta, test_predictions_no_tta)
        f1_no_tta = f1_score(test_labels_no_tta, test_predictions_no_tta, average='weighted')
        roc_auc_no_tta = roc_auc_score(test_labels_no_tta, test_predictions_no_tta, average='weighted', multi_class='ovr')

        # 2. TTAM (Mode-based TTA)
        # accuracy_ttam = accuracy_score(test_labels_tta, mode_predictions_tta)
        # f1_ttam = f1_score(test_labels_tta, mode_predictions_tta, average='weighted')
        # roc_auc_ttam = roc_auc_score(test_labels_tta, mode_predictions_tta, average='weighted', multi_class='ovr')
        accuracy_ttam = 0
        f1_ttam = 0
        roc_auc_ttam = 0

        # 3. TTAWCo-S (Weighted with Confidences)
        weighted_predictions_co = self.weighted_predictions_with_confidence(test_predictions_tta, test_confidences_tta)
        accuracy_tta_co = accuracy_score(test_labels_tta, weighted_predictions_co)
        f1_tta_co = f1_score(test_labels_tta, weighted_predictions_co, average='weighted')
        roc_auc_tta_co = roc_auc_score(test_labels_tta, weighted_predictions_co, average='weighted', multi_class='ovr')

        # 4. TTAWCe-S (Weighted with Certainties)
        weighted_predictions_ce = self.weighted_predictions_with_certainty(test_predictions_tta, test_uncertainties_tta)
        accuracy_tta_ce = accuracy_score(test_labels_tta, weighted_predictions_ce)
        f1_tta_ce = f1_score(test_labels_tta, weighted_predictions_ce, average='weighted')
        roc_auc_tta_ce = roc_auc_score(test_labels_tta, weighted_predictions_ce, average='weighted', multi_class='ovr')

        # 5. Ensembling (Simple)
        ensembled_predictions = self.ensemble_predictions(test_predictions_tta)
        accuracy_ensemble = accuracy_score(test_labels_tta, ensembled_predictions)
        f1_ensemble = f1_score(test_labels_tta, ensembled_predictions, average='weighted')
        roc_auc_ensemble = roc_auc_score(test_labels_tta, ensembled_predictions, average='weighted', multi_class='ovr')

        # 6. Ensembling with Confidences
        ensembled_predictions_co = self.ensemble_predictions_with_confidence(test_predictions_tta, test_confidences_tta)
        accuracy_ensemble_co = accuracy_score(test_labels_tta, ensembled_predictions_co)
        f1_ensemble_co = f1_score(test_labels_tta, ensembled_predictions_co, average='weighted')
        roc_auc_ensemble_co = roc_auc_score(test_labels_tta, ensembled_predictions_co, average='weighted', multi_class='ovr')

        # 7. Ensembling with Certainties
        ensembled_predictions_ce = self.ensemble_predictions_with_certainty(test_predictions_tta, test_uncertainties_tta)
        accuracy_ensemble_ce = accuracy_score(test_labels_tta, ensembled_predictions_ce)
        f1_ensemble_ce = f1_score(test_labels_tta, ensembled_predictions_ce, average='weighted')
        roc_auc_ensemble_ce = roc_auc_score(test_labels_tta, ensembled_predictions_ce, average='weighted', multi_class='ovr')

        # 8. Ensembling with TTA
        ensembled_tta_predictions = self.ensemble_tta_predictions(test_predictions_tta)
        accuracy_ensemble_tta = accuracy_score(test_labels_tta, ensembled_tta_predictions)
        f1_ensemble_tta = f1_score(test_labels_tta, ensembled_tta_predictions, average='weighted')
        roc_auc_ensemble_tta = roc_auc_score(test_labels_tta, ensembled_tta_predictions, average='weighted', multi_class='ovr')

        # Create a metrics comparison table
        metrics_table = pd.DataFrame({
            'Approach': ['Without TTA', 'TTAM', 'TTAWCo-S', 'TTAWCe-S', 'Ensembling (Simple)', 'Ensembling with Confidences', 'Ensembling with Certainties', 'Ensembling with TTA'],
            'Accuracy': [accuracy_no_tta, accuracy_ttam, accuracy_tta_co, accuracy_tta_ce, accuracy_ensemble, accuracy_ensemble_co, accuracy_ensemble_ce, accuracy_ensemble_tta],
            'F1 Score': [f1_no_tta, f1_ttam, f1_tta_co, f1_tta_ce, f1_ensemble, f1_ensemble_co, f1_ensemble_ce, f1_ensemble_tta],
            'ROC-AUC': [roc_auc_no_tta, roc_auc_ttam, roc_auc_tta_co, roc_auc_tta_ce, roc_auc_ensemble, roc_auc_ensemble_co, roc_auc_ensemble_ce, roc_auc_ensemble_tta]
        })

        # Save the metrics table to a CSV file
        metrics_table.to_csv(f"{self.config.model.name}_{self.config.dataset.seed}_metrics_comparison.csv", index=False)

        # Print the metrics table
        print(metrics_table)

    def mode_based_tta(self, test_predictions_tta: torch.Tensor) -> np.ndarray:
        """
        Compute mode-based TTA predictions.

        Args:
            test_predictions_tta (torch.Tensor): Predictions with TTA.

        Returns:
            np.ndarray: Mode-based TTA predictions.
        """
        return torch.mode(test_predictions_tta, dim=0).values.cpu().numpy()

    def weighted_predictions_with_confidence(self, test_predictions_tta: torch.Tensor, test_confidences_tta: torch.Tensor) -> np.ndarray:
        """
        Compute weighted predictions based on confidence.

        Args:
            test_predictions_tta (torch.Tensor): Predictions with TTA.
            test_confidences_tta (torch.Tensor): Confidences with TTA.

        Returns:
            np.ndarray: Weighted predictions based on confidence.
        """
        weighted_predictions_co = test_predictions_tta
        weighted_predictions_co[test_confidences_tta < 0.5] = -1  # Ignore low-confidence predictions
        return weighted_predictions_co
    
    def weighted_predictions_with_certainty(self, test_predictions_tta: torch.Tensor, test_uncertainties_tta: torch.Tensor) -> np.ndarray:
        """
        Compute weighted predictions based on certainty.

        Args:
            test_predictions_tta (torch.Tensor): Predictions with TTA.
            test_uncertainties_tta (torch.Tensor): Uncertainties with TTA.

        Returns:
            np.ndarray: Weighted predictions based on certainty.
        """
        weighted_predictions_ce = test_predictions_tta
        weighted_predictions_ce[test_uncertainties_tta > 0.5] = -1  # Ignore high-uncertainty predictions
        return weighted_predictions_ce

    def ensemble_predictions(self, test_predictions_tta: torch.Tensor) -> np.ndarray:
        """
        Compute simple ensembled predictions.

        Args:
            test_predictions_tta (torch.Tensor): Predictions with TTA.

        Returns:
            np.ndarray: Ensembled predictions.
        """
        return test_predictions_tta.mean(dim=0).argmax(dim=1).cpu().numpy()

    def ensemble_predictions_with_confidence(self, test_predictions_tta: torch.Tensor, test_confidences_tta: torch.Tensor) -> np.ndarray:
        """
        Compute ensembled predictions weighted by confidence.

        Args:
            test_predictions_tta (torch.Tensor): Predictions with TTA.
            test_confidences_tta (torch.Tensor): Confidences with TTA.

        Returns:
            np.ndarray: Ensembled predictions weighted by confidence.
        """
        weighted_predictions = test_predictions_tta * test_confidences_tta.unsqueeze(1)
        return weighted_predictions.sum(dim=0).argmax(dim=1).cpu().numpy()

    def ensemble_predictions_with_certainty(self, test_predictions_tta: torch.Tensor, test_uncertainties_tta: torch.Tensor) -> np.ndarray:
        """
        Compute ensembled predictions weighted by certainty.

        Args:
            test_predictions_tta (torch.Tensor): Predictions with TTA.
            test_uncertainties_tta (torch.Tensor): Certainties with TTA.

        Returns:
            np.ndarray: Ensembled predictions weighted by certainty.
        """
        weighted_predictions = test_predictions_tta * test_uncertainties_tta.unsqueeze(1)
        return weighted_predictions.sum(dim=0).argmax(dim=1).cpu().numpy()

    def ensemble_tta_predictions(self, test_predictions_tta: torch.Tensor) -> np.ndarray:
        """
        Compute ensembled predictions with TTA.

        Args:
            test_predictions_tta (torch.Tensor): Predictions with TTA.

        Returns:
            np.ndarray: Ensembled predictions with TTA.
        """
        return test_predictions_tta.mean(dim=0).argmax(dim=1).cpu().numpy()

    def collect_predictions_no_tta(self) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Collect predictions without Test-Time Augmentation (TTA).

        Returns:
            tuple[torch.Tensor, torch.Tensor]: Predictions and labels without TTA.
        """
        test_predictions_no_tta, test_labels_no_tta = [], []
        for batch in self.trainer.datamodule.test_dataloader():
            inputs, labels = batch
            inputs = inputs.to(self.device)
            outputs = self(inputs)

            # Handle model output as in test_step
            if self.config.model.loss_fun in ['TM+CE', 'CE']:
                preds = torch.softmax(outputs[:, :self.config.model.num_classes], dim=1)
            else:
                preds = torch.softmax(outputs[:, :self.config.model.num_classes], dim=1)

            test_predictions_no_tta.append(preds)
            test_labels_no_tta.append(labels)

        test_predictions_no_tta = torch.cat(test_predictions_no_tta)
        test_labels_no_tta = torch.cat(test_labels_no_tta)
        return test_predictions_no_tta, test_labels_no_tta

    def evaluate_no_tta(self, test_predictions_no_tta: torch.Tensor, test_labels_no_tta: torch.Tensor) -> None:
        """
        Evaluate predictions without TTA.

        Args:
            test_predictions_no_tta (torch.Tensor): Predictions without TTA.
            test_labels_no_tta (torch.Tensor): True labels without TTA.
        """
        test_predictions_no_tta = test_predictions_no_tta.argmax(dim=1).cpu()
        test_labels_no_tta = test_labels_no_tta.squeeze().cpu()

        accuracy_no_tta = accuracy_score(test_labels_no_tta, test_predictions_no_tta)
        f1_no_tta = f1_score(test_labels_no_tta, test_predictions_no_tta, average='weighted')
        roc_auc_no_tta = roc_auc_score(test_labels_no_tta, test_predictions_no_tta, average='weighted', multi_class='ovr')

        print(f"Metrics without TTA: Accuracy={accuracy_no_tta}, F1={f1_no_tta}, ROC-AUC={roc_auc_no_tta}")

    def perform_tta_and_ensembling(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Perform Test-Time Augmentation (TTA) and ensembling.

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                Ensembled predictions, labels, confidences, and uncertainties.
        """
        num_tta = self.config.dataset.num_tta
        checkpoint_paths = list(self.config.model.checkpoint_path)

        all_test_predictions_tta = []
        all_test_labels_tta = []
        all_test_confidences_tta = []
        all_test_uncertainties_tta = []

        for checkpoint_path in checkpoint_paths:
            # Reinitialize the model and datamodule for each checkpoint
            self.build_model()
            self.load_state_dict(torch.load(checkpoint_path, weights_only=True)['state_dict'])
            self.to(self.device)
            self.eval()

            datamodule = self.trainer.datamodule  # Reinitialize datamodule if needed
            test_loader = datamodule.test_dataloader()

            test_predictions_tta, test_labels_tta, test_confidences_tta, test_uncertainties_tta = self.collect_tta_predictions(self, test_loader, num_tta)
            all_test_predictions_tta.append(test_predictions_tta)
            all_test_labels_tta.append(test_labels_tta)
            all_test_confidences_tta.append(test_confidences_tta)
            all_test_uncertainties_tta.append(test_uncertainties_tta)

        # Ensemble predictions
        all_test_predictions_tta = torch.stack(all_test_predictions_tta).mean(dim=0)
        all_test_labels_tta = torch.stack(all_test_labels_tta).mode(dim=0).values
        all_test_confidences_tta = torch.stack(all_test_confidences_tta).mean(dim=0)
        all_test_uncertainties_tta = torch.stack(all_test_uncertainties_tta).mean(dim=0)

        return all_test_predictions_tta, all_test_labels_tta, all_test_confidences_tta, all_test_uncertainties_tta

    def collect_tta_predictions(self, model: torch.nn.Module, test_loader: DataLoader, num_tta: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Collect predictions with Test-Time Augmentation (TTA) for a single checkpoint.

        Args:
            model (torch.nn.Module): The model to use for predictions.
            test_loader (DataLoader): The test dataloader.
            num_tta (int): Number of TTA iterations.

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                Predictions, labels, confidences, and uncertainties with TTA.
        """
        test_predictions_tta = []
        test_labels_tta = []
        test_confidences_tta = []
        test_uncertainties_tta = []

        for _ in range(num_tta):
            for batch in test_loader:
                inputs, labels = batch
                inputs = inputs.to(self.device)
                outputs = model(inputs)

                # Handle model output as in test_step
                if self.config.model.loss_fun in ['TM+CE', 'CE']:
                    preds = torch.softmax(outputs[:, :self.config.model.num_classes], dim=1)
                else:
                    preds = torch.softmax(outputs[:, :self.config.model.num_classes], dim=1)

                confidences = preds.max(dim=1).values
                uncertainties = outputs[:, -1] if self.config.model.loss_fun in ['UANLL', 'TM+UANLL'] else torch.zeros_like(confidences)

                test_predictions_tta.append(preds)
                test_labels_tta.append(labels)
                test_confidences_tta.append(confidences)
                test_uncertainties_tta.append(uncertainties)

        test_predictions_tta = torch.cat(test_predictions_tta)
        test_labels_tta = torch.cat(test_labels_tta)
        test_confidences_tta = torch.cat(test_confidences_tta)
        test_uncertainties_tta = torch.cat(test_uncertainties_tta)

        return test_predictions_tta, test_labels_tta, test_confidences_tta, test_uncertainties_tta

    def evaluate_with_tta(self, test_predictions_tta: torch.Tensor, test_labels_tta: torch.Tensor) -> None:
        """
        Evaluate predictions with TTA and ensembling.

        Args:
            test_predictions_tta (torch.Tensor): Predictions with TTA.
            test_labels_tta (torch.Tensor): True labels with TTA.
        """
        print(test_predictions_tta)
        test_predictions_tta = test_predictions_tta.argmax(dim=1).cpu()
        test_labels_tta = test_labels_tta.squeeze().cpu()

        accuracy_tta = accuracy_score(test_labels_tta, test_predictions_tta)
        f1_tta = f1_score(test_labels_tta, test_predictions_tta, average='weighted')
        roc_auc_tta = roc_auc_score(test_labels_tta, test_predictions_tta, average='weighted', multi_class='ovr')

        print(f"Metrics with TTA: Accuracy={accuracy_tta}, F1={f1_tta}, ROC-AUC={roc_auc_tta}")

    def save_predictions_to_csv(self, test_predictions_no_tta: torch.Tensor, test_labels_no_tta: torch.Tensor, test_predictions_tta: torch.Tensor, test_labels_tta: torch.Tensor, test_confidences_tta: torch.Tensor, test_uncertainties_tta: torch.Tensor) -> None:
        """
        Save predictions (without TTA and with TTA) to CSV files.

        Args:
            test_predictions_no_tta (torch.Tensor): Predictions without TTA.
            test_labels_no_tta (torch.Tensor): True labels without TTA.
            test_predictions_tta (torch.Tensor): Predictions with TTA.
            test_labels_tta (torch.Tensor): True labels with TTA.
            test_confidences_tta (torch.Tensor): Confidences with TTA.
            test_uncertainties_tta (torch.Tensor): Uncertainties with TTA.
        """
        # Save predictions without TTA
        no_tta_df = pd.DataFrame({
            'true_labels': test_labels_no_tta.tolist(),
            'predictions': test_predictions_no_tta.argmax(dim=1).tolist()
        })
        no_tta_df.to_csv(f"{self.config.model.name}_{self.config.dataset.seed}_predictions_no_tta.csv", index=False)

        # Save predictions with TTA
        tta_df = pd.DataFrame({
            'true_labels': test_labels_tta.tolist(),
            'predictions': test_predictions_tta.argmax(dim=1).tolist(),
            'confidences': test_confidences_tta.tolist(),
            'uncertainties': test_uncertainties_tta.tolist()
        })
        tta_df.to_csv(f"{self.config.model.name}_{self.config.dataset.seed}_predictions_tta.csv", index=False)

    def handle_weighted_predictions(self, test_predictions_tta: torch.Tensor, test_labels_tta: torch.Tensor, test_confidences_tta: torch.Tensor, test_uncertainties_tta: torch.Tensor) -> None:
        """
        Handle weighted predictions based on confidence and certainty.

        Args:
            test_predictions_tta (torch.Tensor): Predictions with TTA.
            test_labels_tta (torch.Tensor): True labels with TTA.
            test_confidences_tta (torch.Tensor): Confidences with TTA.
            test_uncertainties_tta (torch.Tensor): Uncertainties with TTA.
        """
        # Weighted predictions based on confidence
        weighted_predictions_co = test_predictions_tta.argmax(dim=1)
        weighted_predictions_co[test_confidences_tta < 0.5] = -1  # Example: Ignore low-confidence predictions

        # Weighted predictions based on certainty
        weighted_predictions_ce = test_predictions_tta.argmax(dim=1)
        weighted_predictions_ce[test_uncertainties_tta > 0.5] = -1  # Example: Ignore high-uncertainty predictions

        # Evaluate weighted predictions
        self.evaluate_weighted_predictions(weighted_predictions_co, weighted_predictions_ce, test_labels_tta)

    def evaluate_weighted_predictions(self, weighted_predictions_co: torch.Tensor, weighted_predictions_ce: torch.Tensor, test_labels_tta: torch.Tensor) -> None:
        """
        Evaluate weighted predictions based on confidence and certainty.

        Args:
            weighted_predictions_co (torch.Tensor): Weighted predictions based on confidence.
            weighted_predictions_ce (torch.Tensor): Weighted predictions based on certainty.
            test_labels_tta (torch.Tensor): True labels with TTA.
        """
        # Filter out ignored predictions
        valid_indices_co = weighted_predictions_co != -1
        valid_indices_ce = weighted_predictions_ce != -1
        
        valid_indices_ce = valid_indices_ce.cpu().numpy()
        valid_indices_co = valid_indices_co.cpu().numpy()
        
        test_labels_tta = test_labels_tta.cpu().numpy()
        weighted_predictions_co = weighted_predictions_co.cpu().numpy()
        weighted_predictions_ce = weighted_predictions_ce.cpu().numpy()

        # Evaluate confidence-weighted predictions
        accuracy_co = accuracy_score(test_labels_tta[valid_indices_co], weighted_predictions_co[valid_indices_co])
        f1_co = f1_score(test_labels_tta[valid_indices_co], weighted_predictions_co[valid_indices_co], average='weighted')
        roc_auc_co = roc_auc_score(test_labels_tta[valid_indices_co], weighted_predictions_co[valid_indices_co], average='weighted', multi_class='ovr')

        # Evaluate certainty-weighted predictions
        accuracy_ce = accuracy_score(test_labels_tta[valid_indices_ce], weighted_predictions_ce[valid_indices_ce])
        f1_ce = f1_score(test_labels_tta[valid_indices_ce], weighted_predictions_ce[valid_indices_ce], average='weighted')
        roc_auc_ce = roc_auc_score(test_labels_tta[valid_indices_ce], weighted_predictions_ce[valid_indices_ce], average='weighted', multi_class='ovr')

        print(f"Confidence-Weighted Metrics: Accuracy={accuracy_co}, F1={f1_co}, ROC-AUC={roc_auc_co}")
        print(f"Certainty-Weighted Metrics: Accuracy={accuracy_ce}, F1={f1_ce}, ROC-AUC={roc_auc_ce}")

In [45]:
import timm
import torch.nn as nn

class TimmModel(BaseModel):
    def build_model(self):
        if self.config.model.loss_fun in ['TM+UANLL', 'UANLL']:
            num_classes = self.config.model.num_classes + 1
        else:
            num_classes = self.config.model.num_classes

        self.model = timm.create_model(self.config.model.name, pretrained=self.config.model.pretrained, num_classes=num_classes)
        #self.model.conv1 = nn.Conv2d(self.config.model.input_channel, self.model.conv1.out_channels, kernel_size=self.model.conv1.kernel_size, 
        #                stride=self.model.conv1.stride, padding=self.model.conv1.padding, bias=False)
        
        return self.model

    def forward(self, x):
        return self.model(x)

    def intermediate_forward(self, x):
        # Assuming the last layer before the classifier is the penultimate layer
        for name, module in self.model.named_children():
            if name == 'head':
                break
            x = module(x)
        return x

In [46]:
import torch.nn as nn
import torch.nn.functional as F

class CNN(BaseModel):
    def build_model(self):
        input_channel = self.config.model.input_channel
        if self.config.model.loss_fun in ['TM+UANLL', 'UANLL']:
            print('CNN. Uncertainty toggled')
            n_outputs = self.config.model.num_classes + 1
        else:
            n_outputs = self.config.model.num_classes

        dropout_rate = self.config.model.dropout_rate
        top_bn = self.config.model.top_bn

        self.dropout_rate = dropout_rate
        self.top_bn = top_bn
        self.c1 = nn.Conv2d(input_channel, 128, kernel_size=3, stride=1, padding=1)
        self.c2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.c3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.c4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.c5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.c6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.c7 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0)
        self.c8 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=0)
        self.c9 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=0)
        self.l_c1 = nn.Linear(128, n_outputs)
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm2d(256)
        self.bn6 = nn.BatchNorm2d(256)
        self.bn7 = nn.BatchNorm2d(512)
        self.bn8 = nn.BatchNorm2d(256)
        self.bn9 = nn.BatchNorm2d(128)
        
        # Define batch normalization for the linear layer if needed
        if self.top_bn:
            self.bn_c1 = nn.BatchNorm1d(n_outputs)

    def call_bn(self, bn, x):
        return bn(x)

    def forward(self, x):
        h = x
        h = self.c1(h)
        h = F.leaky_relu(self.call_bn(self.bn1, h), negative_slope=0.01)
        h = self.c2(h)
        h = F.leaky_relu(self.call_bn(self.bn2, h), negative_slope=0.01)
        h = self.c3(h)
        h = F.leaky_relu(self.call_bn(self.bn3, h), negative_slope=0.01)
        h = F.max_pool2d(h, kernel_size=2, stride=2)
        h = F.dropout2d(h, p=self.dropout_rate)

        h = self.c4(h)
        h = F.leaky_relu(self.call_bn(self.bn4, h), negative_slope=0.01)
        h = self.c5(h)
        h = F.leaky_relu(self.call_bn(self.bn5, h), negative_slope=0.01)
        h = self.c6(h)
        h = F.leaky_relu(self.call_bn(self.bn6, h), negative_slope=0.01)
        h = F.max_pool2d(h, kernel_size=2, stride=2)
        h = F.dropout2d(h, p=self.dropout_rate)

        h = self.c7(h)
        h = F.leaky_relu(self.call_bn(self.bn7, h), negative_slope=0.01)
        h = self.c8(h)
        h = F.leaky_relu(self.call_bn(self.bn8, h), negative_slope=0.01)
        h = self.c9(h)
        h = F.leaky_relu(self.call_bn(self.bn9, h), negative_slope=0.01)
        h = F.avg_pool2d(h, kernel_size=h.data.shape[2])

        h = h.view(h.size(0), h.size(1))
        logit = self.l_c1(h)
        if self.top_bn:
            logit = self.call_bn(self.bn_c1, logit)
        return logit

    def intermediate_forward(self, x):
        h = x
        h = self.c1(h)
        h = F.leaky_relu(self.call_bn(self.bn1, h), negative_slope=0.01)
        h = self.c2(h)
        h = F.leaky_relu(self.call_bn(self.bn2, h), negative_slope=0.01)
        h = self.c3(h)
        h = F.leaky_relu(self.call_bn(self.bn3, h), negative_slope=0.01)
        h = F.max_pool2d(h, kernel_size=2, stride=2)
        h = F.dropout2d(h, p=self.dropout_rate)

        h = self.c4(h)
        h = F.leaky_relu(self.call_bn(self.bn4, h), negative_slope=0.01)
        h = self.c5(h)
        h = F.leaky_relu(self.call_bn(self.bn5, h), negative_slope=0.01)
        h = self.c6(h)
        h = F.leaky_relu(self.call_bn(self.bn6, h), negative_slope=0.01)
        h = F.max_pool2d(h, kernel_size=2, stride=2)
        h = F.dropout2d(h, p=self.dropout_rate)

        h = self.c7(h)
        h = F.leaky_relu(self.call_bn(self.bn7, h), negative_slope=0.01)
        h = self.c8(h)
        h = F.leaky_relu(self.call_bn(self.bn8, h), negative_slope=0.01)
        h = self.c9(h)
        h = F.leaky_relu(self.call_bn(self.bn9, h), negative_slope=0.01)
        h = F.avg_pool2d(h, kernel_size=h.data.shape[2])

        h = h.view(h.size(0), h.size(1))
        return h

In [47]:
import hydra
from omegaconf import DictConfig
import pytorch_lightning as pl
from src.utils.clearml_logger import ClearMLLogger
from src.utils.utils import set_seed
import torch
import os
import pandas as pd

import logging
import sys


logging.basicConfig(
    stream=sys.stderr, 
    level=logging.DEBUG, 
    format="%(asctime)s %(levelname)s: %(message)s"
)

def test(config: DictConfig,
         seed: int,
         checkpoint_path: str,
         logger: callable,
         unique_id: str):
    
    set_seed(seed)
    data_module = ISICDataModule(config.dataset)
    data_module.setup()
    
    config.dataset.seed = seed

    print(f"Testing model trained on seed: {seed}")

    if config.model.name == 'CNN':
        model = CNN.load_from_checkpoint(checkpoint_path, config=config)
    elif config.model.name in config.timm_models:
        model = TimmModel.load_from_checkpoint(checkpoint_path[0], config=config)
    else:
        raise ValueError(f"Unknown model: {config.model.name}")

    trainer = pl.Trainer(**config.trainer,
                         logger=logger,
                         deterministic=True)

    trainer.test(model, datamodule=data_module)

    # Load the summary CSV generated by the model
    #summary_df = pd.read_csv(f"{config.model.name}_{config.dataset.seed}_summary.csv")
    #return summary_df

# Testing

In [48]:
hydra.core.global_hydra.GlobalHydra.instance().clear()

In [49]:
import hydra
from hydra import compose, initialize
from src.utils.clearml_logger import ClearMLLogger
from omegaconf import OmegaConf, DictConfig
import uuid
import logging 
import pandas as pd

logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)

def main():

    # global initialization
    with initialize(version_base=None, config_path="configs", job_name="test_app"):
        config = compose(config_name="config", overrides=["+experiment=tm_ce_resnet"])
        print(OmegaConf.to_yaml(config))
    
    all_predictions_df = pd.DataFrame()
    
    logging.info(f'Checkpoint path: {config.model.checkpoint_path[0]}')
    #checkpoint_path = config.model.checkpoint_path
    
    #logger = ClearMLLogger(project_name="ISIC_2024",
        #                   task_name=f"{config.model.name}_{seed}_{config.model.loss_fun}_testing",
    #                    offline=config.offline)
    
    # _, predictions_df = test(config = config,
    #                         seed = 42,
    #                         checkpoint_path = config.model.checkpoint_path,
    #                         logger = None,
    #                         unique_id=uuid.uuid4)
    
    # all_predictions_df = pd.concat([all_predictions_df, predictions_df])
                        
        

if __name__ == "__main__":
    main()

2024-12-24 16:06:42,876 DEBUG: Setting JobRuntime:name=test_app
2024-12-24 16:06:42,917 INFO: Checkpoint path: /repo/uncertainty_skin/multirun/checkpoints_multirun/resnet18_0_TM+CE_61539e93-55c3-43e5-aca5-4e35c95aa493_epoch=99.ckpt


device: cuda
timm_models:
- resnet50
- resnet18
offline: false
clearml_model_id: a65ecaf5e7644d8398383cfc67385ad6
dataset:
  name: isic_balanced
  path: /repo/uncertainty_skin/data/isic_balanced/images
  annotations_path: /repo/uncertainty_skin/data/isic_balanced/dataset.csv
  img_size: 224
  crop_scale:
  - 0.8
  - 1.0
  crop_scale_tta:
  - 0.8
  - 1.0
  batch_size: 32
  num_workers: 0
  target_name: target
  class_names:
  - benign
  - malignant
  train_size: 0.7
  val_size: 0.15
  test_size: 0.15
  seed: 42
  bagging: true
  bagging_size: 15000
  num_tta: 100
  noise: 0.0
  fixed: false
model:
  name: resnet18
  num_classes: 2
  pretrained: true
  input_channel: 3
  dropout_rate: 0.25
  top_bn: false
  loss_fun: TM+CE
  label_smoothing: 0.1
  lr: 0.0001
  checkpoint_path:
  - /repo/uncertainty_skin/multirun/checkpoints_multirun/resnet18_0_TM+CE_61539e93-55c3-43e5-aca5-4e35c95aa493_epoch=99.ckpt
  - /repo/uncertainty_skin/multirun/checkpoints_multirun/resnet18_3_TM+CE_61539e93-55c3-4

In [50]:
with initialize(version_base=None, config_path="configs", job_name="test_app"):
    config = compose(config_name="config", overrides=["+experiment=tm_ce_resnet"])
    print(OmegaConf.to_yaml(config))

2024-12-24 16:06:42,931 DEBUG: Setting JobRuntime:name=test_app


device: cuda
timm_models:
- resnet50
- resnet18
offline: false
clearml_model_id: a65ecaf5e7644d8398383cfc67385ad6
dataset:
  name: isic_balanced
  path: /repo/uncertainty_skin/data/isic_balanced/images
  annotations_path: /repo/uncertainty_skin/data/isic_balanced/dataset.csv
  img_size: 224
  crop_scale:
  - 0.8
  - 1.0
  crop_scale_tta:
  - 0.8
  - 1.0
  batch_size: 32
  num_workers: 0
  target_name: target
  class_names:
  - benign
  - malignant
  train_size: 0.7
  val_size: 0.15
  test_size: 0.15
  seed: 42
  bagging: true
  bagging_size: 15000
  num_tta: 100
  noise: 0.0
  fixed: false
model:
  name: resnet18
  num_classes: 2
  pretrained: true
  input_channel: 3
  dropout_rate: 0.25
  top_bn: false
  loss_fun: TM+CE
  label_smoothing: 0.1
  lr: 0.0001
  checkpoint_path:
  - /repo/uncertainty_skin/multirun/checkpoints_multirun/resnet18_0_TM+CE_61539e93-55c3-43e5-aca5-4e35c95aa493_epoch=99.ckpt
  - /repo/uncertainty_skin/multirun/checkpoints_multirun/resnet18_3_TM+CE_61539e93-55c3-4

In [51]:
model = TimmModel(config)

2024-12-24 16:06:43,035 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)
2024-12-24 16:06:43,036 DEBUG: Resetting dropped connection: huggingface.co
2024-12-24 16:06:43,438 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:06:43,443 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


CE loss is an additional loss term (module 1)


In [52]:
datamodule = ISICDataModule(config.dataset)
datamodule.setup()

15000
(15000, 8)
target
0.0    5361
1.0    5139
Name: count, dtype: int64
target
0.0    1149
1.0    1101
Name: count, dtype: int64
target
0.0    1149
1.0    1101
Name: count, dtype: int64
Index([ 2442,  3224,   706, 12004,  9303,  5278, 13743, 13692,  4484,  4540,
       ...
       12959,  1779,   304, 14470,  3802,  1981, 10774,  3524,  9400, 12538],
      dtype='int64', length=2250)


## Single

In [53]:
device = torch.device('cuda')

In [54]:
models = []

# Load each checkpoint and store the model
for checkpoint_path in config.model.checkpoint_path:
    # Create a new instance of the model
    model = TimmModel(config)
    model.load_state_dict(torch.load(checkpoint_path, weights_only=True)['state_dict'])  # Replace with your actual model class
    
    # Set the model to evaluation mode
    model.eval()
    
    # Move the model to GPU
    model = model.to('cuda')
    
    # Add the model to the list
    models.append(model)

2024-12-24 16:06:43,609 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)
2024-12-24 16:06:43,791 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:06:43,792 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2024-12-24 16:06:43,908 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)


CE loss is an additional loss term (module 1)


2024-12-24 16:06:44,099 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:06:44,104 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2024-12-24 16:06:44,215 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)


CE loss is an additional loss term (module 1)


2024-12-24 16:06:44,407 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:06:44,412 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2024-12-24 16:06:44,516 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)


CE loss is an additional loss term (module 1)


2024-12-24 16:06:44,716 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:06:44,720 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2024-12-24 16:06:44,822 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)


CE loss is an additional loss term (module 1)


2024-12-24 16:06:45,006 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:06:45,007 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


CE loss is an additional loss term (module 1)


In [55]:
from torchmetrics import Accuracy, F1Score, AUROC
from tqdm import tqdm

# Initialize lists to store metrics for individual models
individual_accuracies = []
individual_f1_scores = []
individual_rocaucs = []

# Test loop for individual models
with torch.no_grad():
    for i, model in enumerate(models):
        # Reset metrics for each model
        accuracy = Accuracy(task="multiclass", num_classes=2).to(device)
        f1 = F1Score(task="multiclass", num_classes=2).to(device)
        rocauc = AUROC(task="multiclass", num_classes=2).to(device)
        
        for image, label in datamodule.test_dataloader():
            image = image.to(device)
            label = label.to(device)
            
            # Get predictions for the current model
            predictions = model(image)
            _, predicted_labels = torch.max(predictions, dim=1)
            
            # Update metrics
            accuracy.update(predicted_labels, label)
            f1.update(predicted_labels, label)
            rocauc.update(predictions, label)  # ROC-AUC uses raw predictions
        
        # Compute and store metrics
        individual_accuracies.append(accuracy.compute())
        individual_f1_scores.append(f1.compute())
        individual_rocaucs.append(rocauc.compute())

# Compute mean and standard deviation for individual metrics
mean_individual_accuracy = torch.mean(torch.tensor(individual_accuracies))
std_individual_accuracy = torch.std(torch.tensor(individual_accuracies))

mean_individual_f1 = torch.mean(torch.tensor(individual_f1_scores))
std_individual_f1 = torch.std(torch.tensor(individual_f1_scores))

mean_individual_rocauc = torch.mean(torch.tensor(individual_rocaucs))
std_individual_rocauc = torch.std(torch.tensor(individual_rocaucs))

print(f"Individual Metrics:")
print(f"  Accuracy: Mean = {mean_individual_accuracy:.4f}, Std = {std_individual_accuracy:.4f}")
print(f"  F1 Score: Mean = {mean_individual_f1:.4f}, Std = {std_individual_f1:.4f}")
print(f"  ROC-AUC: Mean = {mean_individual_rocauc:.4f}, Std = {std_individual_rocauc:.4f}")



Individual Metrics:
  Accuracy: Mean = 0.8052, Std = 0.1778
  F1 Score: Mean = 0.8052, Std = 0.1778
  ROC-AUC: Mean = 0.8469, Std = 0.1943


## TTA

In [56]:
from tqdm import tqdm

# Initialize lists to store metrics for TTA
tta_accuracies = []
tta_f1_scores = []
tta_rocaucs = []

#num_tta = config.dataset.num_tta
num_tta = 10

# Test loop for TTA
with torch.no_grad():
    for i, model in enumerate(models):
        # Reset metrics for each model
        tta_accuracy = Accuracy(task="multiclass", num_classes=2).to(device)
        tta_f1 = F1Score(task="multiclass", num_classes=2).to(device)
        tta_rocauc = AUROC(task="multiclass", num_classes=2).to(device)
        
        for image, label in tqdm(datamodule.tta_dataloader()):
            image = image.to(device)
            label = label.to(device)
            
            # Initialize a tensor to store predictions for all TTA iterations
            all_tta_predictions = []
            
            # Apply TTA for `num_tta` iterations
            for _ in range(num_tta):
                # Apply augmentation to the input image
                # Get predictions for the current model (or ensemble)
                predictions = model(image) 
                
                # Store predictions
                all_tta_predictions.append(predictions)
            
            # Stack predictions from all TTA iterations
            all_tta_predictions = torch.stack(all_tta_predictions, dim=0)  # Shape: (num_tta, batch_size, num_classes)
            
            # Combine predictions (e.g., average them)
            tta_prediction = torch.mean(all_tta_predictions, dim=0)  # Average over TTA iterations
            
            # Get the predicted class labels
            _, predicted_labels = torch.max(tta_prediction, dim=1)
            
            # Update metrics
            tta_accuracy.update(predicted_labels, label)
            tta_f1.update(predicted_labels, label)
            tta_rocauc.update(tta_prediction, label)  # ROC-AUC uses raw predictions

# Compute final TTA metrics
final_tta_accuracy = tta_accuracy.compute()
final_tta_f1 = tta_f1.compute()
final_tta_rocauc = tta_rocauc.compute()

print(f"TTA Metrics (num_tta={num_tta}):")
print(f"  Accuracy: {final_tta_accuracy:.4f}")
print(f"  F1 Score: {final_tta_f1:.4f}")
print(f"  ROC-AUC: {final_tta_rocauc:.4f}")

100%|██████████| 71/71 [00:07<00:00,  9.72it/s]
100%|██████████| 71/71 [00:07<00:00,  9.92it/s]
100%|██████████| 71/71 [00:07<00:00,  9.88it/s]
100%|██████████| 71/71 [00:07<00:00,  9.85it/s]
100%|██████████| 71/71 [00:07<00:00,  9.86it/s]

TTA Metrics (num_tta=10):
  Accuracy: 0.9036
  F1 Score: 0.9036
  ROC-AUC: 0.9504





## Ensemble

In [57]:
import torch
from torch.func import stack_module_state, functional_call
import copy
from torchmetrics import Accuracy 

# List to store individual models
pl.seed_everything(42)

models = []

# Load each checkpoint and store the model
for checkpoint_path in config.model.checkpoint_path:
    # Create a new instance of the model
    model = TimmModel(config)
    model.load_state_dict(torch.load(checkpoint_path, weights_only=True)['state_dict'])  # Replace with your actual model class
    
    # Set the model to evaluation mode
    model.eval()
    
    # Move the model to GPU
    model = model.to('cuda')
    
    # Add the model to the list
    models.append(model)

# Step 1: Stack the states of all models
params, buffers = stack_module_state(models)

# Move stacked parameters and buffers to GPU
params = {k: v.to('cuda') for k, v in params.items()}
buffers = {k: v.to('cuda') for k, v in buffers.items()}

# Step 2: Create a stateless version of the model
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')

# Step 3: Define a functional model
def fmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x,))

# Step 4: Use vmap to vectorize the application of the functional model
# Option 1: Use the same minibatch for all models
def ensemble_predictions(predictions):
    return torch.mean(predictions, dim=0)

# Move dataloader to GPU
device = 'cuda'

# Initialize accuracy metric
accuracy = Accuracy(task="multiclass", num_classes=2).to(device)

# Initialize metrics for ensemble
ensemble_accuracy = Accuracy(task="multiclass", num_classes=2).to(device)
ensemble_f1 = F1Score(task="multiclass", num_classes=2).to(device)
ensemble_rocauc = AUROC(task="multiclass", num_classes=2).to(device)

# Test loop for ensemble
with torch.no_grad():
    for image, label in datamodule.test_dataloader():
        image = image.to(device)
        label = label.to(device)
        
        # Apply all models to the input image
        predictions = torch.vmap(fmodel, in_dims=(0, 0, None))(params, buffers, image)
        
        # Combine predictions (e.g., average them)
        ensemble_prediction = ensemble_predictions(predictions)
        
        # Get the predicted class labels
        _, predicted_labels = torch.max(ensemble_prediction, dim=1)
        
        # Update metrics
        ensemble_accuracy.update(predicted_labels, label)
        ensemble_f1.update(predicted_labels, label)
        ensemble_rocauc.update(ensemble_prediction, label)  # ROC-AUC uses raw predictions

# Compute final ensemble metrics
final_ensemble_accuracy = ensemble_accuracy.compute()
final_ensemble_f1 = ensemble_f1.compute()
final_ensemble_rocauc = ensemble_rocauc.compute()

print(f"Ensemble Metrics:")
print(f"  Accuracy: {final_ensemble_accuracy:.4f}")
print(f"  F1 Score: {final_ensemble_f1:.4f}")
print(f"  ROC-AUC: {final_ensemble_rocauc:.4f}")

Seed set to 42
2024-12-24 16:07:34,948 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)
2024-12-24 16:07:35,132 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:07:35,133 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2024-12-24 16:07:35,252 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)


CE loss is an additional loss term (module 1)


2024-12-24 16:07:35,436 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:07:35,437 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2024-12-24 16:07:35,568 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)


CE loss is an additional loss term (module 1)


2024-12-24 16:07:35,977 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:07:35,978 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2024-12-24 16:07:36,100 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)


CE loss is an additional loss term (module 1)


2024-12-24 16:07:36,368 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:07:36,369 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2024-12-24 16:07:36,495 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet18.a1_in1k)


CE loss is an additional loss term (module 1)


2024-12-24 16:07:36,676 DEBUG: https://huggingface.co:443 "HEAD /timm/resnet18.a1_in1k/resolve/main/model.safetensors HTTP/11" 302 0
2024-12-24 16:07:36,677 INFO: [timm/resnet18.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


CE loss is an additional loss term (module 1)
Ensemble Metrics:
  Accuracy: 0.9160
  F1 Score: 0.9160
  ROC-AUC: 0.9605


## Ensemble TTA

In [58]:
from tqdm import tqdm

# Initialize lists to store metrics for TTA
tta_accuracies = []
tta_f1_scores = []
tta_rocaucs = []

#num_tta = config.dataset.num_tta
num_tta = 10

# Test loop for TTA
with torch.no_grad():
    for i, model in enumerate(models):
        # Reset metrics for each model
        tta_accuracy = Accuracy(task="multiclass", num_classes=2).to(device)
        tta_f1 = F1Score(task="multiclass", num_classes=2).to(device)
        tta_rocauc = AUROC(task="multiclass", num_classes=2).to(device)
        
        for image, label in tqdm(datamodule.tta_dataloader()):
            image = image.to(device)
            label = label.to(device)
            
            # Initialize a tensor to store predictions for all TTA iterations
            all_tta_predictions = []
            
            # Apply TTA for `num_tta` iterations
            for _ in range(num_tta):
                # Apply all models to the input image
                predictions = torch.vmap(fmodel, in_dims=(0, 0, None))(params, buffers, image)
                
                # Combine predictions (e.g., average them)
                ensemble_prediction = ensemble_predictions(predictions)
                
                ensemble_prediction = torch.mean(predictions, dim=0)
                
                # Store predictions
                all_tta_predictions.append(predictions)
            
            # Stack predictions from all TTA iterations
            all_tta_predictions = torch.stack(all_tta_predictions, dim=0)  # Shape: (num_tta, batch_size, num_classes)
            
            # Combine predictions (e.g., average them)
            tta_prediction = torch.mean(all_tta_predictions, dim=0).float()  # Average over TTA iterations
            
            # Get the predicted class labels
            _, predicted_labels = torch.max(tta_prediction, dim=1)
            
            # Update metrics
            tta_accuracy.update(predicted_labels.float(), label)
            tta_f1.update(predicted_labels.float(), label)
            tta_rocauc.update(tta_prediction, label)  # ROC-AUC uses raw predictions

# Compute final TTA metrics
final_tta_accuracy = tta_accuracy.compute()
final_tta_f1 = tta_f1.compute()
final_tta_rocauc = tta_rocauc.compute()

print(f"TTA Metrics (num_tta={num_tta}):")
print(f"  Accuracy: {final_tta_accuracy:.4f}")
print(f"  F1 Score: {final_tta_f1:.4f}")
print(f"  ROC-AUC: {final_tta_rocauc:.4f}")

  0%|          | 0/71 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (5) must match the size of tensor b (32) at non-singleton dimension 0