## Polars home page (https://pola.rs/) 
## The User Guide home page (https://docs.pola.rs/)

In [16]:
###TODOS
#Assume that ABM , Why python + polars?

#jump to workshop a dataframe 

#create a powerpoint slides

#design update question

#Sep 16
#1. Julian: Hands on -> update checking death in optimized way
#2. Nefel (DONE) + Julian: (N):introduce a simple example of disease state update + (J) hands on example of updating status of individuals (inf -> sus)
#3. (DONE) Nefel: rather thanr emoving individuals add dead status example
#4. Nefel + Julian: (N):introduce a simple example (J) load csv -> collect -> group by agg and select(), pivot tables
#5. Julian: Ask Rob about the opt way to run interactive workshop: how to use github.dev for instance?


In [4]:
import polars as pl
import numpy as np

## Data structures in Polars

## Model Initialization

In [18]:
# Create a dataframe

pop_size = 3

# Create a dataframe via a dictionary
population = pl.DataFrame(
                        {
                        "id": range(pop_size),
                        "state": ["Susceptible"] * pop_size,
                        }
                        )
print(population)

# Create the same dataframe via list of pl.Series
population2 = [
                pl.Series("id", range(pop_size), dtype=pl.Int64),
                pl.Series("state",  ["Susceptible"] * pop_size),  
            ]
population2 = pl.DataFrame(population2)
print(population2)


shape: (3, 2)
┌─────┬─────────────┐
│ id  ┆ state       │
│ --- ┆ ---         │
│ i64 ┆ str         │
╞═════╪═════════════╡
│ 0   ┆ Susceptible │
│ 1   ┆ Susceptible │
│ 2   ┆ Susceptible │
└─────┴─────────────┘
shape: (3, 2)
┌─────┬─────────────┐
│ id  ┆ state       │
│ --- ┆ ---         │
│ i64 ┆ str         │
╞═════╪═════════════╡
│ 0   ┆ Susceptible │
│ 1   ┆ Susceptible │
│ 2   ┆ Susceptible │
└─────┴─────────────┘


### Polars: with_columns()
Very useful function that adds columns, or updates existing columns

See [documentation link](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.with_columns.html).

In [19]:
# add a new column using with_columns
population = population.with_columns(age = pl.lit(30))
print(population)

shape: (3, 3)
┌─────┬─────────────┬─────┐
│ id  ┆ state       ┆ age │
│ --- ┆ ---         ┆ --- │
│ i64 ┆ str         ┆ i32 │
╞═════╪═════════════╪═════╡
│ 0   ┆ Susceptible ┆ 30  │
│ 1   ┆ Susceptible ┆ 30  │
│ 2   ┆ Susceptible ┆ 30  │
└─────┴─────────────┴─────┘


In [20]:
# create a susceptible population and change status of one individual
# using pl.when().then().otherwise() inside a with_columns()

population = population.with_columns(
        (pl.when(pl.col("id") == 2)
        .then(pl.lit("Infectious"))
        .otherwise(pl.col("state"))
        .alias("state"))
    )
print(population)

shape: (3, 3)
┌─────┬─────────────┬─────┐
│ id  ┆ state       ┆ age │
│ --- ┆ ---         ┆ --- │
│ i64 ┆ str         ┆ i32 │
╞═════╪═════════════╪═════╡
│ 0   ┆ Susceptible ┆ 30  │
│ 1   ┆ Susceptible ┆ 30  │
│ 2   ┆ Infectious  ┆ 30  │
└─────┴─────────────┴─────┘


In [21]:
# Save a dataframe to a csv file
population.write_csv("population.csv")
# Read a dataframe from a csv file
my_population = pl.read_csv("example_population.csv")
display(my_population)

id,state,age,age_group,age_day,stime_infection,stime_recovery
i64,str,i64,i64,i64,i64,i64
0,"""Susceptible""",15,1,364,-1,-1
1,"""Susceptible""",20,2,21,-1,-1
2,"""Infectious""",61,6,211,12,20
3,"""Susceptible""",98,9,301,-1,-1
4,"""Susceptible""",100,10,2,-1,-1


## Model Updating

In [22]:
# Update the age structure based on a given time step 
print(my_population)

t_step = 1 #day
a_year = 365 #day
tstep_per_year = a_year / t_step
years_age_group = 10 # 10-year age band in each age group

