In [6]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
import os
from helpers.helper import Helper
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

import json
# -----------------------
# Model creation
# -----------------------
def create_model(num_classes, use_custom_head=True):
    if use_custom_head:
        model = timm.create_model('resnet101', pretrained=True, num_classes=0)  # no classifier
        in_features = model.num_features
        model.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    else:
        # Simple version (just a linear classifier)
        model = timm.create_model('resnet101', pretrained=True, num_classes=num_classes)
    return model


# -----------------------
# Phase 1: Train only head
# -----------------------
def train_head_only(model, train_loader, val_loader, device, num_epochs=5, lr=1e-3):
    # Freeze backbone
    for param in model.parameters():
        param.requires_grad = False
    for param in model.fc.parameters():
        param.requires_grad = True

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.fc.parameters(), lr=lr)

    best_acc = 0.0
    for epoch in range(num_epochs):
        # Train
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validate
        model.eval()
        correct, total, val_loss = 0, 0, 0.0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, preds = outputs.max(1)
                total += labels.size(0)
                correct += preds.eq(labels).sum().item()

        val_acc = 100 * correct / total
        print(f"[Head Only] Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {running_loss/len(train_loader):.4f} | "
              f"Val Loss: {val_loss/len(val_loader):.4f} | "
              f"Val Acc: {val_acc:.2f}%")

        best_acc = max(best_acc, val_acc)

    print(f"Best accuracy after head-only training: {best_acc:.2f}%")
    return model


# -----------------------
# Phase 2: Fine-tune full model
# -----------------------
def fine_tune_full(model, train_loader, val_loader, device, num_epochs=50, lr=1e-4, save_dir="checkpoints"):
    # Unfreeze backbone
    for param in model.parameters():
        param.requires_grad = True

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    best_acc = 0.0
    os.makedirs(save_dir, exist_ok=True)  # make sure checkpoint dir exists

    for epoch in range(num_epochs):
        # Train
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validate
        model.eval()
        correct, total, val_loss = 0, 0, 0.0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, preds = outputs.max(1)
                total += labels.size(0)
                correct += preds.eq(labels).sum().item()

        val_acc = 100 * correct / total
        print(f"[Fine-Tune] Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {running_loss/len(train_loader):.4f} | "
              f"Val Loss: {val_loss/len(val_loader):.4f} | "
              f"Val Acc: {val_acc:.2f}%")

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), os.path.join(save_dir, "best_resnet101.pth"))
            print(f"âœ… New best model saved (Acc: {best_acc:.2f}%)")

        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            ckpt_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pth")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss/len(val_loader),
            }, ckpt_path)
            print(f"ðŸ’¾ Checkpoint saved: {ckpt_path}")

        scheduler.step()

    print(f"Best accuracy after fine-tuning: {best_acc:.2f}%")
    return model
def get_transforms(image_size=224):
    train_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomRotation(30),
        #transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return train_transform, val_transform


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_model(num_classes=66, use_custom_head=False).to(device)

train_dir = 'C:\\Identification dataset\\KNP_identification_dataset\\Training'
val_dir = 'C:\\Identification dataset\\KNP_identification_dataset\\Validation' 

savePath = Helper.increment_path(Path("identification_models_KNP\\ResNet101_August_v"),mkdir=True)
os.makedirs(savePath, exist_ok=True)
# Create transforms and datasets
train_transform, val_transform = get_transforms()
train_dataset = ImageFolder(train_dir, transform=train_transform)
val_dataset = ImageFolder(val_dir, transform=val_transform)

# Save class mapping
class_mapping = {v: k for k, v in train_dataset.class_to_idx.items()}
import json
with open(f'{savePath}\\class_mapping.json', 'w') as f:
    json.dump(class_mapping, f, indent=4)
batch_size = 16

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=16,pin_memory=True)
# Phase 1
checkpoint_dir=(str(savePath)+"\\checkpoints")
model = train_head_only(model, train_loader, val_loader, device, num_epochs=5, lr=1e-3)

