In [None]:
from __future__ import annotations

import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import torch
import lightning as pl
import kornia.geometry as KG
import matplotlib.pyplot as plt
from torch import Tensor

from coin_ai.alignment.data import HomographyBatch, CoinDataModule, AugmentationBuilder
from coin_ai.alignment.hformer import HFormer, HCorrespondences

In [None]:
def plot_predictions(
    batch: HomographyBatch, correspondences: HCorrespondences
) -> list[plt.Figure]:
    corners_a = correspondences.corners_a.cpu()
    corners_b = correspondences.corners_b.cpu()
    corners_b_gt = KG.linalg.transform_points(batch.H_12.to("cpu"), corners_a)

    rep = [0, 1, 2, 3, 0]
    figures = []
    for s in range(batch.B):
        fig, (a1, a2) = plt.subplots(1, 2)
        a1.imshow(batch.images[s, 0].permute(1, 2, 0).cpu().numpy())
        a2.imshow(batch.images[s, 1].permute(1, 2, 0).cpu().numpy())
        a1.plot(corners_a[s, :, 0][rep], corners_a[s, :, 1][rep], "r--")
        a2.plot(corners_b[s, :, 0][rep], corners_b[s, :, 1][rep], "r--")
        a2.plot(corners_b_gt[s, :, 0][rep], corners_b_gt[s, :, 1][rep], "g--")
        a1.grid(True)
        a2.grid(True)

        figures.append(fig)

    return figures

In [None]:
class CoinLearner(pl.LightningModule):
    def __init__(self, model: HFormer, lr: float = 1e-4):
        super().__init__()
        self.model = model
        self.lr = lr

    def forward(self, batch: HomographyBatch) -> Tensor:
        return self.model(batch.images)

    def build_batch(self, batch: AugmentationBuilder) -> HomographyBatch:
        return batch.to("cpu").build().to(self.device)

    def training_step(self, raw_batch: AugmentationBuilder, batch_idx: int) -> Tensor:
        batch = self.build_batch(raw_batch)
        loss = self.model.loss(batch)
        self.log("train_loss", loss, batch_size=batch.B)
        return loss

    def validation_step(self, raw_batch: AugmentationBuilder, batch_idx: int) -> Tensor:
        batch = self.build_batch(raw_batch)
        loss = self.model.loss(batch)
        self.log("val_loss", loss, batch_size=batch.B)

        correspondences = self(batch)
        figures = plot_predictions(batch, correspondences)

        for i, fig in enumerate(figures):
            self.logger.experiment.add_figure(
                f"val_{batch_idx}_{i}", fig, self.current_epoch
            )

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [None]:
path = "/Users/jatentaki/Data/archeo/coins/krzywousty-homographies"

hformer = HFormer(d_target=128)
learner = CoinLearner(hformer)

data_module = CoinDataModule(
    train_root=f"{path}/split/train/Awers - stempel a.07",
    val_root=f"{path}/split/train/Awers - stempel a.07",
    train_replicate=64,
    batch_size=8,
    num_workers=0,
)

In [None]:
from lightning.pytorch.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="my_model")


trainer = pl.Trainer(
    max_epochs=10,
    accelerator="mps",
    logger=logger,
)


trainer.fit(learner, data_module)