
# Generative GSM pipeline (modularized)

The notebook now delegates data preparation, model training, sampling, and evaluation to the `gsm_pipeline` package. The default learner uses the DSBM (score-based) placeholder and exposes hooks for SB-CFM when needed.


In [1]:
#!pip install --upgrade numpy tensorflow torchcfm sionna

## Imports & reproducibility

In [2]:
import numpy as np
import torch
from pathlib import Path

from gsm_pipeline import (
    ChannelConfig,
    FeatureConfig,
    TrainingConfig,
    SamplingConfig,
    ScoreEvalConfig,
    build_model,
    generate_channel_tensor,
    prepare_feature_tensors,
    create_feature_dataloaders,
    train_model,
    train_repo_bridge,
    sample_sb_sde,
    compute_swd_between_features,
    evaluate_svd_statistics,
    evaluate_grassmann_metric,
    collect_score_distribution,
    score_distribution_summary,
    sliced_wasserstein_between_scores,
)
from gsm_pipeline.features import amp_phase_features_to_complex, complex_to_svd_feature_tensor

torch.manual_seed(0)
np.random.seed(0)


## Channel generation

In [3]:

channel_cfg = ChannelConfig()
channel_artifacts = generate_channel_tensor(channel_cfg)
H_freq = channel_artifacts.channel_tensor
print(f"H_freq shape: {H_freq.shape}")


H_freq shape: (64, 4, 256, 20, 40)


## Feature extraction

In [4]:

feature_cfg = FeatureConfig()
feature_data = prepare_feature_tensors(H_freq, feature_cfg)
print(
    "Amp/phase feature tensor:", feature_data.amp_phase_features.shape,
    "| SVD feature tensor:", feature_data.svd_features.shape,
)


Amp/phase feature tensor: (800, 512, 256) | SVD feature tensor: (800, 548, 256)


## Training data & DSBM model

In [10]:
TRAIN_SPACE = "amp_phase"  # switch to 'svd' if desired
if TRAIN_SPACE == "svd":
    feature_tensor = feature_data.svd_features
    feature_metadata = feature_data.svd_metadata
else:
    feature_tensor = feature_data.amp_phase_features
    feature_metadata = feature_data.amp_phase_metadata

training_cfg = TrainingConfig(
    method="dsbm_repo",  # use repo-backed solvers by default
    batch_size=16,
    learning_rate=5e-4,
    weight_decay=1e-4,
    bridge_solver="dsbm",
    bridge_num_steps=20,
    bridge_inner_iters=500,
    bridge_outer_iters=3,
    bridge_sigma=1.0,
    val_split=0.1,
    seed=0,
    conv_kwargs={"base_dim": 64, "num_layers": 4, "kernel_size": 3},
)

train_loader, val_loader, train_tensor, val_tensor = create_feature_dataloaders(
    feature_tensor,
    batch_size=training_cfg.batch_size,
    val_split=training_cfg.val_split,
    seed=training_cfg.seed,
)
feature_channels = train_tensor.shape[1]
sequence_length = train_tensor.shape[2]

use_repo_bridge = training_cfg.method.lower() == "dsbm_repo"
repo_result = None
if use_repo_bridge:
    repo_result = train_repo_bridge(train_tensor, training_cfg)
    trained_model = None
    training_result = None
    train_device = repo_result.device
    print(
        f"Repo solver {training_cfg.bridge_solver} trained on {train_device}; "
        f"{len(repo_result.history)} inner updates"
    )
else:
    model = build_model(
        training_cfg.model_type,
        feature_channels,
        sequence_length,
        conv_kwargs=training_cfg.conv_kwargs,
        mlp_kwargs=training_cfg.mlp_kwargs,
    )
    training_result = train_model(model, train_loader, training_cfg)
    trained_model = training_result.model
    train_device = training_result.device
    print(f"Training done on {train_device}; last epochs: {training_result.history[-3:]}")


DSBM b:   0%|          | 0/500 [00:00<?, ?it/s]

DSBM f:   0%|          | 0/500 [00:00<?, ?it/s]

DSBM b:   0%|          | 0/500 [00:00<?, ?it/s]

DSBM f:   0%|          | 0/500 [00:00<?, ?it/s]

DSBM b:   0%|          | 0/500 [00:00<?, ?it/s]

DSBM f:   0%|          | 0/500 [00:00<?, ?it/s]

Repo solver dsbm trained on cuda; 3000 inner updates


## Sampling and conversion to SVD space

In [6]:
sampling_cfg = SamplingConfig(num_samples=100, steps=4, sde_sigma=0.0, device=str(train_device))
if repo_result is not None:
    generated_features = repo_result.sample(
        sampling_cfg.num_samples,
        steps=sampling_cfg.steps,
    )
else:
    generated_features = sample_sb_sde(
        trained_model,
        num_samples=sampling_cfg.num_samples,
        shape=(feature_channels, sequence_length),
        steps=sampling_cfg.steps,
        sde_sigma=sampling_cfg.sde_sigma,
        schedule=sampling_cfg.schedule,
        device=train_device,
    )
reference_tensor = train_tensor[: generated_features.shape[0]].detach().cpu()

if TRAIN_SPACE == "svd":
    gen_svd_features = generated_features.detach().cpu()
    ref_svd_features = reference_tensor
else:
    gen_complex = amp_phase_features_to_complex(generated_features, feature_data.amp_phase_metadata)
    ref_complex = amp_phase_features_to_complex(reference_tensor, feature_data.amp_phase_metadata)
    gen_svd_np = complex_to_svd_feature_tensor(gen_complex, feature_data.svd_metadata)
    ref_svd_np = complex_to_svd_feature_tensor(ref_complex, feature_data.svd_metadata)
    gen_svd_features = torch.from_numpy(gen_svd_np)
    ref_svd_features = torch.from_numpy(ref_svd_np)

