# SSM Driver
So what are we trying to do here? We're basically seeing if we can convince an ANN to predict the components of the SSM needed to match the measurements given as input. That way we can take 1D measurements and convert them into semi-realistic 3D data which would be nice. How we exactly are going to manage that? I honestly don't know. I guess the things we need are:
1. **Dataloader** - a means of generating pairs of 3D models, their components and their measurements. That means I actually need to nail down how we define these measurements in the first place which isnt entirely obvious. I guess I just go with something that looks reasonable for now and we can refine exactly where the measurements are taken at a later date. Either way we're going to be giving it 4 circumfrential measures, 2 widths and a length and we'll see what comes out of it. 
2. **Loss function** - How exactly are we defining this loss - probably easiest by just using MSE between the 2 sets of measurements but it might be worth normalising them against the reference measure so that everything is of the same magnitude. Also probably worth creating a class or something that can just be passed the verts and output a set of measurements - this probably should be a `nn.Module` shouldn't it so it gets those nice properties? Idk maybe that doesnt matter too much?
3. **Model** - This is probably the simplest part of the whole shabang. We can just start with a really simple dense network and see what happens. Maybe throw in some normalisation but really this should be as simple as possible.


Something I havent really thought too much about is that I need to create these limbs based on my components 

In [12]:
import torch
from torch import nn
import subprocess
import os
import numpy as np
import igl
import wandb

from functools import partial
import measure_limbs

In [13]:
selected_verts = []
with open("verts.txt", "r") as f:
    for line in f:
        if line[0] == "#":
            continue
        else:
            selected_verts.append([float(x) for x in line.strip().split(", ")])

selected_verts = torch.tensor(selected_verts)

verts, face2vert = igl.read_triangle_mesh("limb_00000.stl")
verts = torch.tensor(verts[torch.load("./data_components/vert_mapping.pt")])

verts.shape, selected_verts.shape

vert_idxs = torch.argmin(torch.linalg.norm(verts[None] - selected_verts[:, None], dim=-1), dim=-1)

with open("test_verts.obj", "w") as f:
    for v in vert_idxs:
        f.write(f"v {verts[v][0]} {verts[v][1]} {verts[v][2]}\n") 

torch.save(vert_idxs, "data_components/selected_verts.pt")

In [14]:
config = {
    "lr": 1e-2,
    "eta_min": 0.00001,
    "batch_size": 256,
    "log": True,
    "seed": 42,
}

torch.manual_seed(config["seed"])

if config["log"]:
    wandb.init(project="Open Limb SSM", config=config)

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
component loss,█▇▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
maximum measurement difference,▇▆▇▇█▇██▇▇▇▇▆▇▂▇▇█▇▅▅▆▆▅██▇█▆█▁█▇▇▃▇▃▆▆█
measurement difference,█▇▅▄▃▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
minimum difference,█▄▃▁▃▁▃▃▂▃▂▂▃▁▂▂▃▃▂▃▃▃▃▂▃▃▃▂▂▂▂▄▂▂▂▁▃▃▂▁

0,1
component loss,13840.52995
loss,0.99069
maximum measurement difference,0.99991
measurement difference,0.99529
minimum difference,0.98828


