In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
from timm.data.mixup import Mixup
from torch.cuda.amp import autocast, GradScaler  # Enable mixed precision training

# Ensure Google Drive is mounted
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append('/content/drive/MyDrive/ColabNotebooks/GhostNet')  # Add path to ghostnetN1.py
from ghostnetN1 import GhostNet_N  # Import model

# Paths for saving and loading checkpoints
checkpoint_path = "/content/drive/MyDrive/ColabNotebooks/GhostNet/ghostnet_cbamN1.pth"
pretrained_path = "/content/drive/MyDrive/ColabNotebooks/GhostNet/models/state_dict_73.98.pth"  # Pretrained weights

# Model configuration
cfgs = [
    [[3, 16, 16, 0, 1, 0]],
    [[3, 48, 24, 0, 2, 0]],
    [[3, 72, 24, 0, 1, 0]],
    [[5, 72, 40, 0.25, 2, 5]],
    [[5, 120, 40, 0.25, 1, 7]],
    [[3, 240, 80, 0, 2, 0]],
    [[3, 200, 80, 0, 1, 0],
     [3, 184, 80, 0, 1, 0],
     [3, 184, 80, 0, 1, 0],
     [3, 480, 112, 0.25, 1, 5],
     [3, 672, 112, 0.25, 1, 5]],
    [[5, 672, 160, 0.25, 2, 3]],
    [[5, 960, 160, 0, 1, 0],
     [5, 960, 160, 0.25, 1, 3],
     [5, 960, 160, 0, 1, 0],
     [5, 960, 160, 0.25, 1, 3]]
]

# Initialize model (train from scratch, then load pretrained weights)
model = GhostNet_N(cfgs, num_classes=1000, width=1.0, dropout=0.2)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Load pretrained weights from 'state_dict_73.98.pth' (backbone weights)
print("🚀 Loading pretrained model weights for fine-tuning...")

# Load the pretrained state dict with weights_only=True to avoid extra objects
pretrained_dict = torch.load(pretrained_path, map_location=device, weights_only=True)

# Get the state_dict of the model
model_dict = model.state_dict()

# Filter out the layers that are not in the model (mismatch layers)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# Update the model's state_dict with the pretrained weights for the matching layers
model_dict.update(pretrained_dict)

# Load the updated state_dict into the model
model.load_state_dict(model_dict)

# Initialize optimizer and scheduler
initial_lr = 0.01  # Lower LR for fine-tuning
optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=1e-4)

# Use ReduceLROnPlateau to adjust LR based on validation accuracy
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

scaler = torch.amp.GradScaler()  # Mixed precision training

# Mixup and Cutmix augmentation
mixup_fn = Mixup(mixup_alpha=0.2, cutmix_alpha=1.0, label_smoothing=0.1, num_classes=1000)

# Dataset and transforms
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/ColabNotebooks/GhostNet/ImageNet/dir/train', transform=train_transform)
val_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/ColabNotebooks/GhostNet/ImageNet/dir/val', transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=2, pin_memory=True)  # Reduce workers to 2 for Colab
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)  # Reduce workers to 2 for Colab

# Resume from checkpoint if available
if os.path.exists(checkpoint_path):
    print("📥 Resuming from checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch
    best_accuracy = checkpoint['best_accuracy']
    print(f"Resuming from epoch {start_epoch}, best accuracy: {best_accuracy:.2f}%")
else:
    start_epoch = 0  # If no checkpoint, start from scratch
    best_accuracy = 0

# Early stopping parameters
early_stopping_patience = 10
epochs_since_best = 0

#Total Epochs:
num_epochs = 200

# Training loop
for epoch in range(start_epoch, num_epochs):
    print(f"\n🔄 Epoch {epoch}/{num_epochs}")
    model.train()
    running_loss = 0.0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        if mixup_fn is not None:
            inputs, targets = mixup_fn(inputs, targets)  # Apply Mixup & Cutmix

        optimizer.zero_grad()
        with torch.amp.autocast(device_type="cuda"):  # Mixed precision
            outputs = model(inputs)
            loss = nn.CrossEntropyLoss()(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"✅ Epoch {epoch} Finished - Avg Loss: {avg_loss:.4f}")

    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = 100 * correct / total
    print(f"📊 Validation Accuracy: {accuracy:.2f}%")

    # Update scheduler
    scheduler.step(accuracy)  # Adjust LR if accuracy plateaus

    # Check for improvement
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        epochs_since_best = 0  # Reset early stopping counter
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_accuracy': best_accuracy,  # Save best accuracy
        }, checkpoint_path)
        print(f"🎯 Best model saved at epoch {epoch} with accuracy: {accuracy:.2f}%")
    else:
        epochs_since_best += 1

    # Early stopping check
    if epochs_since_best >= early_stopping_patience:
        print("⏹️ Early stopping triggered! Training stopped.")
        break



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
🚀 Loading pretrained model weights for fine-tuning...
📥 Resuming from checkpoint...


  checkpoint = torch.load(checkpoint_path)


Resuming from epoch 27, best accuracy: 15.42%

🔄 Epoch 27/200


KeyboardInterrupt: 