# SSM Driver
So what are we trying to do here? We're basically seeing if we can convince an ANN to predict the components of the SSM needed to match the measurements given as input. That way we can take 1D measurements and convert them into semi-realistic 3D data which would be nice. How we exactly are going to manage that? I honestly don't know. I guess the things we need are:
1. **Dataloader** - a means of generating pairs of 3D models, their components and their measurements. That means I actually need to nail down how we define these measurements in the first place which isnt entirely obvious. I guess I just go with something that looks reasonable for now and we can refine exactly where the measurements are taken at a later date. Either way we're going to be giving it 4 circumfrential measures, 2 widths and a length and we'll see what comes out of it. 
2. **Loss function** - How exactly are we defining this loss - probably easiest by just using MSE between the 2 sets of measurements but it might be worth normalising them against the reference measure so that everything is of the same magnitude. Also probably worth creating a class or something that can just be passed the verts and output a set of measurements - this probably should be a `nn.Module` shouldn't it so it gets those nice properties? Idk maybe that doesnt matter too much?
3. **Model** - This is probably the simplest part of the whole shabang. We can just start with a really simple dense network and see what happens. Maybe throw in some normalisation but really this should be as simple as possible.


Something I havent really thought too much about is that I need to create these limbs based on my components 

In [83]:
import torch
from torch import nn
import subprocess
import os
import numpy as np
import igl
import wandb

from functools import partial
import measure_limbs

In [84]:
# selected_verts = []
# with open("verts.txt", "r") as f:
#     for line in f:
#         if line[0] == "#":
#             continue
#         else:
#             selected_verts.append([float(x) for x in line.strip().split(", ")])

# selected_verts = torch.tensor(selected_verts)

# verts, face2vert = igl.read_triangle_mesh("./meshes/limb_00000.stl")
# verts = torch.tensor(verts[torch.load("./data_components/vert_mapping.pt")])

# verts.shape, selected_verts.shape

# vert_idxs = torch.argmin(torch.linalg.norm(verts[None] - selected_verts[:, None], dim=-1), dim=-1)

# # with open("test_verts.obj", "w") as f:
# #     for v in vert_idxs:
# #         f.write(f"v {verts[v][0]} {verts[v][1]} {verts[v][2]}\n") 

# torch.save(vert_idxs, "data_components/selected_verts.pt")

In [85]:
config = {
    "lr": 1e-2,
    "eta_min": 0.00001,
    "batch_size": 256,
    "log": False,
    "seed": 42,
    "epochs":500,
}

torch.manual_seed(config["seed"])

if config["log"]:
    wandb.init(project="Open Limb SSM", config=config)