my_population = my_population.with_columns(
                pl.col('age').add((pl.col('age_day') + t_step) // a_year).alias('age'),
                ((pl.col('age_day') + t_step) % a_year).alias('age_day'),
                ((pl.col('age') + (pl.col('age_day') + t_step) // a_year) // years_age_group).alias('age_group'))

print("age structure is updated")
print(my_population)

shape: (5, 7)
┌─────┬─────────────┬─────┬───────────┬─────────┬─────────────────┬────────────────┐
│ id  ┆ state       ┆ age ┆ age_group ┆ age_day ┆ stime_infection ┆ stime_recovery │
│ --- ┆ ---         ┆ --- ┆ ---       ┆ ---     ┆ ---             ┆ ---            │
│ i64 ┆ str         ┆ i64 ┆ i64       ┆ i64     ┆ i64             ┆ i64            │
╞═════╪═════════════╪═════╪═══════════╪═════════╪═════════════════╪════════════════╡
│ 0   ┆ Susceptible ┆ 15  ┆ 1         ┆ 364     ┆ -1              ┆ -1             │
│ 1   ┆ Susceptible ┆ 20  ┆ 2         ┆ 21      ┆ -1              ┆ -1             │
│ 2   ┆ Infectious  ┆ 61  ┆ 6         ┆ 211     ┆ 12              ┆ 20             │
│ 3   ┆ Susceptible ┆ 98  ┆ 9         ┆ 301     ┆ -1              ┆ -1             │
│ 4   ┆ Susceptible ┆ 100 ┆ 10        ┆ 2       ┆ -1              ┆ -1             │
└─────┴─────────────┴─────┴───────────┴─────────┴─────────────────┴────────────────┘
age structure is updated
shape: (5, 7)
┌─────┬─────

In [23]:
yearly_death_rate = pl.read_csv("death_rates.csv")
#display(yearly_death_rate.head())
#display(yearly_death_rate.tail())

# take yearly death rates and calculate adjusted daily death rates
annual_factor = 1.0 / tstep_per_year
adj_death_rate = yearly_death_rate.with_columns(
    (1 - (1 - pl.col("death_rate")) ** (annual_factor)).alias("adj_death_rate")
)
display(adj_death_rate.head())


age,death_rate,adj_death_rate
i64,f64,f64
0,0.00322,9e-06
1,0.00023,6.3021e-07
2,0.000125,3.4249e-07
3,0.000105,2.8769e-07
4,9e-05,2.4659e-07


In [24]:
# checking deaths
random_seed = 92
rng = np.random.RandomState(random_seed)
    
print(my_population)

# only keep people who are alive
my_population1 = my_population.filter(
    (rng.rand(my_population.height) > 
     adj_death_rate["adj_death_rate"].gather(my_population["age"]).to_list()) )
print(my_population1)

# another example: keep dead individuals in the population with state "Dead"
random_seed = 92
rng = np.random.RandomState(random_seed)

dead_status = (rng.rand(my_population.height) <= 
               adj_death_rate["adj_death_rate"].gather(my_population["age"]).to_list())

my_population2 = my_population.with_columns(
    pl.when(pl.lit(dead_status)).then(pl.lit("Dead"))
    .otherwise(pl.col("state")).alias("state"))
print(my_population2)

shape: (5, 7)
┌─────┬─────────────┬─────┬───────────┬─────────┬─────────────────┬────────────────┐
│ id  ┆ state       ┆ age ┆ age_group ┆ age_day ┆ stime_infection ┆ stime_recovery │
│ --- ┆ ---         ┆ --- ┆ ---       ┆ ---     ┆ ---             ┆ ---            │
│ i64 ┆ str         ┆ i64 ┆ i64       ┆ i64     ┆ i64             ┆ i64            │
╞═════╪═════════════╪═════╪═══════════╪═════════╪═════════════════╪════════════════╡
│ 0   ┆ Susceptible ┆ 16  ┆ 1         ┆ 0       ┆ -1              ┆ -1             │
│ 1   ┆ Susceptible ┆ 20  ┆ 2         ┆ 22      ┆ -1              ┆ -1             │
│ 2   ┆ Infectious  ┆ 61  ┆ 6         ┆ 212     ┆ 12              ┆ 20             │
│ 3   ┆ Susceptible ┆ 98  ┆ 9         ┆ 302     ┆ -1              ┆ -1             │
│ 4   ┆ Susceptible ┆ 100 ┆ 10        ┆ 3       ┆ -1              ┆ -1             │
└─────┴─────────────┴─────┴───────────┴─────────┴─────────────────┴────────────────┘
shape: (4, 7)
┌─────┬─────────────┬─────┬──────────

### Polars: Update()

In [78]:
#checking new infections

random_seed = 9 
rng = np.random.RandomState(random_seed)
    
print(my_population)

#assume contacts are homogeneous; 
prob_transmission_per_contact = 0.1
avg_contacts_per_day = 2
current_time = 20
avg_inf_duration = 8 #days

#only consider susceptibles
susceptibles = my_population.filter(
    pl.col("state") == "Susceptible")

if susceptibles.height:
    #at least one susceptible ind in the population

    #calculate infected fraction in contacts of a susceptible individual
    frac_infected_in_contacts = my_population.filter(
    pl.col("state") == "Infectious").height/(my_population.height - 1)
    print("frac_infected_in_contacts: %s" %frac_infected_in_contacts)

    #calculate prob_transmission for a susceptible individual
    prob_transmission = prob_transmission_per_contact * \
                         avg_contacts_per_day * frac_infected_in_contacts 
    #identify individuals who will be infected
    will_infected = susceptibles.filter(
        rng.rand(susceptibles.height) <= prob_transmission
    )
    print("will infected: %s"%will_infected)

    dur_infections = rng.exponential(avg_inf_duration, will_infected.height)
    new_infected = will_infected.with_columns(
        pl.lit("Infectious").alias("state"),
        pl.lit(current_time).alias("stime_infection"),
        (pl.lit(current_time) + pl.lit(dur_infections)).alias("stime_recovery"),
        )
    print("new infected: %s"%new_infected)

    if new_infected.height:
        #only update the rows of the newly infected population 
        my_population1 = my_population.update(new_infected, on = "id", how = "left")
        
print("updated population %s"%my_population1)

shape: (5, 7)
┌─────┬─────────────┬─────┬───────────┬─────────┬─────────────────┬────────────────┐
│ id  ┆ state       ┆ age ┆ age_group ┆ age_day ┆ stime_infection ┆ stime_recovery │
│ --- ┆ ---         ┆ --- ┆ ---       ┆ ---     ┆ ---             ┆ ---            │
│ i64 ┆ str         ┆ i64 ┆ i64       ┆ i64     ┆ i64             ┆ i64            │
╞═════╪═════════════╪═════╪═══════════╪═════════╪═════════════════╪════════════════╡
│ 0   ┆ Susceptible ┆ 16  ┆ 1         ┆ 0       ┆ -1              ┆ -1             │
│ 1   ┆ Susceptible ┆ 20  ┆ 2         ┆ 22      ┆ -1              ┆ -1             │
│ 2   ┆ Infectious  ┆ 61  ┆ 6         ┆ 212     ┆ 12              ┆ 20             │
│ 3   ┆ Susceptible ┆ 98  ┆ 9         ┆ 302     ┆ -1              ┆ -1             │
│ 4   ┆ Susceptible ┆ 100 ┆ 10        ┆ 3       ┆ -1              ┆ -1             │
└─────┴─────────────┴─────┴───────────┴─────────┴─────────────────┴────────────────┘
frac_infected_in_contacts: 0.25
will infected: shap

In [80]:
## Hands on..
current_time = 20

#Implement a recovery process

#...



# Collect summary statistics

In [81]:
results = pl.read_csv("example_output.csv")
display(results.head())
pop_size = results.filter((pl.col("run_no") == pl.col("run_no").min()) &
                          ((pl.col("t") == pl.col("t").min())))["count"].sum()
print(pop_size)

state,count,t,run_no
str,i64,i64,i64
"""Susceptible""",499,0,0
"""Exposed""",0,0,0
"""Infectious""",1,0,0
"""Recovered""",0,0,0
"""Susceptible""",499,1,0


500


In [59]:
#final size calculation given initial recovered == 0
recovered = results.filter(pl.col("state") == "Recovered"
                          ).with_columns(pl.col("count")/pop_size)

final_sizes = recovered.group_by("run_no").agg(pl.col("count").last()
                                              ).rename({"count": "fsize_fraction"})

#.filter(pl.col("t").last())
sum_final_size = final_sizes.select(
                        pl.col("fsize_fraction").mean().name.suffix("_mean"),
                        pl.col("fsize_fraction").quantile(0.025).name.suffix("_025"),
                        pl.col("fsize_fraction").quantile(0.975).name.suffix("_975"),
                    )
display(sum_final_size)

fsize_fraction_mean,fsize_fraction_025,fsize_fraction_975
f64,f64,f64
0.305,0.002,0.692


In [None]:
#Julian: another agg example?

# Additional examples and functions

In [None]:
#ADDING NEW ROWS - COMBINING TWO DATAFRAME
#add rows
#https://stackoverflow.com/questions/71654966/how-can-i-append-or-concatenate-two-dataframes-in-python-polars
pop_size = 2
population = pl.DataFrame(
                        {
                        "id": range(pop_size),
                        "state": ["Susceptible"] * pop_size,
                        }
                        )

population2 = pl.DataFrame([
    pl.Series("id", [2], dtype=pl.Int64),
    pl.Series("state",  ["Infectious"]),
])
print(population2)

# new memory slab
new_population = pl.concat([population, population2], rechunk=True)

# append free (no memory copy)
new_population2 = population.vstack(population2)

# try to append in place
population.extend(population2)
print(population, new_population, new_population2)

In [None]:
#rename(): rename columns
population2 = population.rename({"unique_id":"id"})
print(population2)
#CAREFUL: using alias to rename changes the order of the columns
population1 = population.with_columns(pl.col("unique_id").alias("id")).drop("unique_id")
print(population1)

## need to combine different dataframes on certain axes? Check join()
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.join.html

## need sorting? Check sort()
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.sort.html

## need the number of rows of a dataframe? Check height
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.height.html

## need to check data types? Check schema()
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.schema.html

## need to introduce columns with list or dict structures? Check struct()
https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.struct.html


## Introduction to Lazy DataFrame


In [None]:
# update pop dynamics
"""
self.P.I = (
               self.P.I
               .with_columns(
                   
                  ( pl.col('age').add((pl.col('age_days') + period) // 364).alias('age')),
                   
                   ((pl.col('age_days') + period) % 364).alias('age_days'),
                   
                  ((pl.col('age') + (pl.col('age_days') + period)// 364) // 5)
                  .alias('age_group'),
                    
                  #(will_live(self.P.I["age"], self.death_rates[t], self.rng)),
                  
                  
                  ((pl.col('random') < death_rates.list.take(self.P.I["age"])[0]).alias("alive")),
                  
                  (pl.lit(self.rng.rand(self.P.I.height))).alias("random"),
                  
                   ).filter(pl.col("alive")).drop("alive")
               )

"""

# assign an exposed from column
"""
P.I = (
                    P.I.with_columns(
                    exposed_strains = self.strain_distribution["strain"][
                        rng.choice(len(self.strain_distribution["strain"]),
                                    p = self.strain_distribution["fraction"],
                                              size = len(P.I))]
             )).sort("random", "no_of_strains", "age_group", descending=False)
"""

#vaccine a target group


#calculate age-group specific foi and infect indiivduals and update the individuals


#identify individuals that are going to be infected by exposed strains
"""P.I = (P.I.with_columns(
                   (  pl.col("random") <= (\
                (1 * (pl.col("no_of_strains") < self.max_no_coinfections)) *\
                    (prob_infection.list.take(P.I["age_group"])[0] *\
                     (1 -  (self.reduction_in_susceptibility) * \
                          (pl.col("no_of_strains"))) * \
                            (1 - \
            pl.col("exposed_strains").is_in(pl.col("strain_list")))     
                                 )
                   )).alias("will_infected"),
               
                   (pl.lit(rng.rand(P.I.height)).alias("random")),
                   ))
"""


#recover individuals with multiple infections
"""
recovered = (
            P.I.filter((pl.col("endTimes").list.eval(pl.element()
                            .filter(pl.element() <= day))).list.lengths() > 0)
            .select(["id", "strain_list", "endTimes", "no_of_strains"])
        )
        
       
        
recovered = (
            recovered.with_columns(
            (pl.col("endTimes").list.eval(pl.element()
                     .filter(pl.element() > day))),
            (pl.col("strain_list").list.eval(pl.element()).list
             .take(pl.col("endTimes").list
        .eval((pl.element() > day).arg_true()))),
           ).with_columns(
               (((pl.col("strain_list").list.lengths()))
                 .cast(pl.Int32).alias("no_of_strains"))))
               #.filter(pl.col("endTimesIndexes").list.lengths() > 0)
            
P.I = P.I.update(recovered, on="id", how="left")
"""



#check the antibody levels of individuals given a vaccine list and antibody log
"""


"""