In [None]:
#| default_exp pinn.pirnn
#| default_cls_lvl 3

In [None]:
#| export
import torch
import torch.nn as nn
from tsfast.prediction.fransys import Diag_RNN, FranSysLearner
from tsfast.models.rnn import RNN
from tsfast.models.layers import SeqLinear
from tsfast.learner.callbacks import CB_TruncateSequence
from tsfast.learner.losses import SkipNLoss
from fastai.basics import *
from fastcore.basics import store_attr
from functools import partial


## PIRNN Model

Physics-Informed RNN (PIRNN) extends the FranSys architecture with a **StateEncoder** for single-state initialization. This enables training with collocation points that don't have full sequence observations.

**Architecture:**
- **SequenceEncoder** (Diagnosis RNN): Processes observation sequences → hidden state (from FranSys)
- **StateEncoder** (NEW): Maps single physical state → hidden state (for collocation points)
- **Prognosis RNN**: Predicts future from hidden state (from FranSys)
- **Final Layer**: Hidden state → outputs (from FranSys)

**Training modes:**
1. Real sequences → SequenceEncoder
2. Random physical states → StateEncoder  
3. Random hidden states → direct initialization


In [None]:
#| export
class PIRNN(nn.Module):
    '''Physics-Informed RNN with dual encoders: Sequence and State'''
    
    def __init__(self,
                 n_u:int, # Number of inputs
                 n_y:int, # Number of outputs
                 init_sz:int, # Initialization sequence length
                 n_x:int = 0, # Number of extra states
                 hidden_size:int = 100, # Hidden state size
                 rnn_layer:int = 1, # Number of RNN layers
                 state_encoder_hidden:int = 64, # Hidden size for state encoder MLP
                 linear_layer:int = 1, # Linear layers in diagnosis RNN
                 final_layer:int = 0, # Final layer complexity
                 init_diag_only:bool = False, # Limit diagnosis to init_sz
                 default_encoder_mode:str = 'sequence', # Default encoder mode
                 **kwargs
                ):
        '''Initialize PIRNN with both sequence and state encoders'''
        super().__init__()
        store_attr('n_u,n_y,n_x,init_sz,init_diag_only,hidden_size,rnn_layer,default_encoder_mode')
        
        # Instantiate FranSys components
        self.rnn_diagnosis = Diag_RNN(
            n_u+n_x+n_y, hidden_size, 
            hidden_size=hidden_size,
            output_layer=rnn_layer,
            rnn_layer=rnn_layer,
            linear_layer=linear_layer,
            **kwargs
        )
        
        rnn_kwargs = dict(hidden_size=hidden_size, num_layers=rnn_layer, ret_full_hidden=True)
        rnn_kwargs = dict(rnn_kwargs, **kwargs)
        self.rnn_prognosis = RNN(n_u, **rnn_kwargs)
        
        self.final = SeqLinear(hidden_size, n_y, hidden_layer=final_layer)
        
        # State encoder: physical state -> hidden state
        self.state_encoder = nn.Sequential(
            nn.Linear(n_y, state_encoder_hidden),
            nn.ReLU(),
            nn.Linear(state_encoder_hidden, hidden_size * rnn_layer)
        )
    
    def forward(
        self,
        x:torch.Tensor, # Input tensor
        init_state:list = None, # Initial hidden state (optional)
        encoder_mode:str = 'default' # 'none', 'sequence', or 'state'
    ) -> torch.Tensor: # Output predictions
        '''Forward pass with encoder mode auto-detection or explicit selection'''

        u = x[:,:,:self.n_u]
        x_init = x[:,:self.init_sz,:self.n_u+self.n_x+self.n_y]
        if encoder_mode == 'default':
            encoder_mode = self.default_encoder_mode
        
        # Detect encoder mode based on input shape
        if encoder_mode == 'none':
            return self._forward_predictor(u, init_state)
        elif encoder_mode == 'sequence':
            return self._forward_sequence_encoder(u[:,self.init_sz:], x_init, init_state)
        elif encoder_mode == 'state':
            return self._forward_state_encoder(u[:,self.init_sz:], x_init, init_state)
        else:
            raise ValueError(f"encoder_mode must be 'none', 'sequence', or 'state', got {encoder_mode}")
    
    def _forward_sequence_encoder(
        self,
        u:torch.Tensor, # Full input [batch, seq, n_u+n_x+n_y]
        x_init:torch.Tensor, # Initial state [batch, seq, n_x+n_y]
        init_state:list = None # Initial hidden state (optional)
    ) -> torch.Tensor: # Output predictions
        '''Forward using sequence encoder (diagnosis RNN)'''
        out_init,_ = self.rnn_diagnosis(x_init)
        if init_state is None:
            init_state = self.rnn_diagnosis.output_to_hidden(out_init, -1)
        out_prog,self.new_hidden = self.rnn_prognosis(u, init_state)
        out_prog = torch.cat([out_init, out_prog], 2)

        result = self.final(out_prog[-1])
        return result
    
    def _forward_state_encoder(
        self,
        u:torch.Tensor, # Prognosis input [batch, seq, n_u]
        x_init:torch.Tensor, # Prognosis input [batch, seq, n_x + n_y]
        init_state, 
    ) -> torch.Tensor: # Output predictions
        if init_state is None: # If init_state is not provided, use last initialization step
            init_state = x_init[:,-1,-self.n_y:]
        init_state = self.encode_single_state(init_state)
        pred = self._forward_predictor(u, init_state)
        if self.training:
            init_state_output = self.final(init_state[-1].squeeze(0).unsqueeze(1))
            init_state_output = init_state_output.repeat(1, self.init_sz, 1)
            return torch.cat([init_state_output, pred], 1)
        else:
            return F.pad(pred, (0, 0, self.init_sz, 0)) # Pad with zeros to match full sequence length


    def _forward_predictor(
        self,
        u:torch.Tensor, # Input tensor
        init_state:list # Initial hidden state
    ) -> torch.Tensor: # Output predictions
        '''Forward using predictor RNN'''
        out_prog,_ = self.rnn_prognosis(u, init_state)
        return self.final(out_prog[-1])
    
    def encode_single_state(
        self,
        physical_state:torch.Tensor # Physical state [batch, n_y]
    ) -> list: # Hidden state compatible with RNN [rnn_layer, batch, hidden_size]
        '''Convert single physical state to RNN-compatible hidden state'''
        batch_size = physical_state.shape[0]
        
        # Encode: [batch, n_y] -> [batch, hidden_size * rnn_layer]
        h_flat = self.state_encoder(physical_state)
        
        # Reshape to RNN format: [rnn_layer, batch, hidden_size]
        h = h_flat.view(batch_size, self.rnn_layer, self.hidden_size)
        h = h.transpose(0, 1).contiguous()  # [rnn_layer, batch, hidden_size]
        
        # Convert to list format expected by Diag_RNN.output_to_hidden
        return [h[i:i+1] for i in range(self.rnn_layer)]


