In [1]:
from typing import Literal, Union

from dr_grading.datasets import STL10DataModule, GenericImageDataModule
from dr_grading.models import Swin_V2_S, LightningModelWrapper, load_state_from_ckpt

import numpy as np
import pandas as pd
import lightning as L
import torch
import h5py
from torchvision.transforms import v2

from tqdm import tqdm


torch.set_float32_matmul_precision("medium")
StageType = Literal["train", "val", "test"]

## Functions

In [2]:
def save_dataset_to_hdf5(features, labels, filename):
    with h5py.File(filename, "w") as f:
        f.create_dataset("features", data=features)
        f.create_dataset("labels", data=labels)


def load_dataset_from_hdf5(filename):
    with h5py.File(filename, "r") as f:
        features = f["features"][:]
        labels = f["labels"][:]
    return features, labels

In [3]:
def get_transform(stage: StageType, image_size: int = 224) -> v2.Compose:
    if stage == "train":
        return v2.Compose(
            [
                v2.ToImage(),
                v2.Resize((image_size, image_size)),
                v2.RandomHorizontalFlip(),
                v2.RandomApply(torch.nn.ModuleList([
                    v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                    v2.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
                ]), p=0.3),
                v2.RandomAdjustSharpness(2, p=0.4),
                v2.RandomAutocontrast(p=0.4),
                v2.ToDtype(
                    torch.float32, 
                    scale=True
                ),  # Converts and normalizes to [0, 1]
            ]
        )
    else:
        return v2.Compose(
            [
                v2.ToImage(),
                v2.Resize((image_size, image_size)),
                # FourierTransform(shift=True, return_abs=True),
                v2.ToDtype(
                    torch.float32, 
                    scale=True
                ),  # Converts and normalizes to [0, 1]
            ]
        )

In [4]:
def get_device(accelerator: Literal["auto", "gpu", "mps"] = "auto") -> torch.device:
    # Determine the device based on the accelerator
    if accelerator == "auto":
        device = torch.device(
            "cuda"
            if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available() else "cpu"
        )
    elif accelerator == "gpu" and torch.cuda.is_available():
        device = torch.device("cuda")
    elif accelerator == "mps" and torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    return device


def get_dataloader(datamodule: L.LightningDataModule, stage: StageType):
    if stage == "train":
        return datamodule.train_dataloader()
    elif stage == "val":
        return datamodule.val_dataloader()
    elif stage == "test":
        return datamodule.test_dataloader()
    else:
        raise ValueError(f"Unknown stage: {stage}")


def get_latent(
    model: torch.nn.Module,
    datamodule: L.LightningDataModule,
    stage: StageType,
    epochs: int = 10,
    device: Literal["auto", "gpu", "mps"] = "auto",
) -> None:
    """
    Get latent features from the model.

    Args:
        model (torch.nn.Module): The model to use for feature extraction.
        datamodule (L.LightningDataModule): The datamodule containing the dataset.
        stage (StageType): The stage of the dataloader ('train', 'validate', 'test').
        epochs (int, optional): Number of epochs to run. Defaults to 10.
        device (str, optional): The device to use ('auto', 'gpu', 'mps'). Defaults to 'auto'.
    """
    device = get_device(device)
    model.to(device)

    # Load state from checkpoint if needed
    # load_state_from_ckpt(model, ckpt_path)

    # Get dataloader
    dataloader = get_dataloader(datamodule, stage)

    # Set model to evaluation mode and extract features
    model.eval()
    feature_list = []
    label_list = []

    for epoch in range(epochs):
        with torch.no_grad():
            for batch in tqdm(dataloader, desc=f"Epochs: {epoch}", leave=False):
                images, labels = batch
                images = images.to(device)
                outputs = model(images)
                feature_list.append(outputs.cpu())
                label_list.append(labels)
                # Print shapes for debugging
                # print(outputs.shape)
                # print(labels.shape)
        # break
    
    latent_features = torch.cat(feature_list, dim=0)
    labels = torch.cat(label_list, dim=0)

    return (latent_features.numpy(), labels.numpy())

## Extract the APTOS2019 data

In [5]:
num_classes = 5
# Get transform
train_transform = get_transform("train", image_size=512)
test_transform = get_transform("test", image_size=512)

# Instantiate datamodule
data_dir = r"D:\Aj_Aof_Work\OCT_Disease\DATASET\APTOS2019_V4"
datamodule = GenericImageDataModule(
    data_dir=data_dir, 
    batch_size=8, 
    num_workers=8,
    train_transform=train_transform,
    test_transform=test_transform,
    )
datamodule.setup()

model = Swin_V2_S(num_classes=num_classes, transfer=True, return_latent=True)

In [None]:
# def FeatureExtraction(
#     model: LightningModelWrapper,
#     datamodule: L.LightningDataModule,
#     stage: StageType = "train",
#     epochs: int = 10,
#     device: Literal["auto", "gpu", "mps"] = "auto",
# ) -> None:
#     """
#     Extract features from the model.

#     Args:
#         model (LightningModelWrapper): The model to use for feature extraction.
#         datamodule (L.LightningDataModule): The datamodule containing the dataset.
#         stage (StageType, optional): The stage of the dataloader ('train', 'validate', 'test'). Defaults to 'train'.
#         epochs (int, optional): Number of epochs to run. Defaults to 10.
#         device (str, optional): The device to use ('auto', 'gpu', 'mps'). Defaults to 'auto'.
#     """
#     latent_features, labels = get_latent(model, datamodule, stage, epochs, device)
#     save_dataset_to_hdf5(latent_features, labels, f"{stage}_features.h5")

In [6]:
train_latent_feature, train_labels = get_latent(model, datamodule, stage="train", epochs=5)

                                                            

In [7]:
val_latent_feature, val_labels = get_latent(model, datamodule, stage="val", epochs=5)

                                                          

In [8]:
test_latent_feature, test_labels = get_latent(model, datamodule, stage="test", epochs=5)

                                                          

In [9]:
def post_process_latent_features(latent_features, labels):
    # Convert to DataFrame for easier manipulation
    df = pd.DataFrame(latent_features)
    df['labels'] = labels
    df = df.drop_duplicates().reset_index(drop=True)

    return df

In [10]:
train_df = post_process_latent_features(train_latent_feature, train_labels)
val_df = post_process_latent_features(val_latent_feature, val_labels)
test_df = post_process_latent_features(test_latent_feature, test_labels)

In [11]:
def save_to_csv_hdf5(df, filename):
    # Save DataFrame to CSV
    df.to_csv(f"{filename}.csv", index=False)
    # Save DataFrame to HDF5
    X = df.drop(columns=['labels']).values
    y = df['labels'].values
    save_dataset_to_hdf5(X, y, f"{filename}.h5")

In [12]:
save_to_csv_hdf5(train_df, "train")
save_to_csv_hdf5(val_df, "val")
save_to_csv_hdf5(test_df, "test")