In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from collections import Counter
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np

# --------------------------
# Paths
# --------------------------
train_dir = r'D:\System Variables\Ongoing\White Blood cells\Train\Train'
val_dir = r'D:\System Variables\Ongoing\White Blood cells\val'
test_dir = r'D:\System Variables\Ongoing\White Blood cells\Train\Test'

# Model save directory
save_dir = r'D:\System Variables\Ongoing\White Blood cells\Save'
os.makedirs(save_dir, exist_ok=True)
model_path = os.path.join(save_dir, 'best_model_sir_jamal.pth')

# --------------------------
# Hyperparameters
# --------------------------
batch_size = 32
num_epochs = 200
patience = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------
# Transforms
# --------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# --------------------------
# Datasets and Loaders
# --------------------------
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
test_dataset = datasets.ImageFolder(test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# --------------------------
# Model
# --------------------------
num_classes = len(train_dataset.classes)
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)

# --------------------------
# Early stopping setup
# --------------------------
best_val_loss = float('inf')
early_stopping_counter = 0
final_weights_tensor = None

train_losses, val_losses = [], []
train_accs, val_accs = [], []

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    avg_loss = total_loss / len(dataloader.dataset)
    acc = accuracy_score(all_labels, all_preds)
    return avg_loss, acc, all_labels, all_preds

# --------------------------
# Training loop
# --------------------------
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train, total_train = 0, 0
    all_epoch_labels = []

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        all_epoch_labels.extend(labels.cpu().tolist())

        # Compute normalized class weights
        class_counts = Counter(all_epoch_labels)
        total_samples = len(all_epoch_labels)
        raw_weights = [total_samples / (num_classes * class_counts.get(c, 1)) for c in range(num_classes)]
        weight_sum = sum(raw_weights)
        normalized_weights = [w / weight_sum for w in raw_weights]
        weights_tensor = torch.tensor(normalized_weights, dtype=torch.float).to(device)
        criterion = nn.CrossEntropyLoss(weight=weights_tensor)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = correct_train / total_train
    train_losses.append(epoch_loss)
    train_accs.append(epoch_acc)

    # Validation
    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Early stopping and save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        torch.save(model.state_dict(), model_path)
        final_weights_tensor = weights_tensor.clone()
        print(f"🎉 New best model saved at epoch {epoch+1} with Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= patience:
            print("⏹️ Early stopping triggered.")
            break

# --------------------------
# Load best model
# --------------------------
model.load_state_dict(torch.load(model_path))

# --------------------------
# Final evaluation
# --------------------------
train_loss, train_acc, train_labels, train_preds = evaluate(model, train_loader, criterion)
test_loss, test_acc, test_labels, test_preds = evaluate(model, test_loader, criterion)

# Confusion matrices
train_cm = confusion_matrix(train_labels, train_preds)
test_cm = confusion_matrix(test_labels, test_preds)

# --------------------------
# Plot confusion matrices and curves
# --------------------------
fig, axs = plt.subplots(1, 3, figsize=(20, 6))

ConfusionMatrixDisplay(train_cm, display_labels=train_dataset.classes).plot(ax=axs[0], xticks_rotation=45)
axs[0].set_title("Train Confusion Matrix")

ConfusionMatrixDisplay(test_cm, display_labels=test_dataset.classes).plot(ax=axs[1], xticks_rotation=45)
axs[1].set_title("Test Confusion Matrix")

axs[2].plot(train_losses, label='Train Loss')
axs[2].plot(val_losses, label='Validation Loss')
axs[2].plot(train_accs, label='Train Acc')
axs[2].plot(val_accs, label='Validation Acc')
axs[2].set_title("Loss & Accuracy Curves")
axs[2].set_xlabel("Epoch")
axs[2].set_ylabel("Value")
axs[2].legend()

plt.tight_layout()
plt.show()

# --------------------------
# Print final results
# --------------------------
print(f"\n✅ Final Train Accuracy: {train_acc*100:.2f}%")
print(f"✅ Final Test Accuracy: {test_acc*100:.2f}%")
print(f"\nFinal Normalized Class Weights (sum = {final_weights_tensor.sum().item():.2f}):")
for idx, w in enumerate(final_weights_tensor.tolist()):
    print(f"Class '{train_dataset.classes[idx]}': {w:.4f}")
