In [1]:
import json
import pandas as pd

In [3]:
# load full TSFresh feature table
df_all = pd.read_parquet("../data/features/tsfresh_efficient_set1-13.parquet")  # or merged df

# load selected feature names
with open("../data/features/sensor_features_for_fusion.json", "r") as f:
    sensor_features = json.load(f)

print("Number of selected sensor features:", len(sensor_features))


Number of selected sensor features: 157


In [4]:
id_col = "image_name"
target_col = "wear_level"

X_sensor = df_all[[id_col] + sensor_features].copy()
y = df_all[target_col].values

print("X_sensor shape:", X_sensor.shape)


X_sensor shape: (1081, 158)


In [5]:
# ensure all features exist
missing = [f for f in sensor_features if f not in df_all.columns]
print("Missing features:", missing)

# check row alignment key
assert X_sensor[id_col].is_unique


Missing features: []


In [None]:
#Model

import torch
import torch.nn as nn
import torch.nn.functional as F

class FusionPINN(nn.Module):
    def __init__(self, d_img: int, d_sensor: int, hidden: int = 256):
        super().__init__()

        # Learnable Taylor params (global)
        self.n = nn.Parameter(torch.tensor(0.2))         # exponent, will be constrained positive below
        self.logC = nn.Parameter(torch.tensor(0.0))      # log(C)

        # Optional small sensor encoder (helps, but you can remove if you want)
        self.sensor_enc = nn.Sequential(
            nn.Linear(d_sensor, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 128),
            nn.ReLU(),
        )

        # Tool-life head (predicts T in same "unit" as t, e.g., samples-to-failure)
        self.T_head = nn.Sequential(
            nn.Linear(d_img + 128 + 1, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )

        # Wear head (consumes tau)
        # +1 for has_sensor, +1 for tau
        self.wear_head = nn.Sequential(
            nn.Linear(d_img + 128 + 1 + 1, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )

    def forward(self, img_emb, sensor_feat, has_sensor, t, Vc):
        """
        img_emb:     (B, d_img)
        sensor_feat: (B, d_sensor)  (use zeros for missing sensors)
        has_sensor:  (B, 1)         (0/1 mask)
        t:           (B, 1)         elapsed usage proxy (e.g. sample_index within tool)
        Vc:          (B, 1)         cutting speed
        """
        s_emb = self.sensor_enc(sensor_feat)

        base = torch.cat([img_emb, s_emb, has_sensor], dim=1)

        # Predict positive T (avoid negative / zero)
        T_raw = self.T_head(base)                  # (B,1)
        T_hat = F.softplus(T_raw) + 1e-6           # ensure > 0

        # Normalized age
        tau = t / (T_hat + 1e-6)                   # (B,1)

        # Predict wear using tau as an extra feature
        wear_inp = torch.cat([base, tau], dim=1)
        w_hat = self.wear_head(wear_inp)           # (B,1)

        # Taylor loss in log-space: (log Vc + n log T - log C)^2
        n_pos = F.softplus(self.n) + 1e-6          # enforce n>0
        L_taylor = (torch.log(Vc + 1e-6) + n_pos * torch.log(T_hat) - self.logC) ** 2
        L_taylor = L_taylor.mean()

        return w_hat, T_hat, tau, L_taylor


def training_step(model, batch, lam=0.1):
    """
    batch should contain:
      img_emb, sensor_feat, has_sensor, t, Vc, wear_target
    """
    img_emb     = batch["img_emb"]
    sensor_feat = batch["sensor_feat"]
    has_sensor  = batch["has_sensor"]
    t           = batch["t"]
    Vc          = batch["Vc"]
    wear_true   = batch["wear"]

    w_hat, T_hat, tau, L_taylor = model(img_emb, sensor_feat, has_sensor, t, Vc)

    # Data loss (change to MAE if you prefer)
    L_wear = F.mse_loss(w_hat, wear_true)

    # Total loss
    loss = L_wear + lam * L_taylor
    return loss, {"L_wear": L_wear.item(), "L_taylor": L_taylor.item()}


In [None]:
model = FusionPINN(d_img=512, d_sensor=200)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for batch in loader:
    optimizer.zero_grad()

    loss, logs = training_step(model, batch, lam=0.1)

    loss.backward()
    optimizer.step()