In [15]:
class Measurements(nn.Module):
    def __init__(self, edge2vert, face2edge, details, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.edge2vert = edge2vert
        self.face2edge = face2edge
        self.details = details
        self.measures = []

        for detail in details:
            match detail["type"]:
                case "width":
                    measure = partial(measure_limbs.measure_width,
                        edge2vert=self.edge2vert,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                        plane_direction=detail["plane_direction"],
                    )
                case "length":
                    measure = partial(measure_limbs.measure_length,
                        v1=detail["v1"], v2=detail["v2"], direction=detail["direction"]
                    )
                case "circumference":
                    measure = partial(measure_limbs.measure_planar_circumference,
                        edge2vert=self.edge2vert,
                        face2edge=self.face2edge,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                    )
            self.measures.append(measure)

    def forward(self, x, verbose=False):
        measures = [measure(x) for measure in self.measures]
        if verbose:
            for measure, detail in zip(measures, self.details):
                print(f"{detail['name'].ljust(20)}:\t\t{measure}")
        return torch.stack(measures, dim=-1)


In [16]:
class LegMeasurementDataset(torch.utils.data.Dataset):
    def __init__(self, measure, batch_size=64, path="./stls", dtype=torch.float64, device="cpu"):
        super().__init__()
        self.dtype = dtype
        self.device = device
        self.measure = measure
        self.batch_size = batch_size
        self.path = path
        self.loaded_components = {}
        self.raw_components = torch.load("./data_components/vert_components.pt").to(dtype).to(device)
        self.mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype).to(device)
        self.face2vert = torch.load("./data_components/face2vert.pt").to(device)
        self.vert_mapping = torch.load("./data_components/vert_mapping.pt").to(device)

        # Remove all .npy files from the specified path
        for file in os.listdir(self.path):
            if file.endswith(".npy"):
                os.remove(os.path.join(self.path, file))

        self.generate_data(0)
        self.generate_data(self.batch_size)

    def __len__(self):
        return 100_000

    def __getitem__(self, index):
        if index % self.batch_size == 0:
            self.generate_data(index + self.batch_size*2)
            ith_dataset = index // self.batch_size
            if ith_dataset >= 2:
                self.delete_data((ith_dataset - 2)*self.batch_size)

        components = self.loaded_components[(index // self.batch_size)*self.batch_size][
            index % self.batch_size
        ]

        verts = self.get_verts(components)

        measurements = self.measure.forward(verts).squeeze()

        return measurements, components
        # return verts, self.face2vert, measurements, components
    

    def get_verts(self, components):
        if len(components.shape) == 1:
            components, scale = components[:-1], components[-1:]
        else:
            components, scale = components[:,:-1], components[:,-1:]
        
        total = torch.sum(self.raw_components[None] * components[..., None], dim=1)
        verts = self.mean_verts[None] + total.reshape((total.shape[0], self.mean_verts.shape[0], self.mean_verts.shape[1]))
        verts = verts[:, self.vert_mapping]*scale[..., None]

        return verts.squeeze()
    
    def get_measures(self, components=None, verts=None, verbose=False):
        if components is not None:
            verts = self.get_verts(components)
        
        return self.measure.forward(verts, verbose)

    def generate_data(self, start):
        cmd = [
            "./generate_limbs.sh",
            "--num_limbs",
            f"{self.batch_size}",
            "--path",
            self.path,
            "--start",
            f"{start}",
            "--save_mesh",
            "0",
            "--scale",
            "1"
        ]
        if os.name == "nt":  # Windows
            cmd = ["wsl", "-e"] + cmd

        subprocess.run(cmd, check=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE)

        while True:
            try:
                components = np.load(f"{self.path}/components_{start:08d}.npy")
            except FileNotFoundError:
                continue
            else:
                self.loaded_components[start] = torch.tensor(components, dtype=self.dtype, device=self.device)
                break

    def delete_data(self, start):
        # for i in range(start, start + self.batch_size):
            # try:
            #     os.remove(f"{self.path}/limb_{i:05d}.npy")
            # except FileNotFoundError:
            #     pass

        os.remove(f"{self.path}/components_{start:08d}.npy")
        self.loaded_components.pop(start)

In [17]:
dtype = torch.float64

mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype)
face2vert = torch.load("./data_components/face2vert.pt")

edge2vert, face2edge, edge2face = igl.edge_topology(
    mean_verts.numpy(), face2vert.numpy()
)

edge2vert = torch.from_numpy(edge2vert)
face2edge = torch.from_numpy(face2edge)
edge2face = torch.from_numpy(edge2face)


vert_idxs = torch.load("./data_components/selected_verts.pt")

# Order
# Mid patella tendon
# Distal tibia
# Knee widest
# Knee above? This one feels off
# Over fib head
# Fib head
# Circ 3
# Circ 4

measurement_details = (
    {
        # Circ one
        "type": "circumference",
        "plane_point": vert_idxs[4],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 1",
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[5],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 2",
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[6],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 3",
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[7],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 4",
    },
    {
        "type": "length", 
        "v1": vert_idxs[0], 
        "v2": vert_idxs[1], 
        "direction": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Length 1",
    },
    {
        "type": "width",
        "plane_point": vert_idxs[2],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "plane_direction": torch.tensor([[1, 0, 0]], dtype=dtype),
        "name":"Width 1",
    },
    {
        "type": "width",
        "plane_point": vert_idxs[3],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "plane_direction": torch.tensor([[1, 0, 0]], dtype=dtype),
        "name":"Width 2",
    },
)

measure = Measurements(
    edge2vert,
    face2edge,
    measurement_details,
)