# Phase 2
model = fine_tune_full(model, train_loader, val_loader, device, num_epochs=120, lr=1e-4, save_dir=checkpoint_dir)

torch.save(model.state_dict(), f"{savePath}\\best_model.pth")
print("ðŸŽ¯ Training complete. Model saved!")


[Head Only] Epoch 1/5 | Train Loss: 3.9304 | Val Loss: 3.7766 | Val Acc: 21.37%
[Head Only] Epoch 2/5 | Train Loss: 3.5249 | Val Loss: 3.4329 | Val Acc: 23.04%
[Head Only] Epoch 3/5 | Train Loss: 3.2590 | Val Loss: 3.2517 | Val Acc: 33.00%
[Head Only] Epoch 4/5 | Train Loss: 3.0611 | Val Loss: 3.1609 | Val Acc: 35.36%
[Head Only] Epoch 5/5 | Train Loss: 2.8906 | Val Loss: 2.9034 | Val Acc: 37.60%
Best accuracy after head-only training: 37.60%
[Fine-Tune] Epoch 1/120 | Train Loss: 1.4770 | Val Loss: 0.7936 | Val Acc: 76.54%
âœ… New best model saved (Acc: 76.54%)
[Fine-Tune] Epoch 2/120 | Train Loss: 0.6120 | Val Loss: 0.2963 | Val Acc: 91.13%
âœ… New best model saved (Acc: 91.13%)
[Fine-Tune] Epoch 3/120 | Train Loss: 0.3045 | Val Loss: 0.1639 | Val Acc: 94.90%
âœ… New best model saved (Acc: 94.90%)
[Fine-Tune] Epoch 4/120 | Train Loss: 0.1707 | Val Loss: 0.1070 | Val Acc: 97.04%
âœ… New best model saved (Acc: 97.04%)
[Fine-Tune] Epoch 5/120 | Train Loss: 0.1111 | Val Loss: 0.0593 | Val

In [5]:
torch.save(model.state_dict(), f"{savePath}\\resnet_best.pth")
print("ðŸŽ¯ Training complete. Model saved!")

ðŸŽ¯ Training complete. Model saved!


In [7]:
import argparse
import torch
import os

def convert_checkpoint_to_weight(checkpoint_path, output_dir):
    # Check if the checkpoint file exists
    if not os.path.exists(checkpoint_path):
        print(f"Error: Checkpoint file {checkpoint_path} does not exist.")
        return

    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

    # Extract the model state dict
    if 'model_state_dict' in checkpoint:
        model_state_dict = checkpoint['model_state_dict']
    else:
        print("Error: The checkpoint does not contain a 'model_state_dict'.")
        return

    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Generate the output file name
    checkpoint_filename = os.path.basename(checkpoint_path)
    weight_filename = checkpoint_filename.replace('checkpoint', 'weight').replace('.pth', '_weights.pth')
    output_path = os.path.join(output_dir, weight_filename)

    # Save the model state dict
    torch.save(model_state_dict, output_path)

    print(f"Weight file saved to: {output_path}")

    # Print additional information from the checkpoint
    if 'epoch' in checkpoint:
        print(f"Epoch: {checkpoint['epoch']}")
    if 'val_acc' in checkpoint:
        print(f"Validation Accuracy: {checkpoint['val_acc']:.2f}%")

if __name__ == "__main__":
    weight_path = 'identification_models_KNP\ResNet101_August_v9\checkpoints\checkpoint_epoch_120.pth'
    output_path  = 'identification_models_KNP\ResNet101_August_v9\epoch_120.pth'

    convert_checkpoint_to_weight(weight_path, output_path)

Weight file saved to: identification_models_KNP\ResNet101_August_v9\epoch_120.pth\weight_epoch_120_weights.pth
Epoch: 120
Validation Accuracy: 99.53%
