<a href="https://colab.research.google.com/github/pedrohtg/weasel/blob/main/weasel_ft_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

%cd /content/drive/MyDrive/Colab Notebooks/Bakalauras/SPSMM

Mounted at /content/drive/


In [None]:
!pip install SimpleITK nibabel pillow tqdm nilearn
!pip install torchmeta --no-deps
!pip install ordered_set --no-deps
!pip install torchvision --no-deps

In [10]:
# Basic imports
import sys
import os
import numpy as np
import logging
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim

# Custom implementations of the torchmeta required functions
def _get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value
    return None

def _save_response_content(response, destination):
    chunk_size = 32768
    with open(destination, 'wb') as f:
        for chunk in response.iter_content(chunk_size):
            if chunk:
                f.write(chunk)

!sed -i 's/from torchvision.datasets.utils import _get_confirm_token, _save_response_content//' /usr/local/lib/python3.11/dist-packages/torchmeta/datasets/utils.py

from PIL import Image
from sklearn import metrics
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torch import nn
from torchmeta import modules
from collections import OrderedDict


%matplotlib inline

dataset_root = 'dataset'

In [15]:
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ListDataset(Dataset):
    """
    A PyTorch Dataset class to load images and corresponding masks for training or testing.

    This class supports various annotation sparsity modes such as 'points' and 'grid', allowing for few-shot
    learning tasks. It pairs images with their respective masks and applies optional resizing and sparsity transformations.
    """

    def __init__(self, mode, dataset_root, task, fold, resize_to, num_shots, sparsity_mode, sparsity_param, imgtype, make=True):
        """
        Initialize the ListDataset.

        Args:
            mode (str): Dataset usage mode - 'train', 'test', 'tune_train', or 'tune_test'.
            dataset_root (str): Root directory of the dataset.
            task (str): Task name (e.g., 'brains').
            fold (int or str): Fold identifier (e.g., '0').
            resize_to (tuple): Target resize dimensions (height, width). Pass None to keep original size.
            num_shots (int): Number of shots for few-shot learning (-1 for dense mode).
            sparsity_mode (str): Sparsity mode ('points', 'grid', or 'dense').
            sparsity_param (float or int): Parameter controlling sparsity (e.g., number of points or grid spacing).
            imgtype (str): Image type (e.g., 'med').
            make (bool): Whether to initialize and load the dataset.
        """
        self.mode = mode
        self.dataset_root = dataset_root
        self.task = task
        self.fold = str(fold)  # Convert fold to string
        self.resize_to = resize_to
        self.num_shots = num_shots
        self.sparsity_mode = sparsity_mode
        self.sparsity_param = sparsity_param
        self.imgtype = imgtype
        self.make = make

        # Root directory where the dataset is stored
        self.root = os.path.join(self.dataset_root, self.task, self.fold)

        # Initialize the dataset by loading image-mask pairs
        if make:
            self.imgs = self.make_dataset()
        else:
            self.imgs = []

    def make_dataset(self):
        """
        Create the dataset by pairing images and masks.

        Returns:
            list: List of tuples containing paths to images and corresponding masks.
        """
        data_list = []
        mode_dir = os.path.join(self.root, "Training" if 'train' in self.mode.lower() else "Testing")

        images_dir = os.path.join(mode_dir, "images")
        masks_dir = os.path.join(mode_dir, "masks")

        # Verify that the directories exist
        if not os.path.exists(images_dir):
            logger.error(f"Images directory does not exist: {images_dir}")
            return data_list

        if not os.path.exists(masks_dir):
            logger.error(f"Masks directory does not exist: {masks_dir}")
            return data_list

        # Get and sort image and mask files
        img_files = sorted(os.listdir(images_dir))
        mask_files = sorted(os.listdir(masks_dir))

        # Warn if the number of images and masks does not match
        if len(img_files) != len(mask_files):
            logger.warning("Mismatch between the number of images and masks!")

        # Pair images and masks
        for img_file, mask_file in zip(img_files, mask_files):
            img_path = os.path.join(images_dir, img_file)
            mask_path = os.path.join(masks_dir, mask_file)

            if os.path.exists(img_path) and os.path.exists(mask_path):
                data_list.append((img_path, mask_path))
            else:
                logger.warning(f"Missing file pair: Image={img_file}, Mask={mask_file}")

        logger.info(f"Loaded {len(data_list)} samples for task '{self.task}' and mode '{self.mode}' from '{self.dataset_root}'")
        return data_list

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.imgs)

    def __getitem__(self, idx):
        """
        Get a single sample (image and mask) from the dataset.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple containing:
                - image (torch.Tensor): The image tensor of shape [C, H, W].
                - y_dense (torch.Tensor): The dense mask tensor of shape [H, W].
                - y_tr (torch.Tensor): The sparse mask tensor (after sparsity transformation).
                - img_name (str): Name of the image file.
        """
        img_path, mask_path = self.imgs[idx]

        # Load image and mask
        image = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L')

        # Resize if necessary
        if self.resize_to:
            image = image.resize(self.resize_to, Image.BILINEAR)
            mask = mask.resize(self.resize_to, Image.NEAREST)

        # Convert image and mask to numpy arrays
        image = np.array(image)
        mask = np.array(mask)

        # Apply sparsity transformations if specified
        if self.sparsity_mode == 'points':
            y_tr = self._apply_point_sparsity(mask)
        elif self.sparsity_mode == 'grid':
            y_tr = self._apply_grid_sparsity(mask)
        else:
            y_tr = mask  # Dense mask without sparsity

        # Convert to PyTorch tensors
        image = transforms.ToTensor()(image)  # Shape: [1, H, W]
        y_dense = torch.tensor(mask, dtype=torch.long)  # Shape: [H, W]
        y_tr = torch.tensor(y_tr, dtype=torch.long)  # Sparse mask

        img_name = os.path.basename(img_path)
        return image, y_dense, y_tr, img_name

    def _apply_point_sparsity(self, mask):
        """
        Apply sparsity by selecting a fixed number of points for each class.

        Args:
            mask (numpy.ndarray): Original dense mask.

        Returns:
            numpy.ndarray: Sparse mask with selected points.
        """
        logger.info(f"Applying point sparsity: {self.sparsity_param}")
        sparse_mask = np.zeros_like(mask)
        num_points = self.sparsity_param if self.sparsity_param else 10

        for cls in np.unique(mask):
            if cls == 0:  # Ignore background
                continue
            cls_points = np.argwhere(mask == cls)
            selected_points = cls_points[
                np.random.choice(len(cls_points), min(num_points, len(cls_points)), replace=False)
            ]
            sparse_mask[tuple(zip(*selected_points))] = cls
        return sparse_mask

    def _apply_grid_sparsity(self, mask):
        """
        Apply sparsity by selecting pixels at regular grid intervals.

        Args:
            mask (numpy.ndarray): Original dense mask.

        Returns:
            numpy.ndarray: Sparse mask with grid sampling.
        """
        logger.info(f"Applying grid sparsity: {self.sparsity_param}")
        sparse_mask = np.zeros_like(mask)
        grid_spacing = self.sparsity_param if self.sparsity_param else 5
        sparse_mask[::grid_spacing, ::grid_spacing] = mask[::grid_spacing, ::grid_spacing]
        return sparse_mask


