<a href="https://colab.research.google.com/github/dustint121/bd4h-final-project/blob/main/BD4H_Final_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
import os
from google.colab import drive

# View current working directory
print("Current Working Directory:", os.getcwd())

# Mount Google Drive
drive.mount('/content/gdrive')
# Change working directory to your file position
path = "/content/gdrive/My Drive/bd4h-final-project-data/"
os.chdir(path)

# Confirm the change
print("Working Directory:", os.getcwd())

Current Working Directory: /content/gdrive/MyDrive/bd4h-final-project-data
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
Working Directory: /content/gdrive/My Drive/bd4h-final-project-data


In [2]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

from tqdm import tqdm

# 2D UNet Model


In [8]:
for i in range(131):
  #check if file exists
  if not os.path.exists(f"/content/gdrive/My Drive/bd4h-final-project-data/LITS17/{i}.npz"):
    print(f"{i}.npz does not exist")
  # else:
  # np.load(f"/content/gdrive/My Drive/bd4h-final-project-data/LITS17/{i}.npz")

##Classes and Util functions for 2D UNet Model

In [10]:

"""
MedicalSliceDataset_2D converts 3D arrays(volumes and segmentations) into lists of 2D arrays

In model training/testing:
    self.volumes = inputs
    self.segmentations = targets
"""
class MedicalSliceDataset_2D(Dataset):
    def __init__(self, dataset_path, file_indices, max_slices=512):
        self.volumes = [] #will be a list of 2D arrays
        self.segmentations = [] #will be a list of 2D arrays


        for count, i in enumerate(file_indices):
            print(f"Loading file {count+1}/{len(file_indices)}")
            data = np.load(f"{dataset_path}/{i}.npz")
            volume, seg = None, None
            if "LITS17" in dataset_path:
                volume = data["volume"]
                seg = data["segmentation"]
            elif "KITS19" in dataset_path:
                volume = np.transpose(data["volume"], (1, 2, 0))
                seg = np.transpose(data["segmentation"], (1, 2, 0))

            # Per-volume min-max normalization (prevents data leakage)
            # volume = (volume - volume.min()) / (volume.max() - volume.min() + 1e-8)

            #Using z-score
            volume = (volume - volume.mean()) / (volume.std() + 1e-8)

            # Add all slices as individual samples
            for d in range(volume.shape[2]): #for each layer of "depth"
                self.volumes.append(volume[..., d]) #add the "height x weight" 2D data of the layer to the self.volume
                self.segmentations.append(seg[..., d])

    def __len__(self):
        return len(self.volumes)

    def __getitem__(self, idx):
        # Remove channel dim from masks & convert to int64
        slice = torch.FloatTensor(self.volumes[idx][None, ...])  # (1,512,512)
        mask = torch.LongTensor(self.segmentations[idx])  # (512,512) NOT (1,512,512)
        return slice, mask


