# ModelGen: PyTorch → LTspice Model Generator

This notebook is a utility tool for generating PyTorch models from simple switch-based definitions,
automatically converting them into standalone `.py` classes and LTspice subcircuits `.sp`, and then 
running parity checks between PyTorch and LTspice outputs.

---

## Features

1. **Switch-based model definition**  
   - Models are defined with `nn.Sequential` presets using a simple `make_model()` function.  
   - Users only need to add a new `elif` block to register their own model.  
   - Supports:
     - Pure MLP (no recurrent cells)  
     - Models with `RNNCell`, `GRUCell`, `LSTMCell` at any layer  
     - Stacked recurrent cells  

2. **Automatic code generation**  
   - From a given `nn.Sequential`, the notebook generates a PyTorch class (`<NAME>`) with:
     - `step(x, state)` for stateful execution  
     - `forward(x)` for stateless execution  
   - The code is saved as a `.py` file in the same directory.

3. **LTspice subcircuit export**  
   - The generated model class (not the original `Sequential`) is exported to a `.sp` file.  
   - Input/output ports: `NNIN1..NNIN9`, `NNOUT1..NNOUT2`.  
   - Hidden ports (`HIN*`, `HOUT*`, `CIN*`, `COUT*`) are included automatically if recurrent cells exist.  
   - The `.py` and `.sp` files are always created as a **pair**.

4. **Parity check utilities**  
   - LTspice is run with a selected `.asc` environment (e.g., `env_buck_9x2.asc`).  
   - Observations from LTspice are passed to the PyTorch model.  
   - PyTorch outputs are compared with LTspice outputs (`NNOUT*`).  
   - MAE/MSE metrics can be printed and plotted.

---


## Change Log:

2025-09-21, Initial Version

2025-10-01,
- Python Code Generator: added `clone_state` method in generated classes for safe hidden state duplication.
- Python Code Generator: updated `forward` in generated classes to accept h as an alias for state.
---

In [None]:
import importlib
import importlib.util
import os
import shutil
import sys
import textwrap
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import plotly.graph_objects as go
import torch
from torch import nn

from PyLTSpice import LTspice, RawRead, SimRunner

sys.path.insert(0, os.path.join(os.getcwd(), '..', '..', 'PyTorch2LTspice'))
from PyTorch2LTspice import export_model_to_ltspice


---
## Configuration

In [None]:
# Environment configuration
ENV_NAME = "env_buck_9x2"              # Select environment file here
ENV_CONFIG = {
    "env_buck_9x1":  {"in": 9, "out": 1},
    "env_buck_9x2":  {"in": 9, "out": 2},
    # Add more environments as needed
}

# Model configuration
MODEL_NAME = "mlp"                     # "mlp"/"rnn_linear"/"gru_linear"/"linear_lstm_linear"
MODEL_CONFIG = {
    "mlp":  {"clk_needed": False},
    "rnn_linear":  {"clk_needed": True},
    "gru_linear":  {"clk_needed": True},
    "linear_lstm_linear":  {"clk_needed": True},
    # Add more models as needed
}
SUBCKT_NAME = MODEL_NAME                      # Sub-circuit name in LTspice


# LTspice simulation configuration
SIM_STEP = 200                         # Nnumber of time steps to run in LTSpice 
SIM_TIMEOUT = 300                      # Timeout therhold in seconds

# Working directory for pyLTspice
NOTEBOOK_DIR = Path.cwd()
ENVDIR  = NOTEBOOK_DIR / "env"
OUTDIR  = NOTEBOOK_DIR / "env"         # save at same directory as .asc file
WORKDIR = NOTEBOOK_DIR / "tmp" 
WORKDIR.mkdir(exist_ok=True)

---
## Model Definition

