LOCALLY TRAINED NN

Implementation of a localized training approach for neural network interpolation.

Specifically, it trains the model on small triplets of neighboring data points, allowing the learning of the network on a local scale. By focusing on triplets, the function isolates the network's ability to learn smooth mappings between closely related inputs and outputs, avoiding global patterns and emphasizing local generalization.

### PREPARE TRIPLETS AND MODEL

In [None]:
# Function for performing inverse transform to predictions to get physical values
def inv_log_scaled(values, scaler):
    real_val = scaler.inverse_transform(np.array(values).reshape(-1, 1)).flatten()
    physical_val = 10**real_val
    return physical_val

# Function to get triplet indices
def get_interpolation_indices(center_index, data_length, gap):
    idx1 = center_index - gap
    idx3 = center_index + gap
    if idx1 < 0 or idx3 >= data_length:
        return None
    return idx1, center_index, idx3

# Function to load activation from string
def load_activation_fn(activation_str):
    activation_map = {
        "ReLU": nn.ReLU,
        "LeakyReLU": nn.LeakyReLU,
        "SiLU": nn.SiLU
    }

    if "LeakyReLU" in activation_str:
        return activation_map["LeakyReLU"]
    elif "ReLU" in activation_str:
        return activation_map["ReLU"]
    elif "SiLU" in activation_str:
        return activation_map["SiLU"]
    else:
        raise ValueError(f"Activation function '{activation_str}' not recognized.")

In [None]:
######### We upload the model just as in global interpolation but considerinh input_size=4 because we've concate inputs ######
top_path = # path to save metrics and plots
output_dir  = os.path.join(top_path, "Triplet")
os.makedirs(save_dir, exist_ok=True)

config_path =  #path were config.json is saved
def load_model_from_config(config_path, in_size, out_size):
    with open(config_path, "r") as f:
        config = json.load(f)
    hidden_layer_sizes = config.get("hidden_layer_sizes")
    dropout = config.get("dropout")
    activation_str = config.get("activation_fn")
    activation_cls = load_activation_fn(activation_str)

    model = Interpolation(in_size, out_size, hidden_layer_sizes, dropout, activation_cls)
    return model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_triplet = 4 # concate x1, x3 = [T1, nb1, T3, nb3]
output_triplet = 2 # P, S

# Instantiate the model, without  previous weights just hiperparameters !!!
triplet_model = load_model_from_config(config_path=config_path, in_size = input_triplet, out_size = output_triplet)
model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(triplet_model.parameters(), lr)

# 5 random index for both methods
max_index = len(in_test)
indices = random.sample(range(max_index), 5)
print(indices)

### LOCALLY TRAIN MODEL

In [None]:
gap = 1
# Training triplets
in_train_triplets = torch.stack([
    torch.cat((in_train_tensor[i - gap], in_train_tensor[i + gap]))
    for i in range(gap, len(in_train_tensor) - gap)
])
out_train_targets = out_train_tensor[gap : len(out_train_tensor) - gap]

# Validation triplets
in_val_triplets = torch.stack([
    torch.cat((in_val_tensor[i - gap], in_val_tensor[i + gap]))
    for i in range(gap, len(in_val_tensor) - gap)
])
out_val_targets = out_val_tensor[gap : len(out_val_tensor) - gap]

In [None]:
def train_triplet_model(model, in_train, out_train, in_val, out_val, criterion, optimizer, device, epochs, save_path):
    best_val_loss = float("inf")
    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        preds_train = model(in_train.to(device))
        loss_train = criterion(preds_train, out_train.to(device))
        loss_train.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            preds_val = model(in_val.to(device))
            loss_val = criterion(preds_val, out_val.to(device))

        train_losses.append(loss_train.item())
        val_losses.append(loss_val.item())
        print(f"Epoch {epoch+1:03d} | Train Loss: {loss_train.item():.5e} | Val Loss: {loss_val.item():.5e}")

        if loss_val.item() < best_val_loss:
            best_val_loss = loss_val.item()
            torch.save(model.state_dict(), save_path)
            print(f"  Best model saved (val_loss = {best_val_loss:.5e})")
    return train_losses, val_losses

In [None]:
save_path  = os.path.join(output_dir, "best_triplet_model.pth")

# Train triplet model and log losses
train_losses, val_losses = train_triplet_model(model=triplet_model, In_train=in_train_triplets, out_train=out_train_targets, in_val=in_val_triplets, out_val=out_val_targets,
                                               criterion=criterion, optimizer=optimizer, device=device, epochs=150, save_path=save_path)

losses_df = pd.DataFrame({"epoch": list(range(1, len(train_losses) + 1)), "train_loss": train_losses, "val_loss": val_losses})
losses_csv_path = os.path.join(output_dir, "triplet_model_losses.csv")
losses_df.to_csv(losses_csv_path, index=False)

# Plot losses
plt.figure(figsize=(10, 6))
plt.plot(losses_df["epoch"], losses_df["train_loss"], label="Train Loss", linewidth=2)
plt.plot(losses_df["epoch"], losses_df["val_loss"], label="Validation Loss", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.title("TripletNet Training and Validation Loss")
plt.yscale("log")
plt.grid(True, which='both', linestyle='--', alpha=0.5)
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "triplet_model_loss_curve.png"), dpi=300)
plt.show()

### INTERPOLATION

