<a href="https://colab.research.google.com/github/pedrohtg/weasel/blob/main/weasel_ft_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo WeaSeL Fine-tuning
In this notebook we will show how to use a pretrained network with the WeaSeL method and fine tune to a new dataset in a few-shot regime.

We will use the OpenIST dataset in the task of lung segmentation.


### Obs: To run this notebook, use a session with GPU

---

Links:

- **Github** : https://github.com/pedrohtg/weasel.git
- **OpenIST** Dataset : https://github.com/pi-null-mezon/OpenIST

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
!pip install SimpleITK nibabel pillow tqdm


Collecting SimpleITK
  Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.4/52.4 MB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.4.0


In [None]:
%%writefile nii_to_png_converter.py



import os
import nibabel as nib
import numpy as np
from PIL import Image
import argparse
from tqdm import tqdm  # Progress bars
import re
from nilearn.image import resample_img

def parse_arguments():
    parser = argparse.ArgumentParser(description='Convert 3D NIfTI (.nii) files to 2D PNG slices.')
    parser.add_argument('--input_dir', type=str, required=True, help='Path to the source data directory.')
    parser.add_argument('--output_dir', type=str, required=True, help='Path to the output data directory.')
    parser.add_argument('--resize', type=int, nargs=2, default=None, help='Resize image to (width height).')
    parser.add_argument('--save_all_slices', action='store_true', help='Save all slices. If not set, save only slices with at least one class in {1,2,3}.')
    return parser.parse_args()

def normalize_image(image):
    """Normalize image to 0-1 scale."""
    img_min = np.min(image)
    img_max = np.max(image)
    if img_max - img_min == 0:
        return np.zeros(image.shape, dtype=np.float32)
    normalized = (image - img_min) / (img_max - img_min)
    return normalized

def process_and_save_slice(image_slice, mask_slice, image_save_path, mask_save_path, resize=None):
    """Process and save 2D slices as PNG files."""
    # Normalize image
    image_norm = normalize_image(image_slice)
    image_uint8 = (image_norm * 255).astype(np.uint8)
    image_pil = Image.fromarray(image_uint8)

    # Resize image if specified
    if resize:
        image_pil = image_pil.resize(resize, Image.BILINEAR)

    # Convert mask to uint8
    mask_uint8 = mask_slice.astype(np.uint8)
    mask_pil = Image.fromarray(mask_uint8)

    # Resize mask if specified
    if resize:
        mask_pil = mask_pil.resize(resize, Image.NEAREST)

    # Save image and mask
    image_pil.save(image_save_path)
    mask_pil.save(mask_save_path)

def resample_mask_nilearn(mask_path, reference_image_path):
    """
    Resample the mask to match the reference image's dimensions and affine using nilearn.

    Args:
        mask_path (str): Path to the input mask file (.nii).
        reference_image_path (str): Path to the reference image file (.nii).

    Returns:
        numpy.ndarray: Resampled mask data with shape matching the reference image.
    """
    # Load mask and reference image
    mask_img = nib.load(mask_path)
    reference_img = nib.load(reference_image_path)

    # Resample mask to match reference image
    resampled_mask_img = resample_img(
        mask_img,
        target_affine=reference_img.affine,
        target_shape=reference_img.shape,
        interpolation='nearest',
        force_resample=True,    # Pridėta
        copy_header=True        # Pridėta
    )
    resampled_mask_data = resampled_mask_img.get_fdata().astype(np.uint8)

    return resampled_mask_data

def convert_dataset(input_dir, output_dir, resize=None, save_all_slices=False):
    """
    Convert all .nii files to 2D PNG slices and save them into separate directories.

    Args:
        input_dir (str): Path to the source data directory.
        output_dir (str): Path to the output data directory.
        resize (tuple, optional): Resize image to (width, height). Defaults to None.
        save_all_slices (bool, optional): Save all slices. If False, save only slices with at least one class in {1,2,3}. Defaults to False.
    """
    splits = ['Training', 'Testing']
    modality_pattern = re.compile(r'^T\d+$')  # Match directories starting with 'T' followed by digits

    for split in splits:
        split_input_path = os.path.join(input_dir, split)
        split_output_path = os.path.join(output_dir, split)

        print(f"\nProcessing split: {split}")

        # Detect modality directories matching T followed by digits
        if not os.path.exists(split_input_path):
            print(f"Warning: Split directory '{split_input_path}' does not exist. Skipping.")
            continue

        images_modalities = [d for d in os.listdir(split_input_path)
                             if os.path.isdir(os.path.join(split_input_path, d)) and modality_pattern.match(d)]
        labels_dir = os.path.join(split_input_path, 'label')

        # Check if 'label' directory exists
        if not os.path.exists(labels_dir):
            print(f"Warning: 'label' directory not found in '{split_input_path}'. Skipping.")
            continue

        # Create output directories for images and masks
        images_output_dir = os.path.join(split_output_path, 'images')
        masks_output_dir = os.path.join(split_output_path, 'masks')
        os.makedirs(images_output_dir, exist_ok=True)
        os.makedirs(masks_output_dir, exist_ok=True)

        for modality in images_modalities:
            modality_input_dir = os.path.join(split_input_path, modality)
            modality_output_dir = os.path.join(images_output_dir, modality)
            os.makedirs(modality_output_dir, exist_ok=True)

            # List all image files in the modality directory
            image_files = [f for f in os.listdir(modality_input_dir) if re.match(r'^(train|test)_T\d+_subject\d{2}\.nii$', f, re.IGNORECASE)]

            print(f"\nProcessing modality: {modality} with {len(image_files)} image files")

            for image_filename in tqdm(image_files, desc=f'Processing {split}/{modality}'):
                # Extract prefix, modality number, and subject number from image_filename
                match = re.match(r'^(train|test)_T(\d+)_subject(\d{2})\.nii$', image_filename, re.IGNORECASE)
                if not match:
                    print(f"Filename {image_filename} does not match expected pattern. Skipping.")
                    continue

                prefix = match.group(1).lower()        # 'train' or 'test'
                modality_num = match.group(2)         # e.g., '1', '2', ..., '15'
                subject_num = match.group(3)          # e.g., '01', '02', etc.

                # Construct corresponding label filename
                label_filename = f"{prefix}_label_subject{subject_num}.nii"
                label_path = os.path.join(labels_dir, label_filename)

                if not os.path.exists(label_path):
                    print(f"Warning: Label file '{label_path}' does not exist for image '{image_filename}'. Skipping.")
                    continue

                # Load image data
                try:
                    image_nii = nib.load(os.path.join(modality_input_dir, image_filename))
                    image_data = image_nii.get_fdata()
                except Exception as e:
                    print(f"Error loading image file '{image_filename}': {e}")
                    continue

                # Load label data
                try:
                    label_nii = nib.load(label_path)
                    label_data = label_nii.get_fdata().astype(np.uint8)
                except Exception as e:
                    print(f"Error loading label file '{label_filename}': {e}")
                    continue

                # Check if dimensions match
                if image_data.shape != label_data.shape:
                    print(f"Warning: Dimensions mismatch between '{image_filename}' and '{label_filename}'. Resampling mask.")
                    # Resample mask using nilearn
                    try:
                        resampled_mask_array = resample_mask_nilearn(label_path, os.path.join(modality_input_dir, image_filename))
                        mask_data = resampled_mask_array
                        print(f"Resampled mask shape: {mask_data.shape}")
                        print(f"Image shape: {image_data.shape}")
                    except Exception as e:
                        print(f"Error resampling mask '{label_filename}': {e}")
                        continue

                    # Verify dimensions after resampling
                    if mask_data.shape != image_data.shape:
                        # Attempt to transpose mask_data if shapes are swapped
                        if mask_data.shape[::-1] == image_data.shape:
                            mask_data = mask_data.transpose(2, 1, 0)
                            print(f"Transposed mask shape: {mask_data.shape}")

                    # Final check after potential transposition
                    if mask_data.shape != image_data.shape:
                        print(f"Error: Resampled mask shape {mask_data.shape} does not match image shape {image_data.shape}. Skipping this subject.")
                        continue
                else:
                    mask_data = label_data

                # Iterate through slices (assuming third axis is slice)
                num_slices = image_data.shape[2]

                for slice_idx in range(num_slices):
                    image_slice = image_data[:, :, slice_idx]
                    mask_slice = mask_data[:, :, slice_idx]

                    # If not saving all slices, filter out slices with only background
                    if not save_all_slices:
                        if not np.any(np.isin(mask_slice, [1, 2, 3])):
                            continue  # Skip slices with only background

                    # Define slice name
                    slice_name = f"{prefix}_subject{subject_num}_slice_{slice_idx:03d}.png"

                    # Define save paths
                    image_save_path = os.path.join(images_output_dir, modality, slice_name)
                    mask_save_path = os.path.join(masks_output_dir, slice_name)

                    # Create modality directory if it doesn't exist
                    os.makedirs(os.path.dirname(image_save_path), exist_ok=True)

                    # Process and save slices
                    try:
                        process_and_save_slice(image_slice, mask_slice, image_save_path, mask_save_path, resize=resize)
                    except Exception as e:
                        print(f"Error saving slice '{slice_name}': {e}")
                        continue

