In [None]:
import os
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd

IMAGES_TO_LOAD = 10

def preprocess_nifti_image(img, max_dim = 100):
    height, width, depth = img.shape
    start = max(0, depth // 2 - max_dim // 2)
    end = min(depth, depth // 2 + max_dim // 2)
    new_img = np.zeros((height, width, max_dim), dtype=img.dtype)
    # Avoid odd dimensions during the redimension
    if (end - start) % 2 != 0:
        start += 1
    new_img[:, :, (max_dim // 2 - (end - start)//2):(max_dim // 2 + (end - start)//2)] = img[:, :, start:end]

    return new_img

def load_nii_images(folder, labels_file):
    preprocessed_images = []
    labels = []
    labels_df = pd.read_csv(labels_file)
    for i, file_name in enumerate(os.listdir(folder)):
        if i >= IMAGES_TO_LOAD:
            break
        if file_name.endswith('.nii.gz'):
            id = file_name.split("/")[-1].split(".")[0]
            
            file_path = os.path.join(folder, file_name)
            image_nifti = nib.load(file_path)
            image_array = image_nifti.get_fdata()
            preprocessed_image = preprocess_nifti_image(image_array)
            preprocessed_images.append(
                preprocessed_image
            )
            
            labels.append(
                int(
                    labels_df.loc[
                        labels_df.id == id,
                        "label"
                    ].iloc[
                        0
                    ]
                )
            )
    images_tensor = torch.tensor(preprocessed_images, dtype = torch.float32)
    labels_tensor = torch.tensor(labels, dtype=torch.int64)
    return {
        "images": images_tensor, 
        "labels": labels_tensor
    }
# Load the images and reshape if necessary
dict_images = load_nii_images('../aocr2024/1_Train,Valid_Image/', "../aocr2024/TrainValid_ground_truth.csv")

In [50]:
import torch
import torch.nn as nn

class ModeloC3D(nn.Module):
    def __init__(self, num_frames, num_classes):
        super(ModeloC3D, self).__init__()
        self.num_frames = num_frames
        self.num_classes = num_classes

        self.conv1 = nn.Conv3d(1, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))

        self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.fc6 = nn.Linear(512 * 2 * 2 * 2, 4096)
        self.fc7 = nn.Linear(4096, 4096)
        self.fc8 = nn.Linear(4096, num_classes)

    def forward(self, x):
        if x.size(2) < self.num_frames:
            x = torch.cat([x] * (self.num_frames // x.size(2)), dim=2)
        batches, h, w, c = x.shape 
        x = x.reshape(
            (
                batches, 
                1,
                h, 
                w, 
                c
            )
        )
        x = x[:, :, :, :, :self.num_frames]
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3b(torch.relu(self.conv3a(x)))))
        x = self.pool4(torch.relu(self.conv4b(torch.relu(self.conv4a(x)))))
        x = self.pool5(torch.relu(self.conv5b(torch.relu(self.conv5a(x)))))

        x = x.view(-1, 512 * 2 * 2 * 2)
        x = torch.relu(self.fc6(x))
        x = torch.relu(self.fc7(x))
        x = self.fc8(x)

        return x

modelo = ModeloC3D(num_frames=100, num_classes=2)


In [51]:
dataset = TensorDataset(dict_images["images"], dict_images["labels"])
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = ModeloC3D(num_frames=100, num_classes=2)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(modelo.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, labels = batch
        optimizer.zero_grad()
        outputs = modelo(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}] - Loss: {loss.item():.4f}')

torch.Size([10, 512, 512, 100])
torch.Size([10, 1, 512, 512, 100])
