In [26]:
from pathlib import Path
import os
import pandas as pd
import polars as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torch import argmax, concat, set_float32_matmul_precision, stack
from torch.utils.data import DataLoader

from mirror.dataloaders.loader import DataModule
from mirror.encoders import TableEncoder, YXDataset, YZDataset, ZDataset
from mirror.encoders.maps import rename
from mirror.models.cvae import CVAE
from mirror.models.cvae_components import (
    CVAEDecoderBlock,
    CVAEEncoderBlock,
    LabelsEncoderBlock,
)

In [2]:
LOGDIR = Path("demo_logs")

In [3]:
census = pd.read_csv("data/census.csv.zip")
census = census.set_index("resident_id_m")
census = census.apply(lambda col: col.astype("category"))
print(len(census))

uniques = census.drop_duplicates()
p = len(uniques) / len(census)
print(f"Probability of unique person = {p:.3}")

census = census.rename(columns=rename)
census.describe()

604351
Probability of unique person = 0.616


Unnamed: 0,social,country_of_birth,employment_status,ethnicity,health,household_type,hours_worked,full_time_student,industry,inner/outer_london,marital_status,occupaion,region,religion,residence_type,age_group,sex,residency_type
count,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351
unique,5,3,10,6,6,6,5,3,10,3,6,10,10,10,2,7,2,3
top,2,1,1,4,1,2,-8,2,-8,-8,2,-8,E12000008,2,1,1,1,1
freq,155374,496377,223809,487868,289229,320211,326132,449456,171052,514862,217340,171052,94344,275536,593416,111272,308536,596020


In [4]:
census.head()

Unnamed: 0_level_0,social,country_of_birth,employment_status,ethnicity,health,household_type,hours_worked,full_time_student,industry,inner/outer_london,marital_status,occupaion,region,religion,residence_type,age_group,sex,residency_type
resident_id_m,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
PTS000000588097,4,1,1,4,1,4,4,2,4,-8,1,5,E12000003,2,1,4,2,1
PTS000000000320,-8,1,5,4,2,1,-8,2,7,-8,1,2,E12000005,2,1,7,2,1
PTS000000397448,-8,2,5,4,2,1,-8,2,7,-8,1,3,E12000002,2,1,7,2,1
PTS000000082442,-8,1,5,4,3,2,-8,2,8,-8,2,8,E12000006,2,1,7,1,1
PTS000000016066,4,1,8,4,2,1,-8,2,9,-8,1,9,E12000002,1,1,2,2,1


In [5]:
# Define controls aka labels aka Y, that will be used to condition generation
controls = ["sex", "age_group", "region"]
census_controls = census[controls]
target_census = census.drop(columns=controls)

controls_encoder = TableEncoder(census_controls)
y_dataset = controls_encoder.encode(data=census_controls)
controls_encoder.names()

tensor(0.) tensor(9.)


['sex', 'age_group', 'region']

In [6]:
# Define census aka X, that will be generated
census_encoder = TableEncoder(target_census)
x_dataset = census_encoder.encode(data=target_census)
census_encoder.names()

tensor(0.) tensor(9.)


['social',
 'country_of_birth',
 'employment_status',
 'ethnicity',
 'health',
 'household_type',
 'hours_worked',
 'full_time_student',
 'industry',
 'inner/outer_london',
 'marital_status',
 'occupaion',
 'religion',
 'residence_type',
 'residency_type']

In [7]:
# combine into dataset object
yx_dataset = YXDataset(x_dataset, y_dataset)
dataloader = DataModule(
    dataset=yx_dataset,
    val_split=0.1,
    test_split=0.1,
    train_batch_size=512,
    val_batch_size=512,
    test_batch_size=512,
    num_workers=4,
    pin_memory=False,
)

In [8]:
latent = 12

# encoder block to embed labels into vec with hidden size
labels_encoder_block = LabelsEncoderBlock(
    encoder_types=controls_encoder.types(),
    encoder_sizes=controls_encoder.sizes(),
    depth=2,
    hidden_size=32,
)

# encoder and decoder block to process census data
encoder = CVAEEncoderBlock(
    encoder_types=census_encoder.types(),
    encoder_sizes=census_encoder.sizes(),
    depth=2,
    hidden_size=32,
    latent_size=latent,
)
decoder = CVAEDecoderBlock(
    encoder_types=census_encoder.types(),
    encoder_sizes=census_encoder.sizes(),
    depth=2,
    hidden_size=32,
    latent_size=latent,
)

# CVAE model
cvae = CVAE(
    embedding_names=census_encoder.names(),
    embedding_types=census_encoder.types(),
    labels_encoder_block=labels_encoder_block,
    encoder_block=encoder,
    decoder_block=decoder,
    beta=0.1,
    lr=0.001,
)

In [9]:
set_float32_matmul_precision("medium")