In [None]:
def make_model(model_name: str) -> nn.Sequential:
    if model_name == "mlp":
        return nn.Sequential(
            nn.Linear(ENV_CONFIG[ENV_NAME]["in"], 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, ENV_CONFIG[ENV_NAME]["out"]),
            nn.Tanh(),
        )
    elif model_name == "rnn_linear":
        return nn.Sequential(
            nn.RNNCell(ENV_CONFIG[ENV_NAME]["in"], 32),
            nn.Linear(32, ENV_CONFIG[ENV_NAME]["out"]),
            nn.Tanh(),
        )
    elif model_name == "gru_linear":
        return nn.Sequential(
            nn.GRUCell(ENV_CONFIG[ENV_NAME]["in"], 32),
            nn.Linear(32, ENV_CONFIG[ENV_NAME]["out"]),
            nn.Tanh(),
        )
    elif model_name == "linear_lstm_linear":
        return nn.Sequential(
            nn.Linear(ENV_CONFIG[ENV_NAME]["in"], 32),
            nn.Tanh(),
            nn.LSTMCell(32, 32),
            nn.Linear(32, ENV_CONFIG[ENV_NAME]["out"]),
            nn.Tanh(),
        )
    # Example extension for stacked LSTM cells:
    # elif model_name == "stacked_lstm":
    #     return nn.Sequential(
    #         nn.LSTMCell(ENV_CONFIG[ENV_NAME]["in"], 32),
    #         nn.LSTMCell(32, 32),
    #         nn.Linear(32, ENV_CONFIG[ENV_NAME]["out"]),
    #         nn.Tanh(),
    #     )
    else:
        raise ValueError(f"Unknown model preset: {model_name}")


---
## PyTorch Code Generator

In [None]:
SUPPORTED_LAYERS: Dict[type, str] = {
    nn.Linear: "nn.Linear",
    nn.ReLU: "nn.ReLU",
    nn.Sigmoid: "nn.Sigmoid",
    nn.Tanh: "nn.Tanh",
    nn.RNNCell: "nn.RNNCell",
    nn.GRUCell: "nn.GRUCell",
    nn.LSTMCell: "nn.LSTMCell",
}
CELL_TYPES = (nn.RNNCell, nn.GRUCell, nn.LSTMCell)


def _layer_to_ctor_line(layer: nn.Module, idx: int) -> str:
    prefix = f"self.l{idx} = "
    if isinstance(layer, nn.Linear):
        bias_flag = layer.bias is not None
        return f"{prefix}nn.Linear({layer.in_features}, {layer.out_features}, bias={bias_flag})"
    if isinstance(layer, nn.ReLU):
        return f"{prefix}nn.ReLU(inplace={layer.inplace})"
    if isinstance(layer, nn.Sigmoid):
        return f"{prefix}nn.Sigmoid()"
    if isinstance(layer, nn.Tanh):
        return f"{prefix}nn.Tanh()"
    if isinstance(layer, nn.RNNCell):
        bias_flag = bool(layer.bias)
        return (
            f"{prefix}nn.RNNCell({layer.input_size}, {layer.hidden_size}, "
            f"nonlinearity={repr(layer.nonlinearity)}, bias={bias_flag})"
        )
    if isinstance(layer, nn.GRUCell):
        bias_flag = bool(layer.bias)
        return f"{prefix}nn.GRUCell({layer.input_size}, {layer.hidden_size}, bias={bias_flag})"
    if isinstance(layer, nn.LSTMCell):
        bias_flag = bool(layer.bias)
        return f"{prefix}nn.LSTMCell({layer.input_size}, {layer.hidden_size}, bias={bias_flag})"
    raise ValueError(f"Unsupported layer for code generation: {type(layer)}")


