# GPA
Generalized Procrustes analysis

In [1]:
import torch


def center_shape(x):
    """Removes translation by shifting the shape so that its centroid is at (0, 0)."""
    return x - x.mean(dim=0, keepdim=True)


def normalize_shape(x):
    """Scales the shape so that it has unit length (normalizes its size)."""
    return x / torch.norm(x)


def procrustes_align(x, y, only_matrix=False):
    """
    Aligns shape X to shape Y by rotation (Procrustes alignment).
    Assumes both shapes are already centered and normalized.
    """
    # Orthogonal rotation matrix `r` computed via SVD
    u, _, vt = torch.linalg.svd(x.T @ y)
    r = u @ vt
    if torch.det(r) < 0:
        # det(r)<0 means that it is a flip, not a rotation
        # flip last column of U
        u[:, -1] *= -1
        r = u @ vt
    if only_matrix:
        return r
    return x @ r


def generalized_procrustes_analysis(shapes, tol=1e-6, max_iter=100, device=torch.device('cpu')):
    """
    Performs Generalized Procrustes Analysis (GPA) on a set of 2D shapes.

    Args:
        shapes (torch.Tensor): Tensor of shape (N, n_points, 2) containing N shapes.
        tol (float): Convergence tolerance for mean shape updates.
        max_iter (int): Maximum number of iterations allowed.
        device (torch.device): Device on which to perform computations.

    Returns:
        mean_shape (torch.Tensor): The resulting mean shape of shape (n_points, 2).
    """
    shapes = shapes.to(device, dtype=torch.float32)
    shapes = torch.stack([normalize_shape(center_shape(s)) for s in shapes])
    mean_shape = normalize_shape(shapes.mean(dim=0))

    for i in range(max_iter):
        aligned = []
        for s in shapes:
            aligned.append(procrustes_align(s, mean_shape))
        aligned = torch.stack(aligned)

        new_mean = normalize_shape(aligned.mean(dim=0))
        diff = torch.norm(mean_shape - new_mean)
        mean_shape = new_mean

        print(f"Iteration {i} diff: {diff}")

        if diff < tol:
            print(f"Convergence reached after {i + 1} iterations.")
            break

    return mean_shape




# Dataset

In [2]:
import torch

from wings.config import PROCESSED_DATA_DIR

countries = ['AT', 'GR', 'HR', 'HU', 'MD', 'PL', 'RO', 'SI']
train_dataset = torch.load(
    PROCESSED_DATA_DIR / "mask_datasets" / 'rectangle' / "train_mask_dataset.pth",
    weights_only=False
)
max_n = len(train_dataset)
print(max_n)
_, _, orig_labels, _ = train_dataset[0]
print(orig_labels.shape)


