# Model training

This notebook trains the Stretcher (a small descriptor-adaptation network) on paired SuperPoint descriptors and their transformed counterparts.

High-level steps:
1) Load a prepared dataset of descriptor pairs and the corresponding affine/strain parameters.
2) Build PyTorch DataLoaders for train/val/test.
3) Define a small MLP-based model (TripleNet) that fuses descriptors with transformation parameters.
4) Train with a cosine-based loss and save the trained weights.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
import numpy as np

### Load prepared dataset

This notebook assumes you have a precomputed .pth dataset containing: `descriptors` (base descriptors), `deformed_descriptors` (target descriptors after applying an image deformation), and `transformations` (the affine/strain parameters used to transform the keypoint).

If you don't have a dataset yet, see `dataset_creation.ipynb` in this repo for an example pipeline that builds these tensors from SuperPoint outputs.

In [2]:
# Load the dataset
file_path = "data/SuperPoint_Descriptors_Dataset_Test.pth"
data = torch.load(file_path)

base_descriptors = data['descriptors']
transformed_descriptors = data['deformed_descriptors']
parameters = data['transformations']

### Create dataset and dataloaders

Wrap the tensors in a `TensorDataset` and split into train/validation/test sets.

In [9]:
# Create a dataset and dataloader
dataset = TensorDataset(base_descriptors, transformed_descriptors, parameters)
# dataset = TensorDataset(all_base_descriptors, all_transformed_descriptors)
train_size = 0.7
val_size = 0.15
test_size = 0.15

train_dataset, val_dataset, test_dataset = random_split(dataset,[train_size,val_size,test_size])

train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last = False, num_workers=32,pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last = False)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last = False)

### Model architecture

Define a small MLP helper and the `TripleNet`, which fuses descriptor vectors with the transformation parameters. The network adds learned corrections to the input descriptor conditioned on the affine/strain parameters.

In [4]:
class MLP(nn.Module):
    def __init__(self, input_dim,output_dim,hidden_dim,num_layers):
        super(MLP, self).__init__()

        if num_layers == 0:
            self.model = nn.Linear(input_dim,output_dim)
        else:
            layers = []
            layers.append(nn.Linear(input_dim,hidden_dim))
            layers.append(nn.ReLU())

            for _ in range(num_layers - 1):
                layers.append(nn.Linear(hidden_dim,hidden_dim))
                layers.append(nn.ReLU())
            
            layers.append(nn.Linear(hidden_dim,output_dim))
            self.model = nn.Sequential(*layers)
    
    def forward(self,x):
        return self.model(x)

class TripleNet(nn.Module):
    def __init__(self, descriptor_dim=256, parameter_dim=3, hidden_dim=256, num_layers=2, num_nets=3):
        super(TripleNet, self).__init__()
        
        self.p_scale = nn.Parameter(torch.ones(1))  # Learnable scale parameter
        
        # Create the first fusion layer
        self.mlp_list = nn.ModuleList([MLP(descriptor_dim + parameter_dim, descriptor_dim, hidden_dim, num_layers) 
                                      for _ in range(num_nets)])
        
    def forward(self, x, p):
        
        if torch.all(p==0):
            return x
        
        device = x.device
        
        # Scale affine parameters to match descriptor magnitude
        scaled_p = p * self.p_scale
        
        # Concatenate descriptor and affine parameters
        combined = torch.cat([scaled_p, x], dim=1)
        
        for mlp in self.mlp_list:
            mlp = mlp.to(device)
            x += mlp(combined)
        
        return x

### Instantiate model and device

Move the model to GPU when available and set input/output dimensions from the loaded tensors.


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
input_dim = base_descriptors.shape[1]
output_dim = transformed_descriptors.shape[1]
parameter_dim = parameters.shape[1]
model = TripleNet(descriptor_dim=input_dim, parameter_dim=parameter_dim, hidden_dim=2048, num_layers=2).to(device).double()

cpu


### Loss and optimizer

We use a cosine-based loss (1 - cosine_similarity) to directly encourage aligned descriptor directions. The notebook sets up an Adam optimizer; adjust learning rate and scheduler as needed for your experiments.


In [6]:
# Define loss function and optimizer
def cosine_loss(output,target,reduction = 'mean'):
    loss = 1 - torch.nn.functional.cosine_similarity(output,target)

    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'none':
        return loss

criterion = cosine_loss
optimizer = optim.Adam(model.parameters(), lr=0.0001)

### Training loop

A straightforward training loop is used: iterate epochs, compute per-sample cosine losses, backpropagate, and evaluate on the validation set.

In [7]:
# Training loop
epochs = 11
for epoch in range(epochs):
    model.train()
    train_losses = []
    for base, transformed, param in train_dataloader:
        base, transformed, param = base.to(device), transformed.to(device), param.to(device)
        
        optimizer.zero_grad()
        output = model(base, param)
        losses = cosine_loss(output,transformed,'none')
        loss = losses.mean()
        loss.backward()
        optimizer.step()
        
        train_losses.extend(losses.cpu().detach().numpy())

    model.eval()
    val_losses = []
    val_l2s = []
    val_cosines = []
    with torch.no_grad():
        for base, transformed, param in val_dataloader:
            base, transformed, param = base.to(device), transformed.to(device), param.to(device)
            output = model(base, param)
            losses = cosine_loss(output,transformed,'none')
            val_losses.extend(losses.cpu().numpy())
            
    avg_train_loss = np.mean(train_losses)
    std_train_loss = np.std(train_losses)

    avg_val_loss = np.mean(val_losses)
    std_val_loss = np.std(val_losses)
    
    if (epoch)%5 == 0:
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")



Epoch 1/11, Train Loss: 0.398795, Val Loss: 0.362467
Epoch 6/11, Train Loss: 0.301016, Val Loss: 0.342721
Epoch 11/11, Train Loss: 0.253727, Val Loss: 0.330493


### Save trained weights

At the end we save the model state dict as float32 to `models/stretcher.pth`.


In [8]:
torch.save(model.to(torch.float32).state_dict(), "models/stretcher.pth")
print("Model training complete and saved.")

Model training complete and saved.
