In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
import tifffile as tiff
import pandas as pd
import numpy as np
import os

# Define dataset: Tiff stack with two labels
class TiffDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_name = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        image = tiff.imread(img_name)  # Load stack

        # Manually convert the image to a tensor and normalize
        image = torch.from_numpy(image).type(torch.FloatTensor) / 255.0

        # Add channel dim
        image = image.unsqueeze(0)

        #if self.transform:
        #    image = self.transform(image)

        # Ensure the image is a float tensor
        if not isinstance(image, torch.FloatTensor):
            image = image.type(torch.FloatTensor)

        # Convert labels to numerical format
        rotation_class = 1 if self.annotations.iloc[index, 1] == 'clockwise' else 0
        angle_class = int(self.annotations.iloc[index, 2])  # Maybe use int()

        # Combine the labels (e.g., using one-hot encoding for the input class)
        label = torch.tensor([rotation_class, angle_class], dtype=torch.long)

        return image, label

# Define a custom transform
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     lambda x: x.unsqueeze(0)  # Add a channel dimension
# ])

dataset = TiffDataset(csv_file='dataset/slice/labels_slice.csv', root_dir='dataset/')

# Determine the lengths for train and test sets
train_size = int(0.6 * len(dataset))
test_size = len(dataset) - train_size

# Split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=360, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=360, shuffle=False)