In [16]:
# The experiments parameters
list_shots = [5]                                 # Number of shots in the task (i.e, total annotated sparse samples)
list_sparsity_points = [1, 5, 10, 20]                       # Number of labeled pixels in point annotation
list_sparsity_grid = [8, 12, 16, 20]                        # Spacing between selected pixels in grid annotation

In [18]:
def get_tune_loaders(shots, points, grid, fold_name, resize_to, args, imgtype='med'):
    dataset_root = 'dataset'
    task_name = 'brains'

    loaders = {'points': [], 'grid': [], 'dense': []}

    for sparsity_mode, sparsity_values in [('points', points), ('grid', grid)]:
        for n_shots in shots:
            for sparsity in sparsity_values:
                tune_train_set = ListDataset(
                    mode='tune_train',
                    dataset_root=dataset_root,
                    task=task_name,
                    fold=fold_name,
                    resize_to=resize_to,
                    num_shots=n_shots,
                    sparsity_mode=sparsity_mode,
                    sparsity_param=sparsity,
                    imgtype=imgtype
                )
                tune_train_loader = DataLoader(tune_train_set, batch_size=args['batch_size'], num_workers=args['num_workers'], shuffle=True)

                tune_test_set = ListDataset(
                    mode='tune_test',
                    dataset_root=dataset_root,
                    task=task_name,
                    fold=fold_name,
                    resize_to=resize_to,
                    num_shots=-1,
                    sparsity_mode='dense',
                    sparsity_param=None,
                    imgtype=imgtype
                )
                tune_test_loader = DataLoader(tune_test_set, batch_size=1, num_workers=args['num_workers'], shuffle=False)

                loaders[sparsity_mode].append({
                    'n_shots': n_shots,
                    'sparsity': sparsity,
                    'train': tune_train_loader,
                    'test': tune_test_loader
                })

    return loaders


