In [20]:
import zarr
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
from rich import print
from sklearn.linear_model import Ridge
from sklearn.model_selection import cross_val_score
from pathlib import Path
from omegaconf import OmegaConf
from dataclasses import dataclass
from alive_progress import alive_it
from sortedcontainers import SortedList

from walrus_workshop.utils import get_key_value_from_string
from walrus_workshop.walrus import get_trajectory
from walrus_workshop.model import load_sae
from walrus_workshop.metrics import coarsen_field, compute_enstrophy, compute_deformation

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
# Load the config
cfg = OmegaConf.load("configs/train.yaml")

# Load the trajectory
trajectory_id = 50 # 56 also interesting
trajectory, trajectory_metadata = get_trajectory(cfg.walrus.dataset, trajectory_id)

# Load file list of the activations
activations_dir = (
    Path("activations")
    / "test"
    / "blocks.20.space_mixing.activation"
    / cfg.walrus.dataset
)
act_files = sorted(glob.glob(str(activations_dir / f"*_traj_{trajectory_id}*")))
# List of steps with activations (starting step)
steps = np.array([int(get_key_value_from_string(file_name, "step")) for file_name in act_files])

# Load the trained SAE
checkpoint_path = (
    Path("checkpoints")
    / "sae_checkpoint_blocks.20.space_mixing.activation_source_test_k_active=32_k_aux=2048_latent=22528_beta=0.1.pt"
)
sae_model, sae_config = load_sae(checkpoint_path)
sae_model = sae_model.to(device).eval()

In [14]:
@dataclass
class DataChunk:
    step: int
    n_neurons: int
    n_features: int
    n_timesteps: int
    simulation: np.ndarray
    neurons: np.ndarray
    code: np.ndarray
    target: np.ndarray

def get_data_chunk(step, step_index, act_files, trajectory, cfg, sae_model, device, verbose=False, target='tke'):

    # Get SAE features
    if verbose:
        print(f"Opening activation file {Path(act_files[step_index]).stem}")
    assert get_key_value_from_string(Path(act_files[step_index]).stem, "step") == step # make sure we are processing the same step
    act = zarr.open(act_files[step_index], mode="r")
    act = torch.from_numpy(np.array(act)).to(device)
    with torch.no_grad():
        _, code, _ = sae_model(act)
    code = code.cpu().numpy()

    # Get simulation chunk
    simulation_chunk = trajectory['input_fields'][0, step:step+cfg.walrus.n_steps_input, :, :, 0, :]
    if verbose:
        print(f"Simulation chunk shape: {simulation_chunk.shape}")

    scale_x = int(simulation_chunk.shape[2] / 32)  # width
    scale_y = int(simulation_chunk.shape[1] / 32)  # height

    target_index_dict = {'u':2, 'v':3}
    target_field = np.zeros((simulation_chunk.shape[0], 32, 32)) # 32 x 32 
    for i in range(simulation_chunk.shape[0]):
        target_field[i]  = coarsen_field(simulation_chunk[i, ..., target_index_dict[target]], (32, 32), method='mean')

    # target_index_dict = {'tau_xx': 0, 'tau_yy': 1, 'tau_xy': 2, 'tke': 3}
    # target_field = np.zeros((simulation_chunk.shape[0], 32, 32)) # 32 x 32 
    # for i in range(simulation_chunk.shape[0]):
    #     target_field[i] = subgrid_stress(simulation_chunk[i, ..., 1], simulation_chunk[i, ..., 2], (32, 32))[target_index_dict[target]]

    data_chunk = DataChunk(step=step, n_neurons=act.shape[1], n_features=code.shape[1], n_timesteps=1, simulation=simulation_chunk[-1], neurons=act.cpu().numpy().reshape(6, 32, 32, -1)[-1], code=code.reshape(6, 32, 32, -1)[-1], target=target_field[-1])
    return data_chunk

In [15]:
step_index = 0
for step_index in alive_it(range(len(steps)), force_tty=True):
    ix, iy = (16, 16) # grid point
    step = steps[step_index]
    data_chunk = get_data_chunk(step, step_index, act_files, trajectory, cfg, sae_model, device, verbose=False, target='u')
    if step_index == 0:
        X = np.zeros((len(steps), data_chunk.n_neurons))
        y = np.zeros(len(steps))
    X[step_index] = data_chunk.neurons[iy, ix, :]
    y[step_index] = data_chunk.target[iy, ix]

|████████████████████████████████████████| 34/34 [100%] in 17.2s (1.95/s)       


In [21]:
print(X.shape, y.shape)
r = Ridge(alpha=100)
model = r.fit(X, y)
scores = cross_val_score(model, X, y, cv=5)
print(f"Ridge R² (5-fold): {scores.mean():.4f} ± {scores.std():.4f}")
