In [1]:
import zipfile
import os
import nibabel as nib
from torch.utils.data import DataLoader, random_split
import os
import torch
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from tqdm import tqdm
import numpy as np

In [3]:
BrATS2021 = 0
BrATS2023 = 0
if os.path.isdir('/content/'):# using COLAB
  if not os.path.isdir("/content/drive/MyDrive/datasets/BrATS2021/"):
    print("No se encuentra")
    !mkdir -p ~/.kaggle
    !cp kaggle.json ~/.kaggle/
    !chmod 600 /root/.kaggle/kaggle.json
    !kaggle datasets download -d dschettler8845/brats-2021-task1
    !unzip  /content/brats-2021-task1.zip -d /content/BrATS
    !mkdir -p /content/BrATS2021
    !tar -xvf /content/BrATS/BraTS2021_Training_Data.tar -C /content/BrATS2021
    !rm -R /content/BrATS
    !rm /content/brats-2021-task1.zip
    BrATS2021 ="/content/BrATS2021"
  else:
    print("Si se encuentra la ruta")
    BrATS2021 ="/content/drive/MyDrive/datasets/BrATS2021/"

  if not os.path.isdir("/content/drive/MyDrive/datasets/BrATS2023/"):
    !kaggle datasets download -d shakilrana/brats-2023-adult-glioma -p ./BraTS2023
#    !unzip BraTS2023/brats2023-part-1.zip -d /content/BrATS2023
#    BrATS2023 ="/content/BrATS2023"
  else:
    BrATS2023 ="/content/drive/MyDrive/datasets/BrATS2023/"
else:# using local
    BrATS2021 = "BrATS2021/"
    BrATS2023 = "BrATS2023/"
print("Path BrATS2021:",BrATS2021 )
print("Path BrATS2023:",BrATS2023 )

using local
Local
Path BrATS2021: BrATS2021/
Path BrATS2023: BrATS2023/


In [None]:
import os
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset

class BRATSDataset_2(Dataset):
    def __init__(self, base_path, img_transform=None, mask_transform=None, year=2023):
        self.base_path = base_path
        self.folders = [f for f in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, f))]
        self.transform = img_transform
        self.mask_transform = mask_transform  # Transformación para la máscara

    def __len__(self):
        return len(self.folders)

    def __getitem__(self, idx):
        folder_path = os.path.join(self.base_path, self.folders[idx])
        files = os.listdir(folder_path)

        t2f_file, seg_file = None, None

        for nifty_file in files:
            file_path = os.path.join(folder_path, nifty_file)

            try:
                if nifty_file.endswith('.nii') and 't2f' in nifty_file:
                    t2f_file = nib.load(file_path).get_fdata()
                elif nifty_file.endswith('.nii') and 'seg' in nifty_file:
                    seg_file = nib.load(file_path).get_fdata()
                elif nifty_file.endswith('.nii.gz') and 't2' in nifty_file:
                    t2f_file = nib.load(file_path).get_fdata()
                elif nifty_file.endswith('.nii.gz') and 'seg' in nifty_file:
                    seg_file = nib.load(file_path).get_fdata()
            except Exception as e:
                print(f"Error al cargar {file_path}: {e}")
                # Intentar con el siguiente índice si hay error
                return self.__getitem__((idx + 1) % len(self))  

        if t2f_file is None or seg_file is None:
            print(f"Advertencia: Archivo t2f o segmentación faltante en {folder_path}.")
            # Intentar con el siguiente índice si faltan archivos
            return self.__getitem__((idx + 1) % len(self))

        # Encontrar el slice con mayor contenido en la máscara (más píxeles no cero)
        max_content_slice = np.argmax(np.sum(seg_file, axis=(0,1)))

        # Usar el slice con mayor contenido en la máscara
        t2f_image = t2f_file[:, :, max_content_slice]
        seg_image = (seg_file[:, :, max_content_slice] > 0).astype(np.uint8)  # Máscara binaria

        # Convertir la imagen t2f a un tensor de PyTorch con la forma (1, H, W)
        t2f_image = torch.tensor(t2f_image, dtype=torch.float32).unsqueeze(0)  # Añadir la dimensión de canales

        # Convertir la máscara a un tensor de PyTorch con la forma (1, H, W) y tipo float
        seg_image = torch.tensor(seg_image, dtype=torch.float32).unsqueeze(0)  # Convertir a float

        # Aplicar transformaciones si están definidas
        if self.transform:
            t2f_image = self.transform(t2f_image)

        # Aplicar la transformación de la máscara si está definida
        if self.mask_transform:
            seg_image = self.mask_transform(seg_image)

        return t2f_image, seg_image



In [None]:
class BRATSDataset(Dataset):
    def __init__(self, base_path, transform=None, year=2023):
        self.base_path = base_path
        self.folders = [f for f in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, f))]
        self.transform = transform
        self.mask_transform = transforms.Resize((256, 256))  # Transformación para la máscara

    def __len__(self):
        return len(self.folders)

    def __getitem__(self, idx):
        folder_path = os.path.join(self.base_path, self.folders[idx])
        files = os.listdir(folder_path)

        # Inicializar variables para la imagen t2f y la segmentación
        t2f_file, seg_file = None, None

        for nifty_file in files:
            file_path = os.path.join(folder_path, nifty_file)
            if nifty_file.endswith('.nii'):  # for 2023 dataset
                if 't2f' in nifty_file:
                    t2f_file = nib.load(file_path).get_fdata()
                elif 'seg' in nifty_file:
                    seg_file = nib.load(file_path).get_fdata()

            if nifty_file.endswith('.nii.gz'):  # for 2023 dataset
                if 't2' in nifty_file:
                    t2f_file = nib.load(file_path).get_fdata()
                elif 'seg' in nifty_file:
                    seg_file = nib.load(file_path).get_fdata()

        if t2f_file is None or seg_file is None:
            print(f"Advertencia: Archivo t2f o segmentación faltante en {folder_path}.")
            return self.__getitem__((idx + 1) % len(self))

        # Seleccionar la imagen de la mitad del volumen
        t2f_image = t2f_file[:, :, t2f_file.shape[2] // 2]
        #t2f_image = torch.tensor(t2f_image, dtype=torch.float32).unsqueeze(0)

        # Segmentación
        seg_image = seg_file[:, :, seg_file.shape[2] // 2]
        seg_image = (seg_image > 0).astype(int)  # Convierte a 1 donde la intensidad es mayor que 0
        #seg_image = torch.tensor(seg_image, dtype=torch.long)

        if self.transform:
            t2f_image = self.transform(t2f_image)
            seg_image = self.transform(seg_image)
        # Convertir la máscara a PIL y aplicar la transformación de redimensionado
        seg_image = seg_image#.unsqueeze(0)  # Asegurarse de que la máscara tenga la forma correcta
        #seg_image = self.mask_transform(seg_imag  # Transformación para cambiar el tamaño
        #seg_image = torch.tensor(seg_image, dtype=torch.long)  # Convertir a tensor

        return t2f_image, seg_image
# Transformación opcional, si necesitas redimensionar o normalizar
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Cambia al tamaño deseado
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# La siguiente está en fase de pruebas. Ejemplo de usos:
 1️⃣  Ventana ±5 con solo T2
ds_t2 = BRATSMultiSliceDataset("/datos/BRATS", modalities="t2")

 2️⃣  Ventana ±3 tomando T1 contrastada y FLAIR juntos
ds_multi = BRATSMultiSliceDataset("/datos/BRATS",
                                  modalities=["t1ce", "flair"],
                                  window=3)

 3️⃣  Todos los slices que contengan tumor, modalidad FLAIR
ds_flair_tumor = BRATSMultiSliceDataset("/datos/BRATS",
                                        modalities="flair",
                                        only_tumor=True)

In [None]:
import os, numpy as np, nibabel as nib, torch
from torch.utils.data import Dataset

class BRATSMultiSliceDataset(Dataset):
    """
    Dataset BRATS que permite:
    • Elegir modalidades (t2, t1ce, flair…) -> 1 o N.
    • Extraer ventana ±window alrededor del slice con +tumor, 
      o todos los slices con tumor (only_tumor=True).
    ◇ Entrada  : tensor (C, H, W) con C = (#modalidades) * (2*window+1)
    ◇ Etiqueta : tensor (2*window+1, H, W)  ← ahora mismo nº cortes   # ⇠ CAMBIO
    """
    def __init__(
        self,
        base_path: str,
        modalities=("t2",),
        window: int = 5,
        only_tumor: bool = False,
        img_transform=None,
        mask_transform=None,
    ):
        super().__init__()
        if isinstance(modalities, str):
            modalities = (modalities,)
        self.modalities = modalities
        self.base_path  = base_path
        self.window     = window
        self.only_tumor = only_tumor
        self.img_tf     = img_transform
        self.mask_tf    = mask_transform

        # ------- indexar (carpeta, slice_idx) --------
        self.samples = []
        for case in sorted(os.listdir(base_path)):
            case_dir = os.path.join(base_path, case)
            if not os.path.isdir(case_dir):
                continue
            seg_path = self._find_file(case_dir, "seg")
            if seg_path is None:
                continue
            seg = nib.load(seg_path).get_fdata()

            if self.only_tumor:
                z_idx = np.where(seg.sum(axis=(0, 1)) > 0)[0]
            else:
                z_idx = [int(np.argmax(seg.sum(axis=(0, 1))))]

            self.samples.extend((case_dir, z) for z in z_idx)

    # ---------- utilidades ----------
    @staticmethod
    def _find_file(case_dir, keyword):
        for f in os.listdir(case_dir):
            if keyword in f and f.endswith((".nii", ".nii.gz")):
                return os.path.join(case_dir, f)
        return None

    def _load_vol(self, case_dir, keyword):
        path = self._find_file(case_dir, keyword)
        if path is None:
            raise FileNotFoundError(f"{keyword} no encontrado en {case_dir}")
        return nib.load(path).get_fdata()

    # ---------- API ----------
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        case_dir, z = self.samples[idx]

        vols = [self._load_vol(case_dir, m) for m in self.modalities]
        seg  = self._load_vol(case_dir, "seg")

        H, W, Z = vols[0].shape
        w = self.window
        z_min, z_max = max(0, z - w), min(Z, z + w + 1)
        num_cuts = 2 * w + 1

        # ---------- imágenes ----------
        crop = np.zeros((len(vols)*num_cuts, H, W), dtype=np.float32)
        for i, v in enumerate(vols):
            pad = np.zeros((H, W, num_cuts), dtype=v.dtype)
            pad[:, :, (z_min - (z - w)):(z_max - (z - w))] = v[:, :, z_min:z_max]
            crop[i*num_cuts:(i+1)*num_cuts] = pad.transpose(2, 0, 1)

        # ---------- máscaras (mismo padding) ----------
        mask_pad = np.zeros((H, W, num_cuts), dtype=np.float32)                  # ⇠ NUEVO
        mask_pad[:, :, (z_min - (z - w)):(z_max - (z - w))] = (seg[:, :, z_min:z_max] > 0)
        mask_crop = mask_pad.transpose(2, 0, 1)                                  # ⇠ NUEVO

        img_tensor  = torch.tensor(crop,      dtype=torch.float32)
        mask_tensor = torch.tensor(mask_crop, dtype=torch.float32)               # ⇠ CAMBIO

        if self.img_tf:
            img_tensor = self.img_tf(img_tensor)
        if self.mask_tf:
            mask_tensor = self.mask_tf(mask_tensor)

        return img_tensor, mask_tensor

In [None]:
class BratsDataLoader(nn.Module, BRATSDataset_2):
    def __init__(self, base_path, transform=transform, batch_size=1, train_size=0.7, val_size=0.1):
        super(BratsDataLoader, self).__init__()
        self.base_path = base_path
        self.transform = transform
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size

    def get_train_test(self):
        # Crear el dataset y el DataLoader
        dataset = BRATSDataset_2(self.base_path, self.transform)
        
        train_size = int(self.train_size * len(dataset))  # 70% para entrenamiento
        val_size = int(self.val_size * len(dataset))  # 10% para validación
        test_size = len(dataset) - train_size - val_size  # 20% para prueba

        # Dividir el dataset en entrenamiento, validación y prueba
        train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

        # Crear los DataLoaders para cada conjunto
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
        
        print(len(train_loader))
        print(len(test_loader))

        
        # Regresar los tres DataLoaders
        return train_loader, val_loader, test_loader