In [19]:
def compute_metrics_per_class(y_true, y_pred, ignore_index=0):
    mask = y_true != ignore_index
    y_true_filtered = y_true[mask]
    y_pred_filtered = y_pred[mask]

    classes = [1, 2, 3]

    metrics_dict = {}
    for cls in classes:
        true_cls = y_true_filtered == cls
        pred_cls = y_pred_filtered == cls

        TP = np.sum(pred_cls & true_cls)
        FP = np.sum(pred_cls & ~true_cls)
        TN = np.sum(~pred_cls & ~true_cls)
        FN = np.sum(~pred_cls & true_cls)

        iou = TP / (TP + FP + FN) if (TP + FP + FN) != 0 else 0
        dice = 2 * TP / (2 * TP + FP + FN) if (2 * TP + FP + FN) != 0 else 0
        sensitivity = TP / (TP + FN) if (TP + FN) != 0 else 0
        specificity = TN / (TN + FP) if (TN + FP) != 0 else 0

        metrics_dict[cls] = {
            'IoU': iou,
            'Dice': dice,
            'Sensitivity': sensitivity,
            'Specificity': specificity
        }

    return metrics_dict


In [20]:
def tune_train_test(tune_train_loader, tune_test_loader, net, optimizer, args, sparsity_mode, best_weights_path):
    writer = SummaryWriter()

    net.train()

    tune_train_loss_list = []
    tune_test_loss_list = []

    color_map = {
        0: [0, 0, 0],    # Black (background)
        1: [0, 0, 255],  # Blue (CSF)
        2: [255, 0, 0],  # Red (GM)
        3: [0, 255, 0]   # Green (WM)
    }

    # Funkcija konvertuoti segmentacijos žemėlapį į spalvotą vaizdą
    def map_to_color(segmentation, color_map):
        colored = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
        for cls, color in color_map.items():
            colored[segmentation == cls] = color
        return colored

    # Vizualizacijos funkcija
    def visualize_segmentation(inputs, labels, preds, color_map, num_samples=3):
        fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
        for i in range(min(num_samples, len(inputs))):
            # Input image
            inp = inputs[i].numpy() if isinstance(inputs[i], torch.Tensor) else inputs[i]
            if inp.ndim == 3 and inp.shape[0] == 1:
                inp = inp.squeeze(0)
            axes[i, 0].imshow(inp, cmap='gray')
            axes[i, 0].set_title('Input Image')

            # Ground truth
            lab_colored = map_to_color(labels[i], color_map)
            axes[i, 1].imshow(lab_colored)
            axes[i, 1].set_title('Ground Truth')

            # Prediction
            pred_colored = map_to_color(preds[i], color_map)
            axes[i, 2].imshow(pred_colored)
            axes[i, 2].set_title('Prediction')

            for ax in axes[i]:
                ax.axis('off')
        plt.tight_layout()
        plt.show()

    if os.path.exists(best_weights_path):
        net.load_state_dict(torch.load(best_weights_path))
        print(f"Loaded best weights from {best_weights_path}")
    else:
        print("No pre-saved weights found. Starting training from scratch.")

    for epoch in range(1, args['tuning_epochs'] + 1):
        sys.stdout.flush()
        train_loss_list = []

        for i, data in enumerate(tune_train_loader):
            x_tr, y_dense, y_tr, img_name = data
            x_tr, y_tr = x_tr.cuda(), y_tr.cuda()
            optimizer.zero_grad()
            p_tr = net(x_tr)
            tune_train_loss = F.cross_entropy(p_tr, y_tr, ignore_index=-1)
            tune_train_loss.backward()
            optimizer.step()
            train_loss_list.append(tune_train_loss.detach().item())

        avg_train_loss = np.mean(train_loss_list)
        tune_train_loss_list.append(avg_train_loss)

        if epoch % args['val_freq'] == 0:
            test_loss_list = []
            with torch.no_grad():
                net.eval()
                for i, data in enumerate(tune_test_loader):
                    x_ts, y_ts, _, img_name = data
                    x_ts, y_ts = x_ts.cuda(), y_ts.cuda()
                    y_ts = y_ts.long()
                    p_ts = net(x_ts)
                    tune_test_loss = F.cross_entropy(p_ts, y_ts, ignore_index=-1)
                    test_loss_list.append(tune_test_loss.detach().item())
            avg_test_loss = np.mean(test_loss_list)
            tune_test_loss_list.append(avg_test_loss)
            writer.add_scalar(f'{sparsity_mode}/Validation Loss', avg_test_loss, epoch)
            net.train()

    test_loss_list = []
    inps_all, labs_all, prds_all = [], [], []

    with torch.no_grad():
        net.eval()
        for i, data in enumerate(tune_test_loader):
            x_ts, y_ts, _, img_name = data
            x_ts, y_ts = x_ts.cuda(), y_ts.cuda()
            y_ts = y_ts.long()
            p_ts = net(x_ts)
            tune_test_loss = F.cross_entropy(p_ts, y_ts, ignore_index=-1)
            test_loss_list.append(tune_test_loss.detach().item())
            prds = p_ts.detach().max(1)[1].squeeze(1).squeeze(0).cpu().numpy()
            inps_all.append(x_ts.detach().squeeze(1).squeeze(0).cpu())
            labs_all.append(y_ts.detach().cpu().numpy())
            prds_all.append(prds)

    avg_test_loss = np.mean(test_loss_list)
    writer.add_scalar(f'{sparsity_mode}/Final Validation Loss', avg_test_loss, args['tuning_epochs'])

    labs_np = np.asarray(labs_all).ravel()
    prds_np = np.asarray(prds_all).ravel()
    metrics_per_class = compute_metrics_per_class(labs_np, prds_np, ignore_index=0)

    print('--------------------------------------------------------------------')
    for cls, metrics in metrics_per_class.items():
        print(f'Class {cls}: IoU: {metrics["IoU"]*100:.2f}%, Dice: {metrics["Dice"]*100:.2f}%, Sensitivity: {metrics["Sensitivity"]*100:.2f}%, Specificity: {metrics["Specificity"]*100:.2f}%')
        writer.add_scalar(f'{sparsity_mode}/Final Class_{cls}_IoU', metrics["IoU"], args['tuning_epochs'])
        writer.add_scalar(f'{sparsity_mode}/Final Class_{cls}_Dice', metrics["Dice"], args['tuning_epochs'])
        writer.add_scalar(f'{sparsity_mode}/Final Class_{cls}_Sensitivity', metrics["Sensitivity"], args['tuning_epochs'])
        writer.add_scalar(f'{sparsity_mode}/Final Class_{cls}_Specificity', metrics["Specificity"], args['tuning_epochs'])
    print('--------------------------------------------------------------------')
    sys.stdout.flush()

    num_samples = min(3, len(inps_all))
    visualize_segmentation(inps_all[:num_samples], labs_all[:num_samples], prds_all[:num_samples], color_map)

    writer.close()

