In [1]:
from typing import cast

from mads_datasets.settings import ImgDatasetSettings, FileTypes
from mads_datasets.factories.torchfactories import ImgDataset
from mads_datasets.base import AbstractDatasetFactory, DatasetProtocol
from mads_datasets.datatools import iter_valid_paths
from pydantic import HttpUrl
from pathlib import Path

from torch import nn

from torchvision import transforms
import torch

import hashlib
import random
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
    Any,
    Callable,
    Generic,
    Iterator,
    List,
    Mapping,
    Optional,
    Protocol,
    Sequence,
    Tuple,
    TypeVar,
)

import numpy as np
from loguru import logger

from mads_datasets.datatools import create_headers, get_file
from mads_datasets.settings import DatasetSettings, SecureDatasetSettings

if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = "cpu"
    logger.warning("This model will take 15-20 minutes on CPU. Consider using accelaration, eg with google colab (see button on top of the page)")
logger.info(f"Using {device}")

eurosatsettings = ImgDatasetSettings(
    dataset_url=cast(
        HttpUrl,
        "https://huggingface.co/datasets/torchgeo/eurosat/resolve/c877bcd43f099cd0196738f714544e355477f3fd/EuroSAT.zip",
    ),
    filename=Path("EuroSAT_RGB.zip"),
    name="EuroSAT_RGB",
    unzip=True,
    formats=[FileTypes.JPG],
    trainfrac=0.8,
    img_size=(64, 64),
    digest="c8fa014336c82ac7804f0398fcb19387",
)

