# Tutorial 4: Data-driven discovery of Lorenz system: Transformers

#### Author: Taku Ito

7/7/2025

In [None]:
# Import required packages
import numpy as np
import matplotlib.pyplot as plt
import tutorial_ddd
import tutorial_ddd.lorenz
%matplotlib inline
%load_ext autoreload
%autoreload 2


#### Experiment: We want to infer the governing differential equations (e.g., $\dot{x}$, $\dot{y}$, $\dot{z}$) from $x$, $y$, and $z$ using a transformer
* So, we want to infer the derivatives using the linear combination of features of $x$, $y$, and $z$.
* Like MLPs, transformers are universal function approximators, so they can learn any function $y = f(X)$, but the components/bases of $f$ will be difficult to interpret as there is a vast parameter space

Practically, suppose we want to predict $Y = [ \dot{x}, \dot{y}, \dot{z} ]$. To train our model, we will try to learn a transformer $f$ that maps $X = [x, y, z]$ to $Y$, i.e., $Y = f(X)$. We can then assess how well the learned $f$ can be used to simulate/reproduce the Lorenz system under new initial conditions.

#### 3.1: Simulate Lorenz time series with specified parameters

In [None]:
#### Define initial conditions and parameters and simulate
import tutorial_ddd.models


initial_conditions = [0.1, 0.0, 0.0]  # Starting point [x, y, z]
sigma_val = 10.0
rho_val = 28.0
beta_val = 8/3
delta_t = 0.01
total_steps = 20000 # More steps to see the chaotic attractor

print(f"Simulating Lorenz system with initial conditions: {initial_conditions}")
print(f"Parameters: sigma={sigma_val}, rho={rho_val}, beta={beta_val}")
print(f"Time step (dt): {delta_t}, Number of steps: {total_steps}")

# Simulate the system
lorenz_trajectory, derivatives = tutorial_ddd.lorenz.simulate_lorenz(
    initial_conditions,
    sigma=sigma_val,
    rho=rho_val,
    beta=beta_val,
    dt=delta_t,
    num_steps=total_steps
)

noise_amplitude = 0
noise = np.random.normal(0,noise_amplitude,lorenz_trajectory.shape)
lorenz_trajectory = lorenz_trajectory + noise

#### 3.2: Initialize and train Transformer (one transformer layer/block) on Lorenz data

In [None]:
##### Import Pytorch
import torch
torch.manual_seed(701)

## Instantiate MLP with single hidden layer
y = derivatives
X = lorenz_trajectory.copy()
model = tutorial_ddd.models.Transformer(
                        input_dim=1, # dimension of input tokens (time series so just 1)
                        output_dim=1, # mask pretraining so just 1
                        nhead=1,
                        nlayers=1,
                        embedding_dim=4
)

# specify loss function (MSE)
MSE = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(),lr=0.0001, weight_decay=0.01)
dataset = torch.utils.data.TensorDataset(torch.from_numpy(X).float(), torch.from_numpy(y).float())
batch_size = 512
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

num_epochs = 500
dataframe = {}
dataframe['Loss'] = []
dataframe['Epoch'] = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch_X, batch_Y in dataloader:
        batch_X = batch_X.unsqueeze(-1)
        batch_Y = batch_Y.unsqueeze(-1)
        # set gradients to 0 before computing them again
        optimizer.zero_grad()

        # compute predictions
        output = model(batch_X)

        # calculate loss
        loss = MSE(output, batch_Y)

        # Backward pass, compute gradients + update weights
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * batch_X.size(0) # compute overall loss
    
    # Normalize loss per epoch
    epoch_loss = epoch_loss / len(dataset)
    if epoch % 10 == 0:
        print(f"Epoch {epoch+1}, Loss: {epoch_loss}")
    dataframe['Loss'].append(epoch_loss)
    dataframe['Epoch'].append(epoch)

import pandas as pd
dataframe = pd.DataFrame(dataframe)

#### 2.4: Evaluate model fit: Compute Lorenz system using the learned model

In [None]:
def simulateEstimatedLorenzWithTransformer(model, x0, y0, z0, num_steps=20000):
    """
    """
    trajectory = torch.zeros(num_steps,3)
    trajectory[0] = torch.tensor([x0, y0, z0])
    for t in range(num_steps-1):
        dxdt = model(trajectory[t].unsqueeze(-1).unsqueeze(0))[0,:,0]
        new_val = trajectory[t] + dxdt
        trajectory[t+1] = new_val
    return trajectory.detach().numpy()

model_trajectory = simulateEstimatedLorenzWithTransformer(model, initial_conditions[0], initial_conditions[1], initial_conditions[2])

#### 2.5: Compare the original Lorenz system with the learned model 

In [None]:
import seaborn as sns
# Plot the results
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot(lorenz_trajectory[:, 0], lorenz_trajectory[:, 1], lorenz_trajectory[:, 2], lw=0.5, alpha=0.8, color='black', label='lorenz')
ax.plot(model_trajectory[:, 0], model_trajectory[:, 1], model_trajectory[:, 2], lw=0.5, alpha=0.8, label='model reconstruction -- Transformer')
ax.set_xlabel("X-axis")
ax.set_ylabel("Y-axis")
ax.set_zlabel("Z-axis")
ax.set_title("Model Reconstruction: Lorenz Attractor")
plt.legend()

# You can also plot individual dimensions against time
plt.figure(figsize=(12, 3))
time_points = np.arange(0, total_steps * delta_t, delta_t)
plt.plot(time_points, model_trajectory[:, 0], label='X-Transformer', color=sns.color_palette('Blues')[1])
plt.plot(time_points, lorenz_trajectory[:, 0], label='X(t)-original', color=sns.color_palette('Reds')[1])
plt.xlabel("Time")
plt.ylabel("Value")
plt.title("X Variables Over Time")
plt.legend()
plt.grid(True)

plt.figure(figsize=(12, 3))
time_points = np.arange(0, total_steps * delta_t, delta_t)
plt.plot(time_points, model_trajectory[:, 1], label='Y-Transformer', color=sns.color_palette('Blues')[1])
plt.plot(time_points, lorenz_trajectory[:, 1], label='Y(t)-original', color=sns.color_palette('Reds')[1])
plt.xlabel("Time")
plt.ylabel("Value")
plt.title("Y Variables Over Time")
plt.legend()
plt.grid(True)

plt.figure(figsize=(12, 3))
time_points = np.arange(0, total_steps * delta_t, delta_t)
plt.plot(time_points, model_trajectory[:, 2], label='Z-Transformer', color=sns.color_palette('Blues')[1])
plt.plot(time_points, lorenz_trajectory[:, 2], label='Z(t)-original', color=sns.color_palette('Reds')[1])
plt.xlabel("Time")
plt.ylabel("Value")
plt.title("Z Variables Over Time")
plt.legend()
plt.grid(True)