In [None]:
from pathlib import Path

import polars as pl
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from robin.dataloaders.loader import DataModule
from robin.encoders import TableEncoder, YXDataset
from robin.runners import helpers

In [6]:
data = pl.read_csv(
    Path("~/Data/ethiopia/ethiopia_data.csv"),
    columns=[
        "age",
        "sex",
        "dist_road",
        "dist_market",
        "dist_border",
        "dist_popcenter",
        "dist_admhq",
        "denomination",
        "total_cons_ann",
        "nom_totcons_aeq",
        "walls_material",
        "roof_material",
        "floor_material",
    ],
)
# note that hh structure is being ignored

# fill missing numeric values with column means
means = {
    col: data[col].mean() for col in data.select(pl.col(pl.Float64)).columns
}
data = data.with_columns(
    [pl.col(col).fill_null(means[col]) for col in means.keys()]
)

# fill missing categorical values with column modes
modes = {
    col: data[col].mode()[0] for col in data.select(pl.col(pl.Utf8)).columns
}
data = data.with_columns(
    [pl.col(col).fill_null(modes[col]) for col in modes.keys()]
)

data.head()

dist_road,dist_market,dist_border,dist_popcenter,dist_admhq,age,sex,denomination,total_cons_ann,nom_totcons_aeq,walls_material,roof_material,floor_material
f64,f64,f64,f64,f64,str,str,str,f64,f64,str,str,str
7.7,162.3,82.9,0.4,0.0,"""65+""","""Female""","""urban""",226020.0,144884.625,"""Plastered hallow blocks""","""Corrugated iron sheet""","""Plastic tiles"""
7.7,162.3,82.9,0.4,0.0,"""31-50""","""Female""","""urban""",226020.0,144884.625,"""Plastered hallow blocks""","""Corrugated iron sheet""","""Plastic tiles"""
7.7,162.3,82.9,0.4,0.0,"""0-17""","""Female""","""urban""",226020.0,144884.625,"""Plastered hallow blocks""","""Corrugated iron sheet""","""Plastic tiles"""
7.7,162.3,82.9,0.4,0.0,"""31-50""","""Female""","""urban""",248090.0,62967.003906,"""Wood and mud""","""Corrugated iron sheet""","""Plastic tiles"""
7.7,162.3,82.9,0.4,0.0,"""0-17""","""Female""","""urban""",248090.0,62967.003906,"""Wood and mud""","""Corrugated iron sheet""","""Plastic tiles"""


In [9]:
write_path = Path("../data/ethiopia/ethiopia_data_cleaned.csv")
write_path.parent.mkdir(parents=True, exist_ok=True)
data.write_csv(Path("../data/ethiopia/ethiopia_data_cleaned.csv"))

In [None]:
save_dir = Path("tmp/logs")
project = "ethopia"
name = "demo"

# create directories
save_dir.mkdir(exist_ok=True, parents=True)

# logger
logger = WandbLogger(save_dir=save_dir, project=project, name=name)

seed = 12345
torch.manual_seed(seed)

# split x and y (y are the control columns)
controls = ["age", "sex", "denomination"]
y = data.select(controls)
x = data.drop(controls)

# encoders
x_encoder = TableEncoder(x, verbose=True)
x_dataset = x_encoder.encode(data=x)
y_encoder = TableEncoder(y, verbose=True)
y_dataset = y_encoder.encode(data=y)

xy_dataset = YXDataset(x_dataset, y_dataset)
datamodule = DataModule(
    dataset=xy_dataset,
    val_split=0.1,
    test_split=0.0,
    train_batch_size=512,
    val_batch_size=512,
    test_batch_size=512,
    gen_batch_size=512,
    num_workers=4,
    pin_memory=True,
)

model = helpers.build_model(
    config={
        "model": {
            "latent_size": 20,
            "beta": 1,
            "lr": 0.1,
            "controls_encoder": {"depth": 4, "hidden_size": 64},
            "encoder": {"depth": 4, "hidden_size": 64},
            "decoder": {"depth": 4, "hidden_size": 64},
        }
    },
    x_encoder=x_encoder,
    y_encoder=y_encoder,
    ckpt_path=None,
)

callbacks = helpers.build_callbacks(config={"early_stopping": {"patience": 10}})

trainer = Trainer(
    callbacks=callbacks,
    logger=logger,
    min_epochs=5,
    max_epochs=100,
    check_val_every_n_epoch=1,
)

  encoded = Tensor(encoded.to_numpy()).int()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


ContinuousEncoder: min: 0.0, max: 72.8, range: 72.8, dtype: Float64
ContinuousEncoder: min: 0.4, max: 448.7, range: 448.3, dtype: Float64
ContinuousEncoder: min: 3.2, max: 496.3, range: 493.1, dtype: Float64
ContinuousEncoder: min: 0.4, max: 285.1, range: 284.70000000000005, dtype: Float64
ContinuousEncoder: min: 0.0, max: 0.6, range: 0.6, dtype: Float64
ContinuousEncoder: min: 2400.66650390625, max: 2798544.0, range: 2796143.3334960938, dtype: Float64
ContinuousEncoder: min: 1854.05407714844, max: 748273.8125, range: 746419.7584228516, dtype: Float64
CategoricalTokeniser: size: 11, categories: {'Bricks': 0, 'Corrugated iron': 1, 'Mud bricks': 2, 'Other': 3, 'Plastered hallow blocks': 4, 'Reed/bamboo': 5, 'Stone and cement': 6, 'Stone and mud/Stone only': 7, 'Unplastered hallow blocks': 8, 'Wood and mud': 9, 'Wood and thatch/Wood only': 10}, dtype: String
CategoricalTokeniser: size: 8, categories: {'Asbestos': 0, 'Bamboo or reed': 1, 'Concrete or cement': 2, 'Corrugated iron sheet': 3,

In [4]:
trainer.fit(model=model, train_dataloaders=datamodule)

[34m[1mwandb[0m: Currently logged in as: [33mfredjshone[0m ([33mfredjshone-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type               | Params | Mode  | FLOPs
----------------------------------------------------------------------------
0 | labels_encoder_block | LabelsEncoderBlock | 8.9 K  | train | 0    
1 | encoder_block        | CVAEEncoderBlock   | 13.6 K | train | 0    
2 | decoder_block        | CVAEDecoderBlock   | 7.8 K  | train | 0    
3 | criterion            | ModuleList         | 0      | train | 0    
----------------------------------------------------------------------------
30.3 K    Trainable params
0         Non-trainable params
30.3 K    Total params
0.121     Total estimated model params size (MB)
92        Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/fred/Projects/robin/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:317: The number of training batches (45) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Trainer was signaled to stop but the required `min_epochs=5` or `min_steps=None` has not been met. Training will continue...


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]