In [21]:
def tune_train_test(tune_train_loader, tune_test_loader, net, optimizer, args, sparsity_mode, best_weights_path):
    writer = SummaryWriter()

    net.train()

    tune_train_loss_list = []
    tune_test_loss_list = []

    if os.path.exists(best_weights_path):
        net.load_state_dict(torch.load(best_weights_path))
        print(f"Loaded best weights from {best_weights_path}")
    else:
        print("No pre-saved weights found. Starting training from scratch.")

    for epoch in range(1, args['tuning_epochs'] + 1):
        sys.stdout.flush()

        train_loss_list = []

        for i, data in enumerate(tune_train_loader):
            x_tr, _, y_tr, _ = data
            x_tr, y_tr = x_tr.cuda(), y_tr.cuda()
            optimizer.zero_grad()
            p_tr = net(x_tr)
            tune_train_loss = F.cross_entropy(p_tr, y_tr, ignore_index=-1)
            tune_train_loss.backward()
            optimizer.step()

            train_loss_list.append(tune_train_loss.detach().item())

        avg_train_loss = np.mean(train_loss_list)
        tune_train_loss_list.append(avg_train_loss)
    
        if epoch % args['val_freq'] == 0:
            test_loss_list = []
            inps_all, labs_all, prds_all = [], [], []

            with torch.no_grad():
                net.eval()

                for i, data in enumerate(tune_test_loader):
                    x_ts, y_ts, _, _ = data
                    x_ts, y_ts = x_ts.cuda(), y_ts.cuda()
                    y_ts = y_ts.long()
                    p_ts = net(x_ts)
                    tune_test_loss = F.cross_entropy(p_ts, y_ts, ignore_index=-1)
                    test_loss_list.append(tune_test_loss.detach().item())
                    prds = p_ts.detach().max(1)[1].squeeze(1).squeeze(0).cpu().numpy()
                    inps_all.append(x_ts.detach().squeeze(1).squeeze(0).cpu())
                    labs_all.append(y_ts.detach().cpu().numpy())
                    prds_all.append(prds)

            avg_test_loss = np.mean(test_loss_list)
            tune_test_loss_list.append(avg_test_loss)
            writer.add_scalar(f'{sparsity_mode}/Validation Loss', avg_test_loss, epoch)

            labs_np = np.asarray(labs_all).ravel()
            prds_np = np.asarray(prds_all).ravel()

            metrics_per_class = compute_metrics_per_class(labs_np, prds_np, ignore_index=0)

            print('--------------------------------------------------------------------')
            for cls, metrics in metrics_per_class.items():
                print(f'Class {cls}: IoU: {metrics["IoU"]*100:.2f}%, Dice: {metrics["Dice"]*100:.2f}%, Sensitivity: {metrics["Sensitivity"]*100:.2f}%, Specificity: {metrics["Specificity"]*100:.2f}%')
                # Įrašome metrikas į TensorBoard
                writer.add_scalar(f'{sparsity_mode}/Class_{cls}_IoU', metrics["IoU"], epoch)
                writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Dice', metrics["Dice"], epoch)
                writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Sensitivity', metrics["Sensitivity"], epoch)
                writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Specificity', metrics["Specificity"], epoch)
            print('--------------------------------------------------------------------')
            sys.stdout.flush()

            avg_iou = np.mean([m['IoU'] for m in metrics_per_class.values()])
            writer.add_scalar(f'{sparsity_mode}/Average IoU', avg_iou, epoch)

            avg_iou = np.mean([m['IoU'] for m in metrics_per_class.values()])

            net.train()

    if args['tuning_epochs'] % args['val_freq'] != 0:
        print('------------------------TESTING------------------------')
        test_loss_list = []
        inps_all, labs_all, prds_all = [], [], []

        with torch.no_grad():
            net.eval()

            for i, data in enumerate(tune_test_loader):
                x_ts, y_ts, _, img_name = data
                x_ts, y_ts = x_ts.cuda(), y_ts.cuda()

                y_ts = y_ts.long()

                p_ts = net(x_ts)

                tune_test_loss = F.cross_entropy(p_ts, y_ts, ignore_index=-1)
                test_loss_list.append(tune_test_loss.detach().item())

                prds = p_ts.detach().max(1)[1].squeeze(1).squeeze(0).cpu().numpy()

                inps_all.append(x_ts.detach().squeeze(1).squeeze(0).cpu())
                labs_all.append(y_ts.detach().cpu().numpy())
                prds_all.append(prds)

        avg_test_loss = np.mean(test_loss_list)
        tune_test_loss_list.append(avg_test_loss)
        writer.add_scalar(f'{sparsity_mode}/Validation Loss', avg_test_loss, args['tuning_epochs'])

        labs_np = np.asarray(labs_all).ravel()
        prds_np = np.asarray(prds_all).ravel()

        metrics_per_class = compute_metrics_per_class(labs_np, prds_np, ignore_index=0)

        print('--------------------------------------------------------------------')
        for cls, metrics in metrics_per_class.items():
            print(f'Class {cls}: IoU: {metrics["IoU"]*100:.2f}%, Dice: {metrics["Dice"]*100:.2f}%, Sensitivity: {metrics["Sensitivity"]*100:.2f}%, Specificity: {metrics["Specificity"]*100:.2f}%')
            writer.add_scalar(f'{sparsity_mode}/Class_{cls}_IoU', metrics["IoU"], args['tuning_epochs'])
            writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Dice', metrics["Dice"], args['tuning_epochs'])
            writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Sensitivity', metrics["Sensitivity"], args['tuning_epochs'])
            writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Specificity', metrics["Specificity"], args['tuning_epochs'])
        print('--------------------------------------------------------------------')
        sys.stdout.flush()

    writer.close()

