In [180]:
from pathlib import Path
import random
import polars as pl
import yaml
from typing import Iterable

In [181]:
root = Path("~/Data/foundata/LTDS2425")

In [182]:
def euro_sampler(bounds: tuple[int, int]) -> int:
    a, b = bounds
    return int(random.randint(int(a), int(b)) * 0.9)


def sampler(bounds: tuple[int, int]) -> int:
    a, b = bounds
    return random.randint(int(a), int(b))


def default(config, year):
    return config.get(year, config["default"])

In [183]:
def load_mapping(path: Path, k_name: str, v_name: str) -> dict:
    file = pl.read_csv(path)
    k_col, v_col = file[k_name], file[v_name]
    mapping = dict(zip(k_col, v_col))
    return mapping

In [184]:
def preprocess_hhs(hhs, config: dict, year: str):
    column_mapping = default(config["column_mappings"], year)
    income_mapping = default(config["hh_income"], year)
    struct_mapping = default(config["hh_structure"], year)
    zone_mapping = load_mapping(root / "HABORO_T.csv", "HABORO", "TYPE")

    hhs = hhs.select(column_mapping.keys()).rename(column_mapping)

    # year
    hhs = hhs.with_columns(pl.col("year") + 2000)

    # income
    hhs = hhs.with_columns(
        pl.col("hh_income")
        .replace_strict(income_mapping, default=pl.lit((0, 0)), return_dtype=pl.List)
        .map_elements(euro_sampler, return_dtype=pl.Float64)
    )

    # structure
    hhs = hhs.with_columns(
        pl.col("hh_structure").replace_strict(struct_mapping).fill_null("unknown")
    )

    # urban/rural
    hhs = hhs.with_columns(
        pl.col("zone")
        .replace_strict(zone_mapping, default=pl.col("zone"))
        .fill_null("unknown")
        .alias("urban/rural")
    )
    hhs = hhs.drop("zone")

    print(len(hhs))
    # remove rows with any nulls
    hhs = hhs.drop_nulls()
    print(len(hhs))

    return hhs


config = yaml.safe_load(open("ltds/hh_dictionary.yaml"))
columns = list(default(config["column_mappings"], year="2425").keys())
hhs = pl.read_csv(root / "Household.csv", columns=columns)


hhs = preprocess_hhs(hhs, config, year="2425")

hhs.head()

8208
8208


hid,year,day,hh_structure,hh_size,hh_income,num_vehicles,num_bikes,weight,urban/rural
i64,i64,i64,str,i64,f64,i64,i64,f64,str
24009181,2024,6,"""lone parent""",4,61444.0,1,0,251.969105,"""suburban"""
24010021,2024,2,"""couple""",6,152933.0,1,3,481.303435,"""suburban"""
24010041,2024,2,"""couple""",4,139837.0,1,0,481.303435,"""suburban"""
24010051,2024,7,"""lone parent""",3,122488.0,0,1,348.118159,"""suburban"""
24010091,2024,2,"""single adult""",1,61312.0,1,0,810.201572,"""suburban"""


In [185]:
def preprocess_persons(persons, config: dict, year: str):
    column_mapping = default(config["column_mappings"], year)
    sex_mapping = default(config["sex"], year)
    relationship_mapping = default(config["relationship"], year)
    race_mapping = default(config["race"], year)

    persons = persons.select(column_mapping.keys()).rename(column_mapping)

    # age
    persons = persons.with_columns(
        pl.col("age")
        .replace_strict({"65+": "65-100"}, default=pl.col("age"))
        .str.split("-")
        .map_elements(sampler, pl.Float64)
    )

    # sex
    persons = persons.with_columns(pl.col("sex").replace_strict(sex_mapping))

    # relationship
    persons = persons.with_columns(
        pl.col("relationship")
        .replace_strict(relationship_mapping)
        .alias("relationship")
    )

    # race
    persons = persons.with_columns(
        pl.col("race")
        .replace_strict(race_mapping, default=pl.col("race"))
        .fill_null("unknown")
    )

    return persons


year = "2425"
config = yaml.safe_load(open("ltds/person_dictionary.yaml"))
columns = list(default(config["column_mappings"], year).keys())

persons = pl.read_csv(root / "person.csv", columns=columns)

persons = preprocess_persons(persons, config, year=year)

persons.head()

pid,hid,age,sex,relationship,race
i64,i64,f64,str,str,str
2412723104,24127231,3.0,"""male""","""child""","""unknown"""
2405718303,24057183,2.0,"""male""","""child""","""unknown"""
2426119103,24261191,1.0,"""female""","""child""","""unknown"""
2437901111,24379011,3.0,"""female""","""child""","""unknown"""
2439211107,24392111,0.0,"""female""","""other""","""unknown"""