### Example Usage

Basic PIRNN model creation and state encoding:


In [None]:
model = PIRNN(n_u=1, n_y=2, init_sz=10, hidden_size=50, rnn_layer=2)

# Test 1: Sequence encoder (auto-detected)
x_full = torch.randn(4, 20, 3)  # [batch, seq, n_u+n_y]
output_seq = model(x_full, encoder_mode='sequence')
print(f"Sequence encoder output: {output_seq.shape}")

# Test 2: State encoder with explicit mode
x_input = torch.randn(4, 20, 1)  # [batch, seq, n_u] 
physical_state = torch.randn(4, 2)
output_state = model(x_input, init_state=physical_state, encoder_mode='state')
print(f"State encoder output: {output_state.shape}")


Sequence encoder output: torch.Size([4, 20, 2])
State encoder output: torch.Size([4, 20, 2])


## PIRNN Learner

Convenience function to create a PIRNN learner with appropriate settings.


In [None]:
#| export
@delegates(PIRNN, keep=True)
def PIRNNLearner(
    dls, # DataLoaders
    init_sz:int, # Initialization sequence length
    attach_output:bool = False, # Whether to attach output to input
    loss_func = nn.L1Loss(), # Loss function
    metrics = None, # Metrics
    opt_func = Adam, # Optimizer
    lr:float = 3e-3, # Learning rate
    cbs = None, # Additional callbacks
    **kwargs # Additional arguments for PIRNN
):
    '''Create PIRNN learner with appropriate configuration'''
    from tsfast.prediction.core import PredictionCallback
    from tsfast.learner.losses import fun_rmse
    
    cbs = [] if cbs is None else list(cbs)
    metrics = [fun_rmse] if metrics is None else list(metrics) if is_iter(metrics) else [metrics]
    
    _batch = dls.one_batch()
    inp = _batch[0].shape[-1]
    out = _batch[1].shape[-1]

    if attach_output:
        model = PIRNN(inp, out, init_sz, **kwargs)

        # Add PredictionCallback if not present
        if not any(isinstance(cb, PredictionCallback) for cb in cbs):
            pred_callback = PredictionCallback(0)
            pred_callback.init_normalize(_batch)
            cbs.append(pred_callback)
    else:
        model = PIRNN(inp-out, out, init_sz, **kwargs)

    # For long sequences, add truncation callback
    seq_len = _batch[0].shape[1]
    LENGTH_THRESHOLD = 300
    if seq_len > init_sz + LENGTH_THRESHOLD:
        if not any(isinstance(cb, CB_TruncateSequence) for cb in cbs):
            INITIAL_SEQ_LEN = 100
            cbs.append(CB_TruncateSequence(init_sz + INITIAL_SEQ_LEN))
  
    # Skip initial timesteps in loss/metrics
    skip = partial(SkipNLoss, n_skip=init_sz)
    metrics = [skip(f) for f in metrics]
    loss_func = skip(loss_func)
        
    lrn = Learner(dls, model, loss_func=loss_func, metrics=metrics, cbs=cbs, opt_func=opt_func, lr=lr)
    return lrn