LOGDIR.mkdir(parents=True, exist_ok=True)
log_dir = str(Path(LOGDIR))

logger = WandbLogger(project="nomis_demo", dir=log_dir)

callbacks = [
    EarlyStopping(monitor="val_loss", patience=5, mode="min"),
    ModelCheckpoint(
        monitor="val_loss",
        save_top_k=1,
        mode="min",
        dirpath=Path(log_dir, "checkpoints"),
        save_weights_only=False,
    ),
]
trainer = Trainer(
    min_epochs=1,
    max_epochs=10,
    callbacks=callbacks,
    logger=logger,
    check_val_every_n_epoch=1,
)
trainer.fit(model=cvae, train_dataloaders=dataloader)
# trainer.validate(model=cvae, dataloaders=dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mfredjshone[0m ([33mfredjshone-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


/home/fred/Projects/mirror/.venv/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /home/fred/Projects/mirror/demo_logs/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type               | Params | Mode 
--------------------------------------------------------------------
0 | labels_encoder_block | LabelsEncoderBlock | 2.7 K  | train
1 | encoder_block        | CVAEEncoderBlock   | 5.7 K  | train
2 | decoder_block        | CVAEDecoderBlock   | 4.4 K  | train
--------------------------------------------------------------------
12.8 K    Trainable params
0         Non-trainable params
12.8 K    Total params
0.051     Total estimated model params size (MB)
88        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [10]:
n = len(yx_dataset)
z_loader = ZDataset(n, latent_size=latent)
yz_loader = YZDataset(z_loader, y_dataset)
gen_loader = DataLoader(
    yz_loader, batch_size=512, num_workers=4, persistent_workers=True
)

ys, xs, zs = zip(*trainer.predict(dataloaders=gen_loader))
ys = concat(ys)
xs = concat([stack([argmax(x, dim=1) for x in xb], dim=-1) for xb in xs], dim=0)
zs = concat(zs)

Restoring states from the checkpoint path at /home/fred/Projects/mirror/demo_logs/checkpoints/epoch=0-step=945-v2.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/fred/Projects/mirror/demo_logs/checkpoints/epoch=0-step=945-v2.ckpt


Predicting: |          | 0/? [00:00<?, ?it/s]

In [15]:
controls_df = controls_encoder.decode(ys)
census_df = census_encoder.decode(xs).drop(columns=["pid"])
df = pd.concat([controls_df, census_df], axis=1)
df.head()

Unnamed: 0,pid,sex,age_group,region,social,country_of_birth,employment_status,ethnicity,health,household_type,hours_worked,full_time_student,industry,inner/outer_london,marital_status,occupaion,religion,residence_type,residency_type
0,0,2,4,E12000003,2,1,1,4,1,2,3,2,8,-8,2,2,1,1,1
1,1,2,7,E12000005,-8,1,5,4,2,2,-8,2,8,-8,2,-8,2,1,1
2,2,2,7,E12000002,-8,1,5,4,2,2,-8,2,8,-8,2,2,2,1,1
3,3,1,7,E12000006,-8,1,5,4,2,2,-8,2,8,-8,2,2,2,1,1
4,4,2,2,E12000002,2,1,1,4,1,2,-8,2,-8,-8,1,-8,1,1,1


In [27]:
os.makedirs("tmp", exist_ok=True)
path = Path("tmp") / "demo_synthetic.csv"
print(f"Writing synthetic data to {path}")
df.to_csv(path, index=False)

Writing synthetic data to tmp/demo_synthetic.csv


In [38]:
df = pl.read_csv(path).drop(["pid"])
df.describe()

statistic,sex,age_group,region,social,country_of_birth,employment_status,ethnicity,health,household_type,hours_worked,full_time_student,industry,inner/outer_london,marital_status,occupaion,religion,residence_type,residency_type
str,f64,f64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,f64,f64,f64,f64,f64
"""count""",604351.0,604351.0,"""604351""",604351.0,604351.0,604351.0,604351.0,604351.0,604351.0,604351.0,604351.0,604351.0,"""604351""",604351.0,604351.0,604351.0,604351.0,604351.0
"""null_count""",0.0,0.0,"""0""",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"""0""",0.0,0.0,0.0,0.0,0.0
"""mean""",1.489475,4.013229,,-0.048197,1.0,0.078152,4.0,1.419448,2.0,-3.260698,1.797997,3.208586,,-0.136747,-2.202806,1.482822,1.0,1.0
"""std""",0.49989,2.128897,,4.035702,0.0,4.12367,0.0,0.493469,0.0,5.447145,0.401495,7.328374,,3.760548,4.936043,0.499705,0.0,0.0
"""min""",1.0,1.0,"""E12000001""",-8.0,1.0,-8.0,4.0,1.0,2.0,-8.0,1.0,-8.0,"""-8""",-8.0,-8.0,1.0,1.0,1.0
"""25%""",1.0,2.0,,2.0,1.0,1.0,4.0,1.0,2.0,-8.0,2.0,-8.0,,1.0,-8.0,1.0,1.0,1.0
"""50%""",1.0,4.0,,2.0,1.0,1.0,4.0,1.0,2.0,-8.0,2.0,8.0,,2.0,2.0,1.0,1.0,1.0
"""75%""",2.0,6.0,,2.0,1.0,1.0,4.0,2.0,2.0,3.0,2.0,8.0,,2.0,2.0,2.0,1.0,1.0
"""max""",2.0,7.0,"""W92000004""",2.0,1.0,5.0,4.0,2.0,2.0,3.0,2.0,8.0,"""E13000002""",2.0,2.0,2.0,1.0,1.0


