# 1. Jet Data

In [1]:
from utils import Configs
from jetdata import JetDataclass

config = Configs("configs.yaml")

# ...take small data samples:

jets = JetDataclass(config=config)
jets.preprocess()


INFO: created experiment instance GaussNoise_to_AspenOpenJets_HybridEPiC_2024.12.17_14h16_1569


  discrete = torch.tensor(discrete).long()


### inspect data

In [None]:
jets.source.display_cloud(idx=0, scale_marker=10.0)
jets.target.display_cloud(idx=9, scale_marker=5.0)


In [None]:
(
    jets.source.continuous.shape,
    jets.target.continuous.shape,
    jets.source.discrete.shape,
    jets.target.discrete.shape,
    jets.source.mask.shape,
    jets.target.mask.shape,
)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(3, 3))
sns.histplot(
    jets.source.multiplicity,
    element="step",
    fill=False,
    discrete=True,
    lw=0.75,
    stat="density",
    color="r",
    log_scale=(False, False),
    ax=ax,
    label="source",
)
sns.histplot(
    jets.target.multiplicity,
    element="step",
    fill=False,
    discrete=True,
    lw=0.75,
    stat="density",
    log_scale=(False, False),
    ax=ax,
    label="target",
)
ax.legend(fontsize=6)
ax.set_xlabel("Particle Multiplicity")
plt.show()

# 2. Absorbing Bridge Matching
Permutation equivariant architecture for point clouds

In [2]:
import torch
import lightning as L
from torch import nn
from dataclasses import dataclass

from architecture import HybridEPiC
from bridges import LinearUniformBridge, TelegraphBridge


@dataclass
class BridgeState:
    time: torch.Tensor = None
    continuous: torch.Tensor = None
    discrete: torch.Tensor = None
    absorbing: torch.Tensor = None

    def append(self, state):
        return BridgeState(
            time=torch.cat([self.time, state.time], dim=0),
            continuous=torch.cat([self.continuous, state.continuous], dim=0),
            discrete=torch.cat([self.discrete, state.discrete], dim=0),
            absorbing=torch.cat([self.absorbing, state.absorbing], dim=0),
        )


@dataclass
class OutputHeads:
    continuous: torch.Tensor = None
    discrete: torch.Tensor = None
    absorbing: torch.Tensor = None