def generate_model_code_from_sequential(name: str, seq: nn.Sequential) -> Tuple[str, str]:
    class_name = f"{name}".replace('-', '_')
    ctor_lines: List[str] = []
    model_lines: List[str] = ["                self.model = nn.Sequential("]
    cell_indices: List[int] = []

    for idx, layer in enumerate(seq):
        if type(layer) not in SUPPORTED_LAYERS:
            raise TypeError(f"Layer type {type(layer)} is not supported.")
        ctor_lines.append(f"                {_layer_to_ctor_line(layer, idx)}")
        suffix = ',' if idx < len(seq) - 1 else ''
        model_lines.append(f"                    self.l{idx}{suffix}")
        if isinstance(layer, CELL_TYPES):
            cell_indices.append(idx)

    model_lines.append("                )")
    ctor_block = "\n".join(ctor_lines)
    model_block = "\n".join(model_lines)
    cell_idx_literal = ', '.join(str(i) for i in cell_indices)

    code = textwrap.dedent(
        f"""
        import torch
        import torch.nn as nn
        from typing import Any, List, Optional, Tuple

        class {class_name}(nn.Module):
            def __init__(self):
                super().__init__()
{ctor_block}
{model_block}
                self._cells = [{cell_idx_literal}]
                self._num_layers = {len(seq)}

            def _prepare_state_list(self, state: Optional[List[Any]]) -> List[Any]:
                if not self._cells:
                    return []
                if state is None:
                    return [None] * len(self._cells)
                state_list = list(state)
                if len(state_list) != len(self._cells):
                    raise ValueError(f"Expected {{len(self._cells)}} state entries, got {{len(state_list)}}.")
                return state_list

            def clone_state(self, state: Optional[List[Any]]):
                if state is None:
                    return None
                def _clone(item):
                    if item is None:
                        return None
                    if isinstance(item, torch.Tensor):
                        return item.detach().clone()
                    if isinstance(item, (list, tuple)):
                        cloned = [_clone(x) for x in item]
                        return type(item)(cloned)
                    raise TypeError(f"Unsupported state element type: {{type(item)}}")
                return _clone(state)

            def step(self, x: torch.Tensor, state: Optional[List[Any]]):
                if x.dim() != 2:
                    raise ValueError("step expects a 2D tensor shaped (B, D).")
                current = x
                state_list = self._prepare_state_list(state)
                next_states: List[Any] = []
                cell_ptr = 0
                for layer_idx in range(self._num_layers):
                    layer = getattr(self, f"l{{layer_idx}}")
                    if layer_idx in self._cells:
                        prev = state_list[cell_ptr]
                        if isinstance(layer, nn.LSTMCell):
                            if prev is None:
                                h_prev = current.new_zeros((current.size(0), layer.hidden_size))
                                c_prev = current.new_zeros((current.size(0), layer.hidden_size))
                            else:
                                h_prev, c_prev = prev
                            h, c = layer(current, (h_prev, c_prev))
                            current = h
                            next_states.append((h, c))
                        else:
                            if prev is None:
                                h_prev = current.new_zeros((current.size(0), layer.hidden_size))
                            else:
                                h_prev = prev
                            h = layer(current, h_prev)
                            current = h
                            next_states.append(h)
                        cell_ptr += 1
                    else:
                        current = layer(current)
                return current, next_states if self._cells else None

            def forward(self, x: torch.Tensor, state: Optional[List[Any]] = None, h: Optional[List[Any]] = None):
                if h is not None:
                    if state is not None:
                        raise ValueError("Use either 'state' or 'h' to pass hidden state, not both.")
                    state = h

                if not self._cells:
                    if x.dim() == 1:
                        return self.model(x.unsqueeze(0)).squeeze(0)
                    if x.dim() == 2:
                        return self.model(x)
                    if x.dim() == 3:
                        b, t, f = x.shape
                        y = self.model(x.reshape(b * t, f))
                        return y.reshape(b, t, -1)
                    raise ValueError("MLP forward expects tensors with rank 1, 2, or 3.")

                if x.dim() == 1:
                    out, _ = self.step(x.unsqueeze(0), state)
                    return out.squeeze(0)

                if x.dim() == 2:
                    state_in = state
                    outputs: List[torch.Tensor] = []
                    for t in range(x.size(0)):
                        step_input = x[t].unsqueeze(0)
                        out, state_in = self.step(step_input, state_in)
                        outputs.append(out)
                    return torch.cat(outputs, dim=0)

                if x.dim() == 3:
                    state_in = state
                    outputs: List[torch.Tensor] = []
                    for t in range(x.size(1)):
                        step_input = x[:, t, :]
                        out, state_in = self.step(step_input, state_in)
                        outputs.append(out.unsqueeze(1))
                    return torch.cat(outputs, dim=1)

                raise ValueError("RNN forward expects tensors with rank 1, 2, or 3.")
        """
    ).strip()
    return class_name, code


def save_model_code(code: str, out_name: str) -> Path:
    """Write generated code under ENVDIR and return the file path."""
    ENVDIR.mkdir(parents=True, exist_ok=True)
    py_path = ENVDIR / f"{out_name}.py"
    py_path.write_text(code, encoding='utf-8')
    return py_path


