In [15]:
import numpy as np
import torch
import os
import subprocess
from pathlib import Path

In [16]:
def generate_data(batch_size, path="./stls", start=0, scale=1):
    cmd = [
            "./scripts/generate_limbs.sh",
            "--num_limbs",
            f"{batch_size}",
            "--path",
            str(path),
            "--start",
            f"{start}",
            "--save_mesh",
            "0",
            "--scale",
            f"{scale}"
        ]

    if os.name == "nt":  # Windows
        cmd = ["wsl", "-e"] + cmd

    subprocess.run(cmd, check=False, stderr=subprocess.PIPE, stdout=subprocess.PIPE)

    while True:
        try:
            components = np.load(f"{path}/components_{start:08d}.npy")
        except FileNotFoundError:
            continue
        else:
            return torch.tensor(components)

batch_size = 1024
unscaled = generate_data(batch_size, start=0, scale=0)
scaled = generate_data(batch_size, start=batch_size, scale=1)

scaled.shape, unscaled.shape
# Calculate summary statistics for each component in unscaled and scaled tensors
def summarize_tensor(tensor, name):
    stats = {
        'mean': tensor.mean(dim=0),
        'std': tensor.std(dim=0),
        'min': tensor.min(dim=0).values,
        'max': tensor.max(dim=0).values,
    }
    print(f"Summary statistics for {name}:")
    for i in range(tensor.shape[1]):
        print(f"Component {i}: mean={stats['mean'][i]:.4f}, std={stats['std'][i]:.4f}, min={stats['min'][i]:.4f}, max={stats['max'][i]:.4f}")
    print()

summarize_tensor(unscaled, "unscaled")
summarize_tensor(scaled, "scaled")

Summary statistics for unscaled:
Component 0: mean=6.1444, std=14.2853, min=-19.9567, max=33.8261
Component 1: mean=2.6163, std=4.9341, min=-6.3180, max=12.6947
Component 2: mean=0.5011, std=3.1790, min=-6.7520, max=8.9552
Component 3: mean=0.5367, std=2.7262, min=-5.9810, max=7.9328
Component 4: mean=0.9793, std=2.4555, min=-4.8135, max=6.8879
Component 5: mean=-0.2139, std=1.7827, min=-4.7284, max=4.3558
Component 6: mean=0.2772, std=1.7266, min=-3.5803, max=4.5307
Component 7: mean=-0.0566, std=1.2382, min=-2.9769, max=2.7980
Component 8: mean=0.0421, std=0.8715, min=-2.2610, max=2.1669
Component 9: mean=0.2130, std=1.0338, min=-2.0038, max=2.5253
Component 10: mean=1.0000, std=0.0000, min=1.0000, max=1.0000

Summary statistics for scaled:
Component 0: mean=6.1444, std=14.2853, min=-19.9567, max=33.8261
Component 1: mean=2.6163, std=4.9341, min=-6.3180, max=12.6947
Component 2: mean=0.5011, std=3.1790, min=-6.7520, max=8.9552
Component 3: mean=0.5367, std=2.7262, min=-5.9810, max=7.

In [17]:
scaled_component_transforms = torch.vstack([torch.mean(scaled, dim=0, keepdim=True), torch.std(scaled, dim=0, keepdim=True)])
unscaled_component_transforms = torch.vstack([torch.mean(unscaled, dim=0, keepdim=True), torch.std(unscaled, dim=0, keepdim=True)])

torch.save(scaled_component_transforms, "./data_components/scaled_component_transforms.pt")
torch.save(unscaled_component_transforms, "./data_components/unscaled_component_transforms.pt")

scaled_component_transforms.shape

torch.Size([2, 11])

In [18]:
import torch
from torch import nn
import subprocess
import os
import numpy as np
import igl
import wandb

from functools import partial
import measure_limbs

