In [None]:
from __future__ import annotations

import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import matplotlib.pyplot as plt
import numpy as np

from coin_ai.alignment.data import HPairDataset, HomographyBatch

In [None]:
def draw_corners(homography_batch: HomographyBatch):
    assert homography_batch.B == 1

    fig, (a1, a2) = plt.subplots(1, 2)
    im1, im2 = homography_batch.images.squeeze(1).permute(0, 2, 3, 1).cpu().numpy()
    a1.imshow(im1)
    a2.imshow(im2)

    corners_base = homography_batch.corners.squeeze(0).cpu().numpy()
    corners_base = corners_base / 2 + 128
    corners_base = np.concatenate([corners_base, np.ones((4, 1))], axis=1)
    corners_warp = corners_base @ homography_batch.H_12.squeeze(0).cpu().numpy().T

    corners_warp /= corners_warp[:, 2:]
    a2.plot(corners_warp[:, 0][[0, 1, 2, 3, 0]], corners_warp[:, 1][[0, 1, 2, 3, 0]], 'r--')
    a1.plot(corners_base[:, 0][[0, 1, 2, 3, 0]], corners_base[:, 1][[0, 1, 2, 3, 0]], 'r--')

    return fig

In [None]:
from scipy.spatial.transform import Rotation as R

r = R.from_euler('xyz', [0, 0, 10], degrees=True)
r.as_matrix()

In [None]:
from typing import Callable
from torch.utils.data import ConcatDataset, DataLoader
import glob

import torch

def augmentation(batch: HomographyBatch) -> HomographyBatch:
    #scale = torch.eye(3)
    #scale[0, 1] = 0.1
    #scale[:2, :2] = torch.tensor(r.as_matrix(), dtype=torch.float32)
    scale = torch.from_numpy(r.as_matrix()).to(torch.float32)
    transform = torch.stack([
        torch.eye(3),
        scale,
    ]).unsqueeze(1)

    #return batch.build_augmentation().apply(batch.get_alignment_transform()).apply(transform).build()
    return batch.build_augmentation().apply(transform).build()

def no_aug(batch: HomographyBatch) -> HomographyBatch:
    #return batch.build_augmentation().apply(batch.get_alignment_transform()).build()
    return batch.build_augmentation().build()

def assemble_datasets(path: str, augmentation: Callable[[HomographyBatch, ], HomographyBatch]) -> ConcatDataset:
    paths = glob.glob(f"{path}/**/homographies.csv", recursive=True)

    datasets = [HPairDataset(path, augmentation=augmentation, skip_identity=True, infer=False) for path in paths]
    return ConcatDataset(datasets)

In [None]:
dataset_aug = assemble_datasets('/Users/jatentaki/Data/archeo/coins/krzywousty-homographies', augmentation=augmentation)
dataset_base = assemble_datasets('/Users/jatentaki/Data/archeo/coins/krzywousty-homographies', augmentation=no_aug)
dataloader_aug = DataLoader(dataset_aug, batch_size=8, shuffle=False, collate_fn=HPairDataset.collate_fn)
dataloader_base = DataLoader(dataset_base, batch_size=8, shuffle=False, collate_fn=HPairDataset.collate_fn)
batch_aug = next(iter(dataloader_aug))
batch_base = next(iter(dataloader_base))

In [None]:
# for i in range(batch.B):
#     fig = draw_corners(batch.slice(i, i + 1))
# plt.show()

In [None]:
from dataclasses import dataclass

import torch
import kornia.geometry as KG
from einops import rearrange, repeat
from torch import nn, Tensor

from coin_ai.alignment.infra import DenseDino

@dataclass
class HCorrespondences:
    corners_a: Tensor # B x 4 x 2
    corners_b: Tensor # B x 4 x 2
    
    def implied_homography(self) -> Tensor:
        return KG.get_perspective_transform(self.corners_a.cpu(), self.corners_b.cpu())
    
def compute_gt_corners(H_12: Tensor, corners: Tensor) -> Tensor:
    corners_h = torch.cat([
        corners,
        torch.ones(H_12.shape[0], 4, 1, device=H_12.device),
    ], dim=-1)

    return (corners_h @ H_12.mT)[..., :2]

