In [None]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
import lightning.pytorch as pl
from kornia import augmentation as aug
from torchvision.transforms.v2 import ToDtype

from coin_ai.models.readout import AttentionReadout, DinoWithHead
from coin_ai.data.memory_data import InMemoryCoinDataset, build_coin_types, DataloaderAdapter, FlipAdapter, MemorySlab
from coin_ai.learner import LightningLearner
from coin_ai.losses import MarginLoss
from coin_ai.metrics import AccuracyMetric

In [None]:
train_augmentation = aug.container.AugmentationSequential(
    ToDtype(torch.float32, scale=True),
    aug.RandomPerspective(distortion_scale=0.75),
    aug.RandomResizedCrop((224, 224), same_on_batch=False, scale=(0.75, 1.0)),
    aug.ColorJiggle(0.2, 0.2, 0.2),
    aug.RandomGrayscale(p=1.0),
)

val_augmentation = aug.container.AugmentationSequential(
    ToDtype(torch.float32, scale=True),
    aug.RandomResizedCrop((224, 224), same_on_batch=False, scale=(0.98, 1.0)),
    aug.RandomGrayscale(p=1.0),
)

In [None]:
device = 'mps'

data_root = '/Users/jatentaki/Data/archeo/coins/Cropped'
coin_types = build_coin_types(data_root)
memory_slab = MemorySlab(coin_types)
train_dataset = FlipAdapter(InMemoryCoinDataset(coin_types, batch_size=32, augmentation=train_augmentation, memory_slab=memory_slab))
val_dataset = InMemoryCoinDataset(coin_types, batch_size=32, augmentation=val_augmentation, memory_slab=memory_slab)
val_dataset_flip = FlipAdapter(val_dataset)
    
train_dataloader = DataloaderAdapter(train_dataset, 0, device, 10, reseed=True)
val_dataloader = DataloaderAdapter(val_dataset, 1, device, 1, reseed=False)
val_dataloader_flip = DataloaderAdapter(val_dataset_flip, 1, device, 1, reseed=False)

In [None]:
head = AttentionReadout()
model = DinoWithHead(head)
loss_fn = MarginLoss()
metric_fn = AccuracyMetric(similarity=loss_fn.similarity)
learner = LightningLearner(model, loss_fn, metric_fn)

In [None]:
save_callback_0 = pl.callbacks.ModelCheckpoint(monitor='val/acc_at_1/dataloader_idx_0', mode='max', filename='val_0_acc_at_1')
save_callback_1 = pl.callbacks.ModelCheckpoint(monitor='val/acc_at_1/dataloader_idx_1', mode='max', filename='val_1_acc_at_1')

trainer = pl.Trainer(accelerator=device, max_epochs=10, callbacks=[save_callback_0, save_callback_1])
trainer.fit(learner, train_dataloader, (val_dataloader, val_dataloader_flip))