def import_or_reload_generated(py_path: Path, class_name: str):
    module_name = py_path.stem
    importlib.invalidate_caches()
    if module_name in sys.modules:
        module = importlib.reload(sys.modules[module_name])
    else:
        spec = importlib.util.spec_from_file_location(module_name, str(py_path))
        module = importlib.util.module_from_spec(spec)
        assert spec.loader is not None
        spec.loader.exec_module(module)
        sys.modules[module_name] = module
    return getattr(module, class_name)



---
## LTspice data extractor

In [None]:
# Helping Function to extract observation data from LTSpice
def extract_data(df, clk_col='V(ctrlclk)', threshold=0.5):
    clk = df[clk_col].values

    # Check if the clock starts at high level
    if clk[0] > threshold:
        raise ValueError("Clock started with Level Hi")

    indices = []
    state = 'LOW'

    for i in range(1, len(clk)):
        if state == 'LOW' and clk[i - 1] <= threshold and clk[i] > threshold:
            state = 'HIGH'  # Rising edge detected
        elif state == 'HIGH' and clk[i - 1] > threshold and clk[i] <= threshold:
            # Falling edge detected
            indices.append(i)
            state = 'LOW'
    df_falling_edges = df.iloc[indices].reset_index(drop=True)
    return df_falling_edges

---
## Step0) Global variables for Step1-3

In [None]:
DEVICE = torch.device('cpu')

---
## Step1) Create Python code and LTspice sub-circuit

In [None]:
def step1() -> nn.modules:
    # Select input sequential model
    seq = make_model(MODEL_NAME)

    # Generate & Save python code
    class_name, code = generate_model_code_from_sequential(MODEL_NAME, seq)
    py_path = save_model_code(code, PY_FILENAME)

    # Import and instantiate the generated class
    GenClass = import_or_reload_generated(py_path, class_name)
    actor  = GenClass().to(DEVICE)

    # Export LTSpice sub-circuit
    export_model_to_ltspice(actor.model, filename=f"{OUTDIR}/{SP_FILENAME}.sp", subckt_name=SUBCKT_NAME, verbose=False)

    return actor


## Step2) Run Simulation on LTspice and Python

In [None]:
def step2(module):
    # 0) Create parameter file
    with open(f"{ENVDIR}/{ENV_NAME}_param.txt", 'w', encoding='utf-8') as f:
        f.write(f".param STEPS={SIM_STEP}\n")
        nn_inputs = ' '.join(f'NNin{i+1}' for i in range(ENV_CONFIG[ENV_NAME]["in"]))
        nn_outputs = ' '.join(f'NNout{i+1}' for i in range(ENV_CONFIG[ENV_NAME]["out"]))
        CELL_TYPES = (nn.RNNCell, nn.GRUCell, nn.LSTMCell)
        ports = " ".join(p for p in [nn_inputs, ("ctrlclk" if MODEL_CONFIG[MODEL_NAME]["clk_needed"] else ""), nn_outputs] if p)
        f.write(f"X99 {ports} {SUBCKT_NAME}\n")
        f.write(f".include {SP_FILENAME}.sp\n")

    # 1) Create PyLTspice SimRunner instance at WORKDIR
    shutil.copy2(f"{OUTDIR}/{SP_FILENAME}.sp", f"{WORKDIR}/")  
    shutil.copy2(f"{ENVDIR}/{ENV_NAME}_param.txt", f"{WORKDIR}/")  
    runner = SimRunner(output_folder=WORKDIR, simulator=LTspice)
    netlist = runner.create_netlist(f"{ENVDIR}/{ENV_NAME}.asc")
        
    # 2) Run LTSpice simulation
    raw, log = runner.run_now(netlist, timeout=SIM_TIMEOUT)
    raw_data = RawRead(raw)
    df = raw_data.to_dataframe()
    df = extract_data(df)

    # 3) Extract states, actions
    states  = df[[f'V(nnin{i+1})' for i in range(ENV_CONFIG[ENV_NAME]["in"])]].values[:-1]          # S[t]
    actions  = df[[f'V(nnout{i+1})' for i in range(ENV_CONFIG[ENV_NAME]["out"])]].values[:-1]       # A[t]

    # 4) Clean PyLTspice files
    runner.cleanup_files()
    os.remove(f"{ENVDIR}/{ENV_NAME}.net")
    os.remove(f"{WORKDIR}/{SP_FILENAME}.sp")
    os.remove(f"{WORKDIR}/{ENV_NAME}_param.txt")

    # 5) Calculate PyTorch output using observation from LTspice
    states_t  = torch.tensor(states,  dtype=torch.float32, device=DEVICE)
    with torch.no_grad():
        actions_py = actor(states_t).cpu().numpy()
        if actions_py.ndim == 1:
            actions_py = actions_py.reshape(-1, ENV_CONFIG[ENV_NAME]["out"])

    return actions, actions_py