def main():
    args = parse_arguments()
    convert_dataset(
        input_dir=args.input_dir,
        output_dir=args.output_dir,
        resize=tuple(args.resize) if args.resize else None,
        save_all_slices=args.save_all_slices
    )
    print("Data conversion completed!")

if __name__ == "__main__":
    main()

In [None]:
# !python nii_to_png_converter.py \
#     --input_dir "/content/drive/MyDrive/Colab Notebooks/Kursinis/dataverse_files" \
#     --output_dir "/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset" \
#     --resize 256 256

In [None]:
# import os
# import shutil
# from random import sample

# def move_images_with_masks(source_dir, target_dir, t_folders, num_images):
#     """
#     Move a specified number of images and their masks from source folders to target folders.

#     Args:
#         source_dir (str): Path to the source directory (e.g., 'Training').
#         target_dir (str): Path to the target directory where images and masks will be moved.
#         t_folders (list): List of subfolder names in 'images' (e.g., ['T1', 'T2']).
#         num_images (int): Number of images to move per folder.
#     """
#     # Ensure target directory exists
#     os.makedirs(target_dir, exist_ok=True)

#     for t_folder in t_folders:
#         # Define paths for the current T folder
#         images_folder = os.path.join(source_dir, "images", t_folder)
#         masks_folder = os.path.join(source_dir, "masks")
#         target_images_folder = os.path.join(target_dir, "images", t_folder)
#         target_masks_folder = os.path.join(target_dir, "masks")

#         # Create target subdirectories
#         os.makedirs(target_images_folder, exist_ok=True)
#         os.makedirs(target_masks_folder, exist_ok=True)

#         # List all image files in the current T folder
#         image_files = [f for f in os.listdir(images_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]

#         # Ensure there are enough images to sample
#         if len(image_files) < num_images:
#             print(f"Not enough images in {images_folder}. Skipping.")
#             continue

#         # Randomly select the specified number of images
#         selected_files = sample(image_files, num_images)

#         for file in selected_files:
#             image_source_path = os.path.join(images_folder, file)
#             mask_source_path = os.path.join(masks_folder, file)  # Masks are assumed to have the same filename
#             image_target_path = os.path.join(target_images_folder, file)
#             mask_target_path = os.path.join(target_masks_folder, file)

#             try:
#                 # Move image and its mask
#                 shutil.move(image_source_path, image_target_path)
#                 shutil.move(mask_source_path, mask_target_path)
#                 print(f"Moved: {file} from {t_folder} with its mask to {target_dir}")
#             except FileNotFoundError as e:
#                 print(f"Error: Mask for {file} not found. Skipping this file.")
#             except Exception as e:
#                 print(f"Unexpected error: {e}")

# # Define source and target paths
# source_dir = "/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset/Testing"  # Source directory
# target_dir = "/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset/SelectedDataTesting"  # Target directory

# # List of T folders to process
# folders_to_process = ["T1", "T2", "T3", "T4", "T6", "T7", "T8", "T9", "T10", "T11", "T12", "T13", "T14", "T15"]

# # Number of images to move per folder
# num_images_per_folder = 2

# # Execute the function
# move_images_with_masks(source_dir, target_dir, folders_to_process, num_images_per_folder)


In [None]:
# import os

# def rename_images_and_masks_in_place(images_folder, masks_folder):
#     """
#     Rename all images and their corresponding masks to image1, image2, ..., and mask1, mask2, ... respectively.

#     Args:
#         images_folder (str): Path to the folder containing the images.
#         masks_folder (str): Path to the folder containing the masks.
#     """
#     image_counter = 1

#     # List all image files in the images folder
#     image_files = [f for f in os.listdir(images_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]

#     for file in image_files:
#         # Define current file paths
#         image_source_path = os.path.join(images_folder, file)
#         mask_source_path = os.path.join(masks_folder, file)  # Masks are assumed to have the same filename

#         # Define new names for the image and mask
#         new_image_name = f"image{image_counter}.png"
#         new_mask_name = f"mask{image_counter}.png"

