# setup

In [None]:
from google.colab import drive

drive.mount("/content/gdrive")
data_path = "/content/gdrive/MyDrive/austral/tesis/data/processed_data"

# classes

## Dataset

In [None]:
import os
import torch
import tifffile as tif
import pandas as pd
from torch.utils.data import Dataset
import numpy as np

# from torchvision import transforms
import matplotlib.pyplot as plt


class CellDivisionDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, device=None):
        """
        Args:
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.img_labels = pd.read_csv(annotations_file, header=None)
        self.img_dir = img_dir
        self.transform = transform
        self.device = torch.device(device if device else "cpu")

        # make binary labels, rtcc vs all the rest
        self.img_labels_bin = self.img_labels.iloc[:, 1].apply(
            lambda x: 1 if x in [3, 7, 8] else 0
        )

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = os.path.join(str(self.img_dir), str(self.img_labels.iloc[idx, 0]))

        # define the label
        # label = self.img_labels.iloc[idx, 1]
        label = self.img_labels_bin[idx]

        # Use tifffile to read the stacked image directly
        image_stack = tif.imread(img_path)

        # If the image has a different dtype (e.g., uint16), you might want to convert it
        # For example, converting to float and scaling to [0, 1] if necessary
        # image_stack_tensor = image_stack_tensor.float() / image_stack_tensor.max()

        image_stack = image_stack.astype(np.float32) / np.iinfo(image_stack.dtype).max

        # If the image is not already in a torch tensor, convert it
        # Assuming image_stack is a numpy array of shape [num_frames, H, W] for grayscale images
        image_stack_tensor = torch.from_numpy(image_stack).to(self.device)

        # Trim the tensor to have a fixed depth of 20 if it has more
        if image_stack_tensor.shape[0] > 20:
            image_stack_tensor = image_stack_tensor[:20, :, :]

        # Add a channel dimension to the video tensor
        image_stack_tensor = image_stack_tensor.unsqueeze(
            0
        )  # This adds the channel dimension at position 0

        # Convert label to a tensor (if it's not already one)
        label = torch.tensor(label)

        # Apply any transforms here. Note: You'll need to modify or ensure your transforms
        # can handle 3D data if they're meant for 2D images.
        if self.transform:
            # Apply transform to each frame individually or modify your transform to handle 3D data
            # This is a placeholder; actual implementation will depend on your transform
            pass

        return image_stack_tensor, label

    def show_stack(self, idx):
        """
        Display a stack of images from a dataset.

        This function takes the first sample from the provided dataset, assuming it is a tuple
        where the first element is a stack of images (as a torch tensor) and the second element is its label.
        It then plots the images in a 5x5 grid, displaying up to 25 images from the stack.

        Parameters:
        - dataset: A dataset object that allows indexing to retrieve a sample (image stack and label).

        Returns:
        None. This function directly displays the plotted images.
        """
        img, label = self[idx]  # Extract the first sample (image stack and label)
        # Set up the figure and axes for plotting
        fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(15, 15))

        # Flatten the axes array for easy iteration
        axes_flat = axes.flatten()

        for i, ax in enumerate(axes_flat):
            if i < len(img):
                # reshape removing the channel dimension
                img = img.reshape(20, 75, 75)
                ax.imshow(
                    img.cpu().numpy()[i], cmap="gray"
                )  # Assuming img is a torch tensor
                ax.set_title(f"Frame {i+1}")
                ax.axis("off")
            else:
                ax.axis("off")  # Hide unused subplots

        plt.tight_layout()
        plt.show()

## Load config

In [None]:
import yaml


def load_config(config_path: str = "../config/base.yaml") -> dict:
    """
    Loads a YAML configuration file.

    Args:
        config_path (str, optional): The file path to the YAML configuration file.

    Returns:
        dict: The configuration settings as a dictionary.
    """
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config

## train

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from dataset import CellDivisionDataset
import torch
import torch.optim as optim


def train_model(config, model):
    # Load the dataset
    dataset = CellDivisionDataset(
        config["data"]["annotations_path"],
        config["data"]["img_path"],
        transform=None,
        device=config["env"]["device"],
    )

    # Split dataset into training and validation sets
    train_size = int(config["train"]["size"] * len(dataset))
    val_size = len(dataset) - train_size

    # Create a generator on the required device
    generator = torch.Generator(device=config["env"]["device"])

    # Use this generator in the random_split function
    train_dataset, val_dataset = random_split(
        dataset, [train_size, val_size], generator=generator
    )

    # need to explicitly pass the generator to the dataloader for mps to work
    train_loader = DataLoader(
        train_dataset, batch_size=32, shuffle=True, generator=generator
    )

    # val_loader = DataLoader(
    #     val_dataset, batch_size=32, shuffle=False, generator=generator
    # )

    # Initialize the model
    model = model.to(config["env"]["device"])
    model.train()

    # Choose a loss function. For binary classification, BCELoss is commonly used.
    criterion = torch.nn.BCELoss()

    # Choose an optimizer (e.g., Adam) and link it to your model's parameters
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    num_epochs = config["train"]["epochs"]
    model = train_loop(num_epochs, model, criterion, optimizer, train_loader, config)


def train_loop(num_epochs, model, criterion, optimizer, train_loader, config):
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            # Get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # Make sure inputs and labels are on the same device as the model
            inputs, labels = (
                inputs.to(config["env"]["device"]),
                labels.to(config["env"]["device"]),
            )

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass: compute the model output
            outputs = model(inputs)

            # Calculate the loss
            loss = criterion(
                outputs.squeeze(), labels.float()
            )  # Use .squeeze() to remove any extra dims
            # Backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()

            # Perform a single optimization step (parameter update)
            optimizer.step()

            # Print statistics
            running_loss += loss.item()
            if i % 10 == 9:  # print every 10 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 10))
                running_loss = 0.0

    print("Finished Training")
    return model

## Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Simple3DCNN(nn.Module):
    def __init__(self):
        super(Simple3DCNN, self).__init__()
        # 3D Convolutional layers
        self.conv1 = nn.Conv3d(
            in_channels=1, out_channels=16, kernel_size=(3, 3, 3), stride=1, padding=1
        )
        self.conv2 = nn.Conv3d(
            in_channels=16, out_channels=32, kernel_size=(3, 3, 3), stride=1, padding=1
        )
        self.conv3 = nn.Conv3d(
            in_channels=32, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=1
        )

        # Pooling layer
        self.pool = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2, padding=0)

        # Fully connected layers for classification
        self.fc1 = nn.Linear(in_features=64 * 9 * 9 * 5, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=1)

    def forward(self, x):
        # Apply 3D convolutions followed by pooling
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))

        # print(f"Shape after convolutions: {x.shape}")

        # Dynamically calculate the correct number of features for fc1
        # num_features = x.size(1) * x.size(2) * x.size(3) * x.size(4)
        # print(f"num_features: {num_features}")
        # now that I know the correct number, set it as in_features
        num_features = 64 * 9 * 9 * 2

        # Adjust the input size for the first fully connected layer based on the actual size
        self.fc1 = nn.Linear(in_features=num_features, out_features=512).to(x.device)

        # Flatten the output for the fully connected layer
        x = x.view(-1, num_features)

        # Apply fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        # Apply sigmoid activation function to the output layer for binary classification
        x = torch.sigmoid(x)

        return x


# Initialize the model
model = Simple3DCNN()

# training

In [None]:
config = load_config("config/colab.yaml")
model = Simple3DCNN()
model = train_model(config, model)