In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import time

# --------------------------
# 1. Device Setup
# --------------------------
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# --------------------------
# 2. Dataset Paths & Augmentation
# --------------------------
data_dir = "dataset"  # folder containing 20 classes

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# --------------------------
# 3. Print dataset info
# --------------------------
print(f"Found classes: {full_dataset.classes}")
for idx, class_name in enumerate(full_dataset.classes):
    class_count = sum(1 for item in full_dataset.targets if item == idx)
    print(f"Class '{class_name}' has {class_count} images")
print(f"Total images: {len(full_dataset)}")

# --------------------------
# 4. Train/Validation Split
# --------------------------
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
print(f"Train images: {train_size}, Validation images: {val_size}")

# --------------------------
# 5. DataLoaders
# --------------------------
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

# Test one batch to confirm loading
images, labels = next(iter(train_loader))
print(f"Sample batch images shape: {images.shape}, labels shape: {labels.shape}")

# --------------------------
# 6. Model (EfficientNet-B0)
# --------------------------
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, len(full_dataset.classes))  # 20 classes
model = model.to(device)

# --------------------------
# 7. Loss & Optimizer
# --------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# --------------------------
# 8. Training Loop
# --------------------------
epochs = 10
best_acc = 0.0
start_time = time.time()

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    print("-" * 30)

    # --- Training ---
    model.train()
    running_loss, running_corrects = 0.0, 0
    train_bar = tqdm(train_loader, desc="Training", leave=False)
    for inputs, labels in train_bar:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels)

        # Update progress bar (float only, no double)
        train_bar.set_postfix(loss=(running_loss / (train_bar.n + 1)),
                              acc=(running_corrects.float().item() / (train_bar.n + 1)))

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    print(f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}")

    # --- Validation ---
    model.eval()
    val_loss, val_corrects = 0.0, 0
    val_bar = tqdm(val_loader, desc="Validation", leave=False)
    with torch.no_grad():
        for inputs, labels in val_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            val_loss += loss.item() * inputs.size(0)
            val_corrects += torch.sum(preds == labels)

            val_bar.set_postfix(loss=(val_loss / (val_bar.n + 1)),
                                acc=(val_corrects.float().item() / (val_bar.n + 1)))

    val_loss = val_loss / len(val_dataset)
    val_acc = val_corrects.double() / len(val_dataset)
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        print(f"Saved new best model with Val Acc: {best_acc:.4f}")

total_time = time.time() - start_time
print(f"\nTraining complete in {total_time/60:.2f} minutes")
print(f"Best Validation Accuracy: {best_acc:.4f}")


Using device: mps
Found classes: ['Acne And Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma And Other Malignant Lesions', 'Atopic Dermatitis Photos', 'Ba  Cellulitis', 'Ba Impetigo', 'Benign', 'Bullous Disease Photos', 'Cellulitis Impetigo And Other Bacterial Infections', 'Eczema Photos', 'Exanthems And Drug Eruptions', 'Fu Athlete Foot', 'Fu Nail Fungus', 'Fu Ringworm', 'Hair Loss Photos Alopecia And Other Hair Diseases', 'Heathy', 'Herpes Hpv And Other Stds Photos', 'Light Diseases And Disorders Of Pigmentation', 'Lupus And Other Connective Tissue Diseases', 'Malignant', 'Melanoma Skin Cancer Nevi And Moles']
Class 'Acne And Rosacea Photos' has 6837 images
Class 'Actinic Keratosis Basal Cell Carcinoma And Other Malignant Lesions' has 6821 images
Class 'Atopic Dermatitis Photos' has 7645 images
Class 'Ba  Cellulitis' has 8079 images
Class 'Ba Impetigo' has 8148 images
Class 'Benign' has 6459 images
Class 'Bullous Disease Photos' has 7695 images
Class 'Cellulitis Impetigo And 

                                                                                

KeyboardInterrupt: 