# Flow Matching for Mixed (Numerical + Categorical) Tabular Data

This notebook implements **Flow Matching** for datasets with **both numerical and categorical features**.

### Architecture
- **Data Processing**: Uses `EFVFMDataset` from `ef-vfm` for data loading and quantile normalization.
- **Network**: Simple MLP with sinusoidal time encoding (from `Flow_Matching.ipynb`). Input is `[x_num_t, x_cat_t_onehot, time_enc(t)]`, output is `[pred_num, pred_logits_cat]`.
- **Training**: Mixed loss combining numerical and categorical objectives:
  - Numerical: MSE between predicted denoised values and true numerical features.
  - Categorical: Cross-entropy loss from `ef-vfm` (`_absorbed_closs`).
- **Sampling**: ODE integration via `torchdiffeq.odeint` with a velocity field derived from the model's denoised predictions.
- **Evaluation**: `TabMetrics` from `ef-vfm`.

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
import json
import pandas as pd
from typing import *
from torchdiffeq import odeint

sys.path.append(os.path.join(os.getcwd(), 'ef-vfm'))

from utils_train import EFVFMDataset
from ef_vfm.metrics import TabMetrics

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Mixed Flow Matching

Data is represented as $x = [x^{\text{num}}, x^{\text{cat}}_{\text{oh}}]$ where numerical features are continuous
and categorical features are one-hot encoded.

The OT conditional flow interpolates between noise $x_0 \sim \mathcal{N}(0, I)$ and data $x_1$:

$$ \psi_t(x_0) = (1 - (1 - \sigma_{min})t)\, x_0 + t\, x_1 $$

The model predicts the **denoised** data $\hat{x}_1 = f_\theta(x_t, t)$, split into:
- $\hat{x}_1^{\text{num}}$: predicted numerical values
- $\hat{x}_1^{\text{cat}}$: predicted logits per categorical feature

The mixed training loss combines both objectives:

$$ \mathcal{L} = \underbrace{\| \hat{x}_1^{\text{num}} - x_1^{\text{num}} \|^2}_{\text{MSE (numerical)}} + \underbrace{\left( -\sum_j \log p_\theta(c_j \mid x_t, t) \right)}_{\text{Cross-Entropy (categorical)}} $$

In [None]:
class MixedFlowMatching:
    """Flow Matching for mixed numerical + categorical data."""

    def __init__(self, d_numerical: int, categories: List[int], sig_min: float = 1e-4) -> None:
        self.d_numerical = d_numerical
        self.categories = categories
        self.sig_min = sig_min
        self.d_cat_oh = sum(categories)
        self.d_input = d_numerical + self.d_cat_oh

    def to_one_hot(self, x_cat: torch.Tensor) -> torch.Tensor:
        parts = []
        for i, k in enumerate(self.categories):
            parts.append(F.one_hot(x_cat[:, i].long(), num_classes=k))
        return torch.cat(parts, dim=-1).float()

    def psi_t(self, x_0: torch.Tensor, x_1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        return (1 - (1 - self.sig_min) * t) * x_0 + t * x_1

    def _absorbed_closs(self, logits: torch.Tensor, x_cat: torch.Tensor) -> torch.Tensor:
        """Cross-entropy loss over categorical features (from ef-vfm)."""
        cum_sum = 0
        losses = torch.zeros(len(self.categories), device=logits.device)
        for i, val in enumerate(self.categories):
            dist = torch.distributions.Categorical(logits=logits[:, cum_sum:cum_sum + val])
            losses[i] = -dist.log_prob(x_cat[:, i].long()).mean()
            cum_sum += val
        return losses.sum()

    def loss(self, model: nn.Module, x: torch.Tensor) -> torch.Tensor:
        """Compute mixed training loss.
        x: [batch, d_numerical + n_cat_features] where first d_numerical cols are float,
           remaining cols are integer category indices.
        """
        b = x.shape[0]
        dev = x.device

        x_num = x[:, :self.d_numerical].float()
        x_cat = x[:, self.d_numerical:]

        t = torch.rand(b, device=dev)
        t_expand = t[:, None]

        # Numerical interpolation
        if self.d_numerical > 0:
            x_0_num = torch.randn_like(x_num)
            x_t_num = self.psi_t(x_0_num, x_num, t_expand)
        else:
            x_t_num = torch.zeros_like(x_num)

        # Categorical interpolation (in one-hot space)
        if self.d_cat_oh > 0:
            x_1_cat_oh = self.to_one_hot(x_cat)
            x_0_cat = torch.randn(b, self.d_cat_oh, device=dev)
            x_t_cat = self.psi_t(x_0_cat, x_1_cat_oh, t_expand)
        else:
            x_t_cat = torch.zeros_like(x_cat)

        # Concatenate and predict
        x_t = torch.cat([x_t_num, x_t_cat], dim=1)
        pred = model(t, x_t)

        pred_num = pred[:, :self.d_numerical]
        pred_logits = pred[:, self.d_numerical:]

        # Numerical loss: MSE on denoised prediction
        if self.d_numerical > 0:
            num_loss = F.mse_loss(pred_num, x_num)
        else:
            num_loss = torch.tensor(0.0, device=dev)

        # Categorical loss: cross-entropy
        if self.d_cat_oh > 0:
            cat_loss = self._absorbed_closs(pred_logits, x_cat)
        else:
            cat_loss = torch.tensor(0.0, device=dev)

        return num_loss, cat_loss

## 2. Neural Network (MLP with Sinusoidal Time Encoding)

A simple MLP that takes `[x_num_t, x_cat_t_onehot, time_encoding(t)]` and outputs `[pred_num, pred_logits_cat]`.
The output dimension equals `d_numerical + sum(categories)`.

In [None]:
class Net(nn.Module):
    """MLP with sinusoidal time encoding (from Flow_Matching.ipynb)."""

    def __init__(self, in_dim: int, out_dim: int, h_dims: List[int], n_frequencies: int) -> None:
        super().__init__()

        ins = [in_dim + 2 * n_frequencies] + h_dims
        outs = h_dims + [out_dim]
        self.n_frequencies = n_frequencies

        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(in_d, out_d), nn.LeakyReLU())
            for in_d, out_d in zip(ins, outs)
        ])
        self.top = nn.Sequential(nn.Linear(out_dim, out_dim))

    def time_encoder(self, t: torch.Tensor) -> torch.Tensor:
        freq = 2 * torch.arange(self.n_frequencies, device=t.device) * torch.pi
        t = freq * t[..., None]
        return torch.cat((t.cos(), t.sin()), dim=-1)

    def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        t = self.time_encoder(t)
        x = torch.cat((x, t), dim=-1)
        for layer in self.layers:
            x = layer(x)
        return self.top(x)

## 3. Mixed Velocity Wrapper (for ODE Sampling)

During sampling, the model predicts denoised data, which is converted to a velocity field:

$$ v_t(x_t) = \frac{\hat{x}_1 - (1 - \sigma_{min}) \cdot x_t}{1 - (1 - \sigma_{min}) \cdot t} $$

For numerical features, $\hat{x}_1^{\text{num}}$ is used directly.
For categorical features, $\hat{x}_1^{\text{cat}} = \text{softmax}(\text{logits})$ converts logits to probability simplex.

After ODE integration:
- Numerical outputs are used directly as continuous values.
- Categorical outputs are discretized via `argmax` per feature segment.

In [None]:
class MixedVelocity(nn.Module):
    """Converts model denoised predictions to velocity for ODE sampling."""

    def __init__(self, model: nn.Module, d_numerical: int, categories: List[int], sig_min: float = 1e-4):
        super().__init__()
        self.model = model
        self.d_numerical = d_numerical
        self.categories = categories
        self.sig_min = sig_min

    def forward(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        if t.dim() == 0:
            t = t.unsqueeze(0)
        t_batch = t.expand(x.shape[0])

        pred = self.model(t_batch, x)
        denom = 1 - (1 - self.sig_min) * t_batch.unsqueeze(1)

        # Numerical velocity: pred_num is already the denoised prediction
        if self.d_numerical > 0:
            pred_num = pred[:, :self.d_numerical]
            x_num = x[:, :self.d_numerical]
            v_num = (pred_num - (1 - self.sig_min) * x_num) / denom
        else:
            v_num = torch.zeros_like(x[:, :self.d_numerical])
            
        # Categorical velocity: softmax(logits) as denoised prediction
        if len(self.categories) > 0:
            pred_logits = pred[:, self.d_numerical:]
            x_cat = x[:, self.d_numerical:]

            # my version
            v_cat_parts = []
            logit_idx = 0
            oh_idx = 0
            for k in self.categories:
                probs_k = F.softmax(pred_logits[:, logit_idx:logit_idx + k], dim=-1)
                x_k = x_cat[:, oh_idx:oh_idx + k]
                v_k = (probs_k - (1 - self.sig_min) * x_k) / denom
                v_cat_parts.append(v_k)
                logit_idx += k
                oh_idx += k
            v_cat = torch.cat(v_cat_parts, dim=1)
            
            # Guzman's version, the performance is bad
            # v_cat = (pred_logits - (1 - self.sig_min) * x_cat) / denom
            
        else:
            v_cat = torch.zeros_like(x[:, self.d_numerical:])

        return torch.cat([v_num, v_cat], dim=1)

## 4. Data Loading

Load a mixed dataset using `EFVFMDataset`. The dataset should have both `d_numerical > 0` and `len(categories) > 0`.

Available mixed datasets: `adult`, `default`, `shoppers`, `beijing`, `news`, etc.

In [None]:
# ============ Configuration ============
dataname = 'news_numerical_only'
data_dir = f'ef-vfm-dev/data/{dataname}'
info_path = f'ef-vfm-dev/data/{dataname}/info.json'

with open(info_path, 'r') as f:
    info = json.load(f)

# Hyperparameters
batch_size = 2048
n_epochs = 100
lr = 1e-3
sigma_min = 1e-4
h_dims = [512] * 5
n_frequencies = 10

# Load Data
train_data = EFVFMDataset(
    dataname, data_dir, info, isTrain=True,
    dequant_dist='none', int_dequant_factor=0.0
)

d_numerical = train_data.d_numerical
categories = train_data.categories.tolist()

# assert d_numerical > 0, f"Expected numerical features, but d_numerical={d_numerical}"
# assert len(categories) > 0, f"Expected categorical features, but categories is empty"

d_cat_oh = sum(categories)
d_total = d_numerical + d_cat_oh

print(f"Dataset: {dataname}")
print(f"d_numerical: {d_numerical}")
print(f"categories (classes per feature): {categories}")
print(f"d_cat_oh (total one-hot dim): {d_cat_oh}")
print(f"d_total (model input/output dim): {d_total}")
print(f"Training data shape: {train_data.X.shape}")

dataset = train_data.X.to(device)
tensor_dataset = TensorDataset(dataset)
dataloader = DataLoader(tensor_dataset, batch_size=batch_size, shuffle=True)

print(f"Batch size: {batch_size}")
print(f"Number of batches: {len(dataloader)}")

## 5. Training

Train the MLP with the mixed loss: MSE for numerical denoising + cross-entropy for categorical denoising.

In [None]:
in_dim = d_total
out_dim = d_total

mixed_fm = MixedFlowMatching(d_numerical, categories, sig_min=sigma_min)
net = Net(in_dim, out_dim, h_dims, n_frequencies).to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=lr)

print(f"Model parameters: {sum(p.numel() for p in net.parameters()):,}")
print(f"Training for {n_epochs} epochs...")

losses_num = []
losses_cat = []
losses_total = []
net.train()

bar = tqdm(range(n_epochs), ncols=100)
for epoch in bar:
    for batch in dataloader:
        x = batch[0]
        num_loss, cat_loss = mixed_fm.loss(net, x)
        loss = num_loss + cat_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses_num.append(num_loss.item())
        losses_cat.append(cat_loss.item())
        losses_total.append(loss.item())
    bar.set_postfix(total=loss.item(), num=num_loss.item(), cat=cat_loss.item())

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(losses_total)
axes[0].set_title("Total Loss")
axes[0].set_yscale('log')
axes[0].grid(True, alpha=0.3)
axes[1].plot(losses_num)
axes[1].set_title("Numerical Loss (MSE)")
axes[1].set_yscale('log')
axes[1].grid(True, alpha=0.3)
axes[2].plot(losses_cat)
axes[2].set_title("Categorical Loss (CE)")
axes[2].set_yscale('log')
axes[2].grid(True, alpha=0.3)
for ax in axes:
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Loss")
plt.tight_layout()
plt.show()

## 6. Sampling

Generate synthetic data by:
1. Sampling $x_0 \sim \mathcal{N}(0, I)$ in $\mathbb{R}^{d_{\text{num}} + \sum K_j}$
2. ODE integration from $t=0$ to $t=1$ using the mixed velocity field
3. Numerical features are taken directly; categorical features are discretized via `argmax`

In [None]:
def split_num_cat_target(syn_data, info, num_inverse, int_inverse, cat_inverse):
    task_type = info['task_type']
    num_col_idx = info['num_col_idx']
    cat_col_idx = info['cat_col_idx']
    target_col_idx = info['target_col_idx']

    if hasattr(num_inverse, 'n_features_in_'):
        n_num_feat = num_inverse.n_features_in_
    else:
        n_num_feat = len(num_col_idx)
        if task_type == 'regression':
            n_num_feat += len(target_col_idx)

    if hasattr(cat_inverse, 'n_features_in_'):
        n_cat_feat = cat_inverse.n_features_in_
    else:
        n_cat_feat = len(cat_col_idx)
        if task_type != 'regression':
            n_cat_feat += len(target_col_idx)

    syn_num = syn_data[:, :n_num_feat]
    syn_cat = syn_data[:, n_num_feat:] if n_cat_feat > 0 else syn_data[:, n_num_feat:]

    syn_num = num_inverse(syn_num).astype(np.float32)
    syn_num = int_inverse(syn_num).astype(np.float32)
    if n_cat_feat > 0:
        syn_cat = cat_inverse(syn_cat)
    else:
        syn_cat = np.empty((syn_data.shape[0], 0))

    if info['task_type'] == 'regression':
        syn_target = syn_num[:, :len(target_col_idx)]
        syn_num = syn_num[:, len(target_col_idx):]
    else:
        syn_target = syn_cat[:, :len(target_col_idx)]
        syn_cat = syn_cat[:, len(target_col_idx):]

    return syn_num, syn_cat, syn_target


def recover_data(syn_num, syn_cat, syn_target, info):
    num_col_idx = info['num_col_idx']
    cat_col_idx = info['cat_col_idx']
    target_col_idx = info['target_col_idx']
    idx_mapping = info['idx_mapping']
    idx_mapping = {int(key): value for key, value in idx_mapping.items()}

    syn_df = pd.DataFrame()
    if info['task_type'] == 'regression':
        for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)):
            if i in set(num_col_idx):
                syn_df[i] = syn_num[:, idx_mapping[i]]
            elif i in set(cat_col_idx):
                syn_df[i] = syn_cat[:, idx_mapping[i] - len(num_col_idx)]
            else:
                syn_df[i] = syn_target[:, idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)]
    else:
        for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)):
            if i in set(num_col_idx):
                syn_df[i] = syn_num[:, idx_mapping[i]]
            elif i in set(cat_col_idx):
                syn_df[i] = syn_cat[:, idx_mapping[i] - len(num_col_idx)]
            else:
                syn_df[i] = syn_target[:, idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)]
    return syn_df

In [None]:
num_samples = len(train_data)
print(f"Generating {num_samples} samples...")

net.eval()
velocity = MixedVelocity(net, d_numerical, categories, sig_min=sigma_min)

sample_batch_size = 4096
all_samples = []
num_generated = 0

with torch.no_grad():
    while num_generated < num_samples:
        cur_batch = min(sample_batch_size, num_samples - num_generated)
        x_0 = torch.randn(cur_batch, d_total, device=device)
        t_span = torch.tensor([0.0, 0.999], device=device)
        trajectory = odeint(velocity, x_0, t_span, method='dopri5', rtol=1e-5, atol=1e-5)
        out = trajectory[1]

        # Numerical part: take directly
        sample_num = out[:, :d_numerical]

        # Categorical part: argmax per feature
        cat_out = out[:, d_numerical:]
        sample_cat = torch.zeros(cur_batch, len(categories), device=device)
        oh_idx = 0
        for i, k in enumerate(categories):
            sample_cat[:, i] = torch.argmax(cat_out[:, oh_idx:oh_idx + k], dim=1).float()
            oh_idx += k

        sample = torch.cat([sample_num, sample_cat], dim=1)

        mask_nan = torch.any(sample.isnan(), dim=1)
        sample = sample[~mask_nan]
        all_samples.append(sample.cpu())
        num_generated += sample.shape[0]
        print(f"  Generated {num_generated}/{num_samples} samples")

syn_tensor = torch.cat(all_samples, dim=0)[:num_samples]
syn_data_np = syn_tensor.numpy()
print(f"Generated data shape: {syn_data_np.shape}")

num_inverse = train_data.num_inverse
int_inverse = train_data.int_inverse
cat_inverse = train_data.cat_inverse

syn_num, syn_cat, syn_target = split_num_cat_target(
    syn_data_np, info, num_inverse, int_inverse, cat_inverse
)
syn_df = recover_data(syn_num, syn_cat, syn_target, info)

idx_name_mapping = info['idx_name_mapping']
idx_name_mapping = {int(key): value for key, value in idx_name_mapping.items()}
syn_df.rename(columns=idx_name_mapping, inplace=True)

print(f"\nSampled Data Head:")
print(syn_df.head())
print(f"\nGenerated {len(syn_df)} samples")

## 7. Evaluation

Using `ef-vfm`'s `TabMetrics` to evaluate the generated samples with density, MLE, and C2ST metrics.