[32m2025-10-30 00:32:12.529[0m | [1mINFO    [0m | [36mwings.config[0m:[36m<module>[0m:[36m40[0m - [1mPROJ_ROOT path is: /home/mkrajew/bees[0m
[32m2025-10-30 00:32:12.601[0m | [1mINFO    [0m | [36mwings.config[0m:[36m<module>[0m:[36m62[0m - [1mtorch.cuda.get_device_name()='NVIDIA RTX A3000 12GB Laptop GPU'[0m


15206
torch.Size([38])


# Calculate GPA for Dataset

## Create train coordinates array

In [3]:
from tqdm import tqdm
from torch.utils.data import DataLoader

loader = DataLoader(train_dataset, batch_size=64, num_workers=8)
orig_coords_list = []
for _, _, labels, _ in tqdm(loader, desc="Loading orig_labels", unit="batch"):
    orig_coords_list.append(labels.view(labels.size(0), -1, 2))  # we make shape (19, 2) from shape (38)
orig_coords = torch.cat(orig_coords_list)

print(orig_coords.shape)

mean_coords = generalized_procrustes_analysis(orig_coords)
print(mean_coords.shape)


Loading orig_labels: 100%|██████████| 238/238 [00:21<00:00, 11.27batch/s]


torch.Size([15206, 19, 2])
Iteration 0 diff: 0.0001964124385267496
Iteration 1 diff: 7.297820303620028e-08
Convergence reached after 2 iterations.
torch.Size([19, 2])


## Save mean_shape

In [4]:
# torch.save(mean_coords, PROCESSED_DATA_DIR / "mask_datasets" / 'rectangle' / 'mean_shape2.pth')

## Permute coordinates

In [5]:
from scipy.spatial.distance import cdist
import torch
import numpy as np
from scipy.optimize import linear_sum_assignment


def solve_assignment(cost_matrix):
    """
    cost_matrix: numpy array (n_mean, n_shape)
    zwraca permutację idx: array length n_mean, idx[j] = index in shape matched to mean j
    """
    r, c = linear_sum_assignment(cost_matrix)
    idx = np.empty(cost_matrix.shape[0], dtype=int)
    idx[r] = c
    return idx


def recover_order(mean_shape, unordered_shape, max_iter=5, device=torch.device('cpu')):
    """
    mean_shape: torch tensor (n_points, 2)
    unordered_shape: torch tensor (n_points, 2) - same points but random order/transform
    Returns:
        reordered_shape: torch tensor (n_points, 2) = unordered_shape[perm_idx]
    """
    mean = mean_shape.to(device).float()
    s = unordered_shape.to(device).float()
    assert s.shape[0] == mean.shape[0] and mean.shape[1] == 2

    mean_torch = normalize_shape(center_shape(mean))
    shapes_torch = normalize_shape(center_shape(s))

    mean_temp = mean_torch.cpu().numpy()
    shapes_temp = shapes_torch.cpu().numpy()

    cost = cdist(mean_temp, shapes_temp)  # (n,n)
    index = solve_assignment(cost)  # idx[row]=col

    for it in range(max_iter):
        perm = torch.tensor(index, dtype=torch.long, device=device)
        s_perm = shapes_torch[perm]

        r = procrustes_align(s_perm, mean_torch, only_matrix=True).cpu().numpy()
        s_rot = shapes_temp @ r  # (n,2)
        cost = cdist(mean_temp, s_rot)
        new_idx = solve_assignment(cost)

        if np.array_equal(new_idx, index):
            break
        index = new_idx

    reordered_shape = unordered_shape[index]

    return reordered_shape


# Test ordering coords

## Load model and test dataset

In [6]:
from wings.modeling.loss import DiceLoss
from wings.config import MODELS_DIR
from wings.modeling.litnet import LitNet

checkpoint_path = MODELS_DIR / 'unet-rectangle-epoch=08-val_loss=0.14-unet-training-rectangle_1.ckpt'
unet_model = torch.hub.load(
    'mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=False
)
num_epochs = 60
model = LitNet.load_from_checkpoint(checkpoint_path, model=unet_model, num_epochs=num_epochs, criterion=DiceLoss())
model.eval()

test_dataset = torch.load(
    PROCESSED_DATA_DIR / "mask_datasets" / 'rectangle' / "test_mask_dataset.pth",
    weights_only=False
)
max_n = len(test_dataset)


Using cache found in /home/mkrajew/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


## Test

In [9]:
from wings.visualizing.image_preprocess import mask_to_coords, unet_reverse_padding
from sklearn.metrics import mean_squared_error

n = np.random.randint(0, max_n)
print(n)
image, _, orig_labels, orig_size = test_dataset[n]

total_mse = 0
num_samples = 0
bad_masks = 0

output = model(image.cuda().unsqueeze(0))
mask = torch.round(output).squeeze().detach().cpu().numpy()

mask_coords = mask_to_coords(mask)
if len(mask_coords) == 19:
    mask_height, mask_width = mask.shape
    orig_width, orig_height = orig_size

    pad_left, pad_top, pad_right, pad_bottom = unet_reverse_padding(mask, orig_width, orig_height)

    mask_coords = [(x - pad_left, y - pad_bottom) for x, y in mask_coords]

    scale_x = orig_width / (mask_width - pad_right - pad_left)
    scale_y = orig_height / (mask_height - pad_top - pad_bottom)
    mask_coords = torch.tensor([(x * scale_x, y * scale_y) for x, y in mask_coords])

    # print(f"mask coordinates: {mask_coords}")
    # print(f"original coordinates: {orig_labels}")

    reordered = recover_order(mean_coords, mask_coords)
    orig = orig_labels.view(-1, 2)
    print(orig.shape)
    print(reordered.shape)
    print(mean_squared_error(orig, reordered, multioutput='raw_values'))
    for i in range(len(reordered)):
        print(f"{i + 1}:\t{reordered[i]}\t{orig[i]}")
else:
    print(f"Found {len(mask_coords)} spots in mask.")

1130
torch.Size([19, 2])
torch.Size([19, 2])
[2.71185412 1.65336793]
1:	tensor([200.3686, 179.6421])	tensor([203., 180.])
2:	tensor([220.7451, 176.2526])	tensor([222., 177.])
3:	tensor([264.8941, 257.6000])	tensor([266., 258.])
4:	tensor([275.0824, 206.7579])	tensor([275., 206.])
5:	tensor([278.4784, 122.0210])	tensor([282., 123.])
6:	tensor([346.4000, 264.3789])	tensor([347., 265.])
7:	tensor([400.7372, 294.8842])	tensor([400., 297.])
8:	tensor([380.3608, 271.1579])	tensor([381., 273.])
9:	tensor([417.7177, 240.6526])	tensor([420., 239.])
10:	tensor([393.9451, 220.3158])	tensor([394., 219.])
11:	tensor([434.6980, 186.4211])	tensor([435., 188.])
12:	tensor([438.0941, 145.7474])	tensor([439., 147.])
13:	tensor([451.6784, 115.2421])	tensor([455., 115.])
14:	tensor([465.2628, 288.1053])	tensor([463., 289.])
15:	tensor([506.0157, 250.8211])	tensor([507., 250.])
16:	tensor([587.5215, 210.1474])	tensor([586., 212.])
17:	tensor([614.6902, 206.7579])	tensor([615., 208.])
18:	tensor([624.8784, 