In [None]:
def preprocess_persons_data(persons, config: dict, year: str):
    column_mapping = default(config["column_mappings"], year)

    has_license_mapping = default(config["has_licence"], year)
    employment_mapping = default(config["employment_status"], year)
    # occupation_mapping = default(config["occupation"], year)

    persons = persons.select(column_mapping.keys()).rename(column_mapping)

    # has_license
    persons = persons.with_columns(
        pl.col("has_license").replace_strict(
            has_license_mapping, default=None, return_dtype=pl.String
        )
    )

    # employment
    persons = persons.with_columns(
        pl.col("employment_status").replace_strict(employment_mapping)
    )

    # has trips?
    persons = persons.with_columns((pl.col("no_trips") < 0))

    # occupation
    # persons = persons.with_columns(
    #     pl.col("occupation").replace_strict(occupation_mapping).alias("occupation")
    # )

    return persons


year = "2425"
config = yaml.safe_load(open("ltds/person_data_dictionary.yaml"))
columns = list(default(config["column_mappings"], year).keys())

persons_data = pl.read_csv(root / "person data.csv", columns=columns)

persons_data = preprocess_persons_data(persons_data, config, year=year)

persons_data.head()

pid,hid,has_license,employment_status,no_trips
i64,i64,str,str,bool
2400412104,24004121,"""no""","""student""",False
2400412105,24004121,"""no""","""student""",False
2400412106,24004121,"""no""","""student""",False
2400410101,24004101,"""yes""","""employed""",False
2400410102,24004101,"""yes""","""unemployed""",True


In [187]:
def sample_minute(base: int) -> int:
    return random.randint(int(base), int(base) + 5)


def sample_tst(row) -> int:
    tst_hr, tet_hr, duration = row["tst"], row["tet"], row["duration"]
    # earliest start is max of tst_hr and tet_hr - duration
    earliest = max(tst_hr, tet_hr - duration)
    # latest start is min of tst_hr + 1 and tet_hr + 1 - duration
    latest = min(tst_hr + 60, tet_hr + 60 - duration)
    if latest < earliest:
        print("warning: bad overlap")
        return int((tst_hr + tet_hr + duration) / 2)
    return random.randint(earliest, latest)

In [188]:
def preprocess_trips(trips, config: dict, year: str):

    column_mapping = default(config["column_mappings"], year)
    trips = trips.select(column_mapping.keys()).rename(column_mapping)
    zone_mapping = load_mapping(root / "HABORO_T.csv", "HABORO", "TYPE")

    # modes & acts
    mode_map = default(config["mode"], year)
    act_map = default(config["act"], year)
    trips = trips.with_columns(
        pl.col("mode").replace_strict(mode_map),
        pl.col("oact").replace_strict(act_map),
        pl.col("dact").replace_strict(act_map),
    )

    # duration
    trips = trips.with_columns(pl.col("duration").map_elements(sample_minute))

    # time to minutes
    trips = trips.with_columns(
        pl.col("tst") * 60,
        pl.col("tet") * 60,
    )

    # sample trip times
    trips = trips.with_columns(
        pl.struct("tst", "tet", "duration")
        .map_elements(sample_tst, return_dtype=pl.Int32)
        .alias("tst")
    )
    trips = trips.with_columns((pl.col("tst") + pl.col("duration")).alias("tet"))

    # urban/rural
    trips = trips.with_columns(
        pl.col("ozone").replace_strict(zone_mapping),
        pl.col("dzone").replace_strict(zone_mapping),
    )

    mask = pl.any_horizontal(pl.all().is_null())
    print(len(trips))
    keep = (
        trips.group_by("pid")
        .agg(mask.any().alias("flag"))
        .filter(~pl.col("flag"))
        .select("pid")
    )
    trips = trips.join(keep, on="pid")
    print(len(trips))

    return trips


year = "2425"
config = yaml.safe_load(open("ltds/trip_dictionary.yaml"))
columns = list(default(config["column_mappings"], year).keys())

trips = pl.read_csv(
    root / "Trip.csv",
    columns=columns,
    # null_values="Missing",
)
trips = preprocess_trips(trips, config, year=year)

trips.head()

36170
36170


hid,pid,tid,mode,tst,tet,duration,oact,dact,ozone,dzone
i64,i64,i64,str,i32,i64,i64,str,str,str,str
24007421,2400742102,240074210207,"""car""",1096,1106,10,"""home""","""leisure""","""suburban""","""suburban"""
24007421,2400742101,240074210103,"""car""",1099,1109,10,"""home""","""leisure""","""suburban""","""suburban"""
24007421,2400742101,240074210104,"""car""",1360,1372,12,"""leisure""","""home""","""suburban""","""suburban"""
24007421,2400742102,240074210208,"""car""",1364,1374,10,"""leisure""","""home""","""suburban""","""suburban"""
24007421,2400742105,240074210503,"""car""",1107,1122,15,"""home""","""leisure""","""suburban""","""suburban"""


