# SSM Driver
So what are we trying to do here? We're basically seeing if we can convince an ANN to predict the components of the SSM needed to match the measurements given as input. That way we can take 1D measurements and convert them into semi-realistic 3D data which would be nice. How we exactly are going to manage that? I honestly don't know. I guess the things we need are:
1. **Dataloader** - a means of generating pairs of 3D models, their components and their measurements. That means I actually need to nail down how we define these measurements in the first place which isnt entirely obvious. I guess I just go with something that looks reasonable for now and we can refine exactly where the measurements are taken at a later date. Either way we're going to be giving it 4 circumfrential measures, 2 widths and a length and we'll see what comes out of it. 
2. **Loss function** - How exactly are we defining this loss - probably easiest by just using MSE between the 2 sets of measurements but it might be worth normalising them against the reference measure so that everything is of the same magnitude. Also probably worth creating a class or something that can just be passed the verts and output a set of measurements - this probably should be a `nn.Module` shouldn't it so it gets those nice properties? Idk maybe that doesnt matter too much?
3. **Model** - This is probably the simplest part of the whole shabang. We can just start with a really simple dense network and see what happens. Maybe throw in some normalisation but really this should be as simple as possible.


Something I havent really thought too much about is that I need to create these limbs based on my components 

In [1]:
import torch
from torch import nn
import subprocess
import os
import numpy as np
import igl
import wandb

from functools import partial
import measure_limbs

In [2]:
selected_verts = []
with open("verts.txt", "r") as f:
    for line in f:
        if line[0] == "#":
            continue
        else:
            selected_verts.append([float(x) for x in line.strip().split(", ")])

selected_verts = torch.tensor(selected_verts)

verts, face2vert = igl.read_triangle_mesh("limb_00000.stl")
verts = torch.tensor(verts[torch.load("./data_components/vert_mapping.pt")])

verts.shape, selected_verts.shape

vert_idxs = torch.argmin(torch.linalg.norm(verts[None] - selected_verts[:, None], dim=-1), dim=-1)

with open("test_verts.obj", "w") as f:
    for v in vert_idxs:
        f.write(f"v {verts[v][0]} {verts[v][1]} {verts[v][2]}\n") 

torch.save(vert_idxs, "data_components/selected_verts.pt")

In [3]:
config = {
    "lr": 1e-3,
    "eta_min": 0.00001,
    "batch_size": 8,
    "log": False,
    "seed": 42,
}

torch.manual_seed(config["seed"])

if config["log"]:
    wandb.init(project="Open Limb SSM", config=config)

In [4]:
class Measurements(nn.Module):
    def __init__(self, edge2vert, face2edge, details, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.edge2vert = edge2vert
        self.face2edge = face2edge
        self.details = details
        self.measures = []

        for detail in details:
            match detail["type"]:
                case "width":
                    measure = partial(measure_limbs.measure_width,
                        edge2vert=self.edge2vert,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                        plane_direction=detail["plane_direction"],
                    )
                case "length":
                    measure = partial(measure_limbs.measure_length,
                        v1=detail["v1"], v2=detail["v2"], direction=detail["direction"]
                    )
                case "circumference":
                    measure = partial(measure_limbs.measure_planar_circumference,
                        edge2vert=self.edge2vert,
                        face2edge=self.face2edge,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                    )
            self.measures.append(measure)

    def forward(self, x):
        return torch.tensor([measure(x) for measure in self.measures])


In [5]:
class LegMeasurementDataset(torch.utils.data.Dataset):
    def __init__(self, measure, batch_size=64, path="./stls", dtype=torch.float64, device="cpu"):
        super().__init__()
        self.dtype = dtype
        self.device = device
        self.measure = measure
        self.batch_size = batch_size
        self.path = path
        self.loaded_components = {}
        self.raw_components = torch.load("./data_components/vert_components.pt").to(dtype).to(device)
        self.mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype).to(device)
        self.face2vert = torch.load("./data_components/face2vert.pt").to(device)
        self.vert_mapping = torch.load("./data_components/vert_mapping.pt").to(device)

        # Remove all .npy files from the specified path
        for file in os.listdir(self.path):
            if file.endswith(".npy"):
                os.remove(os.path.join(self.path, file))

        self.generate_data(0)
        self.generate_data(self.batch_size)

    def __len__(self):
        return 100_000

    def __getitem__(self, index):
        if index % self.batch_size == 0:
            self.generate_data(index + self.batch_size*2)
            ith_dataset = index // self.batch_size
            if ith_dataset >= 2:
                self.delete_data((ith_dataset - 2)*self.batch_size)

        components = self.loaded_components[(index // self.batch_size)*self.batch_size][
            index % self.batch_size
        ]

        verts = self.get_verts(components)

        measurements = self.measure.forward(verts)

        return measurements, components
        # return verts, self.face2vert, measurements, components
    

    def get_verts(self, components):
        sum = torch.sum(self.raw_components[None] * components[..., None], dim=1)
        verts = self.mean_verts[None] + sum.reshape((sum.shape[0], self.mean_verts.shape[0], self.mean_verts.shape[1]))
        verts = verts[:, self.vert_mapping]

        return verts.squeeze()
    
    def get_measures(self, components=None, verts=None):
        if components is not None:
            verts = self.get_verts(components)
        
        return self.measure.forward(verts)

    def generate_data(self, start):
        cmd = [
            "./generate_limbs.sh",
            "--num_limbs",
            f"{self.batch_size}",
            "--path",
            self.path,
            "--start",
            f"{start}",
            "--save_mesh",
            "0",
        ]
        if os.name == "nt":  # Windows
            cmd = ["wsl", "-e"] + cmd

        subprocess.run(cmd, check=True)

        while True:
            try:
                components = np.load(f"{self.path}/components_{start:08d}.npy")
            except FileNotFoundError:
                continue
            else:
                self.loaded_components[start] = torch.tensor(components, dtype=self.dtype, device=self.device)
                break

    def delete_data(self, start):
        # for i in range(start, start + self.batch_size):
            # try:
            #     os.remove(f"{self.path}/limb_{i:05d}.npy")
            # except FileNotFoundError:
            #     pass

        os.remove(f"{self.path}/components_{start:08d}.npy")
        self.loaded_components.pop(start)

In [6]:
dtype = torch.float32

mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype)
face2vert = torch.load("./data_components/face2vert.pt")

edge2vert, face2edge, edge2face = igl.edge_topology(
    mean_verts.numpy(), face2vert.numpy()
)

vert_idxs = torch.load("./data_components/selected_verts.pt")

# Order
# Mid patella tendon
# Distal tibia
# Knee widest
# Knee above? This one feels off
# Over fib head
# Fib head
# Circ 3
# Circ 4

measurement_details = (
    {
        # Circ one
        "type": "circumference",
        "plane_point": vert_idxs[4],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[5],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[6],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[7],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
    },
    {
        "type": "length", 
        "v1": vert_idxs[0], 
        "v2": vert_idxs[1], 
        "direction": torch.tensor([[0, 0, 1]], dtype=dtype)
    },
    {
        "type": "width",
        "plane_point": vert_idxs[2],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "plane_direction": torch.tensor([[1, 0, 0]], dtype=dtype),
    },
    {
        "type": "width",
        "plane_point": vert_idxs[3],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "plane_direction": torch.tensor([[1, 0, 0]], dtype=dtype),
    },
)

measure = Measurements(
    edge2vert,
    face2edge,
    measurement_details,
)

In [7]:
class Model(nn.Module):
    def __init__(self, input_dims, output_dims, activation=nn.ReLU(inplace=True)):
        super().__init__()

        self.layers = nn.ParameterList([
            nn.Linear(input_dims, 256),
            activation,
            nn.Linear(256, 1024),
            activation,
            nn.Linear(1024, 1024),
            activation,
            nn.Linear(1024, 128),
            activation,
            nn.Linear(128, output_dims)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        return x

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = LegMeasurementDataset(measure, batch_size=config["batch_size"], dtype=dtype, device="cpu")
dataloader = torch.utils.data.DataLoader(dataset, config["batch_size"], shuffle=False)
loss_func = nn.MSELoss().to(device)


model = Model(len(measurement_details), dataset.raw_components.shape[0]).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-4)

In [9]:
for i, (measures, components) in enumerate(dataloader):
    measures = measures.to(device)
    components = components.to(device)
    
    preds = model(measures)

    loss = loss_func(preds, components)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    with torch.no_grad():
        measurements = dataset.get_measures(preds.to("cpu"), None)

    if config["log"]:
        wandb.log({
            "loss": loss
        })
    else: 
        print(loss)

verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=torch.Size([7732, 3])
verts.shape=to

IndexError: index 421 is out of bounds for dimension 0 with size 8