In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, Subset, BatchSampler, WeightedRandomSampler
import torchvision.transforms as transforms
import torch.nn as nn
from torchvision import models
from torchvision.models import VGG16_Weights
import torch.optim as optim
from PIL import Image
import os
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import time

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


# DATA PREPARATION

In [3]:
class MultiInputDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.labels = []
        self.label_map = {'KCN':0, 'NOR':1, 'SUSP':2}

        #class loop
        for class_name in os.listdir(root_dir):
            class_path = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_path):
                continue
            for case_name in os.listdir(class_path):
                case_path = os.path.join(class_path, case_name)
                if os.path.isdir(case_path):
                    #pick prefix
                    sample_files = os.listdir(case_path)
                    if sample_files:
                        prefix = sample_files[0].split('_')[0] 
                        case_number = sample_files[0].split('_')[1]
                        case_prefix = f"{prefix}_{case_number}"
                        self.samples.append((case_path, case_prefix, prefix))
                        self.labels.append(self.label_map[prefix])
        self.labels = np.array(self.labels)
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        case_path, case_prefix, prefix = self.samples[idx]

        #akhiran nama file
        feature_suffixes = [
            "_CT_A.jpg", "_EC_A.jpg", "_EC_P.jpg", "_Elv_A.jpg","_Elv_P.jpg", "_Sag_A.jpg", "_Sag_P.jpg"
        ]

        feature_images = []
        for suffix in feature_suffixes :
            filename = f"{case_prefix}{suffix}"
            img_path = os.path.join(case_path, filename)
            img = Image.open(img_path).convert("RGB")
            if self.transform :
                img = self.transform(img)
            feature_images.append(img)

        #shape : (7, 3, H, W)
        stacked = torch.stack(feature_images, dim=0)
        label = self.label_map[prefix]
        return stacked, label

In [4]:
def get_default_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [5]:
def load_dataset(data_dir, transform=None):
    #create and return teh dataset instance
    try :
        dataset = MultiInputDataset(data_dir, transform)
        print(f'Dataset loaded succesfully with {len(dataset)} samples')
        return dataset
    except Exception as e:
        print(f'Error loading dataset : {str(e)}')
        raise

In [6]:
def create_train_val_split(dataset, val_size=0.2, random_state=42):
    #splitting dataset with statification
    try :
        train_idx, val_idx = train_test_split(
            range(len(dataset)),
            test_size = val_size,
            stratify = dataset.labels,
            random_state = random_state
        )
        print(f'Split created : {len(train_idx)} training samples ; {len(val_idx)} validation samples')
        return train_idx, val_idx
    except Exception as e:
        print(f'Error Splitting : {str(e)}')
        raise

In [7]:
def create_dataloaders(train_dataset, val_dataset, batch_size=8, num_workers=2):
    #create and return train and val dataloaders
    try :
        train_loader = DataLoader(
            train_dataset,
            batch_size = batch_size,
            shuffle = True,
            num_workers = num_workers,
            pin_memory = True
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size = batch_size,
            shuffle = False,
            num_workers = num_workers,
            pin_memory = True
        )
        print(f'DataLoaders created with batch size {batch_size}')
        return train_loader, val_loader
    except Exception as e:
        print(f'Error creating DataLoaders : {str(e)}')
        raise

In [8]:
# def prepare_dataloaders(data_dir, batch_size=8, test_size=0.2, random_state=42):

#     print(f'Loading dara from : {os.path.abspath(data_dir)}')
#     transform = get_default_transform()
#     full_dataset = MultiInputDataset(root_dir = data_dir, transform=transform)
#     print(f"Dataset created with {len(full_dataset)} samples")
#     if len(full_dataset) == 0:
#         print(f'WARNING : Dataset is empty!')
#         return None, None

#     train_idx, val_idx = train_test_split(
#         list(range(len(full_dataset))),
#         test_size=test_size,
#         random_state=random_state,
#         shuffle=True
#     )

#     train_dataset = Subset(full_dataset, train_idx)
#     val_dataset = Subset(full_dataset, val_idx)

#     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
#     val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

#     return train_loader, val_loader

In [9]:
def test_sample_batch(loader):
    images, labels = next(iter(loader))
    print("Batch image shape : ", images.shape)
    print("Labels : ", labels)

In [10]:
# data_dir = "/kaggle/input/cornealtopography/Train_Validation sets/Train_Validation sets"

# print("Folders : ", os.listdir(data_dir))
# for class_name in os.listdir(data_dir):
#     class_path = os.path.join(data_dir, class_name)
#     print(f"{class_name} -> {len(os.listdir(class_path))} case")

# train_loader, val_loader = prepare_dataloaders(data_dir, batch_size=8)
# test_sample_batch(train_loader)

In [11]:
def main() : 
    data_dir = "/kaggle/input/cornealtopography/Train_Validation sets/Train_Validation sets"
    try :
        print("Folders : ", os.listdir(data_dir))
        for class_name in os.listdir(data_dir):
            class_path = os.path.join(data_dir, class_name)
            print(f"{class_name} -> {len(os.listdir(class_path))} case")

        dataset = load_dataset(data_dir)
        train_idx, val_idx = create_train_val_split(dataset)

        train_dataset = Subset(dataset, train_idx)
        val_dataset = Subset(dataset, val_idx)

        train_loader, val_loader = create_dataloaders(train_dataset, val_dataset)

        test_sample_batch (train_loader)
        return train_loader, val_loader

    except Exception as e :
        print(f'Error in data preparation : {str(e)}')
        return None, None

In [12]:
if __name__ == "__main__" :
    traind_loader, val_loader = main()

Folders :  ['Keratoconus', 'Normal', 'Suspect']
Keratoconus -> 150 case
Normal -> 150 case
Suspect -> 123 case
Dataset loaded succesfully with 423 samples
Split created : 338 training samples ; 85 validation samples
DataLoaders created with batch size 8
Error in data preparation : Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 50, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataset.py", line 420, in __getitems__
    return [self.dataset[self.indices[idx]] for idx in indices]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataset.py", line 420, in <listcomp>
    return [sel