## dataset creation
adapted from Leo's code

In [None]:
from pathlib import Path
import sys

# Add repo root and code/ to sys.path
repo_root = Path.cwd().parents[1]
sys.path.insert(0, str(repo_root))
sys.path.insert(0, str(repo_root / "code"))

from utils import helper_functions as hf
# Package import
import scipy.io
import torch
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from skimage.registration import phase_cross_correlation
import os
import time
import importlib
import torch.nn.functional as F
import glob
from scipy import signal
import random
import pickle
import math

In [None]:
# Dataset creation class by leo

def gaussian_kernel(size, sigma):
    """Generate a Gaussian kernel."""
    x = torch.arange(-size // 2 + 1, size // 2 + 1, dtype=torch.float32)  # Ensure float32
    y = x
    xx, yy = torch.meshgrid(x, y, indexing='ij')
    kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2))
    kernel /= kernel.sum()
    return kernel
    
class GaussianNoise:
    """Add Gaussian noise to the tensor."""
    def __init__(self, mean=0., std=0.01):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean

class FilteredFUSWindowDataset(Dataset):
    def __init__(self, data_tensor, labels_tensor, window_size=8, stride=1, image_size=112, mode='train'):
        """
        Custom Dataset for sliding windows of FUS frames with cropping and normalization.
        For 'train' mode: Balances windows between labels 0 and 1 by downsampling majority class, and applies data augmentation.
        For 'test' mode: Includes all windows, no balancing, no augmentation.
        
        Args:
            data_tensor (torch.Tensor): Shape [N, 1, H, W] (e.g., [6000, 1, 112, 112]).
            labels_tensor (torch.Tensor): Shape [N] with labels (0, 1). Assumes -1 already excluded.
            window_size (int): Number of frames per window (e.g., 8).
            stride (int): Step size for sliding window (e.g., 9 for non-overlapping with small gap).
            image_size (int): Target square size for frames (e.g., 112).
            mode (str): 'train' or 'test' to control balancing and augmentation.
        """
        assert data_tensor.shape[0] == labels_tensor.shape[0], "Data and labels length mismatch"
        assert window_size > 0, "Window size must be positive"
        assert mode in ['train', 'test'], "Mode must be 'train' or 'test'"
        
        self.data = data_tensor
        self.labels = labels_tensor
        self.window_size = window_size
        self.stride = stride
        self.image_size = image_size
        self.mode = mode
        
        self.transform = T.Compose([
            T.Resize((image_size, image_size))  # Ensure square (redundant if already sized)
        ])

        # Build valid indices, skipping excluded zones
        self.valid_indices = []
        for i in range(window_size - 1, len(self.labels), stride):
            if self.labels[i] != -1:  # Assume -1 already filtered out in preprocessing
                self.valid_indices.append(i)

        if len(self.valid_indices) == 0:
            print("Warning: No valid windows found. Check labels or window_size.")
        
    
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, idx):
        try:
            end_idx = self.valid_indices[idx]
            start_idx = end_idx - self.window_size + 1
            if start_idx < 0 or end_idx >= len(self.data):
                raise IndexError(f"Invalid window range: [{start_idx}:{end_idx + 1}]")
            
            window = self.data[start_idx:end_idx + 1]  # [T, 1, H, W]
            H, W = window.shape[-2:]
            window = torch.stack([self.transform(frame) for frame in window])  # [T, 1, 112, 112]
            
            # Verify window shape
            if window.shape[0] != self.window_size:
                raise ValueError(f"Expected {self.window_size} frames, got {window.shape[0]}")
                
            # === DATA AUGMENTATION (ONLY IN TRAIN MODE) ===
            if self.mode == 'train':
                # 1. Random Affine (same transformation for all frames)
                affine = T.RandomAffine(degrees=30, translate=(0.2, 0.2), scale=(0.8, 1.2))
                angle, translations, scale, shear = affine.get_params(
                    affine.degrees, affine.translate, affine.scale, affine.shear,
                    (self.image_size, self.image_size)
                )
                for t in range(window.size(0)):
                    window[t] = TF.affine(
                        window[t], angle=angle, translate=translations,
                        scale=scale, shear=shear,
                        interpolation=TF.InterpolationMode.BILINEAR, fill=0
                    )
                # 3. Gaussian noise (independent per frame)
                noise = GaussianNoise(std=0.05)
                window = noise(window)
            
            
            label = self.labels[end_idx]
            return window, label
        except Exception as e:
            print(f"Error in __getitem__ at idx {idx}, end_idx {end_idx}: {e}")
            raise

In [None]:
# Create datasets
train_dataset = FilteredFUSWindowDataset(
    data_tensor=train_images,
    labels_tensor=train_labels,
    window_size=window_size,
    stride=stride,
    image_size=image_size,
    mode='train'
)

test_dataset = FilteredFUSWindowDataset(
    data_tensor=test_images,
    labels_tensor=test_labels,
    window_size=window_size,
    stride=stride,
    image_size=image_size,
    mode='test'
)

print(f"Train dataset size: {len(train_dataset)} windows")
print(f"Test dataset size: {len(test_dataset)} windows")