In [33]:
def run_sparse_tuning(loader_dict, net, optimizer, args, model_weights_dir='best_models'):
    os.makedirs(model_weights_dir, exist_ok=True)

    for dict_points in loader_dict['points']:

        n_shots = dict_points['n_shots']
        sparsity = dict_points['sparsity']

        mode_identifier = f'points_{n_shots}_shots_{sparsity}_points'
        best_weights_path = os.path.join(model_weights_dir, f'best_model_{mode_identifier}.pth')

        print(f"Evaluating 'points' ({n_shots}-shot, {sparsity}-points) with identifier '{mode_identifier}'")
        sys.stdout.flush()

        tune_train_test(dict_points['train'], dict_points['test'], net, optimizer, args, mode_identifier, best_weights_path)

    for dict_grid in loader_dict['grid']:

        n_shots = dict_grid['n_shots']
        sparsity = dict_grid['sparsity']

        mode_identifier = f'grid_{n_shots}_shots_{sparsity}_spacing'
        best_weights_path = os.path.join(model_weights_dir, f'best_model_{mode_identifier}.pth')

        print(f"Evaluating 'grid' ({n_shots}-shot, {sparsity}-spacing) with identifier '{mode_identifier}'")
        sys.stdout.flush()

        tune_train_test(dict_grid['train'], dict_grid['test'], net, optimizer, args, mode_identifier, best_weights_path)

    for dict_dense in loader_dict['dense']:

        n_shots = dict_dense['n_shots']

        mode_identifier = f'dense_{n_shots}_shots'
        best_weights_path = os.path.join(model_weights_dir, f'best_model_{mode_identifier}.pth')

        print(f"Evaluating 'dense' ({n_shots}-shot) with identifier '{mode_identifier}'")
        sys.stdout.flush()

        tune_train_test(dict_dense['train'], dict_dense['test'], net, optimizer, args, mode_identifier, best_weights_path)