def homography_loss(homography_batch: HomographyBatch, model: Callable[[Tensor, ], HCorrespondences]) -> Tensor:
    predictions = model(homography_batch.images)
    corners_b_gt = compute_gt_corners(homography_batch.H_12, predictions.corners_a)

    return nn.functional.mse_loss(predictions.corners_b, corners_b_gt)

In [None]:
class QueryInit(nn.Module):
    def __init__(self, d_memory: int, d_target: int):
        super().__init__()
        self.n_heads = 8
        self.q = nn.Parameter(torch.randn(self.n_heads, 4, d_memory))
        self.v = nn.Linear(d_memory, d_target)
        self.o = nn.Linear(d_target, d_target)

    def forward(self, src_features: Tensor) -> Tensor:
        b, n, c = src_features.shape
        q = repeat(self.q, 'h q c -> b h q c', b=b)
        k = rearrange(src_features, 'b n c -> b 1 n c') # full attention
        v = rearrange(self.v(src_features), 'b n (h c) -> b h n c', h=self.n_heads)
        queries = nn.functional.scaled_dot_product_attention(q, k, v)
        return self.o(rearrange(queries, 'b h n c -> b n (h c)'))

class MLP(nn.Module):
    def __init__(self, d_io: int, d_ff: int | None = None):
        super().__init__()

        if d_ff is None:
            d_ff = 4 * d_io

        self.fc1 = nn.Linear(d_io, d_ff)
        self.fc2 = nn.Linear(d_ff, d_io)
    
    def forward(self, x: Tensor) -> Tensor:
        return self.fc2(nn.functional.gelu(self.fc1(x)))
    
class CrossAttention(nn.Module):
    def __init__(self, d_src: int, d_tgt: int, n_heads: int):
        super().__init__()
        self.n_heads = n_heads
        self.q = nn.Linear(d_tgt, d_tgt)
        self.kv = nn.Linear(d_src, 2 * d_tgt)
        self.o = nn.Linear(d_tgt, d_tgt)
    
    def forward(self, target: Tensor, memory: Tensor) -> Tensor:
        b, q, c = target.shape

        q = rearrange(self.q(target), 'b q (h c) -> b h q c', h=self.n_heads)
        k, v = rearrange(self.kv(memory), 'b n (t h c) -> t b h n c', h=self.n_heads, t=2)

        out = nn.functional.scaled_dot_product_attention(q, k, v)
        return self.o(rearrange(out, 'b h n c -> b n (h c)'))

class CrossAttentionBlock(nn.Module):
    def __init__(self, d_memory: int, d_target: int, n_heads: int, d_ff: int | None = None):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_target)
        self.cross_attention = CrossAttention(d_src=d_memory, d_tgt=d_target, n_heads=n_heads)
        self.norm2 = nn.LayerNorm(d_target)
        self.mlp = MLP(d_target, d_ff)
    
    def forward(self, queries: Tensor, memory: Tensor) -> Tensor:
        q = queries
        q = q + self.cross_attention(self.norm1(q), memory)
        q = q + self.mlp(self.norm2(q))
        return q

class HFormer(nn.Module):
    def __init__(self, n_layers: int = 3, d_target: int = 128, deformation_scale: float = 0.5):
        super().__init__()

        self.dino = DenseDino()
        self.dino.requires_grad_(False)
        d_memory = 384
        self.n_heads = 8
        self.query_init = QueryInit(d_memory=d_memory, d_target=d_target)
        self.attn_blocks = nn.ModuleList([
            CrossAttentionBlock(d_memory=d_memory, d_target=d_target, n_heads=self.n_heads)
            for _ in range(2 * n_layers - 1)
        ])

        self.final_norm = nn.LayerNorm(d_target)
        self.xy_head = nn.Linear(d_target, 2, bias=False)

        self.register_buffer(
            'corners_a',
            torch.tensor([
                [0., 0.],
                [1., 0.],
                [1., 1.],
                [0., 1.]
            ], dtype=torch.float32).reshape(1, 4, 2) * 0.5 + 0.25,
        )
        self.deformation_scale = deformation_scale
    
    def forward(self, images: Tensor) -> HCorrespondences:
        b = images.shape[1]

        src_feat, dst_feat = self._get_features(images)
        q = self.query_init(src_feat)

        for i, block in enumerate(self.attn_blocks):
            if i % 2 == 0:
                memory = dst_feat
            else:
                memory = src_feat

            q = block(q, memory)
        
        offset = torch.tanh(self.xy_head(self.final_norm(q)))

        corners_b = self.corners_a + self.deformation_scale * offset

        scale = torch.tensor([images.shape[-1], images.shape[-2]], dtype=torch.float32, device=images.device)
        return HCorrespondences(
            corners_a=self.corners_a.repeat(b, 1, 1) * scale,
            corners_b=corners_b * scale,
        )
    
    def _get_features(self, images: Tensor) -> tuple[Tensor, Tensor]:
        """
        Images: 2 x B x C x H x W
        
        Returns:
            src_feat: B x N x C
            dst_feat: B x N x C
        """
        t, b, c, h, w = images.shape
        assert t == 2
        assert c == 3

        images_flat = rearrange(images, 't b c h w -> (t b) c h w')
        with torch.no_grad():
            features_flat: Tensor = self.dino(images_flat)
        src_feat, dst_feat = rearrange(features_flat, '(t b) h w c -> t b (h w) c', b=b, t=2)
        return src_feat, dst_feat

    def loss_fn(self, homography_batch: HomographyBatch) -> Tensor:
        return homography_loss(homography_batch, self)