In [None]:
from tsfast.pinn.core import diff1_forward
from tsfast.datasets.core import create_dls
from tsfast.learner.losses import zero_loss, fun_rmse
from tsfast.pinn.core import PhysicsLossCallback, CollocationPointsCB, generate_excitation_signals

In [None]:
path = Path("../../test_data/pinn")
dls = create_dls(
    u=['u'],  # Input signal names
    y=['x','v'],  # Output signal names
    dataset=path,
    win_sz=300,  # Full sequence length
    stp_sz=1,  # Non-overlapping windows
    valid_stp_sz=1,
    bs=16,
    n_batches_train=50
).cpu()

In [None]:
# Physical parameters (must match dataset generation)
MASS = 1.0
SPRING_CONSTANT = 1.0
DAMPING_COEFFICIENT = 0.1
DT = 0.01

init_sz = 10
def spring_damper_physics(u, y_pred, y_ref):
    '''Physics loss for spring-damper: ma + cv + kx = u'''
    x = y_pred[:, :, 0]
    v = y_pred[:, :, 1]
    u_force = u[:, :, 0]
    
    a = diff1_forward(v, DT)
    dx_dt = diff1_forward(x, DT)
    
    physics_start = 0
    physics_loss = {
        'physics': ((MASS * a[:,physics_start:] + DAMPING_COEFFICIENT * v[:,physics_start:] + SPRING_CONSTANT * x[:,physics_start:] - u_force[:,physics_start:]) ** 2).mean(),
        'derivative': ((v[:,physics_start:] - dx_dt[:,physics_start:]) ** 2).mean(),
    }
    
    # Add initial condition loss only when y_ref is provided
    if y_ref is not None:
        x_ref = y_ref[:, :, 0]
        v_ref = y_ref[:, :, 1]
        physics_loss['initial'] = ((x[:, :init_sz]-x_ref[:, :init_sz]) ** 2).mean() + ((v[:, :init_sz]-v_ref[:, :init_sz]) ** 2).mean()
    
    return physics_loss

In [None]:
learn = PIRNNLearner(
    dls,
    init_sz=init_sz,
    attach_output=True,
    rnn_type='gru',
    rnn_layer=1,
    hidden_size=20,
    state_encoder_hidden=32,
    loss_func=zero_loss,
    metrics=[fun_rmse]
)

# 1. Physics loss on real data (uses SequenceEncoder)
learn.add_cb(PhysicsLossCallback(
    norm_input=dls.train.after_batch[0],
    physics_loss_func=spring_damper_physics,
    weight=1.0,
    loss_weights={'physics': 1.0, 'derivative': 1.0, 'initial': 10.0},
    n_inputs=1
))

# 2. Collocation with StateEncoder (random physical states → hidden states)
learn.add_cb(CollocationPointsCB(
    norm_input=dls.train.after_batch[0],
    generate_pinn_input=lambda bs, sl, dev: generate_excitation_signals(
        bs, sl, n_inputs=1, dt=DT, device=dev,
        amplitude_range=(0.5, 2.0),
        frequency_range=(0.1, 3.0)
    ),
    physics_loss_func=spring_damper_physics,
    weight=0.5,
    loss_weights={'physics': 1.0, 'derivative': 1.0, 'initial': 5.0},
    init_mode='state_encoder',
    output_ranges=[(-1.0, 1.0), (-2.0, 2.0)]
))

learn.fit_flat_cos(1, 3e-3)

epoch,train_loss,valid_loss,fun_rmse,time
0,3.841751,0.0,0.213517,00:02


In [None]:
#| include: false
import nbdev; nbdev.nbdev_export()
