# Interactive Airfoil VAE Explorer

This notebook loads a trained `AirfoilVAE` model and provides a user interface to explore the latent space.

### Features:
* **Physical Sliders:** Control physically aligned variables (Thickness, Camber, Leading Edge Radius).
* **Free Sliders:** Explore entangled/free latent variables.
* **Real-time Geometry:** Visualizes the resulting airfoil shape using the B-Spline decoder.

### Instructions:
1.  Ensure your trained model weights (`.pth` file) are available.
2.  Update the `MODEL_PATH` variable in the configuration cell if necessary.
3.  Run all cells.

In [1]:
%matplotlib widget
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interactive_output
from IPython.display import display, clear_output

# Check device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running on: {DEVICE}")

Running on: cpu


## 1. Model Definition
We redefine the model architecture here to ensure the notebook is self-contained.

In [2]:
# ==========================================
# MODEL ARCHITECTURE (Copied from Source)
# ==========================================

class BSplineLayer(nn.Module):
    def __init__(self, num_control_points, degree=3, num_eval_points=100, device='cpu'):
        super().__init__()
        self.num_control_points = num_control_points
        self.degree = degree
        theta = torch.linspace(0, np.pi, num_eval_points, device=device)
        u = 0.5 * (1 - torch.cos(theta))
        self.basis_matrix = self._precompute_basis(u).to(device)

    def _precompute_basis(self, u):
        n, p = self.num_control_points - 1, self.degree
        m = n + p + 1
        kv = torch.zeros(m + 1, device=u.device)
        start, end = p + 1, m - p
        if end > start:
            kv[start:end] = torch.linspace(0, 1, end - start + 2, device=u.device)[1:-1]
        kv[end:] = 1.0
        N = torch.zeros(u.shape[0], m, device=u.device)
        for i in range(m):
            mask = (u >= kv[i]) & (u < kv[i + 1])
            if i == m - 1: mask = mask | (u == kv[i + 1])
            N[:, i] = mask.float()
        for d in range(1, p + 1):
            N_new = torch.zeros(u.shape[0], m - d, device=u.device)
            for i in range(m - d - 1):
                d1 = kv[i + d] - kv[i]
                d2 = kv[i + d + 1] - kv[i + 1]
                t1 = ((u - kv[i]) / d1) * N[:, i] if d1 > 1e-6 else 0.0
                t2 = ((kv[i + d + 1] - u) / d2) * N[:, i + 1] if d2 > 1e-6 else 0.0
                N_new[:, i] = t1 + t2
            N = N_new
        return N[:, :self.num_control_points]

    def forward(self, cp):
        return torch.matmul(cp, self.basis_matrix.T)