data_transforms = transforms.Compose([
    # transforms.RandomResizedCrop(224),
    # transforms.RandomHorizontalFlip(),
    # transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

class EurosatDatasetFactory(AbstractDatasetFactory[ImgDatasetSettings]):
    def create_dataset(
        self, *args: Any, **kwargs: Any
    ) -> Mapping[str, DatasetProtocol]:
        self.download_data()
        formats = self._settings.formats
        paths_, class_names = iter_valid_paths(
            self.subfolder / "2750", formats=formats 
        )
        paths = [*paths_]
        random.shuffle(paths)
        trainidx = int(len(paths) * self._settings.trainfrac)
        train = paths[:trainidx]
        valid = paths[trainidx:]
        traindataset = ImgDataset(train, class_names, img_size=self._settings.img_size)
        validdataset = ImgDataset(valid, class_names, img_size=self._settings.img_size)
        return {"train": traindataset, "valid": validdataset}

eurosatfactory = EurosatDatasetFactory(eurosatsettings, datadir=Path.home() / ".cache/mads_datasets")

class AugmentPreprocessor():
    def __init__(self, transform):
        self.transform = transform
    def __call__(self, batch: list[tuple]) -> tuple[torch.Tensor, torch.Tensor]:
        X, y = zip(*batch)
        X = [self.transform(x) for x in X]
        return torch.stack(X), torch.stack(y)
    
streamers = eurosatfactory.create_datastreamer(batchsize=32)

trainprocessor= AugmentPreprocessor(data_transforms)
# validprocessor = AugmentPreprocessor(data_transforms["val"])

train = streamers["train"]
valid = streamers["valid"]
train.preprocessor = trainprocessor
valid.preprocessor = trainprocessor
trainstreamer = train.stream()
validstreamer = valid.stream()

import torchvision
from torchvision.models import resnet18, ResNet18_Weights
resnet = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)

[32m2025-05-26 17:56:23.234[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m46[0m - [1mUsing cuda[0m
[32m2025-05-26 17:56:23.235[0m | [1mINFO    [0m | [36mmads_datasets.base[0m:[36mdownload_data[0m:[36m121[0m - [1mFolder already exists at /home/marnix/.cache/mads_datasets/EuroSAT_RGB[0m


In [2]:
image_datasets = {"train" : streamers["train"].dataset,
                    "val" : streamers["valid"].dataset}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = streamers["train"].dataset.class_names
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                                shuffle=True, num_workers=4)
                for x in ['train', 'val']}

In [3]:
from pathlib import Path
import requests
import zipfile
from loguru import logger

from torchvision import datasets

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory

In [4]:
def train(model, dataloader, lossfn, optimizer, device):
    model.train()
    train_loss: float = 0.0
    train_acc: float = 0.0
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        yhat = model(x)
        loss = lossfn(yhat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, acc = torch.max(yhat, 1)
        train_acc += torch.sum(acc == y.data)
    return train_loss, train_acc

def test(model, dataloader, lossfn, optimizer, scheduler, device):
    model.eval()
    test_loss: float = 0.0
    test_acc: float = 0.0
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)
        yhat = model(x)
        loss = lossfn(yhat, y)
        test_loss += loss.item()
        _, acc = torch.max(yhat, 1)
        test_acc += torch.sum(acc == y.data)
    scheduler.step(test_loss)
    return test_loss, test_acc


def train_model(model, lossfn, optimizer, scheduler, num_epochs, dataloaders, dataset_sizes, device):
    with TemporaryDirectory() as tempdir:
        best_model_params_path = Path(tempdir) / 'best_model_params.pt'
        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0 # we will only save the best model

        for epoch in range(num_epochs):
            logger.info(f'Epoch {epoch}/{num_epochs - 1}')
            train_loss, train_acc = train(model, dataloaders['train'], lossfn, optimizer, device)
            train_loss = train_loss / dataset_sizes['train']
            train_acc = train_acc / dataset_sizes['train']
            logger.info(f'Train Loss: {train_loss:.4f} Accuracy: {train_acc:.4f}')
            test_loss, test_acc = test(model, dataloaders['val'], lossfn, optimizer, scheduler, device)
            test_loss = test_loss / dataset_sizes['val']
            test_acc = test_acc / dataset_sizes['val']
            logger.info(f'Test Loss: {test_loss:.4f} Accuracy: {test_acc:.4f}')
            if test_acc > best_acc:
                    best_acc = test_acc
                    logger.info(f"New best accuracy: {best_acc:.4f}, saving model")
                    torch.save(model.state_dict(), best_model_params_path)
        model.load_state_dict(torch.load(best_model_params_path))
        return model
    
def visualize_model(model, num_images=6):
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                title = class_names[preds[j]]
                ax.set_title(f'predicted: {title}')
                imshow(inputs.cpu().data[j], title)

                if images_so_far == num_images:
                    return

In [21]:
model = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model.fc = nn.Linear(num_ftrs, 10)

model = model.to(device)
lossfn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs

settings = {'step_size' : 7, 'gamma' : 0.1}
scheduler = lr_scheduler.StepLR(optimizer, **settings)

In [22]:
dataloaders["train"]

<torch.utils.data.dataloader.DataLoader at 0x79b1e1f291f0>

In [23]:
layer_groups = [
    [],  # Custom head
    ['layer4'],  # Last ResNet block
    ['layer3'],  # Third ResNet block
    ['layer2'],  # Second ResNet block
    ['layer1', 'conv1', 'bn1']  # First layers
]

In [24]:
for param in model.parameters():
    param.requires_grad = False

In [25]:
for param in model.fc.parameters():
    param.requires_grad = True

In [26]:
import time
def time_convert(sec):
  mins = sec // 60
  sec = sec % 60
  hours = mins // 60
  mins = mins % 60
  return "{0}:{1}:{2}".format(int(hours),int(mins),sec)

In [27]:
for stage, layers_to_unfreeze in enumerate(layer_groups):
    print(f"\nStage {stage + 1}: Unfreezing {layers_to_unfreeze}")
    
    # Unfreeze specified layers
    for layer_name in layers_to_unfreeze:
        layer = getattr(model, layer_name)
    for param in layer.parameters():
        param.requires_grad = True
    
    # Update optimizer to include new parameters
    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=1e-3 * (0.5 ** stage)  # Decay learning rate
    )

    start = time.time()
    epochs = 5
    model = train_model(
        model = model,
        lossfn = lossfn,
        optimizer = optimizer,
        scheduler = scheduler,
        num_epochs = epochs,
        dataloaders = dataloaders,
        dataset_sizes = dataset_sizes,
        device = device,
    )
    end = time.time()
    logger.success(f"Done! layer {layers_to_unfreeze} trained for {epochs} epochs in {time_convert(end - start)}")

[32m2025-05-26 18:04:42.211[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 0/4[0m



Stage 1: Unfreezing []


[32m2025-05-26 18:04:55.582[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.3280 Accuracy: 0.5521[0m
[32m2025-05-26 18:04:58.399[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.2118 Accuracy: 0.7522[0m
[32m2025-05-26 18:04:58.400[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew best accuracy: 0.7522, saving model[0m
[32m2025-05-26 18:04:58.431[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 1/4[0m
[32m2025-05-26 18:05:10.969[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.2725 Accuracy: 0.6286[0m
[32m2025-05-26 18:05:13.685[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.1966 Accuracy: 0.7702[0m
[32m2025-05-26 18:05:13.686[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew bes


Stage 2: Unfreezing ['layer4']


[32m2025-05-26 18:06:19.756[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.2208 Accuracy: 0.7028[0m
[32m2025-05-26 18:06:22.631[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.2029 Accuracy: 0.7622[0m
[32m2025-05-26 18:06:22.632[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew best accuracy: 0.7622, saving model[0m
[32m2025-05-26 18:06:22.664[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 1/4[0m
[32m2025-05-26 18:06:44.503[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.1795 Accuracy: 0.7584[0m
[32m2025-05-26 18:06:47.363[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.1833 Accuracy: 0.7698[0m
[32m2025-05-26 18:06:47.364[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew bes


Stage 3: Unfreezing ['layer3']


[32m2025-05-26 18:08:29.423[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.1343 Accuracy: 0.8226[0m
[32m2025-05-26 18:08:32.251[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.1151 Accuracy: 0.8661[0m
[32m2025-05-26 18:08:32.252[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew best accuracy: 0.8661, saving model[0m
[32m2025-05-26 18:08:32.282[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 1/4[0m
[32m2025-05-26 18:08:58.761[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.1241 Accuracy: 0.8410[0m
[32m2025-05-26 18:09:01.611[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.1153 Accuracy: 0.8528[0m
[32m2025-05-26 18:09:01.612[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 2


Stage 4: Unfreezing ['layer2']


[32m2025-05-26 18:10:59.684[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.1010 Accuracy: 0.8705[0m
[32m2025-05-26 18:11:02.552[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.0738 Accuracy: 0.9057[0m
[32m2025-05-26 18:11:02.553[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew best accuracy: 0.9057, saving model[0m
[32m2025-05-26 18:11:02.584[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 1/4[0m
[32m2025-05-26 18:11:30.528[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.0963 Accuracy: 0.8759[0m
[32m2025-05-26 18:11:33.646[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.0736 Accuracy: 0.9122[0m
[32m2025-05-26 18:11:33.647[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew bes


Stage 5: Unfreezing ['layer1', 'conv1', 'bn1']


[32m2025-05-26 18:13:36.496[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.0843 Accuracy: 0.8932[0m
[32m2025-05-26 18:13:39.518[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.0588 Accuracy: 0.9244[0m
[32m2025-05-26 18:13:39.519[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew best accuracy: 0.9244, saving model[0m
[32m2025-05-26 18:13:39.548[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 1/4[0m
[32m2025-05-26 18:14:09.222[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.0822 Accuracy: 0.8948[0m
[32m2025-05-26 18:14:12.383[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.0708 Accuracy: 0.9157[0m
[32m2025-05-26 18:14:12.384[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 2

In [31]:
model = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model.fc = nn.Linear(num_ftrs, 10)

model = model.to(device)
lossfn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs

settings = {'step_size' : 7, 'gamma' : 0.1}
scheduler = lr_scheduler.StepLR(optimizer, **settings)

In [32]:
for param in model.parameters():
    param.requires_grad = False

In [33]:
for param in model.fc.parameters():
    param.requires_grad = True

In [34]:
start = time.time()
epochs = 25
model = train_model(
    model = model,
    lossfn = lossfn,
    optimizer = optimizer,
    scheduler = scheduler,
    num_epochs = epochs,
    dataloaders = dataloaders,
    dataset_sizes = dataset_sizes,
    device = device,
)
end = time.time()
logger.success(f"Done! layer {layers_to_unfreeze} trained for {epochs} epochs in {time_convert(end - start)}")

[32m2025-05-26 18:30:21.619[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 0/24[0m
[32m2025-05-26 18:30:35.240[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.3678 Accuracy: 0.5613[0m
[32m2025-05-26 18:30:37.991[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.3418 Accuracy: 0.6972[0m
[32m2025-05-26 18:30:37.991[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew best accuracy: 0.6972, saving model[0m
[32m2025-05-26 18:30:38.022[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 1/24[0m
[32m2025-05-26 18:30:50.323[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.3397 Accuracy: 0.5957[0m
[32m2025-05-26 18:30:53.058[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.3850 Accuracy: 0.

In [35]:
model = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model.fc = nn.Linear(num_ftrs, 10)

model = model.to(device)
lossfn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs

settings = {'step_size' : 7, 'gamma' : 0.1}
scheduler = lr_scheduler.StepLR(optimizer, **settings)

In [36]:
for param in model.parameters():
    param.requires_grad = True
for param in model.fc.parameters():
    param.requires_grad = True


In [37]:
start = time.time()
epochs = 25
model = train_model(
    model = model,
    lossfn = lossfn,
    optimizer = optimizer,
    scheduler = scheduler,
    num_epochs = epochs,
    dataloaders = dataloaders,
    dataset_sizes = dataset_sizes,
    device = device,
)
end = time.time()
logger.success(f"Done! layer {layers_to_unfreeze} trained for {epochs} epochs in {time_convert(end - start)}")

[32m2025-05-26 18:37:30.035[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 0/24[0m
[32m2025-05-26 18:38:11.339[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.2396 Accuracy: 0.7038[0m
[32m2025-05-26 18:38:14.824[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.0891 Accuracy: 0.8978[0m
[32m2025-05-26 18:38:14.825[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m52[0m - [1mNew best accuracy: 0.8978, saving model[0m
[32m2025-05-26 18:38:14.858[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m41[0m - [1mEpoch 1/24[0m
[32m2025-05-26 18:38:54.921[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m45[0m - [1mTrain Loss: 0.1343 Accuracy: 0.8315[0m
[32m2025-05-26 18:38:58.017[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_model[0m:[36m49[0m - [1mTest Loss: 0.0840 Accuracy: 0.