class AbsorbingBridgeMatching(L.LightningModule):
    """Model for hybrid data with varying size"""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.vocab_size = config.data.vocab_size.features

        self.encoder = HybridEPiC(config)
        self.bridge_continuous = LinearUniformBridge(config)
        self.bridge_discrete = TelegraphBridge(config)
        self.bridge_absorbing = TelegraphBridge(config)

        self.loss_continuous_fn = nn.MSELoss(reduction="none")
        self.loss_discrete_fn = nn.CrossEntropyLoss(reduction="none")

        self.save_hyperparameters()

    def forward(self, state, batch):
        continuous, discrete, absorbing = self.encoder(
            t=state.time,
            x=state.continuous,
            k=state.discrete,
            mask=state.absorbing,
            context_continuous=getattr(batch, "context_continuous", None),
            context_discrete=getattr(batch, "context_discrete", None),
        )
        return OutputHeads(continuous, discrete, absorbing)

    def sample_bridges(self, batch):
        """sample stochastic bridges"""
        t = torch.rand(
            batch.target_continuous.shape[0], device=batch.target_continuous.device
        ).type_as(batch.target_continuous)
        time = self.reshape_time(
            t, batch.target_continuous
        )  # shape: (b, 1,...) with len as len(x1)

        continuous = self.bridge_continuous.sample(
            time, batch.source_continuous, batch.target_continuous
        )

        discrete = self.bridge_discrete.sample(
            time, batch.source_discrete, batch.target_discrete
        )

        absorbing = self.bridge_absorbing.sample(
            time, batch.target_mask, batch.target_mask
        )
        return BridgeState(time, continuous, discrete, absorbing)

    def loss_continuous(self, heads: OutputHeads, state: BridgeState, batch):
        """mean square error loss for velocity field"""
        vector = heads.continuous
        mask = heads.absorbing
        
        ut = self.bridge_continuous.drift(
            t=state.time,
            x=state.continuous,
            x0=batch.source_continuous,
            x1=batch.target_continuous,
        ).to(vector.device)
        loss_mse = self.loss_continuous_fn(vector, ut) * mask
        return loss_mse.sum() / mask.sum()

    def loss_discrete(self, heads: OutputHeads, batch):
        """cross-entropy loss for discrete state classifier"""
        logits = heads.discrete
        targets = batch.target_discrete
        mask = heads.absorbing
        logits = heads.discrete.reshape(-1, self.vocab_size)
        targets = batch.target_discrete.reshape(-1).long()
        targets = targets.to(logits.device)
        mask = mask.reshape(-1)
        loss_ce = self.loss_discrete_fn(logits, targets) * mask
        return loss_ce.sum() / mask.sum()

    def reshape_time(self, t, x):
        if isinstance(t, (float, int)):
            return t
        else:
            return t.reshape(-1, *([1] * (x.dim() - 1)))

    # ...Lightning functions:

    def training_step(self, batch, batch_idx):
        state = self.sample_bridges(batch)
        heads = self.forward(state, batch)
        loss_continous = self.loss_continuous(heads, state, batch)
        loss_discrete = self.loss_discrete(heads, batch)
        loss = loss_continous + loss_discrete
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        state = self.sample_bridges(batch)
        heads = self.forward(state, batch)
        loss_continous = self.loss_continuous(heads, state, batch)
        loss_discrete = self.loss_discrete(heads, batch)
        loss = loss_continous + loss_discrete
        self.log("val_loss", loss, on_step=True, on_epoch=True)
        return loss

    def predict_step(self, batch, batch_idx):
        """generate target data from source data using trained dynamics"""
        time_steps = torch.linspace(
            0.0, 1.0 - self.config.pipeline.time_eps, self.config.pipeline.num_timesteps
        )
        delta_t = (time_steps[-1] - time_steps[0]) / (len(time_steps) - 1)
        state = BridgeState(
            time_steps[0],
            batch.source_continuous,
            batch.source_discrete,
            batch.source_mask,
        )
        for time in time_steps[1:]:
            state.time = torch.full(
                (len(batch[0]), 1), time.item(), device=batch[0].device
            )
            heads = self.forward(state, batch)
            state = self.bridge_continuous.solver_step(state, heads, delta_t)
            state = self.bridge_discrete.solver_step(state, heads, delta_t)
        return state

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.config.train.optimizer.params.lr
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.config.train.scheduler.params.T_max,  # Adjust as needed
            eta_min=self.config.train.scheduler.params.eta_min,  # Adjust as needed
        )
        return [optimizer], [scheduler]


  Referenced from: <6F778E65-02B4-30B0-8D19-83A1F2195F4D> /Users/dario/anaconda3/lib/python3.10/site-packages/torchvision/image.so
  Expected in:     <A84DFEFF-287E-3B94-A7DB-731FA5F9CBBC> /Users/dario/anaconda3/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [3]:
from dataloader import DataloaderModule

dataloader = DataloaderModule(config=config, dataclass=jets)
databatch = next(dataloader.train.__iter__())
abm = AbsorbingBridgeMatching(config)


INFO: building dataloaders...
INFO: train/val/test split ratios: 0.7/0.1/0.2
INFO: train size: 47736, validation size: 6819, testing sizes: 13640




In [4]:
from lightning.pytorch import Trainer
# from lightning.pytorch.loggers import MLFlowLogger

# mlflow_logger = MLFlowLogger(
#     experiment_name=config.experiment.name,
#     tracking_uri="file:./results",  # or some external tracking server URI
# )

model = Trainer(
    max_epochs=1,
    log_every_n_steps=1,
    #  logger=mlflow_logger,
    accelerator=config.train.device,
    devices=1,
)

model.fit(
    abm,
    train_dataloaders=dataloader.train,
    val_dataloaders=dataloader.valid)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name               | Type             | Params | Mode 
----------------------------------------------------------------
0 | encoder            | HybridEPiC       | 826 K  | train
1 | loss_continuous_fn | MSELoss          | 0      | train
2 | loss_discrete_fn   | CrossEntropyLoss | 0      | train
----------------------------------------------------------------
826 K     Trainable params
0         Non-trainable params
826 K     Total params
3.305     Total estimated model params size (MB)
70        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/dario/anaconda3/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/dario/anaconda3/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [5]:
generation = model.predict(abm, dataloaders=dataloader.test)

# predictions is a list of outputs from each batch
# For example, if your predict_step returns a state for each batch:
for batch_idx, state in enumerate(generation):
    # Here `state` is the BridgeState returned by predict_step for that batch.
    # You can now run evaluation metrics or further processing.
    # `state.continuous`, `state.discrete`, etc. are now accessible.
    print(f"Batch {batch_idx} generated data:", state.continuous.shape)

/Users/dario/anaconda3/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/Users/dario/anaconda3/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined