In [1]:
import socket
import json
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import gaussian_filter1d
from scipy.interpolate import RBFInterpolator
import joblib
import csv
import os

plt.ion()
plt.switch_backend('TkAgg')

HOST = '172.26.128.172'
PORT = 65432
FORCE_CSV = "predicted_force.csv"

with open(FORCE_CSV, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['estimated_force', 'peak_location_x'])

model_data = joblib.load("Interpolation_model.pkl")
model = model_data["model"]
X_scaler = model_data["X_scaler"]
slice_coords = np.asarray(model_data["slice_coords"])
sensor_points = model_data["sensor_points"][::-1]

# Setup all figures ONCE
fig_map, ax_map = plt.subplots(figsize=(6.4, 5))
fig_map.canvas.manager.set_window_title("σ_zz Field")
fig_map.canvas.manager.window.overrideredirect(True)

fig_curve, ax_curve = plt.subplots(figsize=(6.4, 4.5))
fig_curve.canvas.manager.set_window_title("Contact Pressure Curve")
fig_curve.canvas.manager.window.overrideredirect(True)

fig_loc, ax_loc = plt.subplots(figsize=(6, 6))
fig_loc.canvas.manager.set_window_title("True vs Predicted Location")
fig_loc.canvas.manager.window.overrideredirect(True)

try:
    fig_map.canvas.manager.window.wm_geometry("+0+0")
    fig_curve.canvas.manager.window.wm_geometry("+640+0")
    fig_loc.canvas.manager.window.wm_geometry("+1280+0")
except:
    pass

plt.show(block=False)

real_locs = []
est_locs = []

def append_force_to_csv(force, x_peak):
    try:
        with open(FORCE_CSV, mode='a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([f"{force:.6f}", f"{x_peak:.6f}"])
    except Exception as e:
        print(f"Error writing to predicted force CSV: {e}")

def update_location_plot(real_x, est_x):
    real_locs.append(real_x)          # already in mm
    est_locs.append(est_x * 1000)     # convert from m to mm
    ax_loc.clear()
    ax_loc.scatter(real_locs, est_locs, color='darkgreen', s=30, label='Predictions')
    min_val = min(real_locs + est_locs)
    max_val = max(real_locs + est_locs)
    ax_loc.plot([min_val, max_val], [min_val, max_val], color='gray', linestyle='--', label='Ground Truth')
    ax_loc.set_xlabel("True x-location [mm]")
    ax_loc.set_ylabel("Predicted x-location [mm]")
    ax_loc.set_title("True vs Predicted Contact Location")
    ax_loc.set_aspect('equal', 'box')
    ax_loc.grid(True, linestyle='--', alpha=0.4)
    ax_loc.legend()
    fig_loc.tight_layout()
    fig_loc.canvas.draw()
    fig_loc.canvas.flush_events()

def predict_and_plot(sensor_values, real_x):
    global ax_map

    if len(sensor_values) != len(sensor_points) + 1:
        print(f"Expected {len(sensor_points)} sensor values + 1 force, got {len(sensor_values)}")
        return

    sensor_only = sensor_values[:-1][::-1]
    force_gt = float(sensor_values[-1])

    sensor_input = []
    for i, pt in enumerate(sensor_points):
        val = float(sensor_only[i])
        sensor_input.extend([pt[0], pt[2], val])
    sensor_input = np.array(sensor_input).reshape(1, -1)

    X_pos = sensor_input[:, 0::3]
    Z_pos = sensor_input[:, 1::3]
    vals = sensor_input[:, 2::3]
    norm = np.linalg.norm(vals, axis=1, keepdims=True)
    vals_unit = vals / (norm + 1e-8)

    X_input_normed = np.empty_like(sensor_input)
    X_input_normed[:, 0::3] = X_pos
    X_input_normed[:, 1::3] = Z_pos
    X_input_normed[:, 2::3] = vals_unit
    X_scaled = X_scaler.transform(X_input_normed)

    y_unit_pred = model.predict(X_scaled)
    strain_slice_pred = (y_unit_pred * norm).flatten()

    try:
        fig_map.clear()
        ax_map = fig_map.add_subplot(111)
        x_mm = slice_coords[:, 0] * 1000
        z_mm = slice_coords[:, 1] * 1000
        tri = ax_map.tricontourf(x_mm, z_mm, strain_slice_pred, levels=14, cmap='viridis')
        ax_map.set_title("Full Predicted σ_zz Field")
        ax_map.set_xlabel("x [mm]")
        ax_map.set_ylabel("z [mm]")
        ax_map.axis("scaled")
        fig_map.colorbar(tri, ax=ax_map, label="σ_zz")
        fig_map.tight_layout()
        fig_map.canvas.draw()
        fig_map.canvas.flush_events()
    except Exception as e:
        print(f"Error plotting field map: {e}")
        return

    z_vals = slice_coords[:, 1]
    z_top = np.max(z_vals)
    top_mask = np.abs(z_vals - z_top) < 1e-6
    x_top = slice_coords[top_mask][:, 0]
    sigma_top = strain_slice_pred[top_mask]

    if len(x_top) == 0:
        print("No top surface points found")
        return

    sort_idx = np.argsort(x_top)
    x_sorted = x_top[sort_idx]
    sigma_sorted = sigma_top[sort_idx]

    try:
        x_sorted_2d = x_sorted.reshape(-1, 1)
        x_dense = np.linspace(x_sorted.min(), x_sorted.max(), 500).reshape(-1, 1)
        rbf = RBFInterpolator(x_sorted_2d, sigma_sorted, smoothing=0.0, kernel='thin_plate_spline')
        sigma_dense = rbf(x_dense)
        sigma_smooth = gaussian_filter1d(sigma_dense, sigma=8)

        peak_idx = np.argmin(sigma_smooth)
        x_peak = x_dense.flatten()[peak_idx]
        peak_value = sigma_smooth[peak_idx]

        print(f"Peak contact pressure: {peak_value:.4f} at x = {x_peak:.5f} m")

        contact_diameter = 0.006
        contact_radius = contact_diameter / 2
        x_lower = x_peak - contact_radius
        x_upper = x_peak + contact_radius

        hard_mask = (x_dense.flatten() >= x_lower) & (x_dense.flatten() <= x_upper)
        x_contact = x_dense[hard_mask].flatten()
        sigma_contact = sigma_smooth[hard_mask].flatten()

        if len(x_contact) > 0:
            contact_area = np.pi * contact_radius**2
            avg_pressure = np.mean(sigma_contact)
            estimated_force = avg_pressure * contact_area
        else:
            estimated_force = 0.0

        print(f"Estimated force: {estimated_force:.4f} N vs Ground truth: {force_gt:.4f} N")

        append_force_to_csv(estimated_force, x_peak)

        ax_curve.clear()
        if len(x_contact) > 0:
            ax_curve.plot(x_contact * 1000, sigma_contact, color='firebrick', linewidth=2)
            ax_curve.fill_between(x_contact * 1000, sigma_contact, 0, color='blue', alpha=0.5)

        ax_curve.axhline(0, color='gray', linestyle='--', linewidth=1)
        ax_curve.set_title(f"Contact Pressure Curve")
        ax_curve.set_xlabel("x-position [mm]")
        ax_curve.set_ylabel("Contact Pressure")
        ax_curve.set_xlim(0, 54.2)
        ax_curve.set_ylim(1.5 * peak_value, 0)
        ax_curve.grid(True, linestyle="--", alpha=0.4)
        fig_curve.tight_layout()
        fig_curve.canvas.draw()
        fig_curve.canvas.flush_events()

        update_location_plot(real_x, x_peak)

    except Exception as e:
        print(f"Error in curve processing/plotting: {e}")

try:
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        s.bind((HOST, PORT))
        s.listen()
        print("Listening for sensor data from Windows...")

        conn, addr = s.accept()
        with conn:
            print(f"Connection established from {addr}")
            buffer = ""
            while True:
                try:
                    data = conn.recv(4096)
                    if not data:
                        print("Client disconnected")
                        break

                    buffer += data.decode('utf-8')
                    while '\n' in buffer:
                        line, buffer = buffer.split('\n', 1)
                        line = line.strip()
                        if line:
                            try:
                                payload = json.loads(line)
                                print(f"Received X={payload.get('x', 'N/A')} Z={payload.get('z', 'N/A')}")
                                if 'readings' in payload:
                                    x_real = float(payload.get("x", 0))  # assumed already in mm
                                    predict_and_plot(payload["readings"], x_real)
                                else:
                                    print("No 'readings' field in payload")
                            except json.JSONDecodeError as e:
                                print(f"JSON decode error: {e}")
                            except Exception as e:
                                print(f"Error processing data: {e}")
                except socket.error as e:
                    print(f"Socket error: {e}")
                    break
                except KeyboardInterrupt:
                    print("\nInterrupted by user")
                    break
                except Exception as e:
                    print(f"Unexpected error: {e}")
                    break

except Exception as e:
    print(f"Server setup error: {e}")
finally:
    plt.ioff()
    print("Server stopped")


Listening for sensor data from Windows...
Connection established from ('172.26.128.1', 52138)
Received X=0 Z=0
Peak contact pressure: -5338.0154 at x = 0.02056 m
Estimated force: -0.1022 N vs Ground truth: 0.0000 N
Received X=0 Z=0
Peak contact pressure: -2478.7958 at x = 0.04226 m
Estimated force: -0.0565 N vs Ground truth: 0.0000 N
Received X=0 Z=0
Peak contact pressure: -6344.5995 at x = 0.01896 m
Estimated force: -0.1531 N vs Ground truth: 0.0000 N
Received X=0 Z=0
Peak contact pressure: -6315.1867 at x = 0.01237 m
Estimated force: -0.1251 N vs Ground truth: 0.0000 N
Received X=0 Z=0
Peak contact pressure: -2649.1689 at x = 0.04704 m
Estimated force: 0.0065 N vs Ground truth: 0.0000 N
Received X=0 Z=0
Peak contact pressure: -4155.6548 at x = 0.02822 m
Estimated force: -0.0776 N vs Ground truth: 0.0000 N
Received X=0 Z=0
Peak contact pressure: -5500.1118 at x = 0.01194 m
Estimated force: -0.0943 N vs Ground truth: 0.0000 N
Received X=0 Z=0
Peak contact pressure: -4513.0830 at x = 0.