In [86]:
#| export
class Measurements(nn.Module):
    def __init__(self, edge2vert, face2edge, details, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.edge2vert = edge2vert
        self.face2edge = face2edge
        self.details = details
        self.measures = []

        for detail in details:
            match detail["type"]:
                case "width":
                    measure = partial(measure_limbs.measure_width,
                        edge2vert=self.edge2vert,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                        plane_direction=detail["plane_direction"],
                    )
                case "length":
                    measure = partial(measure_limbs.measure_length,
                        v1=detail["v1"], v2=detail["v2"], direction=detail["direction"]
                    )
                case "circumference":
                    measure = partial(measure_limbs.measure_planar_circumference,
                        edge2vert=self.edge2vert,
                        face2edge=self.face2edge,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                    )
            self.measures.append(measure)

    def forward(self, x, verbose=False):
        measures = [measure(x) for measure in self.measures]
        if verbose:
            for measure, detail in zip(measures, self.details):
                print(f"{detail['name'].ljust(20)}:\t\t{measure}")
        return torch.stack(measures, dim=-1)


In [87]:
#| export
class LegMeasurementDataset(torch.utils.data.Dataset):
    def __init__(self, measure, batch_size=64, path="./stls", dtype=torch.float64, device="cpu"):
        super().__init__()
        self.dtype = dtype
        self.device = device
        self.measure = measure
        self.batch_size = batch_size
        self.path = path
        self.loaded_components = {}
        self.raw_components = torch.load("./data_components/vert_components.pt").to(dtype).to(device)
        self.mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype).to(device)
        self.face2vert = torch.load("./data_components/face2vert.pt").to(device)
        self.vert_mapping = torch.load("./data_components/vert_mapping.pt").to(device)
        self.component_transforms = torch.load("./data_components/scaled_component_transforms.pt").to(dtype).to(device)
        self.measurement_transforms = torch.load("./data_components/scaled_measurement_transforms.pt").to(dtype).to(device)

        # Remove all .npy files from the specified path
        for file in os.listdir(self.path):
            if file.endswith(".npy"):
                os.remove(os.path.join(self.path, file))

        self.generate_data(0)
        self.generate_data(self.batch_size)

    def __len__(self):
        return 100_000

    def __getitem__(self, index):
        if index % self.batch_size == 0:
            self.generate_data(index + self.batch_size*2)
            ith_dataset = index // self.batch_size
            if ith_dataset >= 2:
                self.delete_data((ith_dataset - 2)*self.batch_size)

        try:
            components = self.loaded_components[(index // self.batch_size)*self.batch_size][
                index % self.batch_size
            ]
        except KeyError:
            self.generate_data(index)
            components = self.loaded_components[(index // self.batch_size)*self.batch_size][
                index % self.batch_size
            ]

        verts = self.get_verts(components)

        measurements = self.get_measures(verts=verts, normalise=False).squeeze()

        return measurements, components
        # return verts, self.face2vert, measurements, components
    

    def get_verts(self, components):
        if len(components.shape) == 1:
            components, scale = components[:-1], components[-1:]
        else:
            components, scale = components[:,:-1], components[:,-1:]
        
        total = torch.sum(self.raw_components[None] * components[..., None], dim=1)
        verts = self.mean_verts[None] + total.reshape((total.shape[0], self.mean_verts.shape[0], self.mean_verts.shape[1]))
        verts = verts[:, self.vert_mapping]*scale[..., None]

        return verts.squeeze()
    
    def get_measures(self, components=None, verts=None, verbose=False, normalise=False):
        if components is not None:
            verts = self.get_verts(components)

        measurements = self.measure.forward(verts, verbose)

        if normalise:
            measurements = self.normalise_measures(measurements)

        return measurements
    
    def normalise_measures(self, measurements):
        return (measurements - self.measurement_transforms[:1]) / self.measurement_transforms[1:]

    def generate_data(self, start):
        cmd = [
            "./scripts/generate_limbs.sh",
            "--num_limbs",
            f"{self.batch_size}",
            "--path",
            self.path,
            "--start",
            f"{start}",
            "--save_mesh",
            "0",
            "--scale",
            "1",
            "--seed",
            f"{torch.randint(0, 100000, (1,))[0]}"
        ]
        if os.name == "nt":  # Windows
            cmd = ["wsl", "-e"] + cmd

        subprocess.run(cmd, check=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE)

        while True:
            try:
                components = np.load(f"{self.path}/components_{start:08d}.npy")
            except FileNotFoundError:
                continue
            else:
                self.loaded_components[start] = torch.tensor(components, dtype=self.dtype, device=self.device)
                break

    def delete_data(self, start):
        # for i in range(start, start + self.batch_size):
            # try:
            #     os.remove(f"{self.path}/limb_{i:05d}.npy")
            # except FileNotFoundError:
            #     pass

        os.remove(f"{self.path}/components_{start:08d}.npy")
        self.loaded_components.pop(start)

In [88]:
#| export_section
dtype = torch.float64

mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype)
face2vert = torch.load("./data_components/face2vert.pt")

edge2vert, face2edge, edge2face = igl.edge_topology(
    mean_verts.numpy(), face2vert.numpy()
)

edge2vert = torch.from_numpy(edge2vert)
face2edge = torch.from_numpy(face2edge)
edge2face = torch.from_numpy(edge2face)

vert_idxs = torch.load("./data_components/selected_verts.pt")

# Order
# Mid patella tendon
# Distal tibia
# Knee widest
# Knee above? This one feels off
# Over fib head
# Fib head
# Circ 3
# Circ 4

measurement_details = (
    {
        # Circ one
        "type": "circumference",
        "plane_point": vert_idxs[4],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 1",
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[5],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 2",
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[6],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 3",
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[7],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 4",
    },
    {
        "type": "length", 
        "v1": vert_idxs[0], 
        "v2": vert_idxs[1], 
        "direction": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Length 1",
    },
    {
        "type": "width",
        "plane_point": vert_idxs[2],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "plane_direction": torch.tensor([[1, 0, 0]], dtype=dtype),
        "name":"Width 1",
    },
    {
        "type": "width",
        "plane_point": vert_idxs[3],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "plane_direction": torch.tensor([[1, 0, 0]], dtype=dtype),
        "name":"Width 2",
    },
)

measure = Measurements(
    edge2vert,
    face2edge,
    measurement_details,
)

#| end_section

In [89]:
class Model(nn.Module):
    def __init__(self, input_dims, output_dims, transforms, dtype, activation=nn.ReLU(inplace=True)):
        super().__init__()

        self.output_transform = transforms
        self.output_transform.requires_grad_(False)

        self.layers = nn.ParameterList([
            nn.Linear(input_dims, 256, dtype=dtype),
            nn.BatchNorm1d(num_features=256, dtype=dtype),
            activation,
            nn.Linear(256, 512, dtype=dtype),
            nn.BatchNorm1d(num_features=512, dtype=dtype),
            activation,
            nn.Linear(512, 1024, dtype=dtype),
            nn.BatchNorm1d(num_features=1024, dtype=dtype),
            activation,
            nn.Linear(1024, 128, dtype=dtype),
            nn.BatchNorm1d(num_features=128, dtype=dtype),
            activation,
            nn.Dropout(0.3),
            nn.Linear(128, output_dims, dtype=dtype)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        return x * self.output_transform[1:] + self.output_transform[:1]

In [90]:
class MeasurementLoss(nn.Module):
    def __init__(self, measures: LegMeasurementDataset):
        super().__init__()
        self.measures = measures

    def forward(self, components, true_components):
        device = components.device
        pred_measures = self.measures.get_measures(components.to(self.measures.device), verbose=False)
        true_measures = self.measures.get_measures(true_components.to(self.measures.device), verbose=False)
        # print(f"{torch.isnan(components).sum()=}    {torch.isnan(measurements).sum()=}    {pred_measures.shape=}")
        return torch.mean((pred_measures - true_measures)**2 / (true_measures**2) + 1E-12).to(device)

In [91]:
class PointwiseLoss(nn.Module):
    def __init__(self, measures):
        super().__init__()
        self.measures = measures

    def forward(self, pred_components, true_components):
        device = pred_components.device
        pred_verts = self.measures.get_verts(pred_components.to(self.measures.device))
        true_verts = self.measures.get_verts(true_components.to(self.measures.device))
        return torch.mean((pred_verts - true_verts)**2).to(device)

In [92]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = LegMeasurementDataset(measure, batch_size=config["batch_size"], dtype=dtype, device="cpu")
dataloader = torch.utils.data.DataLoader(dataset, config["batch_size"], shuffle=False)
# loss_func = nn.MSELoss().to(device)
# loss_func = MeasurementLoss(dataset).to(device)
loss_func = PointwiseLoss(dataset).to(device)
component_loss = nn.MSELoss()

# Add one for the scale factor 
component_transforms = torch.load("./data_components/scaled_component_transforms.pt").to(dtype).to(device)
model = Model(len(measurement_details), dataset.raw_components.shape[0] + 1, component_transforms, dtype).to(device)
model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config["epochs"])

In [93]:
i = 0
alpha = 10
decay = 0.9
best_loss = torch.inf
for measures, components in dataloader:
    measures = measures.to(device)
    components = components.to(device)
    
    preds = model(measures)
    loss = loss_func(preds, components)
    # loss = loss_func(preds, components)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    with torch.no_grad():
        pred_measures = dataset.get_measures(preds.to("cpu"), None)
        measures = measures.to("cpu")

        if torch.isnan(loss):
            print(components)
            print(preds)
            break

        if config["log"]:
            wandb.log({
                "loss": loss,
                "component loss":  component_loss(preds, components),
                "measurement difference": torch.mean((pred_measures - measures).abs() / measures),
                "maximum measurement difference": torch.max((pred_measures - measures).abs() / measures),
                "minimum difference": torch.min((pred_measures - measures).abs() / measures),
                "scale":torch.mean(preds[:,-1]),
                "scale min":torch.min(preds[:,-1]),
                "scale max":torch.max(preds[:,-1]),
                "scale std":torch.std(preds[:,-1]),
            })
        else: 
            print({
                "loss": loss,
                "component loss":  component_loss(preds, components),
                "measurement difference": torch.mean((pred_measures - measures).abs() / measures),
                "maximum measurement difference": torch.max((pred_measures - measures).abs() / measures),
                "minimum difference": torch.min((pred_measures - measures).abs() / measures) 
            })
            
        if loss < best_loss:
            torch.save(model, "models/best.pt")
            best_loss = loss
    alpha *= decay
    if i > config["epochs"]:
        break
    else:
        print(i)
    
    scheduler.step()
    i += 1


{'loss': tensor(331.7358, dtype=torch.float64, grad_fn=<MeanBackward0>), 'component loss': tensor(118.7557, dtype=torch.float64), 'measurement difference': tensor(0.1206, dtype=torch.float64), 'maximum measurement difference': tensor(0.9560, dtype=torch.float64), 'minimum difference': tensor(5.1713e-05, dtype=torch.float64)}
0
{'loss': tensor(244.0327, dtype=torch.float64, grad_fn=<MeanBackward0>), 'component loss': tensor(104.8347, dtype=torch.float64), 'measurement difference': tensor(0.0841, dtype=torch.float64), 'maximum measurement difference': tensor(1.3503, dtype=torch.float64), 'minimum difference': tensor(2.2641e-05, dtype=torch.float64)}
1
{'loss': tensor(830.5244, dtype=torch.float64, grad_fn=<MeanBackward0>), 'component loss': tensor(200.9148, dtype=torch.float64), 'measurement difference': tensor(0.1300, dtype=torch.float64), 'maximum measurement difference': tensor(1.8014, dtype=torch.float64), 'minimum difference': tensor(0.0001, dtype=torch.float64)}
2
{'loss': tensor(2

KeyboardInterrupt: 

In [None]:
torch.save(model, "models/best.pt")
dataloader = torch.utils.data.DataLoader(dataset, config["batch_size"], shuffle=False)
model.eval()

for measures, components in dataloader:
    with torch.no_grad():
        preds = model(measures)
        # Get predicted and true vertices
        pred_verts = dataset.get_verts(preds)
        true_verts = dataset.get_verts(components)

        pred_measures = dataset.get_measures(components)

        print(f"{true_verts.shape=}    {preds.shape=}")
        print(f"{torch.any(pred_measures < 0, dim=0)=}")
        print(f"{[x['type'] for x in measurement_details]}")

        for i, (pv, tv) in enumerate(zip(pred_verts, true_verts[:5])):
            print(f"{pv.shape=}    {tv.shape=}")
            # Save predicted mesh
            with open(f"meshes/predicted_{i:04d}.obj", "w") as f:
                for v in pv:
                    f.write(f"v {v[0]} {v[1]} {v[2]}\n")
                for face in dataset.face2vert:
                    f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")

            # Save true mesh
            with open(f"meshes/true_{i:04d}.obj", "w") as f:
                for v in tv:
                    f.write(f"v {v[0]} {v[1]} {v[2]}\n")
                for face in dataset.face2vert:
                    f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")

    break

['wsl', '-e', './scripts/generate_limbs.sh', '--num_limbs', '16', '--path', './stls', '--start', '32', '--save_mesh', '0', '--scale', '1', '--seed', '21302']
true_verts.shape=torch.Size([16, 7732, 3])    preds.shape=torch.Size([16, 11])
torch.any(pred_measures < 0, dim=0)=tensor([False, False, False, False, False, False, False])
['circumference', 'circumference', 'circumference', 'circumference', 'length', 'width', 'width']
pv.shape=torch.Size([7732, 3])    tv.shape=torch.Size([7732, 3])
pv.shape=torch.Size([7732, 3])    tv.shape=torch.Size([7732, 3])
pv.shape=torch.Size([7732, 3])    tv.shape=torch.Size([7732, 3])
pv.shape=torch.Size([7732, 3])    tv.shape=torch.Size([7732, 3])
pv.shape=torch.Size([7732, 3])    tv.shape=torch.Size([7732, 3])


In [None]:
from nb_exporter import export_notebook
export_notebook("SSM_Driver.ipynb")