class EncoderBlock(nn.Module):
    def __init__(self, input_len, latent_dim, filters=64, kernel=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, filters, kernel, padding=kernel // 2), nn.GELU(), nn.MaxPool1d(2),
            nn.Conv1d(filters, filters * 2, kernel, padding=kernel // 2), nn.GELU(), nn.MaxPool1d(2),
            nn.Conv1d(filters * 2, filters * 4, kernel, padding=kernel // 2), nn.GELU(), nn.Flatten()
        )
        with torch.no_grad():
            dummy = torch.zeros(1, 1, input_len)
            flat_size = self.net(dummy).shape[1]
        self.fc_mu = nn.Linear(flat_size, latent_dim)
        self.fc_lv = nn.Linear(flat_size, latent_dim)

    def forward(self, x):
        h = self.net(x.unsqueeze(1))
        return self.fc_mu(h), self.fc_lv(h)

class DecoderBlock(nn.Module):
    def __init__(self, latent_dim, output_dim, layers=2, nodes=32):
        super().__init__()
        modules = []
        in_dim = latent_dim
        for _ in range(layers):
            modules.append(nn.Linear(in_dim, nodes))
            modules.append(nn.GELU())
            in_dim = nodes
        modules.append(nn.Linear(in_dim, output_dim))
        self.net = nn.Sequential(*modules)

    def forward(self, z):
        return self.net(z)

class AirfoilVAE(nn.Module):
    def __init__(self, seq_len=200, num_cp=15, device='cpu'):
        super().__init__()
        self.device = device
        self.num_cp = num_cp
        
        # Architecture Constants
        self.dim_t = 3 + 3 # LATENT_PHYS_THICK + LATENT_FREE_THICK
        self.dim_c = 2 + 2 # LATENT_PHYS_CAMBER + LATENT_FREE_CAMBER
        ENC_FILTERS = 64
        ENC_KERNEL = 3
        DEC_LAYERS = 2
        DEC_NODES = 64

        # -- Encoders --
        self.enc_thick = EncoderBlock(seq_len, self.dim_t, ENC_FILTERS, ENC_KERNEL)
        self.enc_camber = EncoderBlock(seq_len, self.dim_c, ENC_FILTERS, ENC_KERNEL)

        # -- Decoders --
        self.dec_thick = DecoderBlock(self.dim_t, num_cp - 2, DEC_LAYERS, DEC_NODES)
        self.dec_camber = DecoderBlock(self.dim_c, num_cp, DEC_LAYERS, DEC_NODES)

        # -- B-Spline --
        self.bspline = BSplineLayer(num_cp, degree=3, num_eval_points=seq_len, device=device)
        self.log_prior = nn.Parameter(torch.tensor([-2.0], device=device))

    def decode_from_latent(self, z_t, z_c):
        # 2. Decode Control Points
        cp_t_raw = self.dec_thick(z_t)
        cp_c = self.dec_camber(z_c)

        # 3. Geometric Constraints
        cp_t_pos = F.softplus(cp_t_raw)
        zeros = torch.zeros(cp_t_pos.shape[0], 1, device=self.device)
        cp_t = torch.cat([zeros, cp_t_pos, zeros], dim=1)

        # 4. Curve Generation
        t_out = self.bspline(cp_t)
        c_out = self.bspline(cp_c)
        
        return t_out, c_out

    def forward(self, x):
        # Standard forward pass (not used in manual UI exploration)
        pass

## 2. Load Model Weights

In [3]:
# CONFIGURATION
MODEL_PATH = './results/model/airfoil_vae.pth' # Update this path if needed
SEQ_LEN = 200
NUM_CP = 15

def load_model(path, device):
    model = AirfoilVAE(seq_len=SEQ_LEN, num_cp=NUM_CP, device=device).to(device)
    try:
        model.load_state_dict(torch.load(path, map_location=device))
        print(f"Successfully loaded model from {path}")
    except FileNotFoundError:
        print(f"WARNING: Model file not found at {path}. Initializing with random weights for demo purposes.")
    model.eval()
    return model

model = load_model(MODEL_PATH, DEVICE)

Successfully loaded model from ./results/model/airfoil_vae.pth


[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.


## 3. Interactive Geometry Explorer

Use the sliders below to manipulate the latent space. 
The variables are grouped into **Thickness** and **Camber** branches, with aligned physical variables separated from free variables.

In [4]:
# Setup Sliders
style = {'description_width': 'initial'}
layout = widgets.Layout(width='300px')

def create_slider(desc, min_v=-4.0, max_v=4.0):
    return widgets.FloatSlider(
        value=0.0, min=min_v, max=max_v, step=0.1,
        description=desc, style=style, layout=layout,
        continuous_update=False 
    )

# --- Thickness Branch (6 Latents) ---
# Physical: Max Thickness, Pos Max Thickness, LE Radius
t_phys_sliders = [
    create_slider("Max Thickness"),
    create_slider("Pos Max Thickness"),
    create_slider("LE Radius")
]
# Free: 3 variables
t_free_sliders = [
    create_slider(f"Free Thick {i+1}") for i in range(3)
]

# --- Camber Branch (4 Latents) ---
# Physical: Max Camber, Pos Max Camber
c_phys_sliders = [
    create_slider("Max Camber"),
    create_slider("Pos Max Camber")
]
# Free: 2 variables
c_free_sliders = [
    create_slider(f"Free Camber {i+1}") for i in range(2)
]

# Plotting Function
def plot_airfoil(**kwargs):
    try:
        # Extract thickness values
        z_t_list = [
            kwargs['Max Thickness'],
            kwargs['Pos Max Thickness'],
            kwargs['LE Radius'],
            kwargs['Free Thick 1'],
            kwargs['Free Thick 2'],
            kwargs['Free Thick 3']
        ]
        
        # Extract camber values
        z_c_list = [
            kwargs['Max Camber'],
            kwargs['Pos Max Camber'],
            kwargs['Free Camber 1'],
            kwargs['Free Camber 2']
        ]
        
        z_t = torch.tensor([z_t_list], dtype=torch.float32, device=DEVICE)
        z_c = torch.tensor([z_c_list], dtype=torch.float32, device=DEVICE)
        
        # Decode
        with torch.no_grad():
            t_out, c_out = model.decode_from_latent(z_t, z_c)
            t_dist = t_out.cpu().numpy()[0]
            c_dist = c_out.cpu().numpy()[0]
        
        # Calculate Coordinates
        x_val = 0.5 * (1 - np.cos(np.linspace(0, np.pi, len(t_dist))))
        yu = c_dist + t_dist / 2
        yl = c_dist - t_dist / 2
        
        # Render Plot
        plt.figure(figsize=(10, 4))
        plt.plot(x_val, yu, 'b-', linewidth=2, label='Upper Surface')
        plt.plot(x_val, yl, 'r-', linewidth=2, label='Lower Surface')
        plt.plot(x_val, c_dist, 'k--', alpha=0.5, label='Camber Line')
        plt.fill_between(x_val, yl, yu, color='gray', alpha=0.1)
        
        plt.title("Generated Airfoil Geometry")
        plt.xlabel("x/c")
        plt.ylabel("y/c")
        plt.axis('equal')
        plt.grid(True, alpha=0.3)
        plt.legend(loc='upper right')
        plt.ylim(-0.5, 0.5)
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"Error updating plot: {e}")
        import traceback
        traceback.print_exc()

# Create dictionary of sliders
slider_dict = {
    'Max Thickness': t_phys_sliders[0],
    'Pos Max Thickness': t_phys_sliders[1],
    'LE Radius': t_phys_sliders[2],
    'Free Thick 1': t_free_sliders[0],
    'Free Thick 2': t_free_sliders[1],
    'Free Thick 3': t_free_sliders[2],
    'Max Camber': c_phys_sliders[0],
    'Pos Max Camber': c_phys_sliders[1],
    'Free Camber 1': c_free_sliders[0],
    'Free Camber 2': c_free_sliders[1]
}

# Layout Construction
thick_phys_box = widgets.VBox([widgets.HTML(value="<b>Thickness Physics</b>"), *t_phys_sliders])
thick_free_box = widgets.VBox([widgets.HTML(value="<b>Thickness Free</b>"), *t_free_sliders])
camber_phys_box = widgets.VBox([widgets.HTML(value="<b>Camber Physics</b>"), *c_phys_sliders])
camber_free_box = widgets.VBox([widgets.HTML(value="<b>Camber Free</b>"), *c_free_sliders])

controls = widgets.HBox([
    widgets.VBox([thick_phys_box, thick_free_box], layout=widgets.Layout(margin='0px 20px 0px 0px')),
    widgets.VBox([camber_phys_box, camber_free_box])
])

# Create interactive output
out = interactive_output(plot_airfoil, slider_dict)

# Display
display(controls, out)

HBox(children=(VBox(children=(VBox(children=(HTML(value='<b>Thickness Physics</b>'), FloatSlider(value=0.0, coâ€¦

Output()