## Polars home page (https://pola.rs/) 
## The User Guide home page (https://docs.pola.rs/)

In [12]:
import polars as pl
import numpy as np

## DataFrame Initialization

Many more details in documentation: [here](https://docs.pola.rs/api/python/stable/reference/dataframe/index.html).

Some simple idioms:

In [13]:
pop_size = 3

# Create a dataframe via a dictionary
population = pl.DataFrame(
                        {
                        "id": range(pop_size),
                        "state": ["Susceptible"] * pop_size,
                        }
                        )
print(population)

# Create the same dataframe via list of pl.Series
population2 = [
                pl.Series("id", range(pop_size), dtype=pl.Int64),
                pl.Series("state",  ["Susceptible"] * pop_size),  
              ]
population2 = pl.DataFrame(population2)
print(population2)

shape: (3, 2)
┌─────┬─────────────┐
│ id  ┆ state       │
│ --- ┆ ---         │
│ i64 ┆ str         │
╞═════╪═════════════╡
│ 0   ┆ Susceptible │
│ 1   ┆ Susceptible │
│ 2   ┆ Susceptible │
└─────┴─────────────┘
shape: (3, 2)
┌─────┬─────────────┐
│ id  ┆ state       │
│ --- ┆ ---         │
│ i64 ┆ str         │
╞═════╪═════════════╡
│ 0   ┆ Susceptible │
│ 1   ┆ Susceptible │
│ 2   ┆ Susceptible │
└─────┴─────────────┘


### with_columns()
Very useful function that adds columns, or updates existing columns (if you give a new column the same name as an existing column).

See [documentation link](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.with_columns.html).

In [14]:
# add a new column using with_columns
population = population.with_columns(age = pl.lit(30))
print(population)

shape: (3, 3)
┌─────┬─────────────┬─────┐
│ id  ┆ state       ┆ age │
│ --- ┆ ---         ┆ --- │
│ i64 ┆ str         ┆ i32 │
╞═════╪═════════════╪═════╡
│ 0   ┆ Susceptible ┆ 30  │
│ 1   ┆ Susceptible ┆ 30  │
│ 2   ┆ Susceptible ┆ 30  │
└─────┴─────────────┴─────┘


### What if we want to update some rows based on a condition?

pl.when().then().otherwise(), inside a with_columns()

Documentation: [here](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.when.html).

In [15]:
# change status of individual with id 2 to "Infectious"

population = population.with_columns(
        (pl.when(pl.col("id") == 2)
        .then(pl.lit("Infectious"))
        .otherwise(pl.col("state"))
        .alias("state"))
    )
print(population)

shape: (3, 3)
┌─────┬─────────────┬─────┐
│ id  ┆ state       ┆ age │
│ --- ┆ ---         ┆ --- │
│ i64 ┆ str         ┆ i32 │
╞═════╪═════════════╪═════╡
│ 0   ┆ Susceptible ┆ 30  │
│ 1   ┆ Susceptible ┆ 30  │
│ 2   ┆ Infectious  ┆ 30  │
└─────┴─────────────┴─────┘


## Writing/reading a dataframe to/from a file

Many formats available, including but absolutely not limitted to CSV

See [here](https://docs.pola.rs/api/python/stable/reference/io.html).

In [16]:
# Save a dataframe to a csv file
population.write_csv("population.csv")
# Read a (different!) dataframe from a csv file, for later use
my_population = pl.read_csv("example_population.csv")
print(my_population)

shape: (5, 7)
┌─────┬─────────────┬─────┬───────────┬─────────┬─────────────────┬────────────────┐
│ id  ┆ state       ┆ age ┆ age_group ┆ age_day ┆ stime_infection ┆ stime_recovery │
│ --- ┆ ---         ┆ --- ┆ ---       ┆ ---     ┆ ---             ┆ ---            │
│ i64 ┆ str         ┆ i64 ┆ i64       ┆ i64     ┆ i64             ┆ i64            │
╞═════╪═════════════╪═════╪═══════════╪═════════╪═════════════════╪════════════════╡
│ 0   ┆ Susceptible ┆ 15  ┆ 1         ┆ 364     ┆ -1              ┆ -1             │
│ 1   ┆ Susceptible ┆ 20  ┆ 2         ┆ 21      ┆ -1              ┆ -1             │
│ 2   ┆ Infectious  ┆ 61  ┆ 6         ┆ 211     ┆ 12              ┆ 20             │
│ 3   ┆ Susceptible ┆ 98  ┆ 9         ┆ 301     ┆ -1              ┆ -1             │
│ 4   ┆ Susceptible ┆ 100 ┆ 10        ┆ 2       ┆ -1              ┆ -1             │
└─────┴─────────────┴─────┴───────────┴─────────┴─────────────────┴────────────────┘


## DataFrame Updating

In [17]:
# Update the age structure after a time step of one day

t_step = 1 # day
a_year = 365 # days
tstep_per_year = a_year / t_step
years_age_group = 10 # 10-year age band in each age group

my_population = my_population.with_columns(
                # update (in-place) the age column by adding 0 or 1 based on if it's the birthday of an individual
                pl.col('age').add((pl.col('age_day') + t_step) // a_year).alias('age'), 
                # update (in-place) the fractional ages of everyone
                ((pl.col('age_day') + t_step) % a_year).alias('age_day'),
                # update the age groups of people who reach the next 10-year age-group
                ((pl.col('age') + (pl.col('age_day') + t_step) // a_year) // years_age_group).alias('age_group'))

print(my_population)

shape: (5, 7)
┌─────┬─────────────┬─────┬───────────┬─────────┬─────────────────┬────────────────┐
│ id  ┆ state       ┆ age ┆ age_group ┆ age_day ┆ stime_infection ┆ stime_recovery │
│ --- ┆ ---         ┆ --- ┆ ---       ┆ ---     ┆ ---             ┆ ---            │
│ i64 ┆ str         ┆ i64 ┆ i64       ┆ i64     ┆ i64             ┆ i64            │
╞═════╪═════════════╪═════╪═══════════╪═════════╪═════════════════╪════════════════╡
│ 0   ┆ Susceptible ┆ 16  ┆ 1         ┆ 0       ┆ -1              ┆ -1             │
│ 1   ┆ Susceptible ┆ 20  ┆ 2         ┆ 22      ┆ -1              ┆ -1             │
│ 2   ┆ Infectious  ┆ 61  ┆ 6         ┆ 212     ┆ 12              ┆ 20             │
│ 3   ┆ Susceptible ┆ 98  ┆ 9         ┆ 302     ┆ -1              ┆ -1             │
│ 4   ┆ Susceptible ┆ 100 ┆ 10        ┆ 3       ┆ -1              ┆ -1             │
└─────┴─────────────┴─────┴───────────┴─────────┴─────────────────┴────────────────┘


## What if we want to implement deaths?

One way is to just calculate the daily effective death rate by age, and give everyone a chance to die every timestep.

In [18]:
yearly_death_rate = pl.read_csv("death_rates.csv")

# take yearly death rates and calculate adjusted daily death rates
annual_factor = 1.0 / tstep_per_year
adj_death_rate = yearly_death_rate.with_columns(
    (1 - (1 - pl.col("death_rate")) ** (annual_factor)).alias("adj_death_rate")
)
print(adj_death_rate.head())

# let's join this death rate df to the population df (we'll talk more about joins later)
df_with_rates = my_population.join(adj_death_rate, on='age')
print(df_with_rates)


shape: (5, 3)
┌─────┬────────────┬────────────────┐
│ age ┆ death_rate ┆ adj_death_rate │
│ --- ┆ ---        ┆ ---            │
│ i64 ┆ f64        ┆ f64            │
╞═════╪════════════╪════════════════╡
│ 0   ┆ 0.00322    ┆ 0.000009       │
│ 1   ┆ 0.00023    ┆ 6.3021e-7      │
│ 2   ┆ 0.000125   ┆ 3.4249e-7      │
│ 3   ┆ 0.000105   ┆ 2.8769e-7      │
│ 4   ┆ 0.00009    ┆ 2.4659e-7      │
└─────┴────────────┴────────────────┘
shape: (5, 9)
┌─────┬─────────────┬─────┬───────────┬───┬──────────────┬──────────────┬────────────┬─────────────┐
│ id  ┆ state       ┆ age ┆ age_group ┆ … ┆ stime_infect ┆ stime_recove ┆ death_rate ┆ adj_death_r │
│ --- ┆ ---         ┆ --- ┆ ---       ┆   ┆ ion          ┆ ry           ┆ ---        ┆ ate         │
│ i64 ┆ str         ┆ i64 ┆ i64       ┆   ┆ ---          ┆ ---          ┆ f64        ┆ ---         │
│     ┆             ┆     ┆           ┆   ┆ i64          ┆ i64          ┆            ┆ f64         │
╞═════╪═════════════╪═════╪═══════════╪═══╪══════

In [19]:
notrandom_seed = 115
rng = np.random.default_rng(notrandom_seed)

# Let's go forward a day and see who survives

# Two choices: i) drop those who die from the DataFrame, or ii) change the state of those who die:
# i)
df_with_dropped = df_with_rates.filter(pl.col("adj_death_rate") < rng.random(df_with_rates.height))
    
print(df_with_dropped)

# ii)
rng = np.random.default_rng(notrandom_seed)
df_with_state_changed = df_with_rates.with_columns(
    pl.when(pl.col('adj_death_rate') > rng.random(df_with_rates.height))
    .then(pl.lit("Dead"))
    .otherwise(pl.col("state")).alias("state"))
print(df_with_state_changed)

shape: (4, 9)
┌─────┬─────────────┬─────┬───────────┬───┬──────────────┬──────────────┬────────────┬─────────────┐
│ id  ┆ state       ┆ age ┆ age_group ┆ … ┆ stime_infect ┆ stime_recove ┆ death_rate ┆ adj_death_r │
│ --- ┆ ---         ┆ --- ┆ ---       ┆   ┆ ion          ┆ ry           ┆ ---        ┆ ate         │
│ i64 ┆ str         ┆ i64 ┆ i64       ┆   ┆ ---          ┆ ---          ┆ f64        ┆ ---         │
│     ┆             ┆     ┆           ┆   ┆ i64          ┆ i64          ┆            ┆ f64         │
╞═════╪═════════════╪═════╪═══════════╪═══╪══════════════╪══════════════╪════════════╪═════════════╡
│ 0   ┆ Susceptible ┆ 16  ┆ 1         ┆ … ┆ -1           ┆ -1           ┆ 0.00025    ┆ 6.8502e-7   │
│ 1   ┆ Susceptible ┆ 20  ┆ 2         ┆ … ┆ -1           ┆ -1           ┆ 0.00041    ┆ 0.000001    │
│ 2   ┆ Infectious  ┆ 61  ┆ 6         ┆ … ┆ 12           ┆ 20           ┆ 0.00548    ┆ 0.000015    │
│ 3   ┆ Susceptible ┆ 98  ┆ 9         ┆ … ┆ -1           ┆ -1           ┆ 0.2

## Above "works", but is not very efficient, even though it's in polars!

We could instead compute the "age at death" of each agent at the start of the simulation, rather than generating random numbers for each person every day to check.

Much more efficient!

In [20]:
from additional_functions import pick_ages_at_death_given_ages

rng = np.random.default_rng(notrandom_seed)
ages_at_death = pick_ages_at_death_given_ages('death_rates.csv', my_population['age'].to_numpy(), rng)
print(ages_at_death)

[ 91  89  88 100 100]


# Let's implement some disease spread using polars functionality!

In [24]:
#checking new infections

    
print(my_population)

#assume contacts are homogeneous; 
prob_transmission_per_contact = 0.1
avg_contacts_per_day = 2
current_time = 20
avg_inf_duration = 8 # days

#only consider susceptibles
susceptibles = my_population.filter(pl.col("state") == "Susceptible")

notrandom_seed = 11
rng = np.random.default_rng(notrandom_seed)

if susceptibles.height > 0:
    #at least one susceptible ind in the population

    #calculate infected fraction in contacts of a susceptible individual
    frac_infected_in_contacts = my_population.filter(
    pl.col("state") == "Infectious").height / (my_population.height - 1)
    print("frac_infected_in_contacts: %s" %frac_infected_in_contacts)

    #calculate prob_transmission for a susceptible individual
    prob_transmission = prob_transmission_per_contact * avg_contacts_per_day * frac_infected_in_contacts 
    #identify individuals who will be infected
    will_infected = susceptibles.filter(
        rng.random(susceptibles.height) <= prob_transmission
    )

    dur_infections = rng.exponential(avg_inf_duration, will_infected.height)
    new_infected = will_infected.with_columns(
        pl.lit("Infectious").alias("state"),
        pl.lit(current_time).alias("stime_infection"),
        (pl.lit(current_time + dur_infections)).alias("stime_recovery"),
        )
    print("new infected: %s"%new_infected)

    if new_infected.height > 0:
        #only update the rows of the newly infected population 
        df_with_new_inf = my_population.update(new_infected, on="id", how="left")
        
print("updated population %s"%df_with_new_inf)

shape: (5, 7)
┌─────┬─────────────┬─────┬───────────┬─────────┬─────────────────┬────────────────┐
│ id  ┆ state       ┆ age ┆ age_group ┆ age_day ┆ stime_infection ┆ stime_recovery │
│ --- ┆ ---         ┆ --- ┆ ---       ┆ ---     ┆ ---             ┆ ---            │
│ i64 ┆ str         ┆ i64 ┆ i64       ┆ i64     ┆ i64             ┆ i64            │
╞═════╪═════════════╪═════╪═══════════╪═════════╪═════════════════╪════════════════╡
│ 0   ┆ Susceptible ┆ 16  ┆ 1         ┆ 0       ┆ -1              ┆ -1             │
│ 1   ┆ Susceptible ┆ 20  ┆ 2         ┆ 22      ┆ -1              ┆ -1             │
│ 2   ┆ Infectious  ┆ 61  ┆ 6         ┆ 212     ┆ 12              ┆ 20             │
│ 3   ┆ Susceptible ┆ 98  ┆ 9         ┆ 302     ┆ -1              ┆ -1             │
│ 4   ┆ Susceptible ┆ 100 ┆ 10        ┆ 3       ┆ -1              ┆ -1             │
└─────┴─────────────┴─────┴───────────┴─────────┴─────────────────┴────────────────┘
frac_infected_in_contacts: 0.25
new infected: shape

## Your turn!

Let's implement a recovery "update" step, that checks whether anyone moves between Infectious and Susceptible states in a time step

In [25]:
# 20 days later...
current_time = 20

# Insert some code here!

# Collect summary statistics

In [26]:
results = pl.read_csv("example_output.csv")
print(results)
pop_size = results.filter((pl.col("run_no") == pl.col("run_no").min()) &
                          ((pl.col("t") == pl.col("t").min())))["count"].sum()
print(pop_size)

shape: (4_136, 4)
┌─────────────┬───────┬─────┬────────┐
│ state       ┆ count ┆ t   ┆ run_no │
│ ---         ┆ ---   ┆ --- ┆ ---    │
│ str         ┆ i64   ┆ i64 ┆ i64    │
╞═════════════╪═══════╪═════╪════════╡
│ Susceptible ┆ 499   ┆ 0   ┆ 0      │
│ Exposed     ┆ 0     ┆ 0   ┆ 0      │
│ Infectious  ┆ 1     ┆ 0   ┆ 0      │
│ Recovered   ┆ 0     ┆ 0   ┆ 0      │
│ Susceptible ┆ 499   ┆ 1   ┆ 0      │
│ …           ┆ …     ┆ …   ┆ …      │
│ Recovered   ┆ 278   ┆ 86  ┆ 19     │
│ Susceptible ┆ 221   ┆ 87  ┆ 19     │
│ Exposed     ┆ 0     ┆ 87  ┆ 19     │
│ Infectious  ┆ 0     ┆ 87  ┆ 19     │
│ Recovered   ┆ 279   ┆ 87  ┆ 19     │
└─────────────┴───────┴─────┴────────┘
500


In [59]:
#final size calculation given initial recovered == 0
recovered = results.filter(pl.col("state") == "Recovered"
                          ).with_columns(pl.col("count")/pop_size)

final_sizes = recovered.group_by("run_no").agg(pl.col("count").last()
                                              ).rename({"count": "fsize_fraction"})

#.filter(pl.col("t").last())
sum_final_size = final_sizes.select(
                        pl.col("fsize_fraction").mean().name.suffix("_mean"),
                        pl.col("fsize_fraction").quantile(0.025).name.suffix("_025"),
                        pl.col("fsize_fraction").quantile(0.975).name.suffix("_975"),
                    )
display(sum_final_size)

fsize_fraction_mean,fsize_fraction_025,fsize_fraction_975
f64,f64,f64
0.305,0.002,0.692


In [None]:
#Julian: another agg example?

# Additional examples and functions

In [None]:
#ADDING NEW ROWS - COMBINING TWO DATAFRAME
#add rows
#https://stackoverflow.com/questions/71654966/how-can-i-append-or-concatenate-two-dataframes-in-python-polars
pop_size = 2
population = pl.DataFrame(
                        {
                        "id": range(pop_size),
                        "state": ["Susceptible"] * pop_size,
                        }
                        )

population2 = pl.DataFrame([
    pl.Series("id", [2], dtype=pl.Int64),
    pl.Series("state",  ["Infectious"]),
])
print(population2)

# new memory slab
new_population = pl.concat([population, population2], rechunk=True)

# append free (no memory copy)
new_population2 = population.vstack(population2)

# try to append in place
population.extend(population2)
print(population, new_population, new_population2)

In [None]:
#rename(): rename columns
population2 = population.rename({"unique_id":"id"})
print(population2)
#CAREFUL: using alias to rename changes the order of the columns
population1 = population.with_columns(pl.col("unique_id").alias("id")).drop("unique_id")
print(population1)

## need to combine different dataframes on certain axes? Check join()
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join.html

## need sorting? Check sort()
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.sort.html

## need the number of rows of a dataframe? Check height
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.height.html

## need to check data types? Check schema()
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.schema.html

## need to introduce columns with list or dict structures? Check struct()
https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.struct.html


# Still not fast enough?

Lazy dataframes ([documentation link](https://docs.pola.rs/api/python/stable/reference/lazyframe/index.html)) help, but add some development overhead (as you can't "look" and debug as easily)