# SSM Driver
So what are we trying to do here? We're basically seeing if we can convince an ANN to predict the components of the SSM needed to match the measurements given as input. That way we can take 1D measurements and convert them into semi-realistic 3D data which would be nice. How we exactly are going to manage that? I honestly don't know. I guess the things we need are:
1. **Dataloader** - a means of generating pairs of 3D models, their components and their measurements. That means I actually need to nail down how we define these measurements in the first place which isnt entirely obvious. I guess I just go with something that looks reasonable for now and we can refine exactly where the measurements are taken at a later date. Either way we're going to be giving it 4 circumfrential measures, 2 widths and a length and we'll see what comes out of it. 
2. **Loss function** - How exactly are we defining this loss - probably easiest by just using MSE between the 2 sets of measurements but it might be worth normalising them against the reference measure so that everything is of the same magnitude. Also probably worth creating a class or something that can just be passed the verts and output a set of measurements - this probably should be a `nn.Module` shouldn't it so it gets those nice properties? Idk maybe that doesnt matter too much?
3. **Model** - This is probably the simplest part of the whole shabang. We can just start with a really simple dense network and see what happens. Maybe throw in some normalisation but really this should be as simple as possible.


Something I havent really thought too much about is that I need to create these limbs based on my components 

In [16]:
import torch
from torch import nn
import measure_limbs
from functools import partial
import subprocess
import os
import numpy as np
import igl

In [17]:
class Measurements(nn.Module):
    def __init__(self, edge2vert, face2edge, details, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.edge2vert = edge2vert
        self.face2edge = face2edge
        self.details = details
        self.measures = []

        for detail in details:
            match detail["type"]:
                case "width":
                    measure = partial(measure_limbs.measure_width,
                        edge2vert=self.edge2vert,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                        plane_direction=detail["plane_direction"],
                    )
                case "length":
                    measure = partial(measure_limbs.measure_length,
                        v1=detail["v1"], v2=detail["v2"], direction=detail["direction"]
                    )
                case "circumference":
                    measure = partial(measure_limbs.measure_planar_circumference,
                        edge2vert=self.edge2vert,
                        face2edge=self.face2edge,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                    )
            self.measures.append(measure)

    def forward(self, x):
        return torch.tensor([measure(x) for measure in self.measures])


In [18]:
class LegMeasurementDataset(torch.utils.data.Dataset):
    def __init__(self, measure, batch_size=64, path="./stls", dtype=torch.float64):
        super().__init__()
        self.dtype = dtype
        self.measure = measure
        self.batch_size = batch_size
        self.path = path
        self.loaded_components = {}
        self.raw_components = torch.load("./data_components/vert_components.pt").to(dtype)
        self.mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype)
        self.face2vert = torch.load("./data_components/face2vert.pt")
        self.vert_mapping = torch.load("./data_components/vert_mapping.pt")

        self.generate_data(0)
        self.generate_data(self.batch_size)

    def __len__(self):
        return 100_000

    def __getitem__(self, index):
        if index % self.batch_size == 0:
            self.generate_data(index + self.batch_size*2)
            ith_dataset = index // self.batch_size
            if ith_dataset >= 2:
                self.delete_data((ith_dataset - 2)*self.batch_size)

        components = self.loaded_components[(index // self.batch_size)*self.batch_size][
            index % self.batch_size
        ]

        verts = self.mean_verts + torch.sum(self.raw_components * components[:, None], dim=0).reshape_as(self.mean_verts)
        verts = verts[self.vert_mapping]

        measurements = self.measure.forward(verts)

        return verts, self.face2vert, measurements

    def generate_data(self, start):
        cmd = [
            "./generate_limbs.sh",
            "--num_limbs",
            f"{self.batch_size}",
            "--path",
            self.path,
            "--start",
            f"{start}",
            "--save_mesh",
            "0",
        ]
        if os.name == "nt":  # Windows
            cmd = ["wsl", "-e"] + cmd

        subprocess.run(cmd, check=True)

        while True:
            try:
                components = np.load(f"{self.path}/components_{start:05d}.npy")
            except FileNotFoundError:
                continue
            else:
                self.loaded_components[start] = torch.tensor(components, dtype=self.dtype)
                break

    def delete_data(self, start):
        # for i in range(start, start + self.batch_size):
            # try:
            #     os.remove(f"{self.path}/limb_{i:05d}.npy")
            # except FileNotFoundError:
            #     pass

        os.remove(f"{self.path}/components_{start:05d}.npy")
        self.loaded_components.pop(start)

In [19]:
dtype = torch.float64

mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype)
face2vert = torch.load("./data_components/face2vert.pt")

edge2vert, face2edge, edge2face = igl.edge_topology(
    mean_verts.numpy(), face2vert.numpy()
)

measure = Measurements(
    edge2vert,
    face2edge,
    (
        {
            "type": "circumference",
            "plane_point": torch.tensor([[0, 0, 0]], dtype=dtype),
            "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        },
    ),
)

dataset = LegMeasurementDataset(measure, batch_size=8, dtype=dtype)

# for i, x in enumerate(dataset):
#     print(i, x[0].shape)
#     break

In [20]:
dataloader = torch.utils.data.DataLoader(dataset, 8, shuffle=False)

for data in dataloader:
    print(data)

[tensor([[[-1.6342e-01, -3.2495e-03,  7.1833e-02],
         [-1.6370e-01, -4.4503e-05,  7.9470e-02],
         [-1.6379e-01, -2.9753e-03,  8.5426e-02],
         ...,
         [ 2.0199e-01,  8.6354e-02,  1.6419e-01],
         [ 2.0595e-01,  6.8690e-02,  1.6474e-01],
         [ 2.0454e-01,  7.7512e-02,  1.6476e-01]],

        [[-1.4629e-01, -1.3873e-02,  6.6298e-02],
         [-1.4723e-01, -1.1991e-02,  7.1771e-02],
         [-1.4714e-01, -1.5334e-02,  7.6014e-02],
         ...,
         [ 1.7325e-01,  8.5600e-03,  1.3572e-01],
         [ 1.7167e-01, -8.0546e-03,  1.3612e-01],
         [ 1.7250e-01,  7.6077e-05,  1.3619e-01]],

        [[-1.5972e-01,  7.7069e-03,  6.7646e-02],
         [-1.5986e-01,  1.0568e-02,  7.5863e-02],
         [-1.5963e-01,  8.0793e-03,  8.2182e-02],
         ...,
         [ 2.2795e-01,  7.5656e-02,  1.6867e-01],
         [ 2.2968e-01,  5.9347e-02,  1.6918e-01],
         [ 2.2965e-01,  6.7532e-02,  1.6926e-01]],

        ...,

        [[-1.4225e-01,  2.9018e-03,  

KeyboardInterrupt: 