In [None]:
device = torch.device('mps')
hformer = HFormer().to(device)
optim = torch.optim.AdamW(hformer.parameters(), lr=1e-4, weight_decay=1e-3)

In [None]:
# n = sum(p.numel() for p in hformer.parameters() if p.requires_grad)
# print(f"{n:,} parameters")

In [None]:
# from tqdm.auto import tqdm
# for _ in range(1):
#     losses = []
#     for batch in tqdm(dataloader):
#         optim.zero_grad()
#         batch = batch.to(device)
#         loss = hformer.loss_fn(batch)
#         loss.backward()
#         optim.step()
#         losses.append(loss.item())
#     print(np.mean(losses))

In [None]:
with torch.no_grad():
    prediction_aug = hformer(batch_aug.images.to(device))
    prediction_base = hformer(batch_base.images.to(device))

In [None]:
def plot(batch, prediction):
    corners_a = prediction.corners_a.cpu()
    corners_b = prediction.corners_b.cpu()
    corners_b_gt = compute_gt_corners(batch.H_12.to('cpu'), corners_a)

    rep = [0, 1, 2, 3, 0]
    for s in range(batch_aug.B):
        fig, (a1, a2) = plt.subplots(1, 2)
        a1.imshow(batch.images[0, s].permute(1, 2, 0).cpu().numpy())
        a2.imshow(batch.images[1, s].permute(1, 2, 0).cpu().numpy())
        a1.plot(corners_a[s, :, 1][rep], corners_a[s, :, 0][rep], 'r--')
    #    a2.scatter(corners_b[s, :, 0], corners_b[s, :, 1], color='r')
        a2.plot(corners_b_gt[s, :, 1][rep], corners_b_gt[s, :, 0][rep], 'r--')
    
plot(batch_aug, prediction_aug)
plt.show()
print('-' * 80)
plot(batch_base, prediction_base)

In [None]:
prediction_aug.corner

In [None]:
import kornia.geometry as KG

unwarped_batch = batch_aug.to('cpu')
s = (3, 4)
for _ in range(5):
    fig = draw_corners(unwarped_batch.slice(*s))
    with torch.no_grad():
        prediction_aug = hformer(unwarped_batch.images.to(device))

    transform_12 = KG.get_perspective_transform(prediction_aug.corners_a.repeat(8, 1, 1).cpu() * 518, prediction_aug.corners_b.cpu() * 518)
    transform = torch.stack([
        torch.eye(3).reshape(1, 3, 3).repeat(8, 1, 1),
        torch.linalg.inv(transform_12),
    ], dim=0)

    unwarped_batch = unwarped_batch.build_augmentation().apply(transform).to('cpu').build()

fig = draw_corners(unwarped_batch.slice(*s))

In [None]:
prediction_aug.corners_b

In [None]:
batch_aug

In [None]:
batch_pred = batch_aug.build_augmentation().apply(transform).to('cpu').build()

In [None]:
for i in range(batch_pred.B):
    fig1 = draw_corners(batch_aug.slice(i, i + 1))
    fig2 = draw_corners(batch_pred.slice(i, i + 1))
plt.show()