In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import timm  # Make sure to install via: pip install timm

In [3]:
# 1. Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Data Transformations & Loading
# Custom image size and transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT expects 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Use ImageFolder pointing to your own data
train_dataset = ImageFolder(root='./dataset', transform=transform)
test_dataset = ImageFolder(root='./dataset', transform=transform)

# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=32
)
test_loader = DataLoader(
    train_dataset,
    batch_size=32
)
# Class names can be accessed with:
print("Class labels:", train_dataset.classes)

Class labels: ['Healthy', 'Retinitis Pigmentosa']


In [4]:
print("Number of classes:", len(train_loader.dataset.classes))
print("Number of training samples:", len(train_loader.dataset))
print("Number of test samples:", len(test_loader.dataset))

Number of classes: 2
Number of training samples: 1637
Number of test samples: 1637


In [5]:
# 3. Load Pretrained Vision Transformer
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=2)
model = model.to(device)

# 4. Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)


In [6]:
# 5. Training Loop
def train_model(num_epochs=15):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")

In [7]:
# 6. Evaluation
def evaluate_model():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")

In [8]:
# Run
train_model(num_epochs=5)
evaluate_model()

  x = F.scaled_dot_product_attention(
100%|██████████| 52/52 [01:08<00:00,  1.31s/it]


Epoch [1/5], Loss: 24.5361


100%|██████████| 52/52 [01:07<00:00,  1.29s/it]


Epoch [2/5], Loss: 106.6691


100%|██████████| 52/52 [01:07<00:00,  1.29s/it]


Epoch [3/5], Loss: 101.7175


100%|██████████| 52/52 [01:07<00:00,  1.30s/it]


Epoch [4/5], Loss: 139.3501


100%|██████████| 52/52 [01:07<00:00,  1.30s/it]


Epoch [5/5], Loss: 171.2874


100%|██████████| 52/52 [00:46<00:00,  1.12it/s]

Test Accuracy: 50.89%





In [9]:
# Save the trained model
torch.save(model.state_dict(), 'vit_model.pth')
print("Model saved as vit_model.pth ✅")

Model saved as vit_model.pth ✅


In [10]:
import torch
from torchvision import transforms
from PIL import Image
import timm

# Load the model (adjust path and class count if needed)
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=2)
model.load_state_dict(torch.load("vit_model.pth", map_location=torch.device('cpu')))  # or 'cuda'
model.eval()

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Transform for single image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Inference function
def predict_image(image_path, class_names):
    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension

    # Forward pass
    with torch.no_grad():
        outputs = model(image_tensor)
        _, predicted = torch.max(outputs, 1)
        predicted_class = class_names[predicted.item()]

    return predicted_class

# Example usage
class_names = ['Healthy', 'Retinitis Pigmentosa']  # Adjust as per your classes
image_path = 'dataset\Retinitis Pigmentosa\Retinitis Pigmentosa2.jpg'

result = predict_image(image_path, class_names)
print(f"Predicted Class: {result}")


Predicted Class: Retinitis Pigmentosa


In [11]:
image_path = 'dataset\Healthy\Healthy6.jpg'

result = predict_image(image_path, class_names)
print(f"Predicted Class: {result}")

Predicted Class: Retinitis Pigmentosa
