An example notebook showing how to train a [Rhino](https://arxiv.org/abs/2210.14706) model on Ecoli Data.

This demonstrates how to assemble the various components of the library and how to perform training.

In [1]:
import os
from dataclasses import dataclass

import networkx as nx
import numpy as np
import pytorch_lightning as pl
import torch
from tensordict import TensorDict
from torch.utils.data import DataLoader

from causica.datasets.causica_dataset_format import CAUSICA_DATASETS_PATH, DataEnum, load_data
from causica.datasets.tensordict_utils import (
    tensordict_shapes,
)
from causica.datasets.timeseries_dataset import IndexedTimeseriesDataset
from causica.distributions import (
    AdjacencyDistribution,
    ContinuousNoiseDist,
    DistributionModule,
    ENCOAdjacencyDistributionModule,
    GibbsDAGPrior,
    JointNoiseModule,
    create_noise_modules,
    RhinoLaggedAdjacencyDistributionModule,
    TemporalAdjacencyDistributionModule,
)
from causica.datasets.variable_types import VariableTypeEnum
from causica.functional_relationships import TemporalEmbedFunctionalRelationships
from causica.graph.dag_constraint import calculate_dagness
from causica.sem.sem_distribution import TemporalSEMDistributionModule
from causica.sem.temporal_distribution_parameters_sem import (
    concatenate_lagged_and_instaneous_values,
    split_lagged_and_instanteneous_values,
)
from causica.training.auglag import AugLagLossCalculator, AugLagLR, AugLagLRConfig

Define various parameters of the training process.

In [2]:
data_path = "https://azuastoragepublic.z6.web.core.windows.net/Ecoli1_100/train.csv"
adj_path = "https://azuastoragepublic.z6.web.core.windows.net/Ecoli1_100/adj_matrix.npy"

device = "cpu"
dataset_train = IndexedTimeseriesDataset(series_index_key=0, data=data_path, adjacency_matrix=adj_path, device=device)
dataloader_train = DataLoader(
    dataset=dataset_train,
    collate_fn=torch.stack,
    batch_size=int(os.environ.get("TEST_RUN", 8)),
    shuffle=True,
    drop_last=True,
)

In [3]:
num_nodes = len(dataset_train._data.keys())
context_length = 21
lags = 21 - 1

prior = GibbsDAGPrior(num_nodes=num_nodes, sparsity_lambda=0.3, context_length=context_length)

Create the Variational Posterior Distribution over Adjacency Matrices, which we will optimize.

In [4]:
adjacency_dist: DistributionModule[AdjacencyDistribution] = TemporalAdjacencyDistributionModule(
    ENCOAdjacencyDistributionModule(num_nodes),
    RhinoLaggedAdjacencyDistributionModule(num_nodes, lags),
)

Create the Graph Neural network that will estimate the functional relationships. More info can be found [here](https://openreview.net/forum?id=S2pNPZM-w-f).

In [5]:
functional_relationships = TemporalEmbedFunctionalRelationships(
    shapes=tensordict_shapes(dataset_train._data),
    embedding_size=32,
    out_dim_g=32,
    num_layers_g=2,
    num_layers_zeta=2,
    context_length=context_length,
)

Create the Noise Distributions for each node assuming they are all continuous.

In [6]:
variable_shapes = tensordict_shapes(dataset_train._data)
types_dict = {key: VariableTypeEnum.CONTINUOUS for key in dataset_train._data.keys()}

noise_submodules = create_noise_modules(variable_shapes, types_dict, ContinuousNoiseDist.GAUSSIAN)
noise_module = JointNoiseModule(noise_submodules)

Create the SEM Module which combines the variational adjacency distribution, the functional relationships and the noise distributions for each node.

In [7]:
sem_module: TemporalSEMDistributionModule = TemporalSEMDistributionModule(
    adjacency_dist, functional_relationships, noise_module
)
sem_module = sem_module.to(device)

Define the [Augmented Lagrangian Scheduler](https://en.wikipedia.org/wiki/Augmented_Lagrangian_method).

This allows Rhino to optimize towards a DAG, by slowly increasing the alpha and rho parameters as the optimization takes
place.

In [8]:
lr_init_dict = {
    "functional_relationships": 3e-3,
    "vardist_inst": 1e-2,
    "vardist_lagged": 1e-2,
    "noise_dist": 3e-4,
}

auglag_config = AugLagLRConfig(lr_init_dict=lr_init_dict)
scheduler = AugLagLR(config=auglag_config)
auglag_loss = AugLagLossCalculator(init_alpha=0.0, init_rho=1.0)

Create the optimizer, with separate learning rates for each module.

In [9]:
modules = {
    "functional_relationships": sem_module.functional_relationships,
    "vardist_inst": sem_module.adjacency_module.inst_dist_module,
    "vardist_lagged": sem_module.adjacency_module.lagged_dist_module,
    "noise_dist": sem_module.noise_module,
}

parameter_list = [
    {"params": module.parameters(), "lr": auglag_config.lr_init_dict[name], "name": name}
    for name, module in modules.items()
]

optimizer = torch.optim.Adam(parameter_list)

The main training loop.

For each batch, we:
* Sample a graph from the SEM.
* Calculate the log probability of that batch, given the graph.
* Create the ELBO to be optimized.
* Calculate the DAG constraint
* Combine the DAG constraint with the ELBO to get the loss.

In [10]:
@dataclass(frozen=True)
class TrainingConfig:
    noise_dist: ContinuousNoiseDist = ContinuousNoiseDist.SPLINE
    batch_size: int = int(os.environ.get("TEST_RUN", 128))
    max_epoch: int = int(os.environ.get("TEST_RUN", 20))  # used by testing to run the notebook as a script
    gumbel_temp: float = 0.25
    averaging_period: int = 10
    prior_sparsity_lambda: float = 5.0
    init_rho: float = 1.0
    init_alpha: float = 0.0


training_config = TrainingConfig()

num_samples = len(dataset_train)
for epoch in range(training_config.max_epoch):
    for i, batch in enumerate(dataloader_train):
        batch = batch.to_tensordict()  # Force dense stacking for tensordict<0.4.0
        batch.batch_size = batch.batch_size[:1]  # Do not consider the time axis part of the batch dims

        optimizer.zero_grad()
        sem_distribution = sem_module()
        sem, *_ = sem_distribution.relaxed_sample(
            torch.Size([]), temperature=training_config.gumbel_temp
        )  # soft sample
        graph_tries = 0
        while sem.graph.isnan().any() and graph_tries < 2:
            sem, *_ = sem_distribution.relaxed_sample(torch.Size([]), temperature=training_config.gumbel_temp)
            graph_tries += 1
        if sem.graph.isnan().any():
            raise ValueError(f"Failed to sample a valid graph after {graph_tries} tries")
        if graph_tries > 0:
            print(f"Used {graph_tries} tries to sample a valid graph")

        batch_log_prob = sem.log_prob(batch).mean()
        sem_distribution_entropy = sem_distribution.entropy()
        prior_term = prior.log_prob(sem.graph)
        objective = (-sem_distribution_entropy - prior_term) / num_samples - batch_log_prob
        constraint = calculate_dagness(sem.graph[..., -1, :, :])  # Calculate dagness on instantaneous graph

        loss = auglag_loss(objective, constraint / num_samples)

        loss.backward()
        optimizer.step()
        # update the Auglag parameters
        scheduler.step(
            optimizer=optimizer,
            loss=auglag_loss,
            loss_value=loss,
            lagrangian_penalty=constraint,
        )
        # log metrics
        if epoch % 10 == 0 and i == 0:
            print(
                f"epoch:{epoch} loss:{loss.item():.5g} nll:{-batch_log_prob.detach().cpu().numpy():.5g} "
                f"dagness:{constraint.item():.5f} num_edges:{(sem.graph > 0.0).sum()} "
                f"alpha:{auglag_loss.alpha:.5g} rho:{auglag_loss.rho:.5g} "
                f"step:{scheduler.outer_opt_counter}|{scheduler.step_counter} "
                f"num_lr_updates:{scheduler.num_lr_updates}"
            )

epoch:0 loss:2.3792e+17 nll:123.08 dagness:31731050496.00000 num_edges:102465 alpha:0 rho:1 step:0|1 num_lr_updates:0

epoch:10 loss:2.1545e+17 nll:90.606 dagness:30195685376.00000 num_edges:96460 alpha:0 rho:1 step:0|51 num_lr_updates:0


In [11]:
# The above used to throw errors because of NaNs in the lagged graph.
# This is an ephemeral issue that is not present in the final model.
inst_nans = np.sum(
    [
        sem_module.adjacency_module.inst_dist_module()
        .relaxed_sample(temperature=training_config.gumbel_temp)
        .isnan()
        .any()
        for _ in range(1000)
    ]
)
lagged_nans = np.sum(
    [
        sem_module.adjacency_module.lagged_dist_module()
        .relaxed_sample(temperature=training_config.gumbel_temp)
        .isnan()
        .any()
        for _ in range(1000)
    ]
)

print(f"Instantaneous nans: {inst_nans} Lagged nans: {lagged_nans}")

Instantaneous nans: 1 Lagged nans: 16



In [12]:
print(f"batch shape: {batch[batch.sorted_keys[0]].shape}, {batch.batch_size}")
# Sample to noise
noise = sem.sample_to_noise(batch)
print(f"Sample to noise shape: {noise[noise.sorted_keys[0]].shape}")

# Sampled noise
sampled_noise = sem.sample_noise()
print(f"Sampled noise shape: {sampled_noise[sampled_noise.sorted_keys[0]].shape}")

# Noise to sample still fails as it requires the history:
# sem.noise_to_sample(noise)

history, inst = split_lagged_and_instanteneous_values(batch)
concat_noise = concatenate_lagged_and_instaneous_values(history, noise)

new_sample = sem.noise_to_sample(concat_noise)
print(f"Generated sample shape: {new_sample[new_sample.sorted_keys[0]].shape}")
# Because sampling requires the history sem.sample() will fail as the history is not provided

intervention_variable = sem.node_names[0]
intervention = TensorDict({intervention_variable: torch.zeros_like(batch[intervention_variable][0])}, batch_size=[])
# Interventions currently faily because we are intervening on the instanteneous values but want to use the history for
# forward predictions.
# do_sem = sem.do(intervention)
# do_sample = do_sem.noise_to_sample(concat_noise.clone())

# print(f"Generated sample after intervention shape: {do_sample[do_sample.sorted_keys[0]].shape}")

batch shape: torch.Size([8, 21, 1]), torch.Size([8])

Sample to noise shape: torch.Size([8, 1])
Sampled noise shape: torch.Size([1])
Generated sample shape: torch.Size([8, 1])