print("Generated features:", generated_features.shape)
print("Reference SVD features:", ref_svd_features.shape)


Generated features: torch.Size([100, 512, 256])
Reference SVD features: torch.Size([100, 548, 256])


## Classical SVD metrics

In [7]:

swd_value = compute_swd_between_features(ref_svd_features, gen_svd_features)
svd_stats = evaluate_svd_statistics(ref_svd_features, gen_svd_features, feature_data.svd_metadata)
grassmann_metrics = evaluate_grassmann_metric(ref_svd_features, gen_svd_features, feature_data.svd_metadata)

print(f"Sliced Wasserstein distance (SVD feature space): {swd_value:.6f}")
for label, stats in svd_stats.items():
    print(label, stats)
print("Grassmann metric:", grassmann_metrics["grassmann_metric"])


Sliced Wasserstein distance (SVD feature space): 0.250810
sigma_real {'label': 'sigma_real', 'mean': 7.126501083374023, 'std': 3.6349661350250244, 'min': 0.6837253570556641, 'max': 19.554134368896484}
sigma_gen {'label': 'sigma_gen', 'mean': 8.711811065673828, 'std': 1.1155962944030762, 'min': 5.835947036743164, 'max': 12.855240821838379}
amp_u_real {'label': 'amp_u_real', 'mean': 0.11312510818243027, 'std': 0.053176119923591614, 'min': 6.92337125656195e-05, 'max': 0.4704986810684204}
amp_u_gen {'label': 'amp_u_gen', 'mean': 0.11038912832736969, 'std': 0.05864511802792549, 'min': 2.719094482017681e-05, 'max': 0.48063966631889343}
phase_u_real {'label': 'phase_u_real', 'mean': -0.01248952466994524, 'std': 1.7829346656799316, 'min': -3.1415915489196777, 'max': 3.1415910720825195}
phase_u_gen {'label': 'phase_u_gen', 'mean': 0.001736940466798842, 'std': 1.7985862493515015, 'min': -3.1415908336639404, 'max': 3.141592502593994}
amp_v_real {'label': 'amp_v_real', 'mean': 0.4604397118091583, 

## Score distribution comparison (validation vs generated)

In [8]:
score_cfg = ScoreEvalConfig(t_value=0.5, batch_size=32)
val_eval_tensor = val_tensor if val_tensor is not None else train_tensor
if repo_result is not None:
    score_callable = lambda batch, _: repo_result.score(batch, score_cfg.t_value)
    val_scores = collect_score_distribution(
        None,
        val_eval_tensor,
        device=train_device,
        batch_size=score_cfg.batch_size,
        t_value=score_cfg.t_value,
        score_fn=score_callable,
    )
    gen_scores = collect_score_distribution(
        None,
        generated_features.detach().cpu(),
        device=train_device,
        batch_size=score_cfg.batch_size,
        t_value=score_cfg.t_value,
        score_fn=score_callable,
    )
else:
    val_scores = collect_score_distribution(
        trained_model,
        val_eval_tensor,
        device=train_device,
        batch_size=score_cfg.batch_size,
        t_value=score_cfg.t_value,
    )
    gen_scores = collect_score_distribution(
        trained_model,
        generated_features.detach().cpu(),
        device=train_device,
        batch_size=score_cfg.batch_size,
        t_value=score_cfg.t_value,
    )
score_swd = sliced_wasserstein_between_scores(val_scores, gen_scores, score_cfg)
print(f"Score distribution SWD: {score_swd:.6f}")
print("Validation score stats:", score_distribution_summary(val_scores))
print("Generated score stats:", score_distribution_summary(gen_scores))


Score distribution SWD: 0.173023
Validation score stats: {'mean': -0.00017555504746269435, 'std': 0.042140327394008636, 'min': -0.3531128764152527, 'max': 0.3576325476169586}
Generated score stats: {'mean': -0.00038395184674300253, 'std': 0.2280501276254654, 'min': -1.3087732791900635, 'max': 1.2545868158340454}


## Save checkpoint

In [9]:
ckpt_path = Path('channel_sbcfm_svd_features.pth')
if repo_result is not None:
    model_state = repo_result.forward_model.state_dict()
    optimizer_state = None
    training_history = repo_result.history
    training_method = f"repo::{training_cfg.bridge_solver}"
else:
    model_state = trained_model.state_dict()
    optimizer_state = training_result.optimizer_state
    training_history = training_result.history
    training_method = training_result.method
checkpoint = {
    'model_state_dict': model_state,
    'optimizer_state_dict': optimizer_state,
    'feature_meta': feature_data.svd_metadata,
    'model_type': training_cfg.model_type,
    'conv_kwargs': training_cfg.conv_kwargs,
    'mlp_kwargs': training_cfg.mlp_kwargs,
    'svd_summary': feature_data.svd_summary,
    'train_space': TRAIN_SPACE,
    'train_feature_meta': feature_metadata,
    'amp_phase_meta': feature_data.amp_phase_metadata,
    'training_history': training_history,
    'training_method': training_method,
    'repo_solver': training_cfg.bridge_solver if repo_result is not None else None,
    'score_swd': score_swd,
}
torch.save(checkpoint, ckpt_path)
print('Saved checkpoint to', ckpt_path)


Saved checkpoint to channel_sbcfm_svd_features.pth
