In [1]:
# dataset.py
import os
import json
from PIL import Image
from torch.utils.data import Dataset

class EyeDiseaseDataset(Dataset):
    def __init__(self, image_dir, labels_path, transform=None):
        self.image_dir = image_dir
        self.transform = transform

        with open(labels_path, 'r') as f:
            self.labels = json.load(f)

        self.image_names = list(self.labels.keys())

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        label = self.labels[image_name]
        image_path = os.path.join(self.image_dir, image_name)

        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, label


In [3]:
# train.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from dataset import EyeDiseaseDataset  # 👈 Import from separate file

if __name__ == '__main__':  # 👈 Required for multiprocessing
    dataset_path = '/Users/balmukundmishra/Desktop/2025-Learning/Eye_Disease_Detection_MTL/Preprocessed_Data/classification'
    labels_path = os.path.join(dataset_path, 'labels.json')

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    dataset = EyeDiseaseDataset(dataset_path, labels_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = resnet18(pretrained=True)
    num_classes = max(dataset.labels.values()) + 1
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(10):
        model.train()
        total_loss = 0

        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")


ModuleNotFoundError: No module named 'dataset'