In [None]:
import os
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
import torchvision.models as models

batch_size = 16
device = "cpu"

class ModifiedR3D18(nn.Module):
    def __init__(self):
        super(ModifiedR3D18, self).__init__()
        self.r3d_18 = models.video.r3d_18(
            weights=models.video.R3D_18_Weights.DEFAULT
        )
        num_ftrs = self.r3d_18.fc.in_features
        self.r3d_18.fc = nn.Sequential(
            nn.Linear(num_ftrs, 1)         
        )

    def forward(self, x):
        return self.r3d_18(x)

class LazyImageDataset(Dataset):
    def __init__(self, image_dir, labels_file, validation_file, split = "Train"):
        self.image_dir = image_dir
        self.__type_of_data = split
        self.__df_validation = pd.read_csv(
            validation_file,
        )
        self.df_labels = pd.read_csv(
            labels_file
        )
        self.image_filenames = set([filename.split(".")[0] for filename in os.listdir(image_dir)])
        self.image_filenames = self.image_filenames.intersection(
            set(
                self.__df_validation.loc[
                    self.__df_validation.group == self.__type_of_data,
                    "id"
                ]
            )
        )
        self.image_filenames = [f"{filename}.nii.gz" for filename in self.image_filenames]        

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_filenames[idx])
        image_nifti = nib.load(img_name)
        image = torch.from_numpy(
            image_nifti.get_fdata()
        ).float()
        # Get label
        id = img_name.split("/")[-1].split(".")[0]
        label = int(
                    self.df_labels.loc[
                        self.df_labels.id == id,
                        "label"
                    ].iloc[
                        0
                    ]
                )
        return image, label
    
# Load the dataset
train_dataset = LazyImageDataset(
    image_dir = "../aocr2024/preprocessed_images/",
    labels_file = "../aocr2024/TrainValid_ground_truth.csv",
    validation_file = "../aocr2024/TrainValid_split.csv",
    split = "Train",
)
val_dataset = LazyImageDataset(
    image_dir = "../aocr2024/preprocessed_images/",
    labels_file = "../aocr2024/TrainValid_ground_truth.csv",
    validation_file = "../aocr2024/TrainValid_split.csv",
    split = "Valid",
)
train_loader = DataLoader(
    train_dataset, 
    batch_size = batch_size, 
    shuffle = True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size = batch_size,
    shuffle = False,
)

In [None]:
# Model
model = ModifiedR3D18()
model.train()
# Load model if exists
try:
    file_name = "../model_store/execution_10.pth"
    model.load_state_dict(torch.load(file_name))
    original_epochs = file_name.split(".")[2].split("_")[-1]
    print(f"Loaded the model from {file_name}")
except Exception as e:
    original_epochs = 0
    print(f"Exception {e}")
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(
    model.parameters(), 
    lr=1e-3
)

# Number of epochs
num_epochs = 10
print(f"Continue training after {original_epochs} for epochs {num_epochs}")
for epoch in range(num_epochs):
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}")
    for batch_idx, (images, labels) in progress_bar:
        images, labels = images.to(device), labels.reshape((labels.size(0),1)).to(device)
        outputs = model(images)
        loss = criterion(outputs.float(), labels.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # Accuracy
        probabilities = torch.sigmoid(outputs)
        predicted = (probabilities > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_predictions += labels.size(0)
        average_accuracy = correct_predictions / total_predictions
        progress_bar.set_postfix({'loss': total_loss / (batch_idx + 1), 'accuracy': average_accuracy})
    # Validation
    valid_correct_predictions = 0
    valid_total_predictions = 0
    with torch.no_grad():
        for images, labels in val_loader:
            labels = labels.reshape((labels.size(0),1))
            outputs = model(images)
            probabilities = torch.sigmoid(outputs)
            predicted = (probabilities > 0.5).float()
            valid_correct_predictions += (predicted == labels).sum().item()
            valid_total_predictions += labels.size(0)
        average_accuracy = valid_correct_predictions / valid_total_predictions
        print(f"Validation accuracy: {average_accuracy * 100:.2f}%")

In [None]:
torch.save(model.state_dict(), f'../model_store/execution_{num_epochs}.pth')