In [None]:
# #######################################################
# LOCAL NN TRIPLET INTERPOLATION
# #######################################################
def Interpolation_Triplet_test(model, inputs_set, outputs_set, scaler_P, scaler_S, device, indices, save_dir, gap):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    metrics = []

    model = model.to(device)

    for i, center_index in enumerate(indices):
        start_time = time.time()
        idx1, idx2, idx3 = get_interpolation_indices(center_index, len(inputs_set), gap)
        if idx1 is None:
            print(f"Saltando índice {center_index} (gap={gap}) debido a límites.")
            continue

        x1, x2, x3 = in_test[idx1], in_test[idx2], in_test[idx3]
        y1, y2_true, y3 = out_test[idx1], out_test[idx2], out_test[idx3]

        # Concatenate x1 and x3 for input in the Interpolation NN
        triplet_input_tensor = torch.cat([torch.FloatTensor(x1).unsqueeze(0), torch.FloatTensor(x3).unsqueeze(0)], dim=1).to(device)

        # Predictions
        with torch.no_grad():
            y2_pred_scaled = model(triplet_input_tensor).cpu().numpy()[0]

        # Get physical outputs
        P_true = 10 ** np.array([y1[0], y2_true[0], y3[0]])
        S_true = 10 ** np.array([y1[1], y2_true[1], y3[1]])
        P2_pred = inv_log_scaled(y2_pred_scaled[0], scaler_P)
        S2_pred = inv_log_scaled(y2_pred_scaled[1], scaler_S)

        # Metrics
        mse_P = mean_squared_error([P_true[1]], [P2_pred])
        mae_P = mean_absolute_error([P_true[1]], [P2_pred])
        mse_S = mean_squared_error([S_true[1]], [S2_pred])
        mae_S = mean_absolute_error([S_true[1]], [S2_pred])
        metrics.append({
            "method": "TripletNet", "index": int(center_index), "gap": gap, "mse_P": mse_P, "mae_P": mae_P, "mse_S": mse_S, "mae_S": mae_S})

        elapsed = time.time() - start_time
        for m in metrics:
            m["elapsed_sec"] = elapsed


        # Plot
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        x_labels = [f'Idx {idx1}', f'Idx {idx2}', f'Idx {idx3}']
        x_ticks = np.arange(3)

        axs[0].plot(x_ticks, P_true, label='True P', marker='o', color='green')
        axs[0].scatter(x_ticks[1], P2_pred, label='Predicted P', marker='x', s=100, color='red')
        axs[0].set_title('Pressure (P) [MeV/fm³]', fontsize=14)
        axs[0].set_xticks(x_ticks)
        axs[0].set_xticklabels(x_labels, fontsize=14)
        axs[0].tick_params(axis='y', labelsize=14)
        axs[0].grid(True)
        axs[0].legend(fontsize=14)
        axs[0].ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))

        axs[1].plot(x_ticks, S_true, label='True S', marker='o', color='blue')
        axs[1].scatter(x_ticks[1], S2_pred, label='Predicted S', marker='x', s=100, color='orange')
        axs[1].set_title('Entropy (S) [1/fm³]', fontsize=14)
        axs[1].set_xticks(x_ticks)
        axs[1].tick_params(axis='y', labelsize=14)
        axs[1].set_xticklabels(x_labels, fontsize=14)
        axs[1].grid(True)
        axs[1].legend(fontsize=14)
        axs[1].ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))

        plt.suptitle(f"NN Local Interpolation Test #{i+1} (Index {center_index})  (Gap {gap})", fontsize=14)
        plt.tight_layout()
        filename = f"TripletNet_triplet_{i+1}_index_{center_index}.png"
        plt.savefig(os.path.join(save_dir, filename), dpi=300)
        plt.close()

    return metrics

In [None]:
# Upload top1 model hyperparameteres
output_dir  = # path where top1 model is saved
best_model_path = os.path.join(output_dir , "best_triplet_model.pth")
triplet_model.load_state_dict(torch.load(best_model_path))
triplet_model.eval()

# Define saving directory
triplet_nn_save_base_dir = output_dir
os.makedirs(triplet_nn_save_base_dir, exist_ok=True)

# Interpolation schemes
for current_gap in [1, 5]:
    print(f"\n TripletNet and B-Spline with gap={current_gap}...")

    save_dir_gap = os.path.join(triplet_nn_save_base_dir, f"gap_{current_gap}")
    os.makedirs(save_dir_gap, exist_ok=True)

    # 1. TripletNet
    metrics_triplet_nn_gap = Interpolation_Triplet_test(model=triplet_model, inputs_set=in_test, outputs_set=out_test, scaler_P=scaler_P, scaler_S=scaler_S, device=device, indices=indices, save_dir=save_dir_gap, gap=current_gap)
    df_nn = pd.DataFrame(metrics_triplet_nn_gap)

    # 2. B-Spline
    metrics_spline_gap = interpolation_test_bspline(in_train=in_train, out_train=out_train, in_test=in_test, out_test=out_test, indices=indices, save_dir=save_dir_gap, gap=current_gap)
    df_spline = pd.DataFrame(metrics_spline_gap)

    # 3. Save metrics
    df_all = pd.concat([df_nn, df_spline], ignore_index=True)
    csv_path = os.path.join(save_dir_gap, f"metrics_triplet_gap_{current_gap}.csv")
    df_all.to_csv(csv_path, index=False)