## Optimize morphological transformation vectors to map 12w old mice onto all other ages

Run this notebook on a gpu for fast training speed.

In [1]:
import torch
import joblib
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from itertools import product
from collections import defaultdict
from kornia.geometry.transform import scale
from aging.plotting import figure, format_plots
from aging.organization.paths import TrainingPaths
from toolz import keyfilter, concatv, take, partition_all
from aging.size_norm.data import make_grid, pad_vector
from kornia.geometry.transform import warp_image_tps, get_tps_transform, resize

In [2]:
format_plots()

In [3]:
training_paths = TrainingPaths()
training_paths.tps_fits.parent.mkdir(exist_ok=True, parents=True)
poses = joblib.load(training_paths.tps_training_data)

In [4]:
def initialize_params(grid_size=6, batch_size=1):
    movement_vector = torch.zeros(batch_size, 2, grid_size - 2, grid_size - 2, requires_grad=True)
    scale_tensor = torch.ones(batch_size, 2, dtype=torch.float, requires_grad=True)
    height_mtx = torch.ones(batch_size, 1, 5, 8, requires_grad=True)
    return {
        "movement_vector": movement_vector,
        "scale_tensor": scale_tensor,
        "height_mtx": height_mtx,
    }


def fit(template, target, grid_size=6, num_iters=1000, lr=1e-2, **tqdm_kwargs):
    batch_size = len(template)
    template = torch.tensor(template, dtype=torch.float).view(batch_size, 1, *template.shape[-2:])
    target = torch.tensor(target, dtype=torch.float).view(batch_size, 1, *target.shape[-2:])

    grid = make_grid(grid_size, batch_size)
    params = initialize_params(grid_size, batch_size)
    optimizer = torch.optim.AdamW(params.values(), lr=lr, weight_decay=1e-3)

    losses = []
    for i in tqdm(range(num_iters), **tqdm_kwargs):
        optimizer.zero_grad()
        # TODO: try elastic transform instead
        padded_vector = torch.transpose(pad_vector(params["movement_vector"]), 1, 3)
        kernel, affine = get_tps_transform(grid + padded_vector.reshape(batch_size, -1, 2), grid)
        scaled_template = scale(template, params["scale_tensor"])
        transformed_template = warp_image_tps(scaled_template, grid, kernel, affine)
        height_intermediate = resize(params["height_mtx"], target.shape[-2:])

        out = height_intermediate * transformed_template

        loss = torch.nn.functional.mse_loss(out, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    return {
        "params": {k: v.detach().squeeze().numpy() for k, v in params.items()},
        "losses": losses,
        "optimized_template": out.detach().squeeze().numpy().astype('uint8'),
        "template": template.detach().squeeze().numpy().astype('uint8'),
        "target": target.detach().squeeze().numpy().astype('uint8'),
        "grid": grid.detach().squeeze().numpy(),
        "transformed_grid": (
            grid + torch.transpose(pad_vector(params["movement_vector"]), 1, 3).reshape(batch_size, -1, 2)
        )
        .detach()
        .squeeze()
        .numpy(),
    }


def flatten_dict_recursive(input_dict, top_key=None):
    flattened_list = []

    for key, value in input_dict.items():
        if isinstance(value, dict):
            # If the value is a dictionary, recursively call the function
            flattened_list.extend(
                flatten_dict_recursive(
                    value, top_key=key if top_key is None else top_key
                )
            )
        elif isinstance(value, list):
            # If the value is a list, iterate over its elements and append to the result
            for item in value:
                flattened_list.append((top_key, key, item))
        else:
            # If the value is neither a dictionary nor a list, append to the result
            flattened_list.append((top_key, key, value))

    return flattened_list

In [7]:
random.seed(0)

pose_list = flatten_dict_recursive(keyfilter(lambda k: k != 12, poses))
# 8 new animals, and 8 old animals
template_pose_list = list(poses[12].items())
random.shuffle(template_pose_list)
n_animals = 13
full_list = list(product(pose_list, template_pose_list[:n_animals]))
random.shuffle(full_list)

In [6]:
output = defaultdict(dict)
grid_size = 10
batch_size = 8
for i, grp in enumerate(partition_all(batch_size, tqdm(full_list))):
    ages = [i[0][0] for i in grp]
    paths = [i[0][1] for i in grp]
    targets = np.array([i[0][2] for i in grp])
    template_paths = [i[1][0] for i in grp]
    templates = np.array([i[1][1] for i in grp])
    out = fit(templates, targets, grid_size=grid_size, num_iters=150, lr=1e-2, disable=True)
    for j in range(batch_size):
        output[ages[j]][(paths[j], template_paths[j])] = {
            'params': {
                'movement_vector': out['params']['movement_vector'][j],
                'scale_tensor': out['params']['scale_tensor'][j],
                'height_mtx': out['params']['height_mtx'][j],
            },
            'losses': out['losses'],
            'optimized_template': out['optimized_template'][j],
            'template': out['template'][j],
            'target': out['target'][j],
            'grid': out['grid'][j],
            'transformed_grid': out['transformed_grid'][j]
        }
    if (i + 1) % 100 == 0:
        joblib.dump(output, training_paths.tps_fits, compress=3)
joblib.dump(output, training_paths.tps_fits, compress=3)

  0%|          | 0/9324 [00:00<?, ?it/s]

KeyboardInterrupt: 