In [None]:
from pathlib import Path

import pandas as pd
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torch import set_float32_matmul_precision

from mirror.dataloaders.loader import DataModule
from mirror.encoders import TableEncoder, YXDataset, YZDataset, ZDataset
from mirror.encoders.maps import rename
from mirror.models.cvae import CVAE
from mirror.models.cvae_components import (
    CVAEDecoderBlock,
    CVAEEncoderBlock,
    LabelsEncoderBlock,
)

In [2]:
LOGDIR = Path("demo_logs")

In [3]:
census = pd.read_csv("data/census.csv.zip")
census = census.set_index("resident_id_m")
census = census.apply(lambda col: col.astype("category"))
print(len(census))

uniques = census.drop_duplicates()
p = len(uniques) / len(census)
print(f"Probability of unique person = {p:.3}")

census = census.rename(columns=rename)
census.describe()

604351
Probability of unique person = 0.616


Unnamed: 0,social,country_of_birth,employment_status,ethnicity,health,household_type,hours_worked,full_time_student,industry,inner/outer_london,marital_status,occupaion,region,religion,residence_type,age_group,sex,residency_type
count,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351
unique,5,3,10,6,6,6,5,3,10,3,6,10,10,10,2,7,2,3
top,2,1,1,4,1,2,-8,2,-8,-8,2,-8,E12000008,2,1,1,1,1
freq,155374,496377,223809,487868,289229,320211,326132,449456,171052,514862,217340,171052,94344,275536,593416,111272,308536,596020


In [4]:
census.head()

Unnamed: 0_level_0,social,country_of_birth,employment_status,ethnicity,health,household_type,hours_worked,full_time_student,industry,inner/outer_london,marital_status,occupaion,region,religion,residence_type,age_group,sex,residency_type
resident_id_m,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
PTS000000588097,4,1,1,4,1,4,4,2,4,-8,1,5,E12000003,2,1,4,2,1
PTS000000000320,-8,1,5,4,2,1,-8,2,7,-8,1,2,E12000005,2,1,7,2,1
PTS000000397448,-8,2,5,4,2,1,-8,2,7,-8,1,3,E12000002,2,1,7,2,1
PTS000000082442,-8,1,5,4,3,2,-8,2,8,-8,2,8,E12000006,2,1,7,1,1
PTS000000016066,4,1,8,4,2,1,-8,2,9,-8,1,9,E12000002,1,1,2,2,1


In [5]:
# Define controls aka labels aka Y, that will be used to condition generation
controls = ["sex", "age_group", "region"]
census_controls = census[controls]
target_census = census.drop(columns=controls)

controls_encoder = TableEncoder(census_controls)
y_dataset = controls_encoder.encode(data=census_controls)
controls_encoder.names()

tensor(0.) tensor(9.)


['sex', 'age_group', 'region']

In [6]:
# Define census aka X, that will be generated
census_encoder = TableEncoder(target_census)
x_dataset = census_encoder.encode(data=target_census)
census_encoder.names()

tensor(0.) tensor(9.)


['social',
 'country_of_birth',
 'employment_status',
 'ethnicity',
 'health',
 'household_type',
 'hours_worked',
 'full_time_student',
 'industry',
 'inner/outer_london',
 'marital_status',
 'occupaion',
 'religion',
 'residence_type',
 'residency_type']

In [7]:
# combine into dataset object
yx_dataset = YXDataset(x_dataset, y_dataset)
dataloader = DataModule(
    dataset=yx_dataset,
    val_split=0.1,
    test_split=0.1,
    train_batch_size=512,
    val_batch_size=512,
    test_batch_size=512,
    num_workers=4,
    pin_memory=False,
)

In [8]:
# encoder block to embed labels into vec with hidden size
labels_encoder_block = LabelsEncoderBlock(
    encoder_types=controls_encoder.types(),
    encoder_sizes=controls_encoder.sizes(),
    depth=3,
    hidden_size=128,
)

# encoder and decoder block to process census data
encoder = CVAEEncoderBlock(
    encoder_types=census_encoder.types(),
    encoder_sizes=census_encoder.sizes(),
    depth=3,
    hidden_size=128,
    latent_size=12,
)
decoder = CVAEDecoderBlock(
    encoder_types=census_encoder.types(),
    encoder_sizes=census_encoder.sizes(),
    depth=3,
    hidden_size=128,
    latent_size=12,
)

# CVAE model
cvae = CVAE(
    embedding_names=census_encoder.names(),
    embedding_types=census_encoder.types(),
    labels_encoder_block=labels_encoder_block,
    encoder_block=encoder,
    decoder_block=decoder,
    beta=1,
    lr=0.001,
)

In [9]:
set_float32_matmul_precision("medium")

LOGDIR.mkdir(parents=True, exist_ok=True)
log_dir = str(Path(LOGDIR))

logger = WandbLogger(project="nomis_demo", dir=log_dir)

callbacks = [
    EarlyStopping(monitor="val_loss", patience=5, mode="min"),
    ModelCheckpoint(
        monitor="val_loss",
        save_top_k=1,
        mode="min",
        dirpath=Path(log_dir, "checkpoints"),
        save_weights_only=False,
    ),
]
trainer = Trainer(
    min_epochs=1,
    max_epochs=1,
    callbacks=callbacks,
    logger=logger,
    check_val_every_n_epoch=1,
)
trainer.fit(model=cvae, train_dataloaders=dataloader)
# trainer.validate(model=cvae, dataloaders=dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfredjshone[0m ([33mfredjshone-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


/home/fred/miniforge3/envs/mirror/lib/python3.13/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/fred/projects/mirror/demo_logs/checkpoints exists and is not empty.

  | Name                 | Type               | Params | Mode 
--------------------------------------------------------------------
0 | labels_encoder_block | LabelsEncoderBlock | 35.5 K | train
1 | encoder_block        | CVAEEncoderBlock   | 47.4 K | train
2 | decoder_block        | CVAEDecoderBlock   | 29.5 K | train
--------------------------------------------------------------------
112 K     Trainable params
0         Non-trainable params
112 K     Total params
0.449     Total estimated model params size (MB)
91        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [30]:
from torch import concat
from torch.utils.data import DataLoader
from torch import argmax, stack

n = len(yx_dataset)
z_loader = ZDataset(n, latent_size=12)
yz_loader = YZDataset(z_loader, y_dataset)
gen_loader = DataLoader(
    yz_loader, batch_size=512, num_workers=4, persistent_workers=True
)

ys, xs, zs = zip(*trainer.predict(cvae, dataloaders=gen_loader))
ys = concat(ys)
xs = concat([stack([argmax(x, dim=1) for x in xb], dim=-1) for xb in xs], dim=0)
zs = concat(zs)

Predicting: |          | 0/? [00:00<?, ?it/s]

In [31]:
xs

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])