#         # Define new file paths
#         image_target_path = os.path.join(images_folder, new_image_name)
#         mask_target_path = os.path.join(masks_folder, new_mask_name)

#         try:
#             # Rename image and its mask
#             os.rename(image_source_path, image_target_path)
#             os.rename(mask_source_path, mask_target_path)
#             print(f"Renamed: {file} to {new_image_name} and corresponding mask to {new_mask_name}")
#             image_counter += 1
#         except FileNotFoundError:
#             print(f"Error: Mask for {file} not found. Skipping this file.")
#         except Exception as e:
#             print(f"Unexpected error: {e}")

# # Define paths to images and masks
# images_folder = "/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset/Testing/images"  # Images folder
# masks_folder = "/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset/Testing/masks"  # Masks folder

# # Execute the function
# rename_images_and_masks_in_place(images_folder, masks_folder)

In [None]:
# import os
# import numpy as np
# from PIL import Image

# def remap_labels_in_masks(mask_folder, output_folder, class_mapping):
#     """
#     Remap labels in mask files based on the provided class mapping.

#     Args:
#         mask_folder (str): Path to the folder containing the mask files.
#         output_folder (str): Path to save the remapped masks.
#         class_mapping (dict): Dictionary mapping original labels to new labels.
#     """
#     os.makedirs(output_folder, exist_ok=True)

#     # Process each mask file
#     for mask_file in os.listdir(mask_folder):
#         mask_path = os.path.join(mask_folder, mask_file)

#         if mask_file.endswith('.png') or mask_file.endswith('.jpg') or mask_file.endswith('.jpeg'):
#             try:
#                 # Load the mask
#                 mask = np.array(Image.open(mask_path))

#                 # Remap labels
#                 remapped_mask = np.copy(mask)
#                 for original_label, new_label in class_mapping.items():
#                     remapped_mask[mask == original_label] = new_label

#                 # Save the remapped mask
#                 remapped_mask_path = os.path.join(output_folder, mask_file)
#                 Image.fromarray(remapped_mask.astype(np.uint8)).save(remapped_mask_path)
#                 print(f"Remapped and saved: {mask_file}")
#             except Exception as e:
#                 print(f"Error processing {mask_file}: {e}")

# # Define paths
# training_mask_folder = "/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset/Training/masks"  # Path to Training masks
# output_mask_folder = "/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset/Training/masks_remapped"  # Path to save remapped masks

# # Define class mapping
# class_mapping = {
#     0: 0,  # Background
#     1: 1,  # White matter
#     2: 2,  # Gray matter
#     3: 3,  # CFE
#     4: 0,  # Extra classes → Background
#     5: 0,
#     6: 0,
#     7: 0,
#     8: 0
# }

# # Execute remapping
# remap_labels_in_masks(training_mask_folder, output_mask_folder, class_mapping)


In [None]:
# import numpy as np
# import os
# from PIL import Image

# mask_dir = "/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset/Training/masks/"  # Nurodykite kelią į kaukių katalogą
# unique_values = set()

# for mask_file in os.listdir(mask_dir):
#     mask_path = os.path.join(mask_dir, mask_file)
#     mask = np.array(Image.open(mask_path))
#     unique_values.update(np.unique(mask))

# print(f"Unikalios reikšmės kaukėse: {unique_values}")


In [None]:
# Basic imports
import os
import sys
import torch
from sklearn import metrics
import numpy as np
import random

from torch import optim
from torch.autograd import Variable
from torch.backends import cudnn
from torch.utils.data import DataLoader
import torch.nn.functional as F

from matplotlib import pyplot as plt

%matplotlib inline

## Step 0:
Install the torchmeta module. Its not necessary for training, but the
network model was implemented using this module.

In [None]:
!pip install torchmeta --no-deps

Collecting torchmeta
  Downloading torchmeta-1.8.0-py3-none-any.whl.metadata (8.2 kB)
