In [1]:
from datetime import UTC, datetime

import numpy as np
import polars as pl
from pydantic import BaseModel
from retention_data import CohortDataGenerator

In [2]:
seed: int = sum(map(ord, "retention"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

In [3]:
class Market(BaseModel):
    name: str
    start_date: datetime
    n_cohorts: int
    user_base: int = 10_000


class MarketDataGenerator:
    def __init__(self, markets: list[Market], rng: np.random.Generator):
        self.markets = markets
        self.rng = rng

    def run(self) -> pl.DataFrame:
        data_dfs: list[pl.DataFrame] = []
        for market in self.markets:
            cohort_generator = CohortDataGenerator(
                rng=self.rng, start_cohort=market.start_date, n_cohorts=market.n_cohorts
            )
            data_df = cohort_generator.run()
            data_df = data_df.with_columns(
                (pl.col("n_active_users") / pl.col("n_users")).alias("retention"),
                pl.lit(market.name).alias("market"),
            )
            data_dfs.append(data_df)
        return pl.concat(data_dfs)


markets = [
    Market(
        name="a",
        start_date=datetime(2020, 1, 1, tzinfo=UTC),
        n_cohorts=48,
        user_base=10_000,
    ),
    Market(
        name="b",
        start_date=datetime(2020, 6, 1, tzinfo=UTC),
        n_cohorts=43,
        user_base=3_000,
    ),
    Market(
        name="c",
        start_date=datetime(2022, 1, 1, tzinfo=UTC),
        n_cohorts=24,
        user_base=1_000,
    ),
]

market_data_generator = MarketDataGenerator(markets=markets, rng=rng)
data_df = market_data_generator.run()

data_df.head()

cohort,n_users,period,age,cohort_age,retention_true_mu,retention_true,n_active_users,revenue,retention,market
date,i64,date,i64,i64,f64,f64,i64,f64,f64,str
2020-01-01,150,2020-01-01,1430,0,-1.807373,0.140956,150,14019.256906,1.0,"""a"""
2020-01-01,150,2020-02-01,1430,31,-1.474736,0.186224,25,1886.501237,0.166667,"""a"""
2020-01-01,150,2020-03-01,1430,60,-2.281286,0.092685,13,1098.136314,0.086667,"""a"""
2020-01-01,150,2020-04-01,1430,91,-3.20661,0.038918,6,477.852458,0.04,"""a"""
2020-01-01,150,2020-05-01,1430,121,-3.112983,0.042575,2,214.667937,0.013333,"""a"""