In [18]:
class Model(nn.Module):
    def __init__(self, input_dims, output_dims, dtype, activation=nn.ReLU(inplace=True)):
        super().__init__()

        self.layers = nn.ParameterList([
            nn.BatchNorm1d(num_features=input_dims, dtype=dtype),
            nn.Linear(input_dims, 256, dtype=dtype),
            nn.BatchNorm1d(num_features=256, dtype=dtype),
            activation,
            nn.Linear(256, 1024, dtype=dtype),
            nn.BatchNorm1d(num_features=1024, dtype=dtype),
            activation,
            nn.Linear(1024, 2048, dtype=dtype),
            nn.BatchNorm1d(num_features=2048, dtype=dtype),
            activation,
            nn.Linear(2048, 128, dtype=dtype),
            nn.BatchNorm1d(num_features=128, dtype=dtype),
            activation,
            nn.Dropout(0.3),
            nn.Linear(128, output_dims, dtype=dtype)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        return x

In [19]:
class MeasurementLoss(nn.Module):
    def __init__(self, measures: LegMeasurementDataset):
        super().__init__()
        self.measures = measures

    def forward(self, components, true_components):
        pred_measures = self.measures.get_measures(components, verbose=False)
        true_measures = self.measures.get_measures(true_components, verbose=False)
        # print(f"{torch.isnan(components).sum()=}    {torch.isnan(measurements).sum()=}    {pred_measures.shape=}")
        return torch.mean((pred_measures - true_measures)**2 / (true_measures**2) + 1E-12)

In [20]:
class PointwiseLoss(nn.Module):
    def __init__(self, measures):
        super().__init__()
        self.measures = measures

    def forward(self, pred_components, true_components):
        pred_verts = self.measures.get_verts(pred_components)
        true_verts = self.measures.get_verts(true_components)
        return torch.mean((pred_verts - true_verts)**2)

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = LegMeasurementDataset(measure, batch_size=config["batch_size"], dtype=dtype, device="cpu")
dataloader = torch.utils.data.DataLoader(dataset, config["batch_size"], shuffle=False)
# loss_func = nn.MSELoss().to(device)
loss_func = MeasurementLoss(dataset).to(device)
# loss_func = PointwiseLoss(dataset).to(device)
component_loss = nn.MSELoss()

# Add one for the scale factor 
model = Model(len(measurement_details), dataset.raw_components.shape[0] + 1, dtype).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-4)

In [22]:
i = 0
alpha = 10
decay = 0.9
for measures, components in dataloader:
    measures = measures.to(device)
    components = components.to(device)
    
    preds = model(measures)

    loss = loss_func(preds, components) + component_loss(preds, components)*alpha
    # loss = loss_func(preds, components)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    with torch.no_grad():
        pred_measures = dataset.get_measures(preds.to("cpu"), None)
        measures = measures.to("cpu")

        if torch.isnan(loss):
            print(components)
            print(preds)
            break

        if config["log"]:
            wandb.log({
                "loss": loss,
                "component loss":  component_loss(preds, components),
                "measurement difference": torch.mean((pred_measures - measures).abs() / measures),
                "maximum measurement difference": torch.max((pred_measures - measures).abs() / measures),
                "minimum difference": torch.min((pred_measures - measures).abs() / measures) 
            })
        else: 
            print({
                "loss": loss,
                "component loss":  component_loss(preds, components),
                "measurement difference": torch.mean((pred_measures - measures).abs() / measures),
                "maximum measurement difference": torch.max((pred_measures - measures).abs() / measures),
                "minimum difference": torch.min((pred_measures - measures).abs() / measures) 
            })
            pass
    alpha *= decay
    if i > 200:
        break
    else:
        print(i)
    
    i += 1


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200


In [23]:
from pathlib import Path

path = Path(r"C:\Users\chris\Documents\Core\Daily Notes")

files = sorted(path.glob("*"))
# output_path = "./Journalling.md"
output_path = path / "Journalling.md"
with open(output_path, "w") as outfile:
    for file in files:
        print(file.stem)
        date_str = file.stem[:10]
        if date_str >= "2025-05-01" and date_str != "Journallin":
            outfile.write(f"# {file.stem}\n")
            with open(file, "r") as infile:
                outfile.write(infile.read())
            outfile.write("\n\n---\n\n")

2023-08-18
2023-09-06
2023-09-12
2023-09-14
2023-09-25
2023-09-29
2023-10-19
2023-10-21
2023-10-22
2023-10-24
2023-10-25
2023-10-27
2023-10-28
2023-10-29
2023-10-30
2023-10-31
2023-11-02
2023-11-03
2023-11-04
2023-11-05
2023-11-10
2023-11-17
2023-11-30
2023-12-11
2023-12-13
2023-12-15
2023-12-16
2023-12-19
2023-12-20
2023-12-22
2023-12-25
2024-01-22
2024-01-30
2024-03-22
2024-04-18
2024-04-19
2024-04-24
2024-04-26
2024-04-29
2024-05-12
2024-06-04
2024-06-06
2024-06-19
2024-07-20
2024-09-19
2024-10-07
2024-10-11
2024-11-04
2025-01-22
2025-01-27
2025-01-31
2025-02-04
2025-02-15
2025-03-02
2025-06-04
2025-06-07
2025-06-13
2025-07-02 Sentence Completion
2025-07-05
2025-07-06
2025-07-07
2025-07-08
2025-07-09
2025-07-10
2025-07-11
2025-07-14
2025-07-15
2025-07-16
2025-07-17
2025-07-18
2025-07-19
2025-07-20
2025-07-21
2025-07-22
2025-07-23
2025-07-24
2025-07-25
2025-07-26
Journalling
