In [1]:
import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import timm
from torch import nn, optim
from torchinfo import summary
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed_value=1000
#np.random.seed(1000)
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled=False
torch.backends.cudnn.benchmark = False
np.random.seed(seed_value)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 224 		# Dimension of image ( in pixels )
batch_size = 64  		# batch size for training data set 64 to 32
val_batch_size = 32 		# batch size for validation data set
num_classes = 2			# Number of classes
best_val_loss = float('inf')

class_name = ['L', 'NL']	# classified input images has to be kept in folders with the same class names
num_epochs = 50			# 
patience = 5
counter = 0

In [4]:
model_name = 'best_model.pt'		#model weight to be saved
model_dir = os.path.join(os.getcwd(), 'saved_models')

In [5]:
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

In [6]:
train_dir = 'images/training/'
val_dir = 'images/validation/'
test_dir = 'images/testing/'

In [7]:
train_data = datasets.ImageFolder(train_dir, transform=transform)
val_data = datasets.ImageFolder(val_dir, transform=transform)
test_data = datasets.ImageFolder(test_dir, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)

In [8]:
sample_size = len(train_data)
steps = math.ceil(sample_size / float(batch_size))	
sample_size_val = len(val_data)
val_steps =  math.ceil( sample_size_val/ float(val_batch_size))

In [None]:
encoder = timm.create_model(
    'hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
    pretrained=True,
    features_only=True  # gives list of feature maps
)

# Check number of output channels from last feature map
in_channels = encoder.feature_info[-1]['num_chs']
print("Last feature map channels:", in_channels)

# Define custom classifier head
class CustomHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)  # [B, C, H, W] → [B, C, 1, 1]
        self.fc1 = nn.Linear(in_channels, 128)
        self.drop1=nn.Dropout(0.2)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = self.global_avg_pool(x)
        x = self.fc1(x)
        x = self.drop1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Combine encoder and head
class Classifier(nn.Module):
    def __init__(self, encoder, num_classes):
        super().__init__()
        self.encoder = encoder
        self.head = CustomHead(in_channels, num_classes)
    
    def forward(self, x):
        # encoder returns list of features — take last feature map
        x = self.encoder(x)[-1]
        x = self.head(x)
        return x

model = Classifier(encoder, num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False
    
    
# Access underlying ConvNeXt model
base_model = model.encoder  # unwrap the FeatureListNet

# Unfreeze last two stages
for stage in [base_model.stages_3,base_model.stages_2]:
    for param in stage.parameters():
        param.requires_grad = True

# Always keep classifier head trainable
for param in model.head.parameters():
    param.requires_grad = True


In [None]:
lr = 1e-5

layer_decay = 0.5
params = [
    {'params': model.head.parameters(), 'lr': lr * (layer_decay ** 0)}, 
    {'params': base_model.stages_3.parameters(), 'lr': lr * (layer_decay ** 1)}, 
    {'params': base_model.stages_2.parameters(), 'lr': lr * (layer_decay ** 2)}
]

In [None]:
optimizer = torch.optim.Adam(params, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()


In [None]:
summary(model, input_size=(1, 3, image_size, image_size))
# ---- Visualize model graph ----
dummy_input = torch.randn(1, 3, image_size, image_size).to(device)
output = model(dummy_input)

In [None]:
train_losses, val_losses, val_accuracies, train_accuracies = [], [], [], []


for epoch in range(num_epochs):
    model.train()
    running_loss, correct_train, total_train = 0.0, 0.0,0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)
    train_acc = 100 * correct_train / total_train
    train_accuracies.append(train_acc)


    # Validation
    model.eval()
    val_loss = 0.0
    correct, total = 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_loader)
    val_acc = 100 * correct / total
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f"Epoch {epoch+1}/{num_epochs} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | ")

    #scheduler.step(val_loss)


    # Early stopping & checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(),model_name)
        print("✅ Saved best model")
        counter = 0
    else:
        counter += 1
        print(f"⏳ Early stopping patience: {counter}/{patience}")
        if counter >= patience:
            print("Early stopping triggered")
            break


In [None]:
import os
model_dir = './saved_model'
model_name = 'finetuned_zoobot.pt'

if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

model_path = os.path.join(model_dir, model_name)
torch.save(model.state_dict(), model_path)