In [None]:
real_data_path = f'ef-vfm-dev/synthetic/{dataname}/real.csv'
test_data_path = f'ef-vfm-dev/synthetic/{dataname}/test.csv'
val_data_path = f'ef-vfm-dev/synthetic/{dataname}/val.csv'

if not os.path.exists(val_data_path):
    print(f"{dataname} does not have a validation set. MLE evaluation will split from training set.")
    val_data_path = None

is_dcr = 'dcr' in dataname
if is_dcr:
    metric_list = ["dcr"]
else:
    metric_list = ["density", "mle", "c2st"]

print(f"Initializing TabMetrics with metrics: {metric_list}")
metrics = TabMetrics(
    real_data_path=real_data_path,
    test_data_path=test_data_path,
    val_data_path=val_data_path,
    info=info,
    device=device,
    metric_list=metric_list
)

print(f"Real data size: {metrics.real_data_size}")
print(f"Generated data size: {len(syn_df)}")

if len(syn_df) != metrics.real_data_size:
    print(f"\nRegenerating samples to match real data size ({metrics.real_data_size})...")
    net.eval()
    target_n = metrics.real_data_size
    all_samples = []
    num_generated = 0
    with torch.no_grad():
        while num_generated < target_n:
            cur_batch = min(sample_batch_size, target_n - num_generated)
            x_0 = torch.randn(cur_batch, d_total, device=device)
            t_span = torch.tensor([0.0, 0.999], device=device)
            trajectory = odeint(velocity, x_0, t_span, method='dopri5', rtol=1e-5, atol=1e-5)
            out = trajectory[1]
            sample_num = out[:, :d_numerical]
            cat_out = out[:, d_numerical:]
            sample_cat = torch.zeros(cur_batch, len(categories), device=device)
            oh_idx = 0
            for i, k in enumerate(categories):
                sample_cat[:, i] = torch.argmax(cat_out[:, oh_idx:oh_idx + k], dim=1).float()
                oh_idx += k
            sample = torch.cat([sample_num, sample_cat], dim=1)
            mask_nan = torch.any(sample.isnan(), dim=1)
            sample = sample[~mask_nan]
            all_samples.append(sample.cpu())
            num_generated += sample.shape[0]
    syn_tensor = torch.cat(all_samples, dim=0)[:target_n]
    syn_data_np = syn_tensor.numpy()
    syn_num, syn_cat, syn_target = split_num_cat_target(
        syn_data_np, info, num_inverse, int_inverse, cat_inverse
    )
    syn_df = recover_data(syn_num, syn_cat, syn_target, info)
    syn_df.rename(columns=idx_name_mapping, inplace=True)
    print(f"Regenerated {len(syn_df)} samples")

In [None]:
import tempfile

print("\n" + "=" * 50)
print("Starting Evaluation...")
print("=" * 50)

temp_csv = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False)
syn_df.to_csv(temp_csv.name, index=False)
temp_csv.close()

syn_df_for_eval = pd.read_csv(temp_csv.name)
syn_df_for_eval.columns = range(len(syn_df_for_eval.columns))

out_metrics, extras = metrics.evaluate(syn_df_for_eval)

print("\n" + "=" * 50)
print("Evaluation Results:")
print("=" * 50)
for key, value in out_metrics.items():
    if isinstance(value, (int, float)):
        print(f"{key}: {value:.4f}")
    else:
        print(f"{key}: {value}")

os.unlink(temp_csv.name)

In [None]:
from datetime import datetime

results_dir = f"results/{dataname}/fm_mixed_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(results_dir, exist_ok=True)

metrics_path = os.path.join(results_dir, "metrics.json")
with open(metrics_path, 'w') as f:
    json.dump(out_metrics, f, indent=4, default=str)
print(f"Metrics saved to: {metrics_path}")

syn_data_path = os.path.join(results_dir, "synthetic_samples.csv")
syn_df.to_csv(syn_data_path, index=False)
print(f"Synthetic data saved to: {syn_data_path}")

if extras:
    for name, extra in extras.items():
        if isinstance(extra, pd.DataFrame):
            extra_path = os.path.join(results_dir, f"{name}.csv")
            extra.to_csv(extra_path, index=False)
            print(f"Extra {name} saved to: {extra_path}")
        elif isinstance(extra, dict):
            extra_path = os.path.join(results_dir, f"{name}.json")
            with open(extra_path, 'w') as f:
                json.dump(extra, f, indent=4, default=str)
            print(f"Extra {name} saved to: {extra_path}")

print(f"\nAll results saved to: {results_dir}")