In [19]:
class Measurements(nn.Module):
    def __init__(self, edge2vert, face2edge, details, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.edge2vert = edge2vert
        self.face2edge = face2edge
        self.details = details
        self.measures = []

        for detail in details:
            match detail["type"]:
                case "width":
                    measure = partial(measure_limbs.measure_width,
                        edge2vert=self.edge2vert,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                        plane_direction=detail["plane_direction"],
                    )
                case "length":
                    measure = partial(measure_limbs.measure_length,
                        v1=detail["v1"], v2=detail["v2"], direction=detail["direction"]
                    )
                case "circumference":
                    measure = partial(measure_limbs.measure_planar_circumference,
                        edge2vert=self.edge2vert,
                        face2edge=self.face2edge,
                        plane_point=detail["plane_point"],
                        plane_normal=detail["plane_normal"],
                    )
            self.measures.append(measure)

    def forward(self, x, verbose=False):
        measures = [measure(x) for measure in self.measures]
        if verbose:
            for measure, detail in zip(measures, self.details):
                print(f"{detail['name'].ljust(20)}:\t\t{measure}")
        return torch.stack(measures, dim=-1)

In [20]:
class LegMeasurementDataset(torch.utils.data.Dataset):
    def __init__(self, measure, batch_size=64, path="./stls", dtype=torch.float64, device="cpu"):
        super().__init__()
        self.dtype = dtype
        self.device = device
        self.measure = measure
        self.batch_size = batch_size
        self.path = path
        self.loaded_components = {}
        self.raw_components = torch.load("./data_components/vert_components.pt").to(dtype).to(device)
        self.mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype).to(device)
        self.face2vert = torch.load("./data_components/face2vert.pt").to(device)
        self.vert_mapping = torch.load("./data_components/vert_mapping.pt").to(device)
        self.component_transforms = torch.load("./data_components/scaled_component_transforms.pt").to(dtype).to(device)

        # Remove all .npy files from the specified path
        for file in os.listdir(self.path):
            if file.endswith(".npy"):
                os.remove(os.path.join(self.path, file))

        self.generate_data(0)
        self.generate_data(self.batch_size)

    def __len__(self):
        return 100_000

    def __getitem__(self, index):
        if index % self.batch_size == 0:
            self.generate_data(index + self.batch_size*2)
            ith_dataset = index // self.batch_size
            if ith_dataset >= 2:
                self.delete_data((ith_dataset - 2)*self.batch_size)

        try:
            components = self.loaded_components[(index // self.batch_size)*self.batch_size][
                index % self.batch_size
            ]
        except KeyError:
            self.generate_data(index)
            components = self.loaded_components(index)

        verts = self.get_verts(components)

        measurements = self.measure.forward(verts).squeeze()

        return measurements, components
        # return verts, self.face2vert, measurements, components
    

    def get_verts(self, components):
        if len(components.shape) == 1:
            components, scale = components[:-1], components[-1:]
        else:
            components, scale = components[:,:-1], components[:,-1:]
        
        total = torch.sum(self.raw_components[None] * components[..., None], dim=1)
        verts = self.mean_verts[None] + total.reshape((total.shape[0], self.mean_verts.shape[0], self.mean_verts.shape[1]))
        verts = verts[:, self.vert_mapping]*scale[..., None]

        return verts.squeeze()
    
    def get_measures(self, components=None, verts=None, verbose=False):
        if components is not None:
            verts = self.get_verts(components)
        
        return self.measure.forward(verts, verbose)

    def generate_data(self, start):
        cmd = [
            "./scripts/generate_limbs.sh",
            "--num_limbs",
            f"{self.batch_size}",
            "--path",
            self.path,
            "--start",
            f"{start}",
            "--save_mesh",
            "0",
            "--scale",
            "1"
        ]
        if os.name == "nt":  # Windows
            cmd = ["wsl", "-e"] + cmd

        subprocess.run(cmd, check=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE)

        while True:
            try:
                components = np.load(f"{self.path}/components_{start:08d}.npy")
            except FileNotFoundError:
                continue
            else:
                self.loaded_components[start] = torch.tensor(components, dtype=self.dtype, device=self.device)
                break

    def delete_data(self, start):
        # for i in range(start, start + self.batch_size):
            # try:
            #     os.remove(f"{self.path}/limb_{i:05d}.npy")
            # except FileNotFoundError:
            #     pass

        os.remove(f"{self.path}/components_{start:08d}.npy")
        self.loaded_components.pop(start)

In [21]:
dtype = torch.float64

mean_verts = torch.load("./data_components/mean_verts.pt").to(dtype)
face2vert = torch.load("./data_components/face2vert.pt")

edge2vert, face2edge, edge2face = igl.edge_topology(
    mean_verts.numpy(), face2vert.numpy()
)

edge2vert = torch.from_numpy(edge2vert)
face2edge = torch.from_numpy(face2edge)
edge2face = torch.from_numpy(edge2face)

vert_idxs = torch.load("./data_components/selected_verts.pt")
vert_idxs = torch.load("./data_components/selected_verts.pt")

# Order
# Mid patella tendon
# Distal tibia
# Knee widest
# Knee above? This one feels off
# Over fib head
# Fib head
# Circ 3
# Circ 4

measurement_details = (
    {
        # Circ one
        "type": "circumference",
        "plane_point": vert_idxs[4],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 1",
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[5],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 2",
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[6],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 3",
    },
    {
        "type": "circumference",
        "plane_point": vert_idxs[7],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Circumference 4",
    },
    {
        "type": "length", 
        "v1": vert_idxs[0], 
        "v2": vert_idxs[1], 
        "direction": torch.tensor([[0, 0, 1]], dtype=dtype),
        "name":"Length 1",
    },
    {
        "type": "width",
        "plane_point": vert_idxs[2],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "plane_direction": torch.tensor([[1, 0, 0]], dtype=dtype),
        "name":"Width 1",
    },
    {
        "type": "width",
        "plane_point": vert_idxs[3],
        "plane_normal": torch.tensor([[0, 0, 1]], dtype=dtype),
        "plane_direction": torch.tensor([[1, 0, 0]], dtype=dtype),
        "name":"Width 2",
    },
)

measure = Measurements(
    edge2vert,
    face2edge,
    measurement_details,
)

In [22]:
print(f"{scaled.shape=}")
dataset = LegMeasurementDataset(measure, batch_size)
scaled_measures = dataset.get_measures(scaled)
unscaled_measures = dataset.get_measures(unscaled)

scaled.shape=torch.Size([1024, 11])


In [27]:
summarize_tensor(unscaled_measures, "Unscaled Measures")
summarize_tensor(unscaled, "Unscaled Components")
summarize_tensor(scaled_measures, "Scaled Measures")
summarize_tensor(scaled, "Scaled Components")

Summary statistics for Unscaled Measures:
Component 0: mean=0.8744, std=0.0947, min=0.6319, max=1.1155
Component 1: mean=0.8454, std=0.0985, min=0.6096, max=1.0961
Component 2: mean=0.8093, std=0.0948, min=0.5900, max=1.0624
Component 3: mean=0.7672, std=0.0911, min=0.5486, max=1.0472
Component 4: mean=0.3432, std=0.0723, min=0.1987, max=0.4981
Component 5: mean=0.3221, std=0.0320, min=0.2436, max=0.4111
Component 6: mean=0.3374, std=0.0326, min=0.2553, max=0.4262

Summary statistics for Unscaled Components:
Component 0: mean=6.1444, std=14.2853, min=-19.9567, max=33.8261
Component 1: mean=2.6163, std=4.9341, min=-6.3180, max=12.6947
Component 2: mean=0.5011, std=3.1790, min=-6.7520, max=8.9552
Component 3: mean=0.5367, std=2.7262, min=-5.9810, max=7.9328
Component 4: mean=0.9793, std=2.4555, min=-4.8135, max=6.8879
Component 5: mean=-0.2139, std=1.7827, min=-4.7284, max=4.3558
Component 6: mean=0.2772, std=1.7266, min=-3.5803, max=4.5307
Component 7: mean=-0.0566, std=1.2382, min=-2.9

In [28]:
scaled_measurement_transforms = torch.vstack([torch.mean(scaled_measures, dim=0, keepdim=True), torch.std(scaled_measures, dim=0, keepdim=True)])
unscaled_measurement_transforms = torch.vstack([torch.mean(unscaled_measures, dim=0, keepdim=True), torch.std(unscaled_measures, dim=0, keepdim=True)])

torch.save(scaled_measurement_transforms, "./data_components/scaled_measurement_transforms.pt")
torch.save(unscaled_measurement_transforms, "./data_components/unscaled_measurement_transforms.pt")