In [None]:
import os
from dataset import EuroSat
from datasets import load_dataset
from train import batch_train
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

Setting the parameters

In [None]:
# Device initiation and setting seed
device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
torch.manual_seed(1)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(1)
print(f"Device: {device}")

In [None]:
# Parameters
batch_size = 64
learning_rate = 3e-5
betas=(0.9, 0.999)
num_epochs = 20
eps =1e-08
weight_decay=0.01
power=1.0
es_patience = 5

Getting eurosat data from hugging face

In [None]:
# Data
train_data = EuroSat(load_dataset("cm93/eurosat", split='train'))
val_data = EuroSat(load_dataset("cm93/eurosat", split='validation'))

# Data loaders
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2)

Model Fine Tuning

In [None]:
# Saving models to the model folder
os.makedirs("models", exist_ok=True)

model = timm.create_model('resnet18', pretrained=True, num_classes=10)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=num_epochs, power=power)
criterion = nn.CrossEntropyLoss()

In [None]:
# Fine tuning the model
patience_counter = 0
best_vloss = float('inf')
writer = SummaryWriter('models/runs/ResNet_18')

for epoch in range(num_epochs):
    print(f"Epoch: {epoch+1}")
    avg_loss, avg_vloss, acc, vacc = batch_train(
        model = model,
        device = device,
        train_loader = train_dataloader,
        val_loader = val_dataloader,
        optimizer = optimizer,
        criterion = criterion,
        scheduler = scheduler
    )

    writer.add_scalars('Loss', {'Training' : avg_loss, 'Validation' : avg_vloss}, epoch+1)
    writer.add_scalars('Accuracy', {'Training' : acc, 'Validation': vacc}, epoch+1)
    writer.flush()

    # Model Saving
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        patience_counter = 0
        model_path = os.path.join('models','resnet_18_ft_{}.pth'.format(epoch+1))
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            }, model_path)
    else:
        patience_counter += 1

    if patience_counter >= es_patience:
        print("Early stopping triggered")
        break

torch.cuda.empty_cache()