In [None]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
import lightning.pytorch as pl
from kornia import augmentation as aug
from torchvision.transforms.v2 import ToDtype

from coin_ai.models.readout import AttentionReadout, DinoWithHead
from coin_ai.data.memory_data import InMemoryCoinDataset, build_coin_types
from coin_ai.losses import MarginLoss
from coin_ai.metrics import AccuracyMetric

In [None]:
train_augmentation = aug.container.AugmentationSequential(
    ToDtype(torch.float32, scale=True),
    aug.RandomPerspective(distortion_scale=0.75),
    aug.RandomResizedCrop((224, 224), same_on_batch=False, scale=(0.75, 1.0)),
    aug.ColorJiggle(0.2, 0.2, 0.2),
    aug.RandomGrayscale(p=1.0),
)

val_augmentation = aug.container.AugmentationSequential(
    ToDtype(torch.float32, scale=True),
    aug.RandomResizedCrop((224, 224), same_on_batch=False, scale=(0.98, 1.0)),
    aug.RandomGrayscale(p=1.0),
)

In [None]:
class LightningLearner(pl.LightningModule):
    def __init__(self, model, loss_fn, metric_fn):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.metric_fn = metric_fn

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        embeddings = self.model(images)
        loss = self.loss_fn(embeddings, labels)
        self.log('train/loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        embeddings = self.model(images)
        metrics = self.metric_fn(embeddings, labels)
        for k, v in metrics.items():
            self.log(f'val/{k}', v)

    def configure_optimizers(self):
        decay_params = [p for n, p in self.model.named_parameters() if 'weight' in n and p.requires_grad]
        non_decay_params = [p for n, p in self.model.named_parameters() if 'weight' not in n and p.requires_grad]

        optimizer = torch.optim.AdamW([
            {'params': decay_params, 'weight_decay': 1e-2},
            {'params': non_decay_params, 'weight_decay': 0.0},
        ], lr=1e-4)

        return optimizer

In [None]:
device = 'mps'

data_root = '/Users/jatentaki/Data/archeo/coins/Cropped'
coin_types = build_coin_types(data_root)
train_dataset = InMemoryCoinDataset(coin_types, batch_size=32, augmentation=train_augmentation)
val_dataset = InMemoryCoinDataset(coin_types, batch_size=32, augmentation=val_augmentation)

from dataclasses import dataclass

@dataclass
class DatasetAdapter:
    dataset: InMemoryCoinDataset
    init_seed: int
    device: torch.device
    n_batches: int
    
    def __iter__(self):
        return self.dataset.iterate(self.init_seed, self.n_batches, self.device)
    
    def __len__(self):
        return self.n_batches
    
train_dataloader = DatasetAdapter(train_dataset, 0, device, 100)
val_dataloader = DatasetAdapter(val_dataset, 1, device, 100)

In [None]:
head = AttentionReadout()
model = DinoWithHead(head)
loss_fn = MarginLoss()
metric_fn = AccuracyMetric(similarity=loss_fn.similarity)
learner = LightningLearner(model, loss_fn, metric_fn)

In [None]:
trainer = pl.Trainer(accelerator=device, max_epochs=10)
trainer.fit(learner, train_dataloader, val_dataloader)