In [None]:
!git clone https://github.com/JJGO/UniverSeg
!python -m pip install -r ./UniverSeg/requirements.txt

In [None]:
import math
import itertools
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import einops as E
from collections import defaultdict
import pathlib
import os
from dataclasses import dataclass
from typing import Literal, Optional, Tuple
from torch.utils.data import Dataset
import PIL
from PIL import Image
import pandas as pd
import seaborn as sns
import numpy as np
import PIL.Image
import torch.nn.functional as F
import torch
import logging
from sklearn.model_selection import ParameterGrid
from time import time
from tqdm import tqdm
import sys


In [None]:

sys.path.append('UniverSeg')


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

from universeg import universeg
model = universeg(pretrained=True)
_ = model.to(device)

In [None]:


# Define paths 
dataset_dir = '/kaggle/input/partitioned-data/partitioned_dataset_original'


test_input_folder = os.path.join(dataset_dir, 'images', 'test')
test_mask_folder = os.path.join(dataset_dir, 'masks', 'test')

support_input_folder = os.path.join(dataset_dir, 'images', 'train')
support_mask_folder = os.path.join(dataset_dir, 'masks', 'train')

# Print the directories to verify
print("Test Input Folder:", test_input_folder)
print("Test Mask Folder:", test_mask_folder)
print("Support Input Folder:", support_input_folder)
print("Support Mask Folder:", support_mask_folder)



In [None]:

def process_image(image_path: pathlib.Path, size: Tuple[int, int]):
    """Process input image with hot encoded selection of areas."""
    # Load input image
    img = PIL.Image.open(image_path)
    img = img.resize(size, resample=PIL.Image.BILINEAR)
    img = img.convert("L")
    img = np.array(img)
    img = img.astype(np.float32)
    return img


def process_seg(path: pathlib.Path, size: Tuple[int, int]) -> np.ndarray:
    """Process segmentation mask."""
    seg = PIL.Image.open(path)
    seg = seg.resize(size, resample=PIL.Image.NEAREST)
    seg = seg.convert("L")
    seg = np.array(seg)

    # One-hot encoded representation of segmentation mask
    seg = np.stack([seg == 0, seg == 150, seg == 76])
    seg = seg.astype(np.float32)

    return seg


def load_dataset(input_folder: str, mask_folder: str, size: Tuple[int, int] = (128, 128)):
    """Load dataset from input and mask folders."""
    data = []
    input_path = pathlib.Path(input_folder)
    mask_path = pathlib.Path(mask_folder)

    # Sort images based on numerical values in filenames
    input_files = sorted(input_path.glob("*.png"), key=lambda x: int(x.stem.split('_')[-1]))

    for file in input_files:
        img = process_image(file, size=size)
        img_name = file.stem

        # Load segmentation mask
        seg_file = mask_path / f"{img_name}_mask.png"
        if seg_file.exists():
            seg = process_seg(seg_file, size=size)
        else:
            print(f"Mask file '{seg_file}' not found. Skipping.")
            continue

        data.append((img / 255.0, seg))

    return data


@dataclass
class JNU_FMI(Dataset):
    input_folder: str
    mask: str
    size: Tuple[int, int] = (128, 128)
    label: Optional[Literal["head", "symp", "background"]] = None
    
    def __post_init__(self):
        self._data = load_dataset(self.input_folder, self.mask, size=self.size)
        T = torch.from_numpy

        # Convert to tensors and add channel dimension to images
        self._data = [(T(x)[None], T(y)) for x, y in self._data]
        
        if self.label is not None:
            self._ilabel = {"head": 1, "symp": 2, "background": 0}[self.label]
        
        self.idxs = list(range(len(self._data)))
        
    def __len__(self):
        return len(self.idxs)
    
    def __getitem__(self, idx):
        img, seg = self._data[self.idxs[idx]]
        if self.label is not None:
            seg = seg[self._ilabel][None]  
           # seg = seg[None]  # Add channel dimension back
        #print(f"Image shape: {img.shape}, Segmentation shape: {seg.shape}")  # Debugging print statement
        return img, seg

In [None]:

 # Change to 'symp' or 'head' for other labels
label = 'background'
d_support = JNU_FMI(input_folder=support_input_folder, mask=support_mask_folder, size=(128, 128), label=label)
print(f"Support set length: {len(d_support)}")
d_test = JNU_FMI(input_folder=test_input_folder, mask=test_mask_folder, size=(128, 128), label=label)
print(f"Test set length: {len(d_test)}")