In [23]:
# General arguments for training
args = {
    'tuning_epochs': 500,   # Number of epochs on the tuning phase.
    'val_freq': 5,         # Test each val_freq epochs on the tuning phase.
    'vis_freq': 25,         # Visualize predictions samples each vis_freq epochs on the tuning phase.
    'lr': 1e-6,            # Learning rate.
    'weight_decay': 5e-5,  # L2 penalty.
    'momentum': 0.9,       # Momentum.
    'num_workers': 0,      # Number of workers on data loader.
    'batch_size': 5,       # Mini-batch size.
    'w_size': 128,         # Width size for image resizing.
    'h_size': 128,         # Height size for image resizing.
    'num_channels': 1,     # Number of channels in the input
    'num_class': 4,        # Number of classes
}

fold = 0 

resize_to = (args['h_size'], args['w_size'])

In [29]:
def check_mkdir(dir_name):
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)

def prepare_meta_batch(meta_train_set, meta_test_set, index, batch_size=5):

    # Acquiring training and test data.
    x_train = []
    y_train = []

    x_test = []
    y_test = []

    perm_train = torch.randperm(len(meta_train_set[index])).tolist()
    perm_test = torch.randperm(len(meta_test_set[index])).tolist()

    for b in range(batch_size):

        d_tr = meta_train_set[index][perm_train[b]]
        d_ts = meta_test_set[index][perm_test[b]]

        x_tr = d_tr[0].cuda()
        y_tr = d_tr[2].cuda()

        x_ts = d_ts[0].cuda()
        y_ts = d_ts[1].cuda()

        x_train.append(x_tr)
        y_train.append(y_tr)

        x_test.append(x_ts)
        y_test.append(y_ts)

    x_train = torch.stack(x_train, dim=0)
    y_train = torch.stack(y_train, dim=0)

    x_test = torch.stack(x_test, dim=0)
    y_test = torch.stack(y_test, dim=0)

    return x_train, y_train, x_test, y_test

def plot_kernels(kernel, idx, epoch, norm='mean0'):
    if norm == 'mean0':
        tensor = (1/(abs(kernel.min())*2))*kernel + 0.5
    elif norm == '01':
        tensor = (kernel - kernel.min()) / (kernel.max() - kernel.min())

    num_kernels = tensor.shape[0]
    num_rows = num_kernels
    num_cols = tensor.shape[1]
    fig = plt.figure(figsize=(16,16))
    fig.tight_layout()

    tot = num_rows * num_cols
    pos = range(1, tot+1)

    k = 0
    for i in range(num_rows):
        for j in range(num_cols):
            ax1 = fig.add_subplot(num_rows,num_cols,pos[k])
            ax1.imshow(tensor[i][j], cmap='gray')
            ax1.axis('off')
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
            k+=1

    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.savefig('kernels/kernel' + str(idx) + '_ep' + str(epoch) + '.png', format='png')
    # plt.show()

def accuracy(lab, prd):
    # Obtaining class from prediction.
    prd = prd.argmax(1)

    # Tensor to ndarray.
    lab_np = lab.view(-1).detach().cpu().numpy()
    prd_np = prd.view(-1).detach().cpu().numpy()

    # Computing metric and returning.
    metric_val = metrics.jaccard_score(lab_np, prd_np)

    return metric_val

