In [41]:
import torch
from torch.utils.data import Dataset, DataLoader, Subset, BatchSampler, WeightedRandomSampler
import torchvision.transforms as transforms
import torch.nn as nn
from torchvision import models
from torchvision.models import VGG16_Weights
import torch.optim as optim
from PIL import Image
import os
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import time
from collections import defaultdict

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


# DATA PREPARATION

In [43]:
class MultiInputDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform if transform else get_default_transform()
        self.samples = []
        self.labels = []
        self.label_map = {'KCN':0, 'NOR':1, 'SUSP':2}

        #class loop
        for class_name in os.listdir(root_dir):
            class_path = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_path):
                continue
            for case_name in os.listdir(class_path):
                case_path = os.path.join(class_path, case_name)
                if os.path.isdir(case_path):
                    #pick prefix
                    sample_files = os.listdir(case_path)
                    if sample_files:
                        prefix = sample_files[0].split('_')[0] 
                        case_number = sample_files[0].split('_')[1]
                        case_prefix = f"{prefix}_{case_number}"
                        self.samples.append((case_path, case_prefix, prefix))
                        self.labels.append(self.label_map[prefix])
        self.labels = np.array(self.labels)
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        case_path, case_prefix, prefix = self.samples[idx]

        #akhiran nama file
        feature_suffixes = [
            "_CT_A.jpg", "_EC_A.jpg", "_EC_P.jpg", "_Elv_A.jpg","_Elv_P.jpg", "_Sag_A.jpg", "_Sag_P.jpg"
        ]

        feature_images = []
        for suffix in feature_suffixes :
            filename = f"{case_prefix}{suffix}"
            img_path = os.path.join(case_path, filename)
            img = Image.open(img_path).convert("RGB")
            img = self.transform(img)
            feature_images.append(img)

        #shape : (7, 3, H, W)
        stacked = torch.stack(feature_images, dim=0)
        label = self.label_map[prefix]
        return stacked, label

In [44]:
def balance_batches(labels, batch_size):
    try:
        num_classes = len(np.unique(labels))
        class_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            class_indices[label].append(idx)
        min_class_len = min(len(idxs) for idxs in class_indices.values())
        batches = []
        for i in range (0, min_class_len, batch_size // num_classes) :
            batch = []
            for c in range(num_classes) : 
                batch.extend(class_indices[c][i:i + batch_size // num_classes])
            if len (batch) == batch_size:
                batches.append(batch)
        return batches
    except Exception as e :
        print(f'Failed to create balance batches : {str(e)}')
        raise

In [45]:
def get_default_transform():
    try:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    except Exception as e :
        print(f'Error in transform dataset : {str(e)}')
        raise

In [46]:
def load_dataset(data_dir, transform=None):
    #create and return teh dataset instance
    try :
        dataset = MultiInputDataset(data_dir, transform)
        print(f'Dataset loaded succesfully with {len(dataset)} samples')
        return dataset
        
    except Exception as e:
        print(f'Error loading dataset : {str(e)}')
        raise

In [47]:
def create_train_val_split(dataset, val_size=0.2, random_state=42):
    #splitting dataset with statification
    try :
        train_idx, val_idx = train_test_split(
            range(len(dataset)),
            test_size = val_size,
            stratify = dataset.labels,
            random_state = random_state
        )
        print(f'Split created : {len(train_idx)} training samples ; {len(val_idx)} validation samples')
        return train_idx, val_idx
        
    except Exception as e:
        print(f'Error Splitting : {str(e)}')
        raise

In [48]:
def create_dataloaders(train_dataset, val_dataset, batch_size=8, num_workers=2):
    #create and return train and val dataloaders
    try :
        train_loader = DataLoader(
            train_dataset,
            batch_size = batch_size,
            shuffle = True,
            num_workers = num_workers,
            pin_memory = True
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size = batch_size,
            shuffle = False,
            num_workers = num_workers,
            pin_memory = True
        )
        print(f'DataLoaders created with batch size {batch_size}')
        return train_loader, val_loader
    except Exception as e:
        print(f'Error creating DataLoaders : {str(e)}')
        raise

In [49]:
# def prepare_dataloaders(data_dir, batch_size=8, test_size=0.2, random_state=42):

#     print(f'Loading dara from : {os.path.abspath(data_dir)}')
#     transform = get_default_transform()
#     full_dataset = MultiInputDataset(root_dir = data_dir, transform=transform)
#     print(f"Dataset created with {len(full_dataset)} samples")
#     if len(full_dataset) == 0:
#         print(f'WARNING : Dataset is empty!')
#         return None, None

#     train_idx, val_idx = train_test_split(
#         list(range(len(full_dataset))),
#         test_size=test_size,
#         random_state=random_state,
#         shuffle=True
#     )

#     train_dataset = Subset(full_dataset, train_idx)
#     val_dataset = Subset(full_dataset, val_idx)

#     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
#     val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

#     return train_loader, val_loader

In [50]:
def test_sample_batch(loader):
    try :
        images, labels = next(iter(loader))
        print("Batch image shape : ", images.shape)
        print("Labels : ", labels)
        
    except Exception as e :
        print(f'Error in testing sample batch : {str(e)}')
        raise

 # MODEL SETUP

In [51]:
def get_model_vgg16(num_classes = 3, device = None):
    try :
        print(f'Preparing VGG16 for {num_classes} classes...')
        vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        for param in vgg16.parameters():
            param.requires_grad = False
        vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, num_classes)
        for param in vgg16.classifier.parameters():
            param.requires_grad = True
        if device is not None:
            vgg16 = vgg16.to(device)
        print(f'Model ready!')
    
        return vgg16
    except Exception as e:
        print(f'Error in model setup : {str(e)}')
        raise

In [52]:
def main() : 
    data_dir = "/kaggle/input/cornealtopography/Train_Validation sets/Train_Validation sets"
    try :
        print("Folders : ", os.listdir(data_dir))
        for class_name in os.listdir(data_dir):
            class_path = os.path.join(data_dir, class_name)
            print(f"{class_name} -> {len(os.listdir(class_path))} case")

        #data preparation
        dataset = load_dataset(data_dir)
        train_idx, val_idx = create_train_val_split(dataset)

        train_dataset = Subset(dataset, train_idx)
        val_dataset = Subset(dataset, val_idx)

        #balance batches
        # train_loader, val_loader = create_dataloaders(train_dataset, val_dataset)
        batches = balance_batches(dataset.labels[train_idx], batch_size=8)
        train_loader =  DataLoader(Subset(dataset, sum(batches, [])), batch_size=8, shuffle=False)
        val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
        test_sample_batch (train_loader)

        #model setup
        model = get_model_vgg16(num_classes = 3, device=device)
        
        return train_loader, val_loader, model

    except Exception as e :
        print(f'Error in main workflow : {str(e)}')
        return None, None, None

In [53]:
if __name__ == "__main__" :
    traind_loader, val_loader, model = main()

Folders :  ['Keratoconus', 'Normal', 'Suspect']
Keratoconus -> 150 case
Normal -> 150 case
Suspect -> 123 case
Dataset loaded succesfully with 423 samples
Split created : 338 training samples ; 85 validation samples
Error in testing sample batch : 
Error in main workflow : 