In [189]:
def preprocess_stages(stages, config: dict, year: str):

    column_mapping = default(config["column_mappings"], year)
    stages = stages.select(column_mapping.keys()).rename(column_mapping)

    # group by tid and sum distance_col
    stages = stages.group_by(["pid", "tid"]).agg(
        pl.col("distance").sum().alias("distance")
    )

    # drop pid column
    stages = stages.drop("pid")

    return stages


year = "2425"
config = yaml.safe_load(open("ltds/stage_dictionary.yaml"))
columns = list(default(config["column_mappings"], year).keys())

stages = pl.read_csv(
    root / "Stage.csv",
    columns=columns,
    # null_values="Missing",
)
stages = preprocess_stages(stages, config, year=year)

stages.head()

tid,distance
i64,f64
243981610106,24.066
240932210101,3.913
243921110302,7.567
242036310301,2.833
244581710202,17.254


In [190]:
def check_overlap(table_a: pl.DataFrame, table_b: pl.DataFrame, on: str) -> set:
    on_a = set(table_a[on])
    on_b = set(table_b[on])
    missing_in_a = on_b - on_a
    missing_in_b = on_a - on_b
    n_a = len(missing_in_a)
    n_b = len(missing_in_b)
    perc_a = n_a / len(on_b) * 100
    perc_b = n_b / len(on_a) * 100
    if missing_in_a:
        print(
            f"Warning: Missing {n_a} ({perc_a:.2f}%) of '{on}' keys in table_a: {missing_in_a}"
        )
    if missing_in_b:
        print(
            f"Warning: Missing {n_b} ({perc_b:.2f}%) of '{on}' keys in table_b: {missing_in_b}"
        )

    return missing_in_a & missing_in_b


def table_joiner(table_a: pl.DataFrame, table_b: pl.DataFrame, on: str) -> pl.DataFrame:
    # check for missing keys
    check_overlap(table_a, table_b, on)

    # check for duplicates
    a_cols = table_a.columns
    b_cols = table_b.columns
    duplicates = set(a_cols) & set(b_cols) - {on}
    if duplicates:
        print(f"Warning: Duplicate columns (other than join key): {duplicates}")

    return table_a.join(table_b, on=on)

In [191]:
attributes = table_joiner(persons, persons_data.drop("hid"), on="pid")
attributes = table_joiner(attributes, hhs, on="hid")
attributes.head()



pid,hid,age,sex,relationship,race,has_license,employment_status,no_trips,year,day,hh_structure,hh_size,hh_income,num_vehicles,num_bikes,weight,urban/rural
i64,i64,f64,str,str,str,str,str,bool,i64,i64,str,i64,f64,i64,i64,f64,str
2400412104,24004121,8.0,"""female""","""child""","""asian""","""no""","""student""",False,2024,4,"""couple""",6,69229.0,1,4,375.787943,"""suburban"""
2400412105,24004121,16.0,"""male""","""child""","""asian""","""no""","""student""",False,2024,4,"""couple""",6,69229.0,1,4,375.787943,"""suburban"""
2400412106,24004121,12.0,"""female""","""child""","""asian""","""no""","""student""",False,2024,4,"""couple""",6,69229.0,1,4,375.787943,"""suburban"""
2400410101,24004101,29.0,"""male""","""head""","""white""","""yes""","""employed""",False,2024,5,"""couple""",3,18684.0,1,0,289.160438,"""suburban"""
2400410102,24004101,42.0,"""female""","""spouse/partner""","""white""","""yes""","""unemployed""",True,2024,5,"""couple""",3,18684.0,1,0,289.160438,"""suburban"""


In [192]:
trips = table_joiner(trips, stages, on="tid")



In [193]:
check_overlap(attributes, trips, on="pid")



set()

In [194]:
trips.filter(pl.col("pid") == 2419306112)

hid,pid,tid,mode,tst,tet,duration,oact,dact,ozone,dzone,distance
i64,i64,i64,str,i32,i64,i64,str,str,str,str,f64
24193061,2419306112,241930611201,"""walk""",1223,1240,17,"""home""","""other""","""urban""","""urban""",0.677
24193061,2419306112,241930611202,"""walk""",1361,1380,19,"""other""","""home""","""urban""","""urban""",0.677


In [199]:
attributes.filter(pl.col("pid") == 2406318103)

pid,hid,age,sex,relationship,race,has_license,employment_status,no_trips,year,day,hh_structure,hh_size,hh_income,num_vehicles,num_bikes,weight,urban/rural
i64,i64,f64,str,str,str,str,str,bool,i64,i64,str,i64,f64,i64,i64,f64,str
2406318103,24063181,45.0,"""female""","""parent""","""asian""","""no""","""unemployed""",True,2024,7,"""other""",4,85292.0,0,0,435.063613,"""suburban"""
