In [None]:
import pandas as pd
from pytorch_lightning import Trainer

from mirror.encoders import CensusEncoder
from mirror.encoders.maps import rename
from mirror.models.cmodules import ConditionalBlock, DecoderBlock, EncoderBlock

In [3]:
census = pd.read_csv("data/census.csv.zip")
census = census.set_index("resident_id_m")
census = census.apply(lambda col: col.astype("category"))
print(len(census))

uniques = census.drop_duplicates()
p = len(uniques) / len(census)
print(f"Probability of unique person = {p:.3}")

census.describe()

604351
Probability of unique person = 0.616


Unnamed: 0,approx_social_grade,country_of_birth_3a,economic_activity_status_10m,ethnic_group_tb_6a,health_in_general,hh_families_type_6a,hours_per_week_worked,in_full_time_education,industry_10a,iol22cd,legal_partnership_status_6a,occupation_10a,region,religion_tb,residence_type,resident_age_7d,sex,usual_short_student
count,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351,604351
unique,5,3,10,6,6,6,5,3,10,3,6,10,10,10,2,7,2,3
top,2,1,1,4,1,2,-8,2,-8,-8,2,-8,E12000008,2,1,1,1,1
freq,155374,496377,223809,487868,289229,320211,326132,449456,171052,514862,217340,171052,94344,275536,593416,111272,308536,596020


In [4]:
census.head()

Unnamed: 0_level_0,approx_social_grade,country_of_birth_3a,economic_activity_status_10m,ethnic_group_tb_6a,health_in_general,hh_families_type_6a,hours_per_week_worked,in_full_time_education,industry_10a,iol22cd,legal_partnership_status_6a,occupation_10a,region,religion_tb,residence_type,resident_age_7d,sex,usual_short_student
resident_id_m,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
PTS000000588097,4,1,1,4,1,4,4,2,4,-8,1,5,E12000003,2,1,4,2,1
PTS000000000320,-8,1,5,4,2,1,-8,2,7,-8,1,2,E12000005,2,1,7,2,1
PTS000000397448,-8,2,5,4,2,1,-8,2,7,-8,1,3,E12000002,2,1,7,2,1
PTS000000082442,-8,1,5,4,3,2,-8,2,8,-8,2,8,E12000006,2,1,7,1,1
PTS000000016066,4,1,8,4,2,1,-8,2,9,-8,1,9,E12000002,1,1,2,2,1


In [5]:
# from mirror.encoders.maps import rename, lookup

# for col, mapping in lookup.items():
#     census[col] = census[col].map(mapping)


# census.head()

In [12]:
census = census.rename(columns=rename)

controls = ["sex", "age_group", "region"]
census_controls = census[controls]
target_census = census.drop(columns=controls)

controls_encoder = CensusEncoder(census_controls)
controls_encoder.names()

['sex', 'age_group', 'region']

In [14]:
controls_encoder.data_types

{'sex': 'categorical', 'age_group': 'categorical', 'region': 'categorical'}

In [10]:
target_encoder = CensusEncoder(target_census)
target_encoder.names()

['social',
 'country_of_birth',
 'employment_status',
 'ethnicity',
 'health',
 'household_type',
 'hours_worked',
 'full_time_student',
 'industry',
 'inner/outer_london',
 'marital_status',
 'occupaion',
 'religion',
 'residence_type',
 'residency_type']

In [7]:
dataloader = census_encoder.encode(data=census)

In [None]:
conditional_block = ConditionalBlock(
    encoder_types=controls_encoder.types(),
    encoder_sizes=controls_encoder.sizes(),
    depth=2,
    hidden_size=64,
)

encoder = EncoderBlock(
    encodings=census_encoder.encodings(),
    embed_size=32,
    hidden_n=2,
    hidden_size=64,
    latent_size=8,
)
decoder = DecoderBlock(
    encodings=census_encoder.encodings(),
    embed_size=32,
    hidden_n=2,
    hidden_size=64,
    latent_size=8,
)
vae = VAE(
    names=census_encoder.names(),
    encodings=census_encoder.encodings(),
    encoder=encoder,
    decoder=decoder,
    beta=0.001,
    lr=0.001,
)

In [None]:
trainer = Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=vae, train_dataloaders=dataloader)