In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.utils.data import Dataset, DataLoader

In [18]:
# Create the train, val, and test dataloaders
class MRNetDataset(Dataset):
    def __init__(self, data_dir,csv_file,  split='train', plane="axial", transform=None):
        self.data_dir = data_dir
        self.split = split
        self.plane = plane
        df = pd.read_csv(os.path.join(data_dir, csv_file), header=None, names=['id', 'label'])
        self.ids = df['id'].values
        self.labels = df['label'].values
        self.transform = transform

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        id = self.ids[idx]
        # n x 256 x 256 [depth of the mri, height, width]
        volume = np.load(os.path.join(self.data_dir+f"/{self.split}", self.plane, f"{id}.npy"))
        # standardize the volume
        volume = (volume - volume.mean()) /volume.std()
        volume = torch.from_numpy(volume, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return volume, label

In [None]:
data_root = '../data'
batch_size = 16
lr = 1e-3
num_epochs = 50

train_dataset = MRNetDataset(data_dir=data_root, csv_file='train-acl.csv', split='train', plane="axial")
val_dataset = MRNetDataset(data_dir=data_root, csv_file='valid-acl.csv', split='val', plane="axial")

In [21]:
n = len(val_dataset)
n

120

In [27]:
test_size = int(0.4 * n)
val_size = n - test_size

generator = torch.Generator().manual_seed(8)
val_ds, test_ds = torch.utils.data.random_split(val_dataset, [val_size, test_size], generator=generator)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)  
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)  


In [None]:
class SqueezeModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SqueezeModule, self).__init__()
        self.squeeze = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        self.conv1 = nn.Conv3d(out_channels, out_channels//2, kernel_size=1)
        self.conv3 = nn.Conv3d(out_channels, out_channels//2, kernel_size=3, padding=9)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.squeeze(x))
        x1 = self.relu(self.conv1(x))
        x2 = self.relu(self.conv3(x))
        return torch.cat([x1, x2], dim=1)

In [None]:
class AttentionModule(nn.Module):
    def __init__(self, lower_channels, high_channels, intermediate_channels ):
        super().__init__()
        self.win = nn.Conv3d(lower_channels+high_channels,intermediate_channels, kernel_size=1)
        self.wout = nn.Conv3d(intermediate_channels, 1, kernel_size=1)

    def forward(self, low_features, high_features):
        x = torch.cat([low_features, high_features], dim=1)
        x = self.win(x)
        x = F.relu(x)
        x = self.wout(x) # low feat, high feat,  D, H, W
        B, _, D, H, W = x.shape
        a = x.view(B, -1)
        return F.softmax(a, dim=1).view(B, 1, D, H, W)

In [None]:
class ACL3DModel(nn.Module):
    def __init__(self, ):
        super().__init__()
        pass
    def forward(self, x):
        pass