## Step3) Compare outputs from PyTorch and LTspice

In [None]:
def step3(actions, actions_py):
    #Plot Scatter graph
    fig = go.Figure()
    ltspice_y = np.asarray(actions)
    pytorch_y = np.asarray(actions_py)
    samples = ltspice_y.shape[0]
    x_axis = np.arange(samples)
    for idx in range(ltspice_y.shape[1]):
        fig.add_trace(go.Scatter(
            x=x_axis,
            y=ltspice_y[:, idx],
            mode='markers',
            name=f'NNOUT{idx + 1}(LTspice)'
        ))
        fig.add_trace(go.Scatter(
            x=x_axis,
            y=pytorch_y[:, idx],
            mode='markers',
            name=f'NNOUT{idx + 1}(PyTorch)'
        ))
    fig.update_layout(
        title=f"ENV={ENV_NAME}<br>MODEL={MODEL_NAME}",
        xaxis_title='Sample index',
        yaxis_title='Output'
    )
    fig.show()

    #Print MAE/MSE 
    diff = actions - actions_py
    mae_per_output = np.mean(np.abs(diff), axis=0)
    mse_per_output = np.mean(diff ** 2, axis=0)
    for idx, (mae_val, mse_val) in enumerate(zip(mae_per_output, mse_per_output), start=1):
        print(f"  NNOUT{idx}: MAE={mae_val:.6f}, MSE={mse_val:.6f}")

---
## Execution

In [None]:
ENV_NAME = "env_buck_9x2"
MODEL_NAME = "mlp"
PY_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output python file name
SP_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output ltspice subcircuit name
actor = step1()
actions, actions_py = step2(actor)
step3(actions, actions_py)

In [None]:
ENV_NAME = "env_buck_9x2"
MODEL_NAME = "rnn_linear"
PY_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output python file name
SP_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output ltspice subcircuit name
actor = step1()
actions, actions_py = step2(actor)
step3(actions, actions_py)

In [None]:
ENV_NAME = "env_buck_9x2"
MODEL_NAME = "gru_linear"
PY_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output python file name
SP_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output ltspice subcircuit name
actor = step1()
actions, actions_py = step2(actor)
step3(actions, actions_py)

In [None]:
ENV_NAME = "env_buck_9x2"
MODEL_NAME = "linear_lstm_linear"
PY_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output python file name
SP_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output ltspice subcircuit name
actor = step1()
actions, actions_py = step2(actor)
step3(actions, actions_py)

In [None]:
ENV_NAME = "env_buck_9x1"
MODEL_NAME = "mlp"
PY_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output python file name
SP_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output ltspice subcircuit name
actor = step1()
actions, actions_py = step2(actor)
step3(actions, actions_py)

In [None]:
ENV_NAME = "env_buck_9x1"
MODEL_NAME = "rnn_linear"
PY_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output python file name
SP_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output ltspice subcircuit name
actor = step1()
actions, actions_py = step2(actor)
step3(actions, actions_py)

In [None]:
ENV_NAME = "env_buck_9x1"
MODEL_NAME = "gru_linear"
PY_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output python file name
SP_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output ltspice subcircuit name
actor = step1()
actions, actions_py = step2(actor)
step3(actions, actions_py)

In [None]:
ENV_NAME = "env_buck_9x1"
MODEL_NAME = "linear_lstm_linear"
PY_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output python file name
SP_FILENAME = ENV_NAME + '_' + MODEL_NAME     # output ltspice subcircuit name
actor = step1()
actions, actions_py = step2(actor)
step3(actions, actions_py)