In [None]:
def dice_score(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred = y_pred.long()
    y_true = y_true.long()
    intersection = (y_pred * y_true).sum().item()
    total = y_pred.sum().item() + y_true.sum().item()
    if total == 0:
        return 1.0  
    score = 2 * intersection / total
    return score

def accuracy_score(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred_bin = (y_pred > 0.5).float()
    if y_true.sum().item() == 0 and y_pred_bin.sum().item() == 0:
        return 1.0
    correct = (y_pred_bin == y_true).float().sum()
    total = y_true.numel()
    return (correct / total).item()

def sensitivity_score(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred_bin = (y_pred > 0.5).float()
    true_positive = (y_pred_bin * y_true).sum().item()
    false_negative = ((1 - y_pred_bin) * y_true).sum().item()
    if true_positive + false_negative == 0:
        return 1.0
    return true_positive / (true_positive + false_negative)

def precision_score(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred_bin = (y_pred > 0.5).float()
    true_positive = (y_pred_bin * y_true).sum().item()
    false_positive = (y_pred_bin * (1 - y_true)).sum().item()
    if true_positive + false_positive == 0:
        return 1.0
    return true_positive / (true_positive + false_positive)

def jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    y_pred_bin = (y_pred > 0.5).float()
    intersection = (y_pred_bin * y_true).sum().item()
    union = y_pred_bin.sum().item() + y_true.sum().item() - intersection
    if union == 0:
        return 1.0
    return intersection / union


In [None]:

@torch.no_grad()
def inferencesupport(model, image, label, support_images, support_labels, threshold=0.5):
    image, label = image.to(device), label.to(device)

    # inference
    logits = model(
        image[None],
        support_images[None],
        support_labels[None]
    )[0]  

    soft_pred = torch.sigmoid(logits)
    hard_pred = (soft_pred > threshold).float().clip(0, 1)

    # score
    dicescore = dice_score(hard_pred, label)
    accuracy = accuracy_score(hard_pred, label)
    sensitivity = sensitivity_score(hard_pred, label)
    precision = precision_score(hard_pred, label)
    jaccard = jaccard_score(hard_pred, label)

    # return a dictionary of all relevant variables
    return {
        'Image': image,
        'Soft Prediction': soft_pred,
        'Prediction': hard_pred,
        'Ground Truth': label,
        'score': dicescore,
        'accuracy': accuracy,
        'sensitivity': sensitivity,
        'precision': precision,
        'jaccard': jaccard
    }


In [None]:
#code for running experiments

In [None]:
# Setup logging
logging.basicConfig(filename='experiment.log', level=logging.INFO, format='%(asctime)s %(message)s')

# Set global seed for reproducibility
global_seed = 42
np.random.seed(global_seed)
torch.manual_seed(global_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(global_seed)


# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def sample_support(seed, support_size):
    np.random.seed(seed)
    torch.manual_seed(seed)
    rng = np.random.default_rng(seed)
    idxs = rng.integers(0, len(d_support), size=support_size)
    support_images, support_labels = zip(*[d_support[i] for i in idxs])
    support_images = torch.stack(support_images).to(device)
    support_labels = torch.stack(support_labels).to(device)
    return support_images, support_labels

def get_model_size(model):
    torch.save(model.state_dict(), "temp.pth")
    model_size = os.path.getsize("temp.pth") / (1024 * 1024)  
    os.remove("temp.pth")
    return model_size

def run_experiment(support_size, n_ensemble, threshold=0.5):
    # Create empty DataFrames for each metric
    df_dicescore = pd.DataFrame(columns=['dice_score', 'support_size', 'ensemble_count', 'threshold'])
    df_accuracy = pd.DataFrame(columns=['accuracy'])
    df_sensitivity = pd.DataFrame(columns=['sensitivity'])
    df_precision = pd.DataFrame(columns=['precision'])
    df_jaccard = pd.DataFrame(columns=['jaccard'])

    # Get various support sets
    seeds = range(n_ensemble)
    supports = {
        seed: sample_support(seed, support_size)
        for seed in range(n_ensemble)
    }

    # Initialize timing accumulator
    inference_times = []

    # Go through the number of experiments
    for i in tqdm(range(len(d_test)), desc="Processing images"):  # Process the entire dataset
        results = defaultdict(list)
        for j in range(n_ensemble):
            # Set the seed for reproducibility
            seed = global_seed + j
            np.random.seed(seed)
            torch.manual_seed(seed)
            support_images, support_labels = sample_support(seed, support_size)
            image, label = d_test[i]

            start_time = time()  # Start timing
            vals = inferencesupport(model, image, label, support_images, support_labels, threshold)
            end_time = time()  # End timing

            inference_time = end_time - start_time
            inference_times.append(inference_time)

            for k, v in vals.items():
                results[k].append(v)

        # Aggregate results
        ensemble = torch.mean(torch.stack(results['Soft Prediction']), dim=0)
        results['Soft Prediction'].append(ensemble)
        results['Prediction'].append((ensemble > threshold).float())
        results['Ground Truth'].append(label)
        results['score'].append(dice_score((ensemble > threshold).float(), label.to(device)))
        results['jaccard'].append(jaccard_score((ensemble > threshold).float(), label.to(device)))
        results['sensitivity'].append(sensitivity_score((ensemble > threshold).float(), label.to(device)))
        results['precision'].append(precision_score((ensemble > threshold).float(), label.to(device)))

        # Append the metrics to the DataFrames
        df_dicescore = pd.concat([df_dicescore, pd.DataFrame({'dice_score': [results['score'][-1]], 'support_size': support_size, 'ensemble_count': n_ensemble, 'threshold': threshold})], ignore_index=True)
        if 'accuracy' in results:
            df_accuracy = pd.concat([df_accuracy, pd.DataFrame({'accuracy': [results['accuracy'][-1]]})], ignore_index=True)
        df_sensitivity = pd.concat([df_sensitivity, pd.DataFrame({'sensitivity': [results['sensitivity'][-1]]})], ignore_index=True)
        df_precision = pd.concat([df_precision, pd.DataFrame({'precision': [results['precision'][-1]]})], ignore_index=True)
        df_jaccard = pd.concat([df_jaccard, pd.DataFrame({'jaccard': [results['jaccard'][-1]]})], ignore_index=True)

    # Compute mean values and save to a single DataFrame
    mean_values = {
        'dice_score': df_dicescore['dice_score'].mean(),
        'accuracy': df_accuracy['accuracy'].mean(),
        'sensitivity': df_sensitivity['sensitivity'].mean(),
        'precision': df_precision['precision'].mean(),
        'jaccard': df_jaccard['jaccard'].mean(),
        'support_size': support_size,
        'ensemble_count': n_ensemble,
        'inference_time': np.mean(inference_times),  # mean inference time
        'threshold': threshold
    }

    mean_df = pd.DataFrame(mean_values, index=[0])

    # Calculate throughput
    total_images = len(d_test) * n_ensemble
    total_time = np.sum(inference_times)
    throughput = total_images / total_time

    # Add throughput and model size to mean_df
    mean_df['throughput'] = throughput
    mean_df['model_size'] = get_model_size(model)
    
    return df_dicescore, df_accuracy, df_sensitivity, df_precision, df_jaccard, mean_df

# Define parameter grid for hyperparameter tuning
param_grid = {
    'support_size': [3, 5, 10, 14, 20],
    'n_ensemble': [5, 8],
    'threshold': [0.3]  # use '0.5' or '0.7' to run evaluation with the other thresholds.

}

# Initialize lists to store results
all_dicescores = []
all_accuracies = []
all_sensitivities = []
all_precisions = []
all_jaccards = []
all_means = []

# Run experiments sequentially
for params in ParameterGrid(param_grid):
    support_size = params['support_size']
    n_ensemble = params['n_ensemble']
    threshold = params['threshold']
    
    df_dicescore, df_accuracy, df_sensitivity, df_precision, df_jaccard, mean_df = run_experiment(support_size, n_ensemble, threshold)
    
    # Save the DataFrames to CSV files
    df_dicescore.to_csv(f'dicescore_support_{support_size}_ensemble_{n_ensemble}_threshold_{threshold}.csv', index=False)
    df_accuracy.to_csv(f'accuracy_support_{support_size}_ensemble_{n_ensemble}_threshold_{threshold}.csv', index=False)
    df_sensitivity.to_csv(f'sensitivity_support_{support_size}_ensemble_{n_ensemble}_threshold_{threshold}.csv', index=False)
    df_precision.to_csv(f'precision_support_{support_size}_ensemble_{n_ensemble}_threshold_{threshold}.csv', index=False)
    df_jaccard.to_csv(f'jaccard_support_{support_size}_ensemble_{n_ensemble}_threshold_{threshold}.csv', index=False)
    mean_df.to_csv(f'mean_metrics_support_{support_size}_ensemble_{n_ensemble}_threshold_{threshold}.csv', index=False)
    
    # Store results in lists
    all_dicescores.append(df_dicescore)
    all_accuracies.append(df_accuracy)
    all_sensitivities.append(df_sensitivity)
    all_precisions.append(df_precision)
    all_jaccards.append(df_jaccard)
    all_means.append(mean_df)

# Combine all means into a single DataFrame
combined_means = pd.concat(all_means, ignore_index=True)
combined_means.to_csv('combined_mean_metrics_0.3.csv', index=False)

# Print combined means
print("Combined Mean Metrics:")
print(combined_means)




In [None]:
#Plots

In [None]:
# Load the combined metrics CSV file
combined_means = pd.read_csv('/kaggle/working/combined_mean_metrics_0.3.csv')

# Set a color palette
colors = sns.color_palette("husl", len(combined_means.columns[:-5]))

# Function to plot and save the metrics for each threshold
def plot_metrics_for_threshold(combined_means, threshold):
    # Filter the combined_means DataFrame for the current threshold
    threshold_combined_means = combined_means[combined_means['threshold'] == threshold]
    
    # Plot mean values line plot
    plt.figure(figsize=(10, 6))
    x_labels = [f"({int(row['support_size'])}, {int(row['ensemble_count'])})" for _, row in threshold_combined_means.iterrows()]
    unique_x_labels = sorted(set(x_labels), key=lambda x: (int(float(x.split(", ")[0][1:])), int(float(x.split(", ")[1][:-1]))))
    
    for metric in threshold_combined_means.columns[:-6]:  # Exclude the last five columns ('support_size', 'ensemble_count', 'inference_time', 'throughput', 'model_size')
        plt.plot(x_labels, threshold_combined_means[metric], label=metric)

    plt.xlabel('Experiment (Support Size, Ensemble Count)')
    plt.ylabel('Mean Value')
    plt.title(f'Mean Values for Different Metrics (Threshold {threshold})')
    plt.xticks(unique_x_labels, rotation=45)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'mean_values_plot_threshold_{threshold}.png')
    plt.show()

    # Plot inference time
    plt.figure(figsize=(10, 6))
    plt.plot(x_labels, threshold_combined_means['inference_time'], marker='o')
    plt.xlabel('Experiment (Support Size, Ensemble Count)')
    plt.ylabel('Inference Time (s)')
    plt.title(f'Inference Time for Different Experiments (Threshold {threshold})')
    plt.xticks(unique_x_labels, rotation=45)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'inference_time_plot_threshold_{threshold}.png')
    plt.show()

    # Plot throughput
    plt.figure(figsize=(10, 6))
    plt.plot(x_labels, threshold_combined_means['throughput'], marker='o')
    plt.xlabel('Experiment (Support Size, Ensemble Count)')
    plt.ylabel('Throughput (images/s)')
    plt.title(f'Throughput for Different Experiments (Threshold {threshold})')
    plt.xticks(unique_x_labels, rotation=45)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'throughput_plot_threshold_{threshold}.png')
    plt.show()
    
        # Plot boxplot for dice scores
    plt.figure(figsize=(12, 8))
    all_dicescores_combined.boxplot(column='dice_score', by=['support_size', 'ensemble_count'])
    plt.xlabel('Experiment (Support Size and Ensemble Count)')
    plt.ylabel('Dice Score')
    plt.title(f'Boxplot of Dice Scores (Threshold {threshold})')
    plt.suptitle('')
    plt.xticks(rotation=45)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(f'plots/dice_scores_boxplot_threshold_{threshold}.png')
    plt.show()

# Combine all dice scores into a single DataFrame
all_dicescores_combined = pd.concat(all_dicescores, ignore_index=True)


threshold_values = [0.3]  #use '0.5' or '0.7' for the other threshold values
# Iterate over thresholds and generate plots
for threshold in threshold_values:
    plot_metrics_for_threshold(combined_means, threshold)

print("Plots generated and saved with identifiers for the threshold.")
