In [None]:
from __future__ import annotations

import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import glob
from typing import Callable

import torch
import kornia.geometry as KG
import matplotlib.pyplot as plt
import numpy as np
from torch import nn, Tensor
from torch.utils.data import ConcatDataset, DataLoader

from coin_ai.alignment.data import HPairDataset, HomographyBatch
from coin_ai.alignment.hformer import HFormer, HCorrespondences

In [None]:
def assemble_datasets(
    path: str,
    augmentation: Callable[[HomographyBatch, ], HomographyBatch],
    infer: bool = True,
    skip_identity: bool = False,
) -> ConcatDataset:
    paths = glob.glob(f"{path}/**/homographies.csv", recursive=True)

    datasets = [
        HPairDataset(
            path,
            augmentation=augmentation,
            skip_identity=skip_identity,
            infer=infer,
        )
        for path in paths
    ]
    return ConcatDataset(datasets)

In [None]:
import kornia.color as KC

def train_augmentation(batch: HomographyBatch) -> HomographyBatch:
    builder = batch.build_augmentation()
    alignment = batch.get_alignment_transform()
    augmentation = builder.random_h_4_point(scale=0.05)
    out_batch = builder.apply(alignment).apply(augmentation).build()
    return out_batch._replace(images=KC.rgb_to_grayscale(out_batch.images).repeat(1, 1, 3, 1, 1))

In [None]:
def homography_loss(homography_batch: HomographyBatch, model: Callable[[Tensor, ], HCorrespondences]) -> Tensor:
    predictions = model(homography_batch.images)
    corners_b_gt = KG.linalg.transform_points(homography_batch.H_12, predictions.corners_a)

    pred_corners_b = predictions.corners_b / homography_batch.images.shape[-1]
    corners_b_gt = corners_b_gt / homography_batch.images.shape[-1]

    return nn.functional.mse_loss(pred_corners_b, corners_b_gt)

In [None]:
path = '/Users/jatentaki/Data/archeo/coins/krzywousty-homographies'

device = torch.device('mps')
hformer = HFormer(d_target=128).to(device)
optim = torch.optim.AdamW(hformer.parameters(), lr=1e-4, weight_decay=1e-3)

dataset = assemble_datasets(path, augmentation=train_augmentation, skip_identity=True)
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=HPairDataset.collate_fn)

n = sum(p.numel() for p in hformer.parameters() if p.requires_grad)
print(f"{n:,} parameters")

In [None]:
from tqdm.auto import tqdm

losses = []
for _ in range(5):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        optim.zero_grad()
        batch = batch.to(device)
        loss = homography_loss(batch, hformer)
        loss.backward()
        optim.step()
        epoch_losses.append(loss.item())
    losses.append(np.mean(epoch_losses))
    print(losses[-1])
plt.plot(losses)

In [None]:
test_batch = next(iter(train_dataloader))

with torch.no_grad():
    prediction_test = hformer(test_batch.images.to(device))

In [None]:
def plot(batch: HomographyBatch, prediction):
    corners_a = prediction.corners_a.cpu()
    corners_b = prediction.corners_b.cpu()
    corners_b_gt = KG.linalg.transform_points(batch.H_12.to('cpu'), corners_a)

    rep = [0, 1, 2, 3, 0]
    for s in range(batch.B):
        fig, (a1, a2) = plt.subplots(1, 2)
        a1.imshow(batch.images[0, s].permute(1, 2, 0).cpu().numpy())
        a2.imshow(batch.images[1, s].permute(1, 2, 0).cpu().numpy())
        a1.plot(corners_a[s, :, 0][rep], corners_a[s, :, 1][rep], 'r--')
        a2.plot(corners_b[s, :, 0][rep], corners_b[s, :, 1][rep], 'r--')
        a2.plot(corners_b_gt[s, :, 0][rep], corners_b_gt[s, :, 1][rep], 'g--')
    
plot(test_batch, prediction_test)