In [None]:
import src

from src.runner import Runner
import torch


In [None]:
loading_cfg = dict(
    batch_size=512, # Batch size for training
    num_workers=4, # Number of workers for data loading
)

optim_cfg = dict(
    type='Adam', # Optimizer type
    lr=0.001, # Learning rate for the optimizer
    weight_decay=1e-4, # Weight decay for regularization
)

backbone_cfg = dict(
    type='ResNet',
    idims=3, # Input dimensions (e.g., RGB image)
    odims=64, # Output dimensions (e.g., feature size)
    base_dims=12, # Base dimensions for the ResNet architecture
    arch=[2, 2, 2, 2], # Number of blocks in each ResNet layer
    dropout=0.2, # Dropout rate for regularization
)

device = 'cuda' if torch.is_cuda_available() else 'cpu'

# Train Vanilla EuroSAT

In [None]:
# Dataset config (if not available in Kagglehub cache, it will be downloaded)
eurosat_cfg = dict(
    type='EuroSATDataset',
    transform=[
        dict(type='Resize', size=(128, 128)),
        dict(type='ToTensor'),
    ]
)

eurosat_model_cfg = dict(
    type='EuroSATModel',
    backbone_cfg=backbone_cfg,
    head_cfg=dict(
        type='FFN',
        idims=64,
        odims=10,  # EuroSAT has 10 classes
        hidden_dims=1024,
        nlayers=6,
        dropout=0.2,
    )
)

eurosat = Runner(model=eurosat_model_cfg, dataloader_cfg=loading_cfg, dataset=eurosat_cfg, optim=optim_cfg, device=device, work_dir='results/EuroSAT')
eurosat.run(mode='train', val_interval=1, log_interval=1, epochs=100, start_epoch=1)

# Train ImageNet

In [None]:
# Dataset config for Tiny ImageNet (if not available in Kagglehub cache, it will be downloaded)
imagenet_cfg = dict(
    type='ImageNetDataset',
    transform=[
        dict(type='Resize', size=(128, 128)),
        dict(type='ToTensor'),
    ]
)

imagenet_model_cfg = dict(
    type='EuroSATModel',
    backbone_cfg=backbone_cfg,
    head_cfg=dict(
        type='FFN',
        idims=64,
        odims=200,  # Tiny ImageNet has 200 classes
        hidden_dims=1024,
        nlayers=6,
        dropout=0.2,
    )
)

imagenet = Runner(model=imagenet_model_cfg, dataloader_cfg=loading_cfg, dataset=imagenet_cfg, optim=optim_cfg, device=device, work_dir='results/TinyImageNet')
imagenet.run(mode='train', val_interval=1, log_interval=1, epochs=100, start_epoch=1)

# Use transfer learning

In [None]:
# Dataset config (if not available in Kagglehub cache, it will be downloaded)
eurosat_cfg = dict(
    type='EuroSATDataset',
    transform=[
        dict(type='Resize', size=(128, 128)),
        dict(type='ToTensor'),
    ]
)
tfl_model_cfg = dict(
    type='EuroSATModel',
    backbone_cfg=backbone_cfg,
    head_cfg=dict(
        type='FFN',
        idims=64,
        odims=10,  # EuroSAT has 10 classes
        hidden_dims=1024,
        nlayers=6,
        dropout=0.2,
    ),
    ckpt=dict(
        path = imagenet.best_model_path, # Can be replaced with a path to a pre-trained model
        load_head=False,
        load_backbone=True,
        strict=True,
    )
)

tfl_eurosat = Runner(model=tfl_model_cfg, dataloader_cfg=loading_cfg, dataset=eurosat_cfg, optim=optim_cfg, device=device, work_dir='results/tfl_eurosat')
tfl_eurosat.run(mode='train', val_interval=1, log_interval=1, epochs=100, start_epoch=1)