In [30]:
def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, modules.MetaConv2d) or isinstance(module, modules.MetaLinear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d) or isinstance(module, modules.MetaBatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

class MetaConvTranspose2d(nn.ConvTranspose2d, modules.MetaModule):
    __doc__ = nn.ConvTranspose2d.__doc__

    def forward(self, input, output_size=None, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        weights = params.get('weight', None)
        bias = params.get('bias', None)

        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')

        # Compute output padding manually
        if output_size is not None:
            input_size = input.size()[2:]  # Spatial dimensions
            stride = self.stride
            padding = self.padding
            kernel_size = self.kernel_size
            dilation = self.dilation

            # Compute expected output size
            expected_output_size = [
                (input_size[i] - 1) * stride[i] - 2 * padding[i] + dilation[i] * (kernel_size[i] - 1) + 1
                for i in range(len(input_size))
            ]

            # Compute the required output padding
            output_padding = [
                output_size[i] - expected_output_size[i]
                for i in range(len(input_size))
            ]
        else:
            # Use predefined output padding
            output_padding = self.output_padding

        # Perform convolution transpose
        return F.conv_transpose2d(
            input, weights, bias, self.stride, self.padding,
            tuple(output_padding), self.groups, self.dilation
        )

class _MetaEncoderBlock(modules.MetaModule):

    def __init__(self, in_channels, out_channels, dropout=False):

        super(_MetaEncoderBlock, self).__init__()

        layers = [
            modules.MetaConv2d(in_channels, out_channels, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            modules.MetaConv2d(out_channels, out_channels, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]

        if dropout:

            layers.append(nn.Dropout())

        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

        self.encode = modules.MetaSequential(*layers)

    def forward(self, x, params=None):

        return self.encode(x, self.get_subdict(params, 'encode'))

class _MetaDecoderBlock(modules.MetaModule):

    def __init__(self, in_channels, middle_channels, out_channels):

        super(_MetaDecoderBlock, self).__init__()

        self.decode = modules.MetaSequential(
            nn.Dropout2d(),
            modules.MetaConv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            modules.MetaConv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            MetaConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
        )

    def forward(self, x, params=None):

        return self.decode(x, self.get_subdict(params, 'decode'))


class UNet(modules.MetaModule):

    def __init__(self, input_channels, num_classes, prototype=False):

        super(UNet, self).__init__()

        self.prototype = prototype

        self.enc1 = _MetaEncoderBlock(input_channels, 32)
        self.enc2 = _MetaEncoderBlock(32, 64)
        self.enc3 = _MetaEncoderBlock(64, 128, dropout=True)

        self.center = _MetaDecoderBlock(128, 256, 128)

        self.dec3 = _MetaDecoderBlock(256, 128, 64)
        self.dec2 = _MetaDecoderBlock(128, 64, 32)

        self.dec1 = modules.MetaSequential(
            nn.Dropout2d(),
            modules.MetaConv2d(64, 32, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(32),
            nn.ReLU(inplace=True),
            modules.MetaConv2d(32, 32, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(32),
            nn.ReLU(inplace=True),
        )

        if not self.prototype:
            self.final = modules.MetaConv2d(32, num_classes, kernel_size=1)

        initialize_weights(self)

    def forward(self, x, feat=False, params=None):

        enc1 = self.enc1(x, self.get_subdict(params, 'enc1'))
        enc2 = self.enc2(enc1, self.get_subdict(params, 'enc2'))
        enc3 = self.enc3(enc2, self.get_subdict(params, 'enc3'))

        center = self.center(enc3, self.get_subdict(params, 'center'))

        dec3 = self.dec3(torch.cat([center, F.interpolate(enc3, center.size()[2:], mode='bilinear')], 1), self.get_subdict(params, 'dec3'))
        dec2 = self.dec2(torch.cat([dec3, F.interpolate(enc2, dec3.size()[2:], mode='bilinear')], 1), self.get_subdict(params, 'dec2'))
        dec1 = self.dec1(torch.cat([dec2, F.interpolate(enc1, dec2.size()[2:], mode='bilinear')], 1), self.get_subdict(params, 'dec1'))

        if self.prototype:
            return F.interpolate(dec1, x.size()[2:], mode='bilinear')

        else:
            final = self.final(dec1, self.get_subdict(params, 'final'))

            if feat:
                return (F.interpolate(final, x.size()[2:], mode='bilinear'),
                        dec1,
                        F.interpolate(dec2, x.size()[2:], mode='bilinear'),
                        F.interpolate(dec3, x.size()[2:], mode='bilinear'),
                       )
            else:
                return F.interpolate(final, x.size()[2:], mode='bilinear')


In [31]:
# Network and optimizer
# from weasel.utils import *
# from weasel.models.u_net import *
net = UNet(args['num_channels'], num_classes=4).cuda()

optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], betas=(args['momentum'], 0.99))

In [None]:
loaders_dict = get_tune_loaders(
    shots=list_shots,
    points=list_sparsity_points,
    grid=list_sparsity_grid,
    fold_name=fold,
    resize_to=resize_to,
    args=args,
    imgtype='med'
)

run_sparse_tuning(loaders_dict, net, optimizer, args)

In [None]:
# Interfare model and save images

# 1. Specify the model file path
model_path = 'SPSMM/best_model/best_model_points_5_shots_1_points_epoch_1000.pth'

# 2. Initialize the model
net = UNet(input_channels=1, num_classes=4).cuda()

# 3. Load the model weights
net.load_state_dict(torch.load(model_path, map_location='cuda'))
net.eval()  # Set the model to evaluation mode

def test_model(tune_test_loader, net):
    # Color map for visualization
    color_map = {
        0: [0, 0, 0],    # Black (background)
        1: [0, 0, 255],  # Blue (CSF)
        2: [255, 0, 0],  # Red (GM)
        3: [0, 255, 0]   # Green (WM)
    }

    # Helper function to convert segmentation to a colored image
    def map_to_color(segmentation, color_map):
        colored = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8)
        for cls, color in color_map.items():
            colored[segmentation == cls] = color
        return colored

    # Visualization function (updated to save all images with unique filenames)
    def visualize_segmentation(inputs, labels, preds, color_map, img_names, dataset_type='test'):
        # Ensure the output directory exists
        output_dir = f'/content/drive/MyDrive/Colab Notebooks/Bakalauras/Paveiksleliai/SPSMM/'
        os.makedirs(output_dir, exist_ok=True)

        # Loop over all samples (no cap)
        for i in range(len(inputs)):
            # Get image name (remove extension and path for cleaner filename)
            img_name = os.path.splitext(os.path.basename(img_names[i]))[0] if img_names[i] else f"sample_{i+1}"

            # Save input image
            plt.figure()
            inp = inputs[i].numpy()
            if inp.ndim == 3 and inp.shape[0] == 1:
                inp = inp.squeeze(0)
            plt.imshow(inp, cmap='gray')
            plt.axis('off')
            save_path_input = os.path.join(output_dir, f"test_{img_name}_input.png")
            plt.savefig(save_path_input, bbox_inches='tight', pad_inches=0)
            plt.close()
            print(f"Saved input: {save_path_input}")

            # Save real mask
            plt.figure()
            lab_colored = map_to_color(labels[i], color_map)
            plt.imshow(lab_colored)
            plt.axis('off')
            save_path_real_mask = os.path.join(output_dir, f"test_{img_name}_real_mask.png")
            plt.savefig(save_path_real_mask, bbox_inches='tight', pad_inches=0)
            plt.close()
            print(f"Saved real mask: {save_path_real_mask}")

            # Save predicted segmentation
            plt.figure()
            pred_colored = map_to_color(preds[i], color_map)
            plt.imshow(pred_colored)
            plt.axis('off')
            save_path_segmented = os.path.join(output_dir, f"test_{img_name}_segmented.png")
            plt.savefig(save_path_segmented, bbox_inches='tight', pad_inches=0)
            plt.close()
            print(f"Saved segmented: {save_path_segmented}")

    # Testing loop
    test_loss_list = []
    inps_all, labs_all, prds_all, img_names_all = [], [], [], []

    with torch.no_grad():
        for i, data in enumerate(tune_test_loader):
            x_ts, y_ts, _, img_name = data
            x_ts, y_ts = x_ts.cuda(), y_ts.cuda()
            y_ts = y_ts.long()
            p_ts = net(x_ts)
            tune_test_loss = F.cross_entropy(p_ts, y_ts, ignore_index=-1)
            test_loss_list.append(tune_test_loss.item())

            # Process each sample in the batch
            prds_batch = p_ts.max(1)[1].cpu().numpy()  # Shape: [batch_size, height, width]
            print(f"Batch {i+1} image names: {img_name}")  # Debug: Print image names
            for j in range(x_ts.size(0)):  # Iterate over batch_size
                prds = prds_batch[j]  # Shape: [height, width]
                inp = x_ts[j].squeeze(0).cpu()  # Shape: [height, width]
                lab = y_ts[j].cpu().numpy()  # Shape: [height, width]
                inps_all.append(inp)
                labs_all.append(lab)
                prds_all.append(prds)
                img_names_all.append(img_name[j])  # Store image name

    # Calculate the average test loss
    avg_test_loss = np.mean(test_loss_list)
    print(f'Average test loss: {avg_test_loss:.4f}')
    print(f"Total test images processed: {len(inps_all)}")  # Debug: Confirm number of images

    # Visualize all results (no cap)
    visualize_segmentation(
        inps_all,
        labs_all,
        prds_all,
        color_map,
        img_names_all,
        dataset_type='test'
    )

# 6. Clear output directory to avoid confusion
import shutil
output_dir = '/content/drive/MyDrive/Colab Notebooks/Bakalauras/Paveiksleliai/SPSMM'
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)

# 7. Run the testing
test_loader = loaders_dict['points'][0]['test']
test_model(test_loader, net)