In [None]:
from __future__ import annotations

import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import matplotlib.pyplot as plt
import numpy as np

from coin_ai.alignment.data import HPairDataset, HomographyBatch

In [None]:
def draw_corners(homography_batch: HomographyBatch):
    assert homography_batch.B == 1

    fig, (a1, a2) = plt.subplots(1, 2)
    im1, im2 = homography_batch.images.squeeze(1).permute(0, 2, 3, 1).cpu().numpy()
    a1.imshow(im1)
    a2.imshow(im2)

    corners_base = homography_batch.corners.squeeze(0).cpu().numpy()
    corners_base = corners_base / 2 + 128
    corners_base = np.concatenate([corners_base, np.ones((4, 1))], axis=1)
    corners_warp = corners_base @ homography_batch.H_12.squeeze(0).cpu().numpy().T

    corners_warp /= corners_warp[:, 2:]
    a2.plot(corners_warp[:, 0], corners_warp[:, 1], 'ro')
    a1.plot(corners_base[:, 0], corners_base[:, 1], 'ro')

    return fig

In [None]:
from torch.utils.data import ConcatDataset, DataLoader
import glob

def assemble_datasets(path: str) -> ConcatDataset:
    paths = glob.glob(f"{path}/**/homographies.csv", recursive=True)

    datasets = [HPairDataset(path) for path in paths]
    return ConcatDataset(datasets)

In [None]:
concat_ds = assemble_datasets('/Users/jatentaki/Data/archeo/coins/krzywousty-homographies')
dataloader = DataLoader(concat_ds, batch_size=8, shuffle=True, collate_fn=HPairDataset.collate_fn)
print(len(concat_ds))
batch = next(iter(dataloader))

In [None]:
#resized = batch.build_augmentation().resize((256, 256)).build()
builder = batch.build_augmentation()
with_aug = builder.apply(batch.get_alignment_transform()).build()#.apply(builder.random_h_4_point(scale=0.025)).build()
#with_aug = batch

In [None]:
for i in range(with_aug.B):
    fig = draw_corners(with_aug.slice(i, i + 1))
    #fig.suptitle(f"{i}")

plt.show()

In [None]:
from dataclasses import dataclass
from typing import Callable

import torch
from einops import rearrange, repeat
from torch import nn, Tensor

from coin_ai.alignment.infra import DenseDino

@dataclass
class HCorrespondences:
    corners_a: Tensor # B x 4 x 2
    corners_b: Tensor # B x 4 x 2

def homography_loss(homography_batch: HomographyBatch, model: Callable[[Tensor, ], HCorrespondences]) -> Tensor:
    t, b, c, h, w = homography_batch.images.shape
    norm = torch.tensor([w, h, 1], dtype=torch.float32, device=homography_batch.images.device).view(1, 3, 1)
    H_12_norm = homography_batch.H_12 / norm

    predictions = model(homography_batch.images)
    corners_a_h = torch.cat([predictions.corners_a, torch.ones(1, 4, 1, device=H_12_norm.device)], dim=-1)

    corners_b_gt = (corners_a_h @ H_12_norm.mT)[..., :2]

    return nn.functional.mse_loss(predictions.corners_b, corners_b_gt)

In [None]:
class QueryInit(nn.Module):
    def __init__(self, d_memory: int, d_target: int):
        super().__init__()
        self.n_heads = 8
        self.q = nn.Parameter(torch.randn(self.n_heads, 4, d_memory))
        self.v = nn.Linear(d_memory, d_target)
        self.o = nn.Linear(d_target, d_target)

    def forward(self, src_features: Tensor) -> Tensor:
        b, n, c = src_features.shape
        q = repeat(self.q, 'h q c -> b h q c', b=b)
        k = rearrange(src_features, 'b n c -> b 1 n c') # full attention
        v = rearrange(self.v(src_features), 'b n (h c) -> b h n c', h=self.n_heads)
        queries = nn.functional.scaled_dot_product_attention(q, k, v)
        return self.o(rearrange(queries, 'b h n c -> b n (h c)'))

class MLP(nn.Module):
    def __init__(self, d_io: int, d_ff: int | None = None):
        super().__init__()

        if d_ff is None:
            d_ff = 4 * d_io

        self.fc1 = nn.Linear(d_io, d_ff)
        self.fc2 = nn.Linear(d_ff, d_io)
    
    def forward(self, x: Tensor) -> Tensor:
        return self.fc2(nn.functional.gelu(self.fc1(x)))
    