In [31]:
census.head()

Unnamed: 0_level_0,social,country_of_birth,employment_status,ethnicity,health,household_type,hours_worked,full_time_student,industry,inner/outer_london,marital_status,occupaion,region,religion,residence_type,age_group,sex,residency_type
resident_id_m,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
PTS000000588097,4,1,1,4,1,4,4,2,4,-8,1,5,E12000003,2,1,4,2,1
PTS000000000320,-8,1,5,4,2,1,-8,2,7,-8,1,2,E12000005,2,1,7,2,1
PTS000000397448,-8,2,5,4,2,1,-8,2,7,-8,1,3,E12000002,2,1,7,2,1
PTS000000082442,-8,1,5,4,3,2,-8,2,8,-8,2,8,E12000006,2,1,7,1,1
PTS000000016066,4,1,8,4,2,1,-8,2,9,-8,1,9,E12000002,1,1,2,2,1


In [39]:
print(type(df))

# Function to compute individual column frequencies
def compute_column_frequencies(df):
    return {
        col: df.group_by(col).agg(pl.count().alias("count")).sort("count", descending=True)
        for col in df.columns
    }

compute_column_frequencies(df)

<class 'polars.dataframe.frame.DataFrame'>


(Deprecated in version 0.20.5)
  col: df.group_by(col).agg(pl.count().alias("count")).sort("count", descending=True)


{'sex': shape: (2, 2)
 ┌─────┬────────┐
 │ sex ┆ count  │
 │ --- ┆ ---    │
 │ i64 ┆ u32    │
 ╞═════╪════════╡
 │ 1   ┆ 308536 │
 │ 2   ┆ 295815 │
 └─────┴────────┘,
 'age_group': shape: (7, 2)
 ┌───────────┬────────┐
 │ age_group ┆ count  │
 │ ---       ┆ ---    │
 │ i64       ┆ u32    │
 ╞═══════════╪════════╡
 │ 1         ┆ 111272 │
 │ 7         ┆ 111082 │
 │ 3         ┆ 80809  │
 │ 5         ┆ 79272  │
 │ 4         ┆ 77447  │
 │ 6         ┆ 74760  │
 │ 2         ┆ 69709  │
 └───────────┴────────┘,
 'region': shape: (10, 2)
 ┌───────────┬───────┐
 │ region    ┆ count │
 │ ---       ┆ ---   │
 │ str       ┆ u32   │
 ╞═══════════╪═══════╡
 │ E12000008 ┆ 94344 │
 │ E12000007 ┆ 89489 │
 │ E12000002 ┆ 75080 │
 │ E12000006 ┆ 64258 │
 │ E12000005 ┆ 60231 │
 │ E12000009 ┆ 57815 │
 │ E12000003 ┆ 55469 │
 │ E12000004 ┆ 49440 │
 │ W92000004 ┆ 31458 │
 │ E12000001 ┆ 26767 │
 └───────────┴───────┘,
 'social': shape: (2, 2)
 ┌────────┬────────┐
 │ social ┆ count  │
 │ ---    ┆ ---    │
 │ i64   

In [None]:
import polars as pl
from itertools import combinations

# Load CSV files
df1 = pl.read_csv("file1.csv")
df2 = pl.read_csv("file2.csv")

# Function to compute individual column frequencies
def compute_column_frequencies(df):
    return {
        col: df.group_by(col).agg(pl.count().alias("count")).sort("count", descending=True)
        for col in df.columns
    }

# Function to compute joint frequencies for combinations of columns
def compute_joint_frequencies(df, max_combination_size=3):
    joint_freqs = {}
    for r in range(2, max_combination_size + 1):
        for cols in combinations(df.columns, r):
            joint_freqs[cols] = (
                df.group_by(list(cols))
                .agg(pl.count().alias("count"))
                .sort("count", descending=True)
            )
    return joint_freqs

# Compute frequencies
df1_col_freqs = compute_column_frequencies(df)
df2_col_freqs = compute_column_frequencies(df)

df1_joint_freqs = compute_joint_frequencies(df)
df2_joint_freqs = compute_joint_frequencies(df)