"""
DoubleConv:
Strenghts:
Deeper Feature Extraction: Two convolutional layers instead of one (like in ConvNet),
                            enabling richer hierarchical feature learning.
Instance Normalization: Better for medical imaging with small batch sizes compared to batch norm.
Leaky ReLU: Avoids "dead neurons" by allowing small gradients for negative inputs (unlike basic ReLU in ConvNet).

Weaknesses:
Higher computational cost due to dual convolutions.
Requires more memory for intermediate feature maps.
"""
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet2D(nn.Module):
    def __init__(self, n_channels=1, n_classes=3):
        super().__init__()
        self.inc = DoubleConv(n_channels, 32)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(32, 64))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(256, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv2 = DoubleConv(128, 64)
        self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv3 = DoubleConv(64, 32)
        self.outc = nn.Conv2d(32, n_classes, kernel_size=1)  # Output 3 channels

    def forward(self, x):
        # Input shape: (batch_size, channels, height, width)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        x = self.up1(x4)
        x = torch.cat([x, x3], dim=1)
        x = self.conv1(x)
        x = self.up2(x)
        x = torch.cat([x, x2], dim=1)
        x = self.conv2(x)
        x = self.up3(x)
        x = torch.cat([x, x1], dim=1)
        x = self.conv3(x)
        return self.outc(x)

In [11]:
def dice_score(pred, target):
    smooth = 1e-6
    pred = F.softmax(pred, dim=1)
    scores = []
    # Calculate Dice for each class (background, organ, tumor)
    for class_idx in range(3):
        pred_mask = pred[:, class_idx]  # Get probability maps for this class
        target_mask = (target == class_idx).float()

        intersection = (pred_mask * target_mask).sum()
        union = pred_mask.sum() + target_mask.sum()
        scores.append((2. * intersection + smooth) / (union + smooth))

    return torch.mean(torch.stack(scores))  # Return mean Dice across classes


# Additional helper function for class-wise Dice reporting
def dice_score_per_class(pred, target):
    smooth = 1e-6
    pred = F.softmax(pred, dim=1).argmax(1)
    scores = []
    for class_idx in range(3):
        pred_mask = (pred == class_idx).float()
        target_mask = (target == class_idx).float()
        intersection = (pred_mask * target_mask).sum()
        union = pred_mask.sum() + target_mask.sum()
        scores.append((2.*intersection + smooth)/(union + smooth))
    return scores


def hybrid_loss(pred, target):
    # Cross Entropy
    ce = F.cross_entropy(pred, target)

    # Dice Loss (1 - mean Dice score)
    dice_loss = 1 - dice_score(pred, target)

    return ce + dice_loss  # Paper's λ=1 for both terms

## Model Testing Here

In [12]:
#Lits Dataset wrapper

# volume_indices = list(range(131))  # All LITS17 volumes
volume_indices = list(range(5)) #test small subset

# data files with irregular data
# lits_17_list = [0, 1, 4, 16, 17, 18, 22, 35, 37, 38, 48, 50, 52, 53, 54, 55, 57, 63, 65, 68, 69, 70, 71, 72, 74, 76, 77,
#                 78, 80, 81, 82, 87, 88, 89, 90, 91, 92, 93, 95, 104, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 117]

# kits_19_list = [15, 18, 19, 23, 25, 31, 32, 40, 43, 45, 48, 50, 61, 64, 65, 66, 81, 85, 86, 94, 97, 99, 107, 109, 111, 117,
#                 121, 123, 124, 128, 131, 133, 150, 163, 166, 167, 168, 169, 172, 180, 185, 191, 192, 193, 194, 199, 202]

# for i in lits_17_list:
#   volume_indices.remove(i)

train_volumes, test_volumes = train_test_split(volume_indices, test_size=0.2, random_state=42)

lits_training_dataset = MedicalSliceDataset_2D("/content/gdrive/My Drive/bd4h-final-project-data/LITS17", file_indices = train_volumes)
lits_testing_dataset = MedicalSliceDataset_2D("/content/gdrive/My Drive/bd4h-final-project-data/LITS17", file_indices = test_volumes)

Loading file 1/4


  arrmean = umr_sum(arr, axis, dtype, keepdims=True, where=where)


Loading file 2/4
Loading file 3/4
Loading file 4/4
Loading file 1/1


In [18]:
#20 minutes

data_path = "/content/gdrive/My Drive/bd4h-final-project-data/"


train_loader = DataLoader(lits_training_dataset, batch_size=16, shuffle=True,
                                                  pin_memory=True, num_workers=12, persistent_workers=True)

test_loader = DataLoader(lits_testing_dataset, batch_size=16, shuffle=True, pin_memory=True, num_workers=12, persistent_workers=True)


# train_loader = DataLoader(lits_training_dataset, batch_size=16, shuffle=True, pin_memory=True)

# test_loader = DataLoader(lits_testing_dataset, batch_size=16, shuffle=True, pin_memory=True)


In [None]:
# model = UNet2D(n_channels=1, n_classes=3)
# criterion = hybrid_loss
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# #use gpu if available
# model.to(device)

# # Training Loop with Class-wise Dice Reporting
# best_dice = 0
# for epoch in range(500):
#     # Training
#     model.train()
#     print(len(train_loader))
#     for inputs, targets in train_loader:
#         inputs = inputs.to(device)  # Add this
#         targets = targets.to(device).squeeze(1)  # Add this
#         # print("training")
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = hybrid_loss(outputs, targets.squeeze(1))  # Remove channel dim from targets
#         loss.backward()
#         optimizer.step()

#     # Validation
#     model.eval()
#     test_dice = 0
#     class_dice = [0, 0, 0]  # [Background, Organ, Tumor]


#     with torch.no_grad():
#         for inputs, targets in test_loader:
#             inputs = inputs.to(device)  # Add this
#             targets = targets.to(device).squeeze(1)  # Add this
#             outputs = model(inputs)
#             batch_scores = dice_score_per_class(outputs, targets.squeeze(1))
#             for i in range(3):
#                 class_dice[i] += batch_scores[i]

#     # Calculate average Dice per class
#     avg_dice = [c/len(test_loader) for c in class_dice]
#     print(f"Epoch {epoch+1}")
#     print(f"Liver Dice: {avg_dice[1]:.4f}, Tumor Dice: {avg_dice[2]:.4f}, Mean Dice: {torch.mean(torch.tensor(avg_dice)):.4f}")

#     # Save best model based on tumor Dice (most clinically relevant)
#     if avg_dice[2] > best_dice:
#         best_dice = avg_dice[2]
#         torch.save(model.state_dict(), "best_model.pth")

# print(f"\nBest Tumor Dice: {best_dice*100:.2f}%")

In [None]:
model = UNet2D(n_channels=1, n_classes=3)
criterion = hybrid_loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99)