Downloading torchmeta-1.8.0-py3-none-any.whl (210 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/210.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━[0m [32m184.3/210.4 kB[0m [31m5.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.4/210.4 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchmeta
Successfully installed torchmeta-1.8.0


## Step 1:
Clone the source code repo from the link https://github.com/pedrohtg/weasel.git

In [None]:
# !rm -r weasel
# !git clone https://github.com/pedrohtg/weasel.git

## Step 2:
Download and unzip the dataset

In [None]:
# Auxiliar module to download files from Google Drive
!pip install googledrivedownloader
from google_drive_downloader import GoogleDriveDownloader as gdd




In [None]:

# gdd.download_file_from_google_drive(file_id='1Z3VlJ8h7EDDeNlsJDP0GrnnQBKM-iUjR',
#                                     dest_path='./openist.zip',
#                                     unzip=True)

# !unzip -q -o openist.zip

# # Create a list of tuples (img, mask)
# # This will be used below
# openist_list = [(f, f.replace('.jpg', '@.png')) for f in os.listdir('openist') if '@' not in f]

In [None]:
dataset_root = '/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset'

## Step 3:
Modify the ListDataset class.

As mentioned in the github repository, this class was implemented following an set folder organization. The dataset we downloaded above, uses a different folder structure, thus we will demonstrate how to easily adapt this class.

There are two main functions that need adaptation, namely,

```
make_dataset()  # This function generates the list of image files that compose the dataset
get_data()      # This function reads the image, and ground truth files
```
Since, this dataset is composed of grayscale images and binary masks we don't need to modify the `get_data()` function, only the `make_dataset()`.


In [None]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms  # Import transforms
import logging
import torch

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ListDataset(Dataset):
    """
    A PyTorch Dataset class to load images and corresponding masks for training or testing.

    This class supports various annotation sparsity modes such as 'points' and 'grid', allowing for few-shot
    learning tasks. It pairs images with their respective masks and applies optional resizing and sparsity transformations.
    """

    def __init__(self, mode, dataset_root, task, fold, resize_to, num_shots, sparsity_mode, sparsity_param, imgtype, make=True):
        """
        Initialize the ListDataset.

        Args:
            mode (str): Dataset usage mode - 'train', 'test', 'tune_train', or 'tune_test'.
            dataset_root (str): Root directory of the dataset.
            task (str): Task name (e.g., 'brains').
            fold (int or str): Fold identifier (e.g., '0').
            resize_to (tuple): Target resize dimensions (height, width). Pass None to keep original size.
            num_shots (int): Number of shots for few-shot learning (-1 for dense mode).
            sparsity_mode (str): Sparsity mode ('points', 'grid', or 'dense').
            sparsity_param (float or int): Parameter controlling sparsity (e.g., number of points or grid spacing).
            imgtype (str): Image type (e.g., 'med').
            make (bool): Whether to initialize and load the dataset.
        """
        self.mode = mode
        self.dataset_root = dataset_root
        self.task = task
        self.fold = str(fold)  # Convert fold to string
        self.resize_to = resize_to
        self.num_shots = num_shots
        self.sparsity_mode = sparsity_mode
        self.sparsity_param = sparsity_param
        self.imgtype = imgtype
        self.make = make

        # Root directory where the dataset is stored
        self.root = os.path.join(self.dataset_root, self.task, self.fold)

        # Initialize the dataset by loading image-mask pairs
        if make:
            self.imgs = self.make_dataset()
        else:
            self.imgs = []

    def make_dataset(self):
        """
        Create the dataset by pairing images and masks.

        Returns:
            list: List of tuples containing paths to images and corresponding masks.
        """
        data_list = []
        mode_dir = os.path.join(self.root, "Training" if 'train' in self.mode.lower() else "Testing")

        images_dir = os.path.join(mode_dir, "images")
        masks_dir = os.path.join(mode_dir, "masks")

        # Verify that the directories exist
        if not os.path.exists(images_dir):
            logger.error(f"Images directory does not exist: {images_dir}")
            return data_list

        if not os.path.exists(masks_dir):
            logger.error(f"Masks directory does not exist: {masks_dir}")
            return data_list

        # Get and sort image and mask files
        img_files = sorted(os.listdir(images_dir))
        mask_files = sorted(os.listdir(masks_dir))

        # Warn if the number of images and masks does not match
        if len(img_files) != len(mask_files):
            logger.warning("Mismatch between the number of images and masks!")

        # Pair images and masks
        for img_file, mask_file in zip(img_files, mask_files):
            img_path = os.path.join(images_dir, img_file)
            mask_path = os.path.join(masks_dir, mask_file)

            if os.path.exists(img_path) and os.path.exists(mask_path):
                data_list.append((img_path, mask_path))
            else:
                logger.warning(f"Missing file pair: Image={img_file}, Mask={mask_file}")

        logger.info(f"Loaded {len(data_list)} samples for task '{self.task}' and mode '{self.mode}' from '{self.dataset_root}'")
        return data_list

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.imgs)

    def __getitem__(self, idx):
        """
        Get a single sample (image and mask) from the dataset.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple containing:
                - image (torch.Tensor): The image tensor of shape [C, H, W].
                - y_dense (torch.Tensor): The dense mask tensor of shape [H, W].
                - y_tr (torch.Tensor): The sparse mask tensor (after sparsity transformation).
                - img_name (str): Name of the image file.
        """
        img_path, mask_path = self.imgs[idx]

        # Load image and mask
        image = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L')

        # Resize if necessary
        if self.resize_to:
            image = image.resize(self.resize_to, Image.BILINEAR)
            mask = mask.resize(self.resize_to, Image.NEAREST)

        # Convert image and mask to numpy arrays
        image = np.array(image)
        mask = np.array(mask)

        # Apply sparsity transformations if specified
        if self.sparsity_mode == 'points':
            y_tr = self._apply_point_sparsity(mask)
        elif self.sparsity_mode == 'grid':
            y_tr = self._apply_grid_sparsity(mask)
        else:
            y_tr = mask  # Dense mask without sparsity

        # Convert to PyTorch tensors
        image = transforms.ToTensor()(image)  # Shape: [1, H, W]
        y_dense = torch.tensor(mask, dtype=torch.long)  # Shape: [H, W]
        y_tr = torch.tensor(y_tr, dtype=torch.long)  # Sparse mask

        img_name = os.path.basename(img_path)
        return image, y_dense, y_tr, img_name

    def _apply_point_sparsity(self, mask):
        """
        Apply sparsity by selecting a fixed number of points for each class.

        Args:
            mask (numpy.ndarray): Original dense mask.

        Returns:
            numpy.ndarray: Sparse mask with selected points.
        """
        logger.info(f"Applying point sparsity: {self.sparsity_param}")
        sparse_mask = np.zeros_like(mask)
        num_points = self.sparsity_param if self.sparsity_param else 10

        for cls in np.unique(mask):
            if cls == 0:  # Ignore background
                continue
            cls_points = np.argwhere(mask == cls)
            selected_points = cls_points[
                np.random.choice(len(cls_points), min(num_points, len(cls_points)), replace=False)
            ]
            sparse_mask[tuple(zip(*selected_points))] = cls
        return sparse_mask

    def _apply_grid_sparsity(self, mask):
        """
        Apply sparsity by selecting pixels at regular grid intervals.

        Args:
            mask (numpy.ndarray): Original dense mask.

        Returns:
            numpy.ndarray: Sparse mask with grid sampling.
        """
        logger.info(f"Applying grid sparsity: {self.sparsity_param}")
        sparse_mask = np.zeros_like(mask)
        grid_spacing = self.sparsity_param if self.sparsity_param else 5
        sparse_mask[::grid_spacing, ::grid_spacing] = mask[::grid_spacing, ::grid_spacing]
        return sparse_mask


## Step 4:

In this step, we create different loaders for multiple few-shot experiments, varying the number of shots and the sparse label paramenters.

In [None]:
# The experiments parameters
list_shots = [5]                                 # Number of shots in the task (i.e, total annotated sparse samples)
list_sparsity_points = [1, 5, 10, 20]                       # Number of labeled pixels in point annotation
list_sparsity_grid = [8, 12, 16, 20]                        # Spacing between selected pixels in grid annotation

In [None]:
def get_tune_loaders(shots, points, grid, fold_name, resize_to, args, imgtype='med'):
    dataset_root = '/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/dataset'
    task_name = 'brains'

    loaders = {'points': [], 'grid': [], 'dense': []}

    for sparsity_mode, sparsity_values in [('points', points), ('grid', grid)]:
        for n_shots in shots:
            for sparsity in sparsity_values:
                tune_train_set = ListDataset(
                    mode='tune_train',
                    dataset_root=dataset_root,
                    task=task_name,
                    fold=fold_name,
                    resize_to=resize_to,
                    num_shots=n_shots,
                    sparsity_mode=sparsity_mode,
                    sparsity_param=sparsity,
                    imgtype=imgtype
                )
                tune_train_loader = DataLoader(tune_train_set, batch_size=args['batch_size'], num_workers=args['num_workers'], shuffle=True)

                tune_test_set = ListDataset(
                    mode='tune_test',
                    dataset_root=dataset_root,
                    task=task_name,
                    fold=fold_name,
                    resize_to=resize_to,
                    num_shots=-1,
                    sparsity_mode='dense',
                    sparsity_param=None,
                    imgtype=imgtype
                )
                tune_test_loader = DataLoader(tune_test_set, batch_size=1, num_workers=args['num_workers'], shuffle=False)

                loaders[sparsity_mode].append({
                    'n_shots': n_shots,
                    'sparsity': sparsity,
                    'train': tune_train_loader,
                    'test': tune_test_loader
                })

    return loaders


## Step 5:
Training and validation functions.
We group these two functions in one `tune_train_test`.

In [None]:
from sklearn.metrics import confusion_matrix
import numpy as np

def compute_metrics_per_class(y_true, y_pred, ignore_index=0):
    # Sukuriame kaukę, kuri pašalina foną (klasę 0)
    mask = y_true != ignore_index
    y_true_filtered = y_true[mask]
    y_pred_filtered = y_pred[mask]

    classes = [1, 2, 3]  # Klasės, kurias norime įvertinti

    metrics_dict = {}
    for cls in classes:
        true_cls = y_true_filtered == cls
        pred_cls = y_pred_filtered == cls

        # Skaičiuojame True Positives, False Positives, True Negatives, False Negatives
        TP = np.sum(pred_cls & true_cls)
        FP = np.sum(pred_cls & ~true_cls)
        TN = np.sum(~pred_cls & ~true_cls)
        FN = np.sum(~pred_cls & true_cls)

        # Skaičiuojame metrikas
        iou = TP / (TP + FP + FN) if (TP + FP + FN) != 0 else 0
        dice = 2 * TP / (2 * TP + FP + FN) if (2 * TP + FP + FN) != 0 else 0
        sensitivity = TP / (TP + FN) if (TP + FN) != 0 else 0
        specificity = TN / (TN + FP) if (TN + FP) != 0 else 0

        metrics_dict[cls] = {
            'IoU': iou,
            'Dice': dice,
            'Sensitivity': sensitivity,
            'Specificity': specificity
        }

    return metrics_dict


In [None]:
def display_samples(list_imgs, list_labels, list_preds, title="Examples"):
    b = len(list_imgs)
    fig, ax = plt.subplots(3, b, figsize=(16,16))

    fig.suptitle(title)
    for i in range(b):
      ax[0, i].imshow(list_imgs[i], cmap='gray')
      ax[1, i].imshow(list_preds[i], cmap='gray')
      ax[2, i].imshow(list_labels[i], cmap='gray')

      for j in range(3):
        ax[j,i].set_title('')
        ax[j,i].set_xticks([])
        ax[j,i].set_yticks([])

      ax[0,0].set(ylabel='Image')
      ax[1,0].set(ylabel='Ground-truth')
      ax[2,0].set(ylabel='Prediction')
    plt.tight_layout()
    plt.show()

In [None]:
import os
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.tensorboard import SummaryWriter

def tune_train_test(tune_train_loader, tune_test_loader, net, optimizer, args, sparsity_mode, best_weights_path):
    # Inicializuojame TensorBoard rašytoją
    writer = SummaryWriter()

    # Nustatome tinklą mokymosi režimu
    net.train()

    # Sąrašai praradimų reikšmėms
    tune_train_loss_list = []
    tune_test_loss_list = []
    best_iou = 0  # Geriausio IoU sekimui

    # Jei geriausių svorių failas egzistuoja, krauname jį
    if os.path.exists(best_weights_path):
        net.load_state_dict(torch.load(best_weights_path))
        print(f"Loaded best weights from {best_weights_path}")
    else:
        print("No pre-saved weights found. Starting training from scratch.")

    for epoch in range(1, args['tuning_epochs'] + 1):
        # print(f'Tuning epoch {epoch}/{args["tuning_epochs"]}')
        sys.stdout.flush()

        train_loss_list = []

        # Iteruojame per mokymo duomenų paketus
        for i, data in enumerate(tune_train_loader):
            # Gauname vaizdus, tankias etiketes, retkas etiketes ir pavadinimus
            x_tr, y_dense, y_tr, img_name = data

            # Perkeliam į CUDA
            x_tr, y_tr = x_tr.cuda(), y_tr.cuda()

            # Nuliname gradientus
            optimizer.zero_grad()

            # Perkeliame per tinklą
            p_tr = net(x_tr)

            # Skaičiuojame praradimą
            tune_train_loss = F.cross_entropy(p_tr, y_tr, ignore_index=-1)

            # Skaičiuojame gradientus ir atnaujiname svorius
            tune_train_loss.backward()
            optimizer.step()

            train_loss_list.append(tune_train_loss.detach().item())

        # Apskaičiuojame vidutinį praradimą epokei
        avg_train_loss = np.mean(train_loss_list)
        tune_train_loss_list.append(avg_train_loss)
        # writer.add_scalar(f'{sparsity_mode}/Train Loss', avg_train_loss, epoch)

        # print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}')

        # Validacijos fazė
        if epoch % args['val_freq'] == 0:
            # Sąrašai validacijos praradimams ir metrikoms
            test_loss_list = []
            inps_all, labs_all, prds_all = [], [], []

            with torch.no_grad():
                # Nustatome tinklą vertinimo režimu
                net.eval()

                # Iteruojame per testavimo duomenų paketus
                for i, data in enumerate(tune_test_loader):
                    # Gauname vaizdus, etiketes ir pavadinimus
                    x_ts, y_ts, _, img_name = data

                    # Perkeliam į CUDA
                    x_ts, y_ts = x_ts.cuda(), y_ts.cuda()

                    # Užtikriname, kad y_ts yra ilgasis tensorius
                    y_ts = y_ts.long()

                    # Perkeliame per tinklą
                    p_ts = net(x_ts)

                    # Skaičiuojame praradimą
                    tune_test_loss = F.cross_entropy(p_ts, y_ts, ignore_index=-1)
                    test_loss_list.append(tune_test_loss.detach().item())

                    # Gauname prognozes
                    prds = p_ts.detach().max(1)[1].squeeze(1).squeeze(0).cpu().numpy()

                    # Pridedame duomenis į sąrašus
                    inps_all.append(x_ts.detach().squeeze(1).squeeze(0).cpu())
                    labs_all.append(y_ts.detach().cpu().numpy())
                    prds_all.append(prds)

            # Apskaičiuojame vidutinį testavimo praradimą
            avg_test_loss = np.mean(test_loss_list)
            tune_test_loss_list.append(avg_test_loss)
            writer.add_scalar(f'{sparsity_mode}/Validation Loss', avg_test_loss, epoch)

            # Konvertuojame į numpy masyvus metrikų skaičiavimui
            labs_np = np.asarray(labs_all).ravel()
            prds_np = np.asarray(prds_all).ravel()

            # Skaičiuojame metrikas kiekvienai klasei atskirai
            metrics_per_class = compute_metrics_per_class(labs_np, prds_np, ignore_index=0)

            # Išvedame metrikas kiekvienai klasei
            print('--------------------------------------------------------------------')
            for cls, metrics in metrics_per_class.items():
                print(f'Class {cls}: IoU: {metrics["IoU"]*100:.2f}%, Dice: {metrics["Dice"]*100:.2f}%, Sensitivity: {metrics["Sensitivity"]*100:.2f}%, Specificity: {metrics["Specificity"]*100:.2f}%')
                # Įrašome metrikas į TensorBoard
                writer.add_scalar(f'{sparsity_mode}/Class_{cls}_IoU', metrics["IoU"], epoch)
                writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Dice', metrics["Dice"], epoch)
                writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Sensitivity', metrics["Sensitivity"], epoch)
                writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Specificity', metrics["Specificity"], epoch)
            print('--------------------------------------------------------------------')
            sys.stdout.flush()

            # Apskaičiuojame vidutinį IoU per visas klases
            avg_iou = np.mean([m['IoU'] for m in metrics_per_class.values()])
            writer.add_scalar(f'{sparsity_mode}/Average IoU', avg_iou, epoch)

            # Išsaugome kas 20 epochų arba jei pagerėja geriausias IoU
            avg_iou = np.mean([m['IoU'] for m in metrics_per_class.values()])
            # if avg_iou > best_iou or epoch % 20 == 0:
            #     best_iou = max(best_iou, avg_iou)
            #     save_path = f"{best_weights_path.split('.pth')[0]}_epoch_{epoch}.pth"
            #     torch.save(net.state_dict(), save_path)
            #     print(f"Model saved at epoch {epoch} with IoU: {avg_iou * 100:.2f}%")

            # Nustatome tinklą grįžti į mokymo režimą
            net.train()

    # Jei paskutinis epokas nėra validacijos epokas, atliekame galutinę vertinimą
    if args['tuning_epochs'] % args['val_freq'] != 0:
        # Atliekame galutinę vertinimą
        test_loss_list = []
        inps_all, labs_all, prds_all = [], [], []

        with torch.no_grad():
            net.eval()

            for i, data in enumerate(tune_test_loader):
                x_ts, y_ts, _, img_name = data
                x_ts, y_ts = x_ts.cuda(), y_ts.cuda()

                # Užtikriname, kad y_ts yra ilgasis tensorius
                y_ts = y_ts.long()

                p_ts = net(x_ts)

                tune_test_loss = F.cross_entropy(p_ts, y_ts, ignore_index=-1)
                test_loss_list.append(tune_test_loss.detach().item())

                prds = p_ts.detach().max(1)[1].squeeze(1).squeeze(0).cpu().numpy()

                inps_all.append(x_ts.detach().squeeze(1).squeeze(0).cpu())
                labs_all.append(y_ts.detach().cpu().numpy())
                prds_all.append(prds)

        avg_test_loss = np.mean(test_loss_list)
        tune_test_loss_list.append(avg_test_loss)
        writer.add_scalar(f'{sparsity_mode}/Validation Loss', avg_test_loss, args['tuning_epochs'])

        labs_np = np.asarray(labs_all).ravel()
        prds_np = np.asarray(prds_all).ravel()

        # Skaičiuojame metrikas kiekvienai klasei atskirai
        metrics_per_class = compute_metrics_per_class(labs_np, prds_np, ignore_index=0)

        # Išvedame metrikas kiekvienai klasei
        print('--------------------------------------------------------------------')
        for cls, metrics in metrics_per_class.items():
            print(f'Class {cls}: IoU: {metrics["IoU"]*100:.2f}%, Dice: {metrics["Dice"]*100:.2f}%, Sensitivity: {metrics["Sensitivity"]*100:.2f}%, Specificity: {metrics["Specificity"]*100:.2f}%')
            # Įrašome metrikas į TensorBoard
            writer.add_scalar(f'{sparsity_mode}/Class_{cls}_IoU', metrics["IoU"], args['tuning_epochs'])
            writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Dice', metrics["Dice"], args['tuning_epochs'])
            writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Sensitivity', metrics["Sensitivity"], args['tuning_epochs'])
            writer.add_scalar(f'{sparsity_mode}/Class_{cls}_Specificity', metrics["Specificity"], args['tuning_epochs'])
        print('--------------------------------------------------------------------')
        sys.stdout.flush()

    # Užbaikite TensorBoard rašytoją
    writer.close()


In [None]:
import os

def run_sparse_tuning(loader_dict, net, optimizer, args, model_weights_dir='best_models'):
    # Sukuriame katalogą, jei jis neegzistuoja
    os.makedirs(model_weights_dir, exist_ok=True)

    # Tuning/testing on points.
    for dict_points in loader_dict['points']:

        n_shots = dict_points['n_shots']
        sparsity = dict_points['sparsity']

        mode_identifier = f'points_{n_shots}_shots_{sparsity}_points'
        best_weights_path = os.path.join(model_weights_dir, f'best_model_{mode_identifier}.pth')

        print(f"Evaluating 'points' ({n_shots}-shot, {sparsity}-points) with identifier '{mode_identifier}'")
        sys.stdout.flush()

        tune_train_test(dict_points['train'], dict_points['test'], net, optimizer, args, mode_identifier, best_weights_path)

    # Tuning/testing on grid.
    for dict_grid in loader_dict['grid']:

        n_shots = dict_grid['n_shots']
        sparsity = dict_grid['sparsity']

        mode_identifier = f'grid_{n_shots}_shots_{sparsity}_spacing'
        best_weights_path = os.path.join(model_weights_dir, f'best_model_{mode_identifier}.pth')

        print(f"Evaluating 'grid' ({n_shots}-shot, {sparsity}-spacing) with identifier '{mode_identifier}'")
        sys.stdout.flush()

        tune_train_test(dict_grid['train'], dict_grid['test'], net, optimizer, args, mode_identifier, best_weights_path)

    # Tuning/testing on dense.
    for dict_dense in loader_dict['dense']:

        n_shots = dict_dense['n_shots']

        mode_identifier = f'dense_{n_shots}_shots'
        best_weights_path = os.path.join(model_weights_dir, f'best_model_{mode_identifier}.pth')

        print(f"Evaluating 'dense' ({n_shots}-shot) with identifier '{mode_identifier}'")
        sys.stdout.flush()

        tune_train_test(dict_dense['train'], dict_dense['test'], net, optimizer, args, mode_identifier, best_weights_path)


## Step 6:
Download pretrained weights, and setup model, optimizer and other training parameters.

In [None]:
# # Download weights from gdrive

# gdd.download_file_from_google_drive(file_id='12wXceabolEv9H-N8EebPMM7a6A3U1QMM',
#                                     dest_path='./weasel_pretrained.zip',
#                                     unzip=True)

# !unzip -q -o weasel_pretrained.zip

# model_weights = 'weasel_pretrained/weasel_unet_openist_brains_f0/meta.pth' # Choose an pretrained weight fold


In [None]:
# General arguments for training
args = {
    'tuning_epochs': 500,   # Number of epochs on the tuning phase.
    'val_freq': 5,         # Test each val_freq epochs on the tuning phase.
    'vis_freq': 25,         # Visualize predictions samples each vis_freq epochs on the tuning phase.
    'lr': 1e-6,            # Learning rate.
    'weight_decay': 5e-5,  # L2 penalty.
    'momentum': 0.9,       # Momentum.
    'num_workers': 0,      # Number of workers on data loader.
    'batch_size': 5,       # Mini-batch size.
    'w_size': 128,         # Width size for image resizing.
    'h_size': 128,         # Height size for image resizing.
    'num_channels': 1,     # Number of channels in the input
    'num_class': 4,        # Number of classes
}

fold = 0  # Fold number [0-4]

resize_to = (args['h_size'], args['w_size'])

In [None]:
# Custom implementations of the required functions
def _get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value
    return None

def _save_response_content(response, destination):
    chunk_size = 32768
    with open(destination, 'wb') as f:
        for chunk in response.iter_content(chunk_size):
            if chunk:
                f.write(chunk)

In [None]:
!sed -i 's/from torchvision.datasets.utils import _get_confirm_token, _save_response_content//' /usr/local/lib/python3.10/dist-packages/torchmeta/datasets/utils.py


In [None]:
!pip install ordered_set --no-deps

Collecting ordered_set
  Downloading ordered_set-4.1.0-py3-none-any.whl.metadata (5.3 kB)
Downloading ordered_set-4.1.0-py3-none-any.whl (7.6 kB)
Installing collected packages: ordered_set
Successfully installed ordered_set-4.1.0


In [None]:
import torchmeta
print("Torchmeta version:", torchmeta.__version__)


Torchmeta version: 1.8.0


In [None]:
def compute_metrics_per_class(y_true, y_pred, ignore_index=0):
    # Sukuriame kaukę, kuri pašalina foną (klasę 0)
    mask = y_true != ignore_index
    y_true_filtered = y_true[mask]
    y_pred_filtered = y_pred[mask]

    classes = [1, 2, 3]  # Klasės, kurias norime įvertinti

    metrics_dict = {}
    for cls in classes:
        true_cls = y_true_filtered == cls
        pred_cls = y_pred_filtered == cls

        # Skaičiuojame True Positives, False Positives, True Negatives, False Negatives
        TP = np.sum(pred_cls & true_cls)
        FP = np.sum(pred_cls & ~true_cls)
        TN = np.sum(~pred_cls & ~true_cls)
        FN = np.sum(~pred_cls & true_cls)

        # Skaičiuojame metrikas
        iou = TP / (TP + FP + FN) if (TP + FP + FN) != 0 else 0
        dice = 2 * TP / (2 * TP + FP + FN) if (2 * TP + FP + FN) != 0 else 0
        sensitivity = TP / (TP + FN) if (TP + FN) != 0 else 0
        specificity = TN / (TN + FP) if (TN + FP) != 0 else 0

        metrics_dict[cls] = {
            'IoU': iou,
            'Dice': dice,
            'Sensitivity': sensitivity,
            'Specificity': specificity
        }

    return metrics_dict


In [None]:
import os
from math import ceil

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from sklearn import metrics
from matplotlib import pyplot as plt

from torchmeta import modules

from collections import OrderedDict

def check_mkdir(dir_name):
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)

def prepare_meta_batch(meta_train_set, meta_test_set, index, batch_size=5):

    # Acquiring training and test data.
    x_train = []
    y_train = []

    x_test = []
    y_test = []

    perm_train = torch.randperm(len(meta_train_set[index])).tolist()
    perm_test = torch.randperm(len(meta_test_set[index])).tolist()

    for b in range(batch_size):

        d_tr = meta_train_set[index][perm_train[b]]
        d_ts = meta_test_set[index][perm_test[b]]

        x_tr = d_tr[0].cuda()
        y_tr = d_tr[2].cuda()

        x_ts = d_ts[0].cuda()
        y_ts = d_ts[1].cuda()

        x_train.append(x_tr)
        y_train.append(y_tr)

        x_test.append(x_ts)
        y_test.append(y_ts)

    x_train = torch.stack(x_train, dim=0)
    y_train = torch.stack(y_train, dim=0)

    x_test = torch.stack(x_test, dim=0)
    y_test = torch.stack(y_test, dim=0)

    return x_train, y_train, x_test, y_test

def plot_kernels(kernel, idx, epoch, norm='mean0'):
    if norm == 'mean0':
        tensor = (1/(abs(kernel.min())*2))*kernel + 0.5
    elif norm == '01':
        tensor = (kernel - kernel.min()) / (kernel.max() - kernel.min())

    num_kernels = tensor.shape[0]
    num_rows = num_kernels
    num_cols = tensor.shape[1]
    fig = plt.figure(figsize=(16,16))
    fig.tight_layout()

    tot = num_rows * num_cols
    pos = range(1, tot+1)

    k = 0
    for i in range(num_rows):
        for j in range(num_cols):
            ax1 = fig.add_subplot(num_rows,num_cols,pos[k])
            ax1.imshow(tensor[i][j], cmap='gray')
            ax1.axis('off')
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
            k+=1

    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.savefig('kernels/kernel' + str(idx) + '_ep' + str(epoch) + '.png', format='png')
    # plt.show()

def accuracy(lab, prd):
    # Obtaining class from prediction.
    prd = prd.argmax(1)

    # Tensor to ndarray.
    lab_np = lab.view(-1).detach().cpu().numpy()
    prd_np = prd.view(-1).detach().cpu().numpy()

    # Computing metric and returning.
    metric_val = metrics.jaccard_score(lab_np, prd_np)

    return metric_val

In [None]:
import torch
import torch.nn.functional as F
from torch import nn

from torchmeta import modules

from collections import OrderedDict

def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, modules.MetaConv2d) or isinstance(module, modules.MetaLinear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d) or isinstance(module, modules.MetaBatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

class MetaConvTranspose2d(nn.ConvTranspose2d, modules.MetaModule):
    __doc__ = nn.ConvTranspose2d.__doc__

    def forward(self, input, output_size=None, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())
        weights = params.get('weight', None)
        bias = params.get('bias', None)

        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')

        # Compute output padding manually
        if output_size is not None:
            input_size = input.size()[2:]  # Spatial dimensions
            stride = self.stride
            padding = self.padding
            kernel_size = self.kernel_size
            dilation = self.dilation

            # Compute expected output size
            expected_output_size = [
                (input_size[i] - 1) * stride[i] - 2 * padding[i] + dilation[i] * (kernel_size[i] - 1) + 1
                for i in range(len(input_size))
            ]

            # Compute the required output padding
            output_padding = [
                output_size[i] - expected_output_size[i]
                for i in range(len(input_size))
            ]
        else:
            # Use predefined output padding
            output_padding = self.output_padding

        # Perform convolution transpose
        return F.conv_transpose2d(
            input, weights, bias, self.stride, self.padding,
            tuple(output_padding), self.groups, self.dilation
        )

class _MetaEncoderBlock(modules.MetaModule):

    def __init__(self, in_channels, out_channels, dropout=False):

        super(_MetaEncoderBlock, self).__init__()

        layers = [
            modules.MetaConv2d(in_channels, out_channels, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            modules.MetaConv2d(out_channels, out_channels, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]

        if dropout:

            layers.append(nn.Dropout())

        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

        self.encode = modules.MetaSequential(*layers)

    def forward(self, x, params=None):

        return self.encode(x, self.get_subdict(params, 'encode'))

class _MetaDecoderBlock(modules.MetaModule):

    def __init__(self, in_channels, middle_channels, out_channels):

        super(_MetaDecoderBlock, self).__init__()

        self.decode = modules.MetaSequential(
            nn.Dropout2d(),
            modules.MetaConv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            modules.MetaConv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
            MetaConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2, padding=0, output_padding=0)
        )

    def forward(self, x, params=None):

        return self.decode(x, self.get_subdict(params, 'decode'))


class UNet(modules.MetaModule):

    def __init__(self, input_channels, num_classes, prototype=False):

        super(UNet, self).__init__()

        self.prototype = prototype

        self.enc1 = _MetaEncoderBlock(input_channels, 32)
        self.enc2 = _MetaEncoderBlock(32, 64)
        self.enc3 = _MetaEncoderBlock(64, 128, dropout=True)

        self.center = _MetaDecoderBlock(128, 256, 128)

        self.dec3 = _MetaDecoderBlock(256, 128, 64)
        self.dec2 = _MetaDecoderBlock(128, 64, 32)

        self.dec1 = modules.MetaSequential(
            nn.Dropout2d(),
            modules.MetaConv2d(64, 32, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(32),
            nn.ReLU(inplace=True),
            modules.MetaConv2d(32, 32, kernel_size=3, padding=1),
            modules.MetaBatchNorm2d(32),
            nn.ReLU(inplace=True),
        )

        if not self.prototype:
            self.final = modules.MetaConv2d(32, num_classes, kernel_size=1)

        initialize_weights(self)

    def forward(self, x, feat=False, params=None):

        enc1 = self.enc1(x, self.get_subdict(params, 'enc1'))
        enc2 = self.enc2(enc1, self.get_subdict(params, 'enc2'))
        enc3 = self.enc3(enc2, self.get_subdict(params, 'enc3'))

        center = self.center(enc3, self.get_subdict(params, 'center'))

        dec3 = self.dec3(torch.cat([center, F.interpolate(enc3, center.size()[2:], mode='bilinear')], 1), self.get_subdict(params, 'dec3'))
        dec2 = self.dec2(torch.cat([dec3, F.interpolate(enc2, dec3.size()[2:], mode='bilinear')], 1), self.get_subdict(params, 'dec2'))
        dec1 = self.dec1(torch.cat([dec2, F.interpolate(enc1, dec2.size()[2:], mode='bilinear')], 1), self.get_subdict(params, 'dec1'))

        if self.prototype:
            return F.interpolate(dec1, x.size()[2:], mode='bilinear')

        else:
            final = self.final(dec1, self.get_subdict(params, 'final'))

            if feat:
                return (F.interpolate(final, x.size()[2:], mode='bilinear'),
                        dec1,
                        F.interpolate(dec2, x.size()[2:], mode='bilinear'),
                        F.interpolate(dec3, x.size()[2:], mode='bilinear'),
                       )
            else:
                return F.interpolate(final, x.size()[2:], mode='bilinear')


In [None]:
# Network and optimizer
# from weasel.utils import *
# from weasel.models.u_net import *
import torch
import torch.optim as optim
net = UNet(args['num_channels'], num_classes=4).cuda()

optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], betas=(args['momentum'], 0.99))

In [None]:
# Importai
from torch.utils.data import DataLoader

# Sukurkite katalogą geriausių svorių saugojimui
# model_weights_dir = '/content/drive/MyDrive/Colab Notebooks/Kursinis/weasel/best_model'

# Get DataLoaders for exactly 3 ways with 5 shots
loaders_dict = get_tune_loaders(
    shots=list_shots,
    points=list_sparsity_points,
    grid=list_sparsity_grid,
    fold_name=fold,
    resize_to=resize_to,
    args=args,
    imgtype='med'
)


# Run tuning experiments su nurodytu katalogu geriausiems svoriams saugoti
run_sparse_tuning(loaders_dict, net, optimizer, args)


Evaluating 'points' (5-shot, 1-points) with identifier 'points_5_shots_1_points'
No pre-saved weights found. Starting training from scratch.
--------------------------------------------------------------------
Class 1: IoU: 25.93%, Dice: 41.18%, Sensitivity: 73.13%, Specificity: 23.16%
Class 2: IoU: 1.67%, Dice: 3.29%, Sensitivity: 1.71%, Specificity: 98.25%
Class 3: IoU: 4.04%, Dice: 7.77%, Sensitivity: 4.96%, Specificity: 91.50%
--------------------------------------------------------------------
--------------------------------------------------------------------
Class 1: IoU: 26.38%, Dice: 41.75%, Sensitivity: 77.40%, Specificity: 18.36%
Class 2: IoU: 1.26%, Dice: 2.48%, Sensitivity: 1.28%, Specificity: 98.76%
Class 3: IoU: 3.41%, Dice: 6.60%, Sensitivity: 4.07%, Specificity: 92.76%
--------------------------------------------------------------------
--------------------------------------------------------------------
Class 1: IoU: 26.42%, Dice: 41.80%, Sensitivity: 77.48%, Specifi