In [None]:
import os
import time
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.models import resnet18, ResNet18_Weights
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm
from PIL import Image
import requests
from io import BytesIO

In [None]:
# Configuration
# ----------------------------
DATA_ROOT = "/kaggle/input/skindiseasedataset/SkinDisease/SkinDisease"
SELECTED_CLASSES = ['Acne', 'Eczema', 'Psoriasis', 'Warts', 'SkinCancer', 'Unknown_Normal']

NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_WORKERS = 2

# Output folder
OUTPUT_DIR = "./outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, "skin_disease_resnet18.pth")
LABELS_SAVE_PATH = os.path.join(OUTPUT_DIR, "labels.json")


In [None]:
# ----------------------------
# Reproducibility (optional)
# ----------------------------
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [None]:
# Device
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# ============================================================
# Data transforms
# ============================================================
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


In [None]:
# Load dataset + filter selected classes
# ============================================================
train_dir = os.path.join(DATA_ROOT, "train")
test_dir = os.path.join(DATA_ROOT, "test")

train_data = datasets.ImageFolder(train_dir, transform=train_transform)
test_data = datasets.ImageFolder(test_dir, transform=test_transform)

original_class_to_idx = train_data.class_to_idx

# Keep only selected classes that exist in dataset mapping
selected_idx = {
    original_class_to_idx[cls]: cls
    for cls in SELECTED_CLASSES
    if cls in original_class_to_idx
}

assert len(selected_idx) > 0, "No selected classes found in dataset!"

def remap_samples(samples):
    """
    Converts original ImageFolder labels to new labels based on SELECTED_CLASSES order.
    Filters out samples not in selected classes.
    """
    remapped = []
    for path, label in samples:
        if label in selected_idx:
            cls_name = selected_idx[label]
            new_label = SELECTED_CLASSES.index(cls_name)
            remapped.append((path, new_label))
    return remapped

# Remap
train_data.samples = remap_samples(train_data.samples)
test_data.samples = remap_samples(test_data.samples)

train_data.targets = [label for _, label in train_data.samples]
test_data.targets = [label for _, label in test_data.samples]

train_data.classes = SELECTED_CLASSES
test_data.classes = SELECTED_CLASSES

train_data.class_to_idx = {cls: i for i, cls in enumerate(SELECTED_CLASSES)}
test_data.class_to_idx = {cls: i for i, cls in enumerate(SELECTED_CLASSES)}

print("Selected classes:", SELECTED_CLASSES)
print("Train samples:", len(train_data))
print("Test samples:", len(test_data))


In [None]:
# DataLoaders
# ============================================================
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)



In [None]:
# Model: ResNet18 (ImageNet pretrained)
# ============================================================
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)

num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(SELECTED_CLASSES))

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [None]:
# Training
# ============================================================
train_losses = []
train_accs = []

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    epoch_start = time.time()

    print(f"\nEpoch [{epoch + 1}/{NUM_EPOCHS}]")
    for images, labels in tqdm(train_loader, desc=f"Training {epoch + 1}/{NUM_EPOCHS}", leave=False):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_loss = running_loss / max(1, len(train_loader))
    acc = 100.0 * correct / max(1, total)

    train_losses.append(avg_loss)
    train_accs.append(acc)

    print(f"Train Loss: {avg_loss:.4f} | Train Accuracy: {acc:.2f}%")
    print(f"Epoch time: {time.time() - epoch_start:.2f} sec")

print(f"\nTotal Training Time: {time.time() - start_time:.2f} sec")


In [None]:
# Save model + labels (GitHub ready)
# ============================================================
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"\n✅ Model weights saved to: {MODEL_SAVE_PATH}")

with open(LABELS_SAVE_PATH, "w") as f:
    json.dump(SELECTED_CLASSES, f, indent=2)

print(f"✅ Labels saved to: {LABELS_SAVE_PATH}")


# ============================================================
# Evaluation
# ============================================================
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating"):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=SELECTED_CLASSES))


# ============================================================
# Confusion Matrix (Matplotlib only - no seaborn needed)
# ============================================================
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest")
plt.title("Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(SELECTED_CLASSES))
plt.xticks(tick_marks, SELECTED_CLASSES, rotation=45, ha="right")
plt.yticks(tick_marks, SELECTED_CLASSES)

# Write numbers inside cells
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, format(cm[i, j], "d"),
                 ha="center", va="center",
                 color="white" if cm[i, j] > thresh else "black")

plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.tight_layout()
plt.show()


In [None]:
# Utility: Predict from Image URL
# ============================================================
def predict_from_url(image_url, model, transform, class_names):
    model.eval()

    response = requests.get(image_url, timeout=20)
    response.raise_for_status()

    image = Image.open(BytesIO(response.content)).convert("RGB")
    img_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1)
        conf, predicted = torch.max(probs, 1)

    predicted_class = class_names[predicted.item()]
    confidence = conf.item()

    plt.figure(figsize=(5, 5))
    plt.imshow(image)
    plt.title(f"Predicted: {predicted_class} ({confidence*100:.2f}%)")
    plt.axis("off")
    plt.show()

    return predicted_class, confidence


# Example URL prediction
image_url = "https://images.ctfassets.net/4f3rgqwzdznj/6V6F0gVHYPSTmNARliytAY/4fee36cfb2f433f8b43d85ecef211345/eczema-severe-eczema-mistakes.png"
predicted, conf = predict_from_url(image_url, model, test_transform, SELECTED_CLASSES)
print("Predicted Disease:", predicted, "| Confidence:", conf)


# ============================================================
# Utility: Predict from local image path
# ============================================================
def predict_image(image_path, model, transform, class_names):
    model.eval()

    image = Image.open(image_path).convert("RGB")
    img_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_tensor)
        probs = torch.softmax(outputs, dim=1)
        conf, predicted = torch.max(probs, 1)

    predicted_class = class_names[predicted.item()]
    confidence = conf.item()

    plt.figure(figsize=(5, 5))
    plt.imshow(image)
    plt.title(f"Predicted: {predicted_class} ({confidence*100:.2f}%)")
    plt.axis("off")
    plt.show()

    return predicted_class, confidence

# Example:
# predicted, conf = predict_image("/kaggle/input/....jpg", model, test_transform, SELECTED_CLASSES)
# print(predicted, conf)