In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import kagglehub
import os
from sklearn.metrics import classification_report

dataset_path = kagglehub.dataset_download("kondwani/eye-disease-dataset")
BASE_DIR = os.path.join(dataset_path, "Eye_diseases")
print("Dataset path:", BASE_DIR)

Dataset path: /kaggle/input/eye-disease-dataset/Eye_diseases


In [None]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size=3,
            stride=stride, padding=1, groups=in_channels, bias=False
        )
        self.pointwise = nn.Conv2d(
            in_channels, out_channels,
            kernel_size=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.depthwise(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pointwise(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x
class MobileNetV1(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.layers = nn.Sequential(
            DepthwiseSeparableConv(32, 64, stride=1),
            DepthwiseSeparableConv(64, 128, stride=2),
            DepthwiseSeparableConv(128, 128, stride=1),
            DepthwiseSeparableConv(128, 256, stride=2),
            DepthwiseSeparableConv(256, 256, stride=1),
            DepthwiseSeparableConv(256, 512, stride=2),
            *[DepthwiseSeparableConv(512, 512) for _ in range(5)],
            DepthwiseSeparableConv(512, 1024, stride=2),
            DepthwiseSeparableConv(1024, 1024, stride=1),
        )
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(1024, num_classes)
    def forward(self, x):
        x = self.stem(x)
        x = self.layers(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

full_dataset = datasets.ImageFolder(BASE_DIR, transform=transform)
loader = DataLoader(full_dataset, batch_size=32, shuffle=True, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MobileNetV1(num_classes=len(full_dataset.classes)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

best_loss = float('inf')
model_save_path = "best_model_eye.pth"
num_epochs = 10
model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(full_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f}")
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), model_save_path)


Epoch [1/10] Loss: 1.4984
Epoch [2/10] Loss: 1.2959
Epoch [3/10] Loss: 1.2204
Epoch [4/10] Loss: 1.0593
Epoch [5/10] Loss: 0.9286
Epoch [6/10] Loss: 0.6839
Epoch [7/10] Loss: 0.5743
Epoch [8/10] Loss: 0.6165
Epoch [9/10] Loss: 0.4266
Epoch [10/10] Loss: 0.2724


In [None]:
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in loader:
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=full_dataset.classes))


Classification Report:
              precision    recall  f1-score   support

Bulging_Eyes       0.57      1.00      0.72        30
   Cataracts       0.98      0.87      0.92        47
Crossed_Eyes       0.98      0.91      0.95       174
    Glaucoma       0.94      0.93      0.93        82
     Uveitis       0.96      0.86      0.91        50

    accuracy                           0.91       383
   macro avg       0.88      0.91      0.89       383
weighted avg       0.94      0.91      0.92       383