class CrossAttention(nn.Module):
    def __init__(self, d_src: int, d_tgt: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.q = nn.Linear(d_tgt, d_tgt)
        self.kv = nn.Linear(d_src, 2 * d_tgt)
        self.o = nn.Linear(d_tgt, d_tgt)
    
    def forward(self, target: Tensor, memory: Tensor) -> Tensor:
        b, q, c = target.shape

        q = rearrange(self.q(target), 'b q (h c) -> b h q c', h=self.n_heads)
        k, v = rearrange(self.kv(memory), 'b n (t h c) -> t b h n c', h=self.n_heads, t=2)

        out = nn.functional.scaled_dot_product_attention(q, k, v)
        return self.o(rearrange(out, 'b h n c -> b n (h c)'))

class CrossAttentionBlock(nn.Module):
    def __init__(self, d_memory: int, d_target: int, n_heads: int, d_ff: int | None = None):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_target)
        self.cross_attention = CrossAttention(d_src=d_memory, d_tgt=d_target, n_heads=n_heads)
        self.norm2 = nn.LayerNorm(d_target)
        self.mlp = MLP(d_target, d_ff)
    
    def forward(self, queries: Tensor, memory: Tensor) -> Tensor:
        q = queries
        q = q + self.cross_attention(self.norm1(q), memory)
        q = q + self.mlp(self.norm2(q))
        return q

class HFormer(nn.Module):
    def __init__(self, n_layers: int = 3, d_target: int = 128, deformation_scale: float = 0.5):
        super().__init__()

        self.dino = DenseDino()
        self.dino.requires_grad_(False)
        d_memory = 384
        self.n_heads = 8
        self.query_init = QueryInit(d_memory=d_memory, d_target=d_target)
        self.attn_blocks = nn.ModuleList([
            CrossAttentionBlock(d_memory=d_memory, d_target=d_target, n_heads=self.n_heads)
            for _ in range(2 * n_layers - 1)
        ])

        self.final_norm = nn.LayerNorm(d_target)
        self.xy_head = nn.Linear(d_target, 2, bias=False)

        self.register_buffer(
            'corners_a',
            torch.tensor([
                [0., 0.],
                [1., 0.],
                [1., 1.],
                [0., 1.]
            ], dtype=torch.float32).reshape(1, 4, 2),
        )
        self.deformation_scale = deformation_scale
    
    def forward(self, images: Tensor) -> HCorrespondences:
        src_feat, dst_feat = self._get_features(images)
        q = self.query_init(src_feat)

        for i, block in enumerate(self.attn_blocks):
            if i % 2 == 0:
                memory = dst_feat
            else:
                memory = src_feat

            q = block(q, memory)
        
        offset = torch.tanh(self.xy_head(self.final_norm(q)))

        corners_b = self.corners_a + self.deformation_scale * offset
        return HCorrespondences(corners_a=self.corners_a, corners_b=corners_b)
    
    def _get_features(self, images: Tensor) -> tuple[Tensor, Tensor]:
        """
        Images: 2 x B x C x H x W
        
        Returns:
            src_feat: B x N x C
            dst_feat: B x N x C
        """
        t, b, c, h, w = images.shape
        assert t == 2
        assert c == 3

        images_flat = rearrange(images, 't b c h w -> (t b) c h w')
        with torch.no_grad():
            features_flat: Tensor = self.dino(images_flat)
        src_feat, dst_feat = rearrange(features_flat, '(t b) h w c -> t b (h w) c', b=b, t=2)
        return src_feat, dst_feat

    def loss_fn(self, homography_batch: HomographyBatch) -> Tensor:
        return homography_loss(homography_batch, self)

In [None]:
device = torch.device('mps')
hformer = HFormer().to(device)
optim = torch.optim.AdamW(hformer.parameters(), lr=1e-4, weight_decay=1e-3)

In [None]:
n = sum(p.numel() for p in hformer.parameters() if p.requires_grad)
print(f"{n:,} parameters")

In [None]:
for batch in dataloader:
    optim.zero_grad()
    batch = batch.to(device)
    loss = hformer.loss_fn(batch)
    loss.backward()
    optim.step()
    print(loss.item())

    if loss.item() < 0.01:
        break