#use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

best_dice = 0


#10 minutes per epoch for full dataset
for epoch in range(500):
    print(f"\nEpoch {epoch+1}/500")
    model.train()
    train_losses = []

    # Training with tqdm progress bar
    with tqdm(train_loader, desc="Training", leave=False) as pbar:
        for inputs, targets in pbar:
            inputs = inputs.to(device)
            targets = targets.to(device).squeeze(1)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = hybrid_loss(outputs, targets)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    print(f"  Train loss: {sum(train_losses)/len(train_losses):.4f}")

    # Validation with tqdm progress bar
    model.eval()
    class_dice = [0.0, 0.0, 0.0]
    with torch.no_grad():
        with tqdm(test_loader, desc="Validation", leave=False) as pbar:
            for inputs, targets in pbar:
                inputs = inputs.to(device)
                targets = targets.to(device).squeeze(1)
                outputs = model(inputs)
                batch_scores = dice_score_per_class(outputs, targets)
                for i in range(3):
                    class_dice[i] += batch_scores[i]

    avg_dice = [c/len(test_loader) for c in class_dice]
    mean_dice = sum(avg_dice) / len(avg_dice)
    print(f"  Dice - Background: {avg_dice[0]:.4f}, Organ: {avg_dice[1]:.4f}, Tumor: {avg_dice[2]:.4f}, Mean: {mean_dice:.4f}")

    # Save best model based on tumor Dice (class 2)
    if avg_dice[2] > best_dice:
        best_dice = avg_dice[2]
        torch.save(model.state_dict(), "best_model.pth")

print(f"\nBest Tumor Dice: {best_dice*100:.2f}%")


# 3D UNet Model

In [16]:
import torch
torch.cuda.empty_cache()

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.memory_allocated(device=device))

0


In [34]:
free_memory, total_memory = torch.cuda.mem_get_info(device=device)
print(f"Free GPU memory: {free_memory} bytes")
print(f"Total GPU memory: {total_memory} bytes")

Free GPU memory: 9306112 bytes
Total GPU memory: 42474471424 bytes


## Classes and Util functions for 3D UNet Model

In [3]:

class MedicalVolumeDataset_3D(Dataset):
    def __init__(self, dataset_path, file_indices):
        self.volumes = []
        self.segmentations = []

        for i in file_indices:
            data = np.load(f"{dataset_path}/{i}.npz")

            if "LITS17" in dataset_path:
                volume = data["volume"]  # (H, W, D)
                seg = data["segmentation"]
                # Reorder to (D, H, W)
                volume = np.transpose(volume, (2, 0, 1)).astype(np.int32)
                seg = np.transpose(seg, (2, 0, 1))
            elif "KITS19" in dataset_path:
                volume = data["volume"].astype(np.int32)  # (D, H, W)
                seg = data["segmentation"]

            # Z-score normalization (paper Section 4.1)
            volume = self._normalize(volume, seg)

            self.volumes.append(volume[None, ...])  # Add channel dim
            self.segmentations.append(seg)

    def _normalize(self, volume, seg):
        """Paper-style normalization using foreground voxels"""
        mask = seg > 0
        if mask.sum() == 0:  # Handle empty masks
            return volume
        mean = volume[mask].mean()
        std = volume[mask].std()
        return (volume - mean) / (std + 1e-8)

    def __len__(self):
        return len(self.volumes)

    def __getitem__(self, idx):
        volume = torch.FloatTensor(self.volumes[idx])  # (1, D, H, W)
        seg = torch.LongTensor(self.segmentations[idx])  # (D, H, W)

        # Pad all spatial dims to multiples of 8
        d, h, w = volume.shape[1], volume.shape[2], volume.shape[3]
        pad_d = (8 - (d % 8)) % 8
        pad_h = (8 - (h % 8)) % 8
        pad_w = (8 - (w % 8)) % 8

        volume_padded = F.pad(volume, (0, pad_w, 0, pad_h, 0, pad_d, 0, 0))  # (1, D+pd, H+ph, W+pw)
        seg_padded = F.pad(seg, (0, pad_w, 0, pad_h, 0, pad_d))  # (D+pd, H+ph, W+pw)
        return volume_padded, seg_padded








class DoubleConv3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)







class UNet3D(nn.Module):
    def __init__(self, n_channels=1, n_classes=3):
        super().__init__()
        # Encoder
        self.inc = DoubleConv3D(n_channels, 32)
        self.down1 = nn.Sequential(nn.MaxPool3d(2), DoubleConv3D(32, 64))
        self.down2 = nn.Sequential(nn.MaxPool3d(2), DoubleConv3D(64, 128))
        self.down3 = nn.Sequential(nn.MaxPool3d(2), DoubleConv3D(128, 256))


        # Decoder (Fixed with output_padding)
        self.up1 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.conv1 = DoubleConv3D(256, 128)  # 128 + 128 = 256
        self.up2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.conv2 = DoubleConv3D(128, 64)   # 64 + 64 = 128
        self.up3 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.conv3 = DoubleConv3D(64, 32)    # 32 + 32 = 64

        self.outc = nn.Conv3d(32, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        print("here")
        x1 = self.inc(x)    # (B,32,D,H,W)
        x2 = self.down1(x1) # (B,64,D/2,H/2,W/2)
        x3 = self.down2(x2) # (B,128,D/4,H/4,W/4)
        x4 = self.down3(x3) # (B,256,D/8,H/8,W/8)

        # Decoder
        x = self.up1(x4)    # (B,128,D/4,H/4,W/4)
        x = torch.cat([x, x3], dim=1)  # (B,256,D/4,H/4,W/4)
        x = self.conv1(x)   # (B,128,D/4,H/4,W/4)

        x = self.up2(x)     # (B,64,D/2,H/2,W/2)
        x = torch.cat([x, x2], dim=1)  # (B,128,D/2,H/2,W/2)
        x = self.conv2(x)   # (B,64,D/2,H/2,W/2)

        x = self.up3(x)     # (B,32,D,H,W)
        x = torch.cat([x, x1], dim=1)  # (B,64,D,H,W)
        x = self.conv3(x)   # (B,32,D,H,W)

        return self.outc(x)  # (B,n_classes,D,H,W)

In [4]:
# Hybrid Loss (CE + Dice)
def hybrid_loss(pred, target):
    ce = F.cross_entropy(pred, target)
    pred_prob = F.softmax(pred, dim=1)
    dice = 1 - dice_score_3d(pred_prob, target)
    return ce + dice

# 3D Dice Calculation
def dice_score_3d(pred, target):
    smooth = 1e-6
    pred = pred.argmax(1)
    scores = []
    for class_idx in range(3):
        pred_mask = (pred == class_idx).float()
        target_mask = (target == class_idx).float()
        intersection = (pred_mask * target_mask).sum()
        union = pred_mask.sum() + target_mask.sum()
        scores.append((2.*intersection + smooth)/(union + smooth))
    return torch.mean(torch.tensor(scores))

## Model Testing Here

In [5]:
# volume_indices = list(range(131))  # All LITS17 volumes
volume_indices = list(range(2)) #test small subset

# data files with irregular data
# lits_17_list = [0, 1, 4, 16, 17, 18, 22, 35, 37, 38, 48, 50, 52, 53, 54, 55, 57, 63, 65, 68, 69, 70, 71, 72, 74, 76, 77,
#                 78, 80, 81, 82, 87, 88, 89, 90, 91, 92, 93, 95, 104, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 117]

# kits_19_list = [15, 18, 19, 23, 25, 31, 32, 40, 43, 45, 48, 50, 61, 64, 65, 66, 81, 85, 86, 94, 97, 99, 107, 109, 111, 117,
#                 121, 123, 124, 128, 131, 133, 150, 163, 166, 167, 168, 169, 172, 180, 185, 191, 192, 193, 194, 199, 202]

# for i in lits_17_list:
#   volume_indices.remove(i)

train_volumes, test_volumes = train_test_split(volume_indices, test_size=0.2, random_state=42)

training_dataset = MedicalVolumeDataset_3D("/content/gdrive/My Drive/bd4h-final-project-data/LITS17", file_indices = train_volumes)
testing_dataset = MedicalVolumeDataset_3D("/content/gdrive/My Drive/bd4h-final-project-data/LITS17", file_indices = test_volumes)

In [7]:

train_loader = DataLoader(
    training_dataset,
    batch_size=2,  # Matches paper's LiTS batch size
    shuffle=True
)

# test_loader = DataLoader(
#     test_dataset,
#     batch_size=2,  # Matches paper's LiTS batch size
#     shuffle=True,
#     num_workers=4,
#     pin_memory=True
# )
test_loader = DataLoader(
    testing_dataset,
    batch_size=2,  # Matches paper's LiTS batch size
    shuffle=True
)

In [None]:

model = UNet3D(n_channels=1, n_classes=3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)



best_dice = 0



for epoch in range(500):
    print(f"\nEpoch {epoch+1}/500")
    model.train()
    train_losses = []

    # Training with tqdm progress bar
    with tqdm(train_loader, desc="Training", leave=False) as pbar:
        for inputs, targets in pbar:
            inputs = inputs.to(device)
            targets = targets.to(device).squeeze(1)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = hybrid_loss(outputs, targets)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    print(f"  Train loss: {sum(train_losses)/len(train_losses):.4f}")

    # Validation with tqdm progress bar
    model.eval()
    class_dice = [0.0, 0.0, 0.0]
    with torch.no_grad():
        with tqdm(test_loader, desc="Validation", leave=False) as pbar:
            for inputs, targets in pbar:
                inputs = inputs.to(device)
                targets = targets.to(device).squeeze(1)
                outputs = model(inputs)
                batch_scores = dice_score_3d(outputs, targets)
                for i in range(3):
                    class_dice[i] += batch_scores[i]

    avg_dice = [c/len(test_loader) for c in class_dice]
    mean_dice = sum(avg_dice) / len(avg_dice)
    print(f"  Dice - Background: {avg_dice[0]:.4f}, Organ: {avg_dice[1]:.4f}, Tumor: {avg_dice[2]:.4f}, Mean: {mean_dice:.4f}")

    # Save best model based on tumor Dice (class 2)
    if avg_dice[2] > best_dice:
        best_dice = avg_dice[2]
        torch.save(model.state_dict(), "best_model.pth")

print(f"\nBest Tumor Dice: {best_dice*100:.2f}%")

# 2-5D UNet Model

## Classes and Util functions for 2-5D UNet Model

In [21]:

class MedicalVolumeDataset2_5D(Dataset):
    def __init__(self, dataset_path, file_indices, num_slices=3):
        self.num_slices = num_slices
        self.half_slices = num_slices // 2
        self.volumes = []
        self.segmentations = []

        for i in file_indices:
            data = np.load(f"{dataset_path}/{i}.npz")
            volume, seg = None, None
            # volume = np.transpose(data["volume"], (2, 0, 1))  # (D, H, W)
            # seg = np.transpose(data["segmentation"], (2, 0, 1))
            if "LITS17" in dataset_path:
                volume = data["volume"]  # (H, W, D)
                seg = data["segmentation"]
                # Reorder to (D, H, W)
                volume = np.transpose(volume, (2, 0, 1)).astype(np.int32)
                seg = np.transpose(seg, (2, 0, 1))
            elif "KITS19" in dataset_path:
                volume = data["volume"].astype(np.int32)  # (D, H, W)
                seg = data["segmentation"]

            # Normalize using foreground voxels
            volume = self._normalize(volume, seg)

            self.volumes.append(volume)
            self.segmentations.append(seg)

    def _normalize(self, volume, seg):
        mask = seg > 0
        if mask.sum() == 0:
            return volume
        mean = volume[mask].mean()
        std = volume[mask].std()
        return (volume - mean) / (std + 1e-8)

    def __len__(self):
        return len(self.volumes)

    def __getitem__(self, idx):
        volume = self.volumes[idx]  # (D, H, W)
        seg = self.segmentations[idx]

        # Create 2.5D slices with adjacent slices as channels
        slices = []
        masks = []
        for z in range(volume.shape[0]):
            # Handle edge cases by replicating first/last slice
            start = max(0, z - self.half_slices)
            end = min(volume.shape[0], z + self.half_slices + 1)
            slice_stack = []

            # Pad with replicated slices if needed
            while len(slice_stack) < self.num_slices:
                if start < 0:
                    slice_stack.append(volume[0])
                    start += 1
                elif end >= volume.shape[0]:
                    slice_stack.append(volume[-1])
                    end -= 1
                else:
                    slice_stack.append(volume[start])
                    start += 1

            slice_stack = np.stack(slice_stack)  # (C, H, W)
            slices.append(slice_stack)
            masks.append(seg[z])  # Current slice mask

        volume_2_5d = np.stack(slices)  # (D, C, H, W)
        masks = np.stack(masks)         # (D, H, W)

        return torch.FloatTensor(volume_2_5d), torch.LongTensor(masks)

class DoubleConv2D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet2_5D(nn.Module):
    def __init__(self, in_channels=3, n_classes=3):
        super().__init__()
        # Encoder
        self.inc = DoubleConv2D(in_channels, 32)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv2D(32, 64))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv2D(64, 128))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv2D(128, 256))

        # Decoder
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv1 = DoubleConv2D(256, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv2 = DoubleConv2D(128, 64)
        self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.conv3 = DoubleConv2D(64, 32)

        self.outc = nn.Conv2d(32, n_classes, kernel_size=1)

    def forward(self, x):
        # x shape: (B, C, H, W)
        x1 = self.inc(x)    # (B,32,H,W)
        x2 = self.down1(x1) # (B,64,H/2,W/2)
        x3 = self.down2(x2) # (B,128,H/4,W/4)
        x4 = self.down3(x3) # (B,256,H/8,W/8)

        x = self.up1(x4)    # (B,128,H/4,W/4)
        x = torch.cat([x, x3], dim=1) # (B,256,H/4,W/4)
        x = self.conv1(x)   # (B,128,H/4,W/4)

        x = self.up2(x)     # (B,64,H/2,W/2)
        x = torch.cat([x, x2], dim=1) # (B,128,H/2,W/2)
        x = self.conv2(x)   # (B,64,H/2,W/2)

        x = self.up3(x)     # (B,32,H,W)
        x = torch.cat([x, x1], dim=1) # (B,64,H,W)
        x = self.conv3(x)   # (B,32,H,W)

        return self.outc(x) # (B,n_classes,H,W)



In [22]:
def hybrid_loss(pred, target):
    ce = F.cross_entropy(pred, target)
    pred_prob = F.softmax(pred, dim=1)
    dice = 1 - dice_score_2d(pred_prob, target)
    return ce + dice

def dice_score_2d(pred, target):
    smooth = 1e-6
    pred = pred.argmax(1)
    scores = []
    for class_idx in range(3):
        pred_mask = (pred == class_idx).float()
        target_mask = (target == class_idx).float()
        intersection = (pred_mask * target_mask).sum()
        union = pred_mask.sum() + target_mask.sum()
        scores.append((2.*intersection + smooth)/(union + smooth))
    return scores  # Return list of scores instead of mean

## Model Testing Here

In [33]:
# Training setup
volume_indices = list(range(2)) #test small subset


train_volumes, test_volumes = train_test_split(volume_indices, test_size=0.2, random_state=42)

train_dataset = MedicalVolumeDataset2_5D("/content/gdrive/My Drive/bd4h-final-project-data/LITS17", file_indices = train_volumes)
test_dataset = MedicalVolumeDataset2_5D("/content/gdrive/My Drive/bd4h-final-project-data/LITS17", file_indices = test_volumes)

In [34]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True)

In [31]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet2_5D(in_channels=3, n_classes=3).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99)

for epoch in range(500):
    model.train()
    epoch_loss = 0.0

    # Training with progress bar
    with tqdm(train_loader, desc=f"Epoch {epoch+1}/500 [Train]", leave=False) as pbar:
        for batch_idx, (data, targets) in enumerate(pbar):
            # Flatten slices into batch dimension
            batch_size, num_slices, channels, height, width = data.shape
            data = data.view(-1, channels, height, width).to(device)
            targets = targets.view(-1, height, width).to(device)

            optimizer.zero_grad()
            outputs = model(data)
            loss = hybrid_loss(outputs, targets)
            loss.backward()
            optimizer.step()

            # Update progress bar
            epoch_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    # Validation with progress bar
    model.eval()
    with torch.no_grad():
        class_dice = [0.0, 0.0, 0.0]  # [background, organ, tumor]
        with tqdm(test_loader, desc=f"Epoch {epoch+1}/500 [Val]", leave=False) as pbar:
            for data, targets in pbar:
                data = data.view(-1, 3, height, width).to(device)
                targets = targets.view(-1, height, width).to(device)
                outputs = model(data)

                batch_scores = dice_score_2d(outputs, targets)
                for i in range(3):
                    class_dice[i] += batch_scores[i]

                pbar.set_postfix({
                    "bg": f"{batch_scores[0]:.2f}",
                    "organ": f"{batch_scores[1]:.2f}",
                    "tumor": f"{batch_scores[2]:.2f}"
                })

        # Calculate averages
        avg_dice = [c/len(test_loader) for c in class_dice]
        mean_dice = sum(avg_dice) / len(avg_dice)
        print(f"Epoch {epoch+1} \t Loss: {epoch_loss/len(train_loader):.4f}")
        print(f"  Dice - Background: {avg_dice[0]:.4f}, Organ: {avg_dice[1]:.4f}, Tumor: {avg_dice[2]:.4f}, Mean: {mean_dice:.4f}")



OutOfMemoryError: CUDA out of memory. Tried to allocate 4.69 GiB. GPU 0 has a total capacity of 39.56 GiB of which 3.48 GiB is free. Process 43262 has 36.07 GiB memory in use. Of the allocated memory 33.98 GiB is allocated by PyTorch, and 1.60 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)