In [None]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd

IMAGES_TO_LOAD = 10

def preprocess_nifti_image(img, max_dim = 100):
    height, width, depth = img.shape
    start = max(0, depth // 2 - max_dim // 2)
    end = min(depth, depth // 2 + max_dim // 2)
    new_img = np.zeros((height, width, max_dim), dtype=img.dtype)
    # Avoid odd dimensions during the redimension
    if (end - start) % 2 != 0:
        start += 1
    new_img[:, :, (max_dim // 2 - (end - start)//2):(max_dim // 2 + (end - start)//2)] = img[:, :, start:end]

    return new_img

def reshape_input(volume):
    # Change from [100, 512, 512, 1] to [T, H, W, C]
    reshaped = np.moveaxis(volume, -1, 0)
    reshaped = np.expand_dims(reshaped, -1) 
    return reshaped

def load_nii_images(folder, labels_file):
    preprocessed_images = []
    labels = []
    labels_df = pd.read_csv(labels_file)
    for i, file_name in enumerate(os.listdir(folder)):
        if i >= IMAGES_TO_LOAD:
            break
        if file_name.endswith('.nii.gz'):
            id = file_name.split("/")[-1].split(".")[0]
            
            file_path = os.path.join(folder, file_name)
            image_nifti = nib.load(file_path)
            image_array = image_nifti.get_fdata()
            preprocessed_image = preprocess_nifti_image(image_array)
            preprocessed_image_reshaped = reshape_input(preprocessed_image)
            preprocessed_images.append(
                preprocessed_image_reshaped
            )
            
            labels.append(
                int(
                    labels_df.loc[
                        labels_df.id == id,
                        "label"
                    ].iloc[
                        0
                    ]
                )
            )
    images_tensor = torch.tensor(preprocessed_images, dtype = torch.float32)
    # Let's repeat the input to adapt it to the 3 required channels
    images_tensor = images_tensor.repeat(1, 1, 1, 1, 3) 
    # Permutate the input to match the dimensions in the model
    images_tensor = images_tensor.permute(0, 4, 1, 2, 3)
    labels_tensor = torch.tensor(labels, dtype=torch.int64)
    return {
        "images": images_tensor, 
        "labels": labels_tensor
    }
# Load the images and reshape if necessary
dict_images = load_nii_images('../aocr2024/1_Train,Valid_Image/', "../aocr2024/TrainValid_ground_truth.csv")

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np

class ModifiedR3D18(nn.Module):
    def __init__(self):
        super(ModifiedR3D18, self).__init__()
        self.r3d_18 = models.video.r3d_18(pretrained=True)
        num_ftrs = self.r3d_18.fc.in_features
        self.r3d_18.fc = nn.Sequential(
            nn.Linear(num_ftrs, 1), 
            nn.Sigmoid()            
        )

    def forward(self, x):
        return self.r3d_18(x)

model = ModifiedR3D18()
model.train()  

In [None]:
dataset = TensorDataset(dict_images["images"], dict_images["labels"])
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
num_epochs = 10
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):  
    for data in dataloader:    
        inputs, labels = data    

        optimizer.zero_grad()    
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


In [None]:
dict_images["images"].shape