In [None]:
import pandas as pd
import numpy as np
import geopandas as gpd
import itertools
import os
import sys
from typing import List, Tuple, Dict
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from geofeather import to_geofeather, from_geofeather
from pathlib import Path
from shapely import Polygon
import datetime
import libpysal as lps
import matplotlib.pyplot as plt
from libpysal.weights.distance import get_points_array
from scipy.spatial import cKDTree

sys.path.append('./utils')
import utils
data_folder = Path('../data/')

## Load testing data

In [None]:
# Load toy COVID data
df_covid = pd.read_csv(data_folder/'covid_data_example.csv')
df_covid = gpd.GeoDataFrame(df_covid, crs=4326, geometry=gpd.points_from_xy(df_covid.longitude, df_covid.latitude))
df_covid = df_covid.to_crs(2056)

## Load geographic elements

### Populated hectares

In [None]:
## VD statpop
statpop_ha_file = "statpop_communes_vd_ha.feather"
statpop_ha_gdf = from_geofeather(data_folder/'ag-b-00.03-vz2020statpop'/statpop_ha_file)
statpop_ha_gdf_light = statpop_ha_gdf[['RELI','geometry']]

In [None]:
#Entire CH statpop
ch_statpop_ha = pd.read_csv(data_folder/'ag-b-00.03-vz2020statpop/STATPOP2020.csv',sep = ';')

### Municipalities

In [None]:
communes = gpd.read_file(data_folder/'Administrative units'/'swissBOUNDARIES3D_1_3_TLM_HOHEITSGEBIET.shp', engine='pyogrio')
communes_vd = communes[communes.KANTONSNUM == 22]
communes_vd = communes_vd[~communes_vd.geometry.isnull()]
communes_vd = communes_vd.rename(columns={'geom': 'geometry'})
communes_vd = communes_vd[communes_vd.NAME != 'Lac Léman (VD)']
communes_vd = communes_vd[communes_vd.NAME != 'Lac de Neuchâtel (VD)']
communes_vd = communes_vd[communes_vd.NAME != 'Lac de Morat (VD)']
communes_vd = communes_vd.reset_index(drop=True)
communes_vd.crs = 2056
communes_vd = gpd.GeoDataFrame(communes_vd, crs = 2056,geometry=communes_vd['geometry'])

In [None]:
# Lausanne 
xmin, ymin, xmax, ymax = communes_vd[communes_vd.NAME.isin(['Lausanne','Renens (VD)','Prilly','Ecublens (VD)','Pully','Le Mont-sur-Lausanne','Paudex','Lutry'])].total_bounds
gdf_lausanne = statpop_ha_gdf.cx[xmin:xmax, ymin:ymax]
bbox_geometry = [Polygon([(xmin,ymin),(xmax,ymin),(xmax,ymax),(xmin, ymax)])]
bbox_lausanne = gpd.GeoDataFrame(['Lausanne'], crs = 2056, geometry = bbox_geometry)

## Define study periods

In [None]:
# Define time periods for analysis
p1_start = '2020-01-10'
p2_start = '2020-06-30'
p3_start = '2020-12-16'
p4_start = '2021-05-07'
p5_start = '2021-11-28'
p5_end = '2022-04-16'

df_covid.loc[df_covid.date_reception.between(p1_start, p2_start), 'period'] = 'p1'
df_covid.loc[df_covid.date_reception.between(p2_start, p3_start), 'period'] = 'p2'
df_covid.loc[df_covid.date_reception.between(p3_start, p4_start), 'period'] = 'p3'
df_covid.loc[df_covid.date_reception.between(p4_start, p5_start), 'period'] = 'p4'
df_covid.loc[df_covid.date_reception.between(p5_start, p5_end), 'period'] = 'p5'

## Select cases by period

In [None]:
df_covid['date_reception'] = pd.to_datetime(df_covid['date_reception']).dt.strftime('%Y-%m-%d')
df_covid['month'] = pd.to_datetime(df_covid['date_reception']).dt.strftime('%Y-%m')
df_covid['week'] = pd.to_datetime(df_covid['date_reception']).dt.isocalendar().week
df_covid['year'] = pd.to_datetime(df_covid['date_reception']).dt.isocalendar().year
df_covid['week_str'] = df_covid.apply(
    lambda row: datetime.datetime.strptime(f"{row['year']}-{row['week']}-1", "%Y-%W-%w"),
    axis=1
)
df_covid_in_ha = gpd.sjoin(statpop_ha_gdf[['RELI','B20BTOT','NAME','EINWOHNERZ','geometry']], df_covid[['id_demande_study2','date_reception','res_cov_txt','week','week_str','month','year','period','geometry']], predicate = 'intersects',how = 'right').sort_values('B20BTOT',ascending = False) #We sort to make sure that the first hectare is the most populated
df_covid_in_ha = df_covid_in_ha.drop_duplicates(subset = ['id_demande_study2'],keep = 'first') #We keep only the first duplicate (most populated)
df_covid_in_ha = df_covid_in_ha.drop('index_left',axis = 1)
df_covid_in_ha['RELI'] = df_covid_in_ha.apply(lambda row: utils.apply_min_dist(row, statpop_ha_gdf), axis=1)
df_covid_in_ha['period'] = df_covid_in_ha['period'].str[1].fillna('0').astype(int)

In [None]:
cases = df_covid_in_ha[df_covid_in_ha.res_cov_txt == 'POSITIVE']

cases_first_period = cases[(cases.date_reception >= p1_start) & (cases.date_reception < p2_start)]
cases_second_period = cases[(cases.date_reception >= p2_start) & (cases.date_reception < p3_start)]
cases_third_period = cases[(cases.date_reception >= p3_start) & (cases.date_reception < p4_start)]
cases_fourth_period = cases[(cases.date_reception >= p4_start) & (cases.date_reception < p5_start)]
cases_fifth_period = cases[(cases.date_reception >= p5_start) & (cases.date_reception < p5_end)]

In [None]:
cases_lausanne = cases[cases.RELI.isin(gdf_lausanne.RELI)]

cases_first_period_lausanne = cases_first_period[cases_first_period.RELI.isin(gdf_lausanne.RELI)]
cases_second_period_lausanne = cases_second_period[cases_second_period.RELI.isin(gdf_lausanne.RELI)]
cases_third_period_lausanne = cases_third_period[cases_third_period.RELI.isin(gdf_lausanne.RELI)]
cases_fourth_period_lausanne = cases_fourth_period[cases_fourth_period.RELI.isin(gdf_lausanne.RELI)]
cases_fifth_period_lausanne = cases_fifth_period[cases_fifth_period.RELI.isin(gdf_lausanne.RELI)]

## Testing rates

In [None]:
df_covid_in_ha_fullperiod = df_covid_in_ha.copy()
df_covid_in_ha_fullperiod['period'] = 'full'
df_covid_in_ha_w_fullperiod = pd.concat([df_covid_in_ha, df_covid_in_ha_fullperiod])

In [None]:
df_testing_rates = df_covid_in_ha_w_fullperiod.groupby(['RELI','period']).agg({'id_demande_study2':'size','B20BTOT':'first'})

In [None]:
df_testing_rates.columns = ['n_tests','population']
df_testing_rates = df_testing_rates.reset_index()

In [None]:
df_testing_rates['testing_rate'] = (df_testing_rates['n_tests']/df_testing_rates['population'])*100

In [None]:
df_testing_rates.loc[df_testing_rates['testing_rate'] > 100, 'testing_rate'] = 100.0

## Functions for MST-DBSCAN

In [None]:
def extract_multiple_survival_times(x: pd.Series, first_date: str, last_date: str) -> Tuple[List[int], List[int], List[int], List[int]]:
    """
    Extract multiple survival times from a series of cluster statuses.

    Args:
        x (pd.Series): Series containing cluster statuses.
        first_date (str): Start date of the analysis period.
        last_date (str): End date of the analysis period.

    Returns:
        Tuple[List[int], List[int], List[int], List[int]]: Survival times, time between clusters, status, and status_resistance.
    """
    cluster_statuses = [status.replace('keep', 'cluster').replace('increase', 'cluster').replace('decrease', 'cluster') 
                        for status in x]
    
    grouped_statuses = [(status, sum(1 for _ in group)) 
                        for status, group in itertools.groupby(cluster_statuses)]
    
    first_date_ts = pd.Timestamp(first_date)
    last_date_ts = pd.Timestamp(last_date)
    total_days = (last_date_ts - first_date_ts).days + 1

    survival_times, time_between_clusters, status, status_resistance = [], [], [], []
    cumulative_days = 0

    for status_type, duration in grouped_statuses:
        cumulative_days += duration
        if status_type == 'cluster':
            survival_times.append(duration)
            status.append(0 if cumulative_days == total_days else 1)
        else:
            time_between_clusters.append(duration)
            status_resistance.append(0 if cumulative_days == total_days else 1)

    return survival_times, time_between_clusters, status, status_resistance

def prep_survival(polygons: pd.DataFrame, first_date: str, last_date: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Prepare survival analysis data from polygon data.

    Args:
        polygons (pd.DataFrame): DataFrame containing polygon data.
        first_date (str): Start date of the analysis period.
        last_date (str): End date of the analysis period.
        dict_ses_class (Dict): Dictionary mapping RELI to SES class.
        dict_ses_class_q3 (Dict): Dictionary mapping RELI to SES class (quartile 3).

    Returns:
        Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: Survival data, resistance data, and multiple survival times data.
    """
    scaler = StandardScaler()


    multiple_survivals = polygons.set_index('RELI').loc[:, first_date:last_date].apply(
        lambda x: extract_multiple_survival_times(x, first_date, last_date), axis=1
    ).to_dict()
    multiple_survivals = pd.DataFrame.from_dict(multiple_survivals).T

    multiple_survivals.columns = ['Survival times', 'Time between clusters', 'status', 'status_resistance']

    df_mult_survival_times = multiple_survivals['Survival times'].apply(pd.Series).reset_index().melt(id_vars='index').dropna()
    df_mult_survival_times.columns = ['RELI', 'cluster_n', 'Persistence_survival']

    df_time_bt_clusters = multiple_survivals['Time between clusters'].apply(pd.Series).reset_index().melt(id_vars='index').dropna()
    df_time_bt_clusters.columns = ['RELI', 'interval_n', 'Resistance_survival']

    df_status = multiple_survivals['status'].apply(pd.Series).reset_index().melt(id_vars='index').dropna().reset_index(drop=True)
    df_status.columns = ['RELI', 'cluster_n', 'status']

    df_status_resistance = multiple_survivals['status_resistance'].apply(pd.Series).reset_index().melt(id_vars='index').dropna().reset_index(drop=True)
    df_status_resistance.columns = ['RELI', 'interval_n', 'status_resistance']

    df_survival = pd.merge(df_mult_survival_times, df_status, on=['RELI', 'cluster_n'])
    df_survival['status'] = df_survival['status'].map({1: True, 0: False})
    df_survival['Survival_scaled'] = scaler.fit_transform(df_survival[['Persistence_survival']])

    df_survival_resistance = pd.merge(df_time_bt_clusters, df_status_resistance, on=['RELI', 'interval_n'])
    df_survival_resistance['status_resistance'] = df_survival_resistance['status_resistance'].map({1: True, 0: False})
    df_survival_resistance['Survival_scaled'] = scaler.fit_transform(df_survival_resistance[['Resistance_survival']])
    return df_survival, df_survival_resistance, df_mult_survival_times

In [None]:
import pysda
import pandas as pd
from pathlib import Path
from typing import Dict, Any

def analyze_period(
    cases_data: pd.DataFrame,
    period_name: str,
    start_date: str,
    end_date: str,
    population_data: Any,
    res_folder: Path,
    eps_spatial: float = 200,
    eps_temporal_low: float = 1,
    eps_temporal_high: float = 14,
    min_pts: int = 3,
    moving_ratio: float = 0.1,
    area_ratio: float = 0.1
) -> Dict[str, Any]:
    """
    Analyze a specific period of data using MSTDBSCAN and prepare survival analysis.

    Args:
        cases_data (pd.DataFrame): DataFrame containing case data for the period.
        period_name (str): Name of the period (e.g., '1st period', '2nd period').
        start_date (str): Start date of the period in format 'YYYY/MM/DD-HH:MM:SS'.
        end_date (str): End date of the period in format 'YYYY/MM/DD-HH:MM:SS'.
        gdf_soc_final (Any): GeoDataFrame for setting polygons.
        res_folder (Path): Path to the results folder.
        eps_spatial (float): Spatial epsilon parameter for MSTDBSCAN.
        eps_temporal_low (float): Lower temporal epsilon parameter for MSTDBSCAN.
        eps_temporal_high (float): Higher temporal epsilon parameter for MSTDBSCAN.
        min_pts (int): Minimum points parameter for MSTDBSCAN.
        moving_ratio (float): Moving ratio parameter for MSTDBSCAN.
        area_ratio (float): Area ratio parameter for MSTDBSCAN.

    Returns:
        Dict[str, Any]: Dictionary containing the results of the analysis.
    """
    # Read data
    pysda_data = pysda.data.readGDF(cases_data, timeColumn='date_reception', timeUnit="day")

    # Set up and run MSTDBSCAN
    mst = pysda.MSTDBSCAN(pysda_data)
    mst.setParams(eps_spatial, eps_temporal_low, eps_temporal_high, min_pts, moving_ratio, area_ratio)
    mst.run()

    # Process results
    result = mst.result
    result.setPolygons(population_data)
    all_results = result.getAll()

    clusters = all_results["clusters"]
    polygons = all_results["polygons"]
    points = all_results["points"]

    # Prepare survival analysis
    survival, time_bt_clusters, mult_survival = prep_survival(polygons, start_date, end_date)

    # Add period information
    survival['period'] = period_name
    time_bt_clusters['period'] = period_name
    clusters['period'] = period_name

    # Save results
    survival_folder = res_folder / 'Survival analyses'
    survival_folder.mkdir(parents=True, exist_ok=True)

    survival.to_pickle(survival_folder / f'_{period_name.replace(" ", "_")}_survival.pkl')
    time_bt_clusters.to_pickle(survival_folder / f'_{period_name.replace(" ", "_")}_time_bt.pkl')

    return {
        'clusters': clusters,
        'polygons': polygons,
        'points': points,
        'survival': survival,
        'time_bt_clusters': time_bt_clusters
    }

## MSTDBSCAN

In [None]:
mstdbscan_folder = data_folder/'MSTDBSCAN'

### Canton of Vaud

In [None]:
def analyze_all_periods(periods: List[Dict[str, Any]], population_data: pd.DataFrame, output_folder: str) -> Dict[str, Dict[str, Any]]:
    """
    Analyze all periods and return the results.

    Args:
        periods (List[Dict[str, Any]]): List of dictionaries containing period information.
        population_data (pd.DataFrame): The population data for analysis.
        output_folder (str): The folder path for saving output files.

    Returns:
        Dict[str, Dict[str, Any]]: A dictionary containing results for all periods.
    """
    all_results = {}

    for period in periods:
        period_results = analyze_period(
            cases_data=period['cases_data'],
            period_name=period['name'],
            start_date=period['start_date'],
            end_date=period['end_date'],
            population_data=population_data,
            res_folder=output_folder
        )
        
        all_results[period['name']] = {
            'clusters': period_results['clusters'],
            'survival_data': period_results['survival'],
            'intercluster_times': period_results['time_bt_clusters']
        }
    
    return all_results

In [None]:
from typing import Dict, Any, List

# Define period information
analysis_periods = [
    {
        'name': 'first_period',
        'cases_data': cases_first_period,
        'start_date': '2020/03/02-00:00:00',
        'end_date': '2020/06/29-00:00:00'
    },
    {
        'name': 'second_period',
        'cases_data': cases_second_period,
        'start_date': '2020/06/30-00:00:00',
        'end_date': '2020/12/15-00:00:00'
    },
    {
        'name': 'third_period',
        'cases_data': cases_third_period,
        'start_date': '2020/12/16-00:00:00',
        'end_date': '2021/05/06-00:00:00'
    },
    {
        'name': 'fourth_period',
        'cases_data': cases_fourth_period,
        'start_date': '2021/05/07-00:00:00',
        'end_date': '2021/11/27-00:00:00'
    },
    {
        'name': 'fifth_period',
        'cases_data': cases_fifth_period,
        'start_date': '2021/11/28-00:00:00',
        'end_date': '2022/04/15-00:00:00'
    }
]

In [None]:
# Analyze all periods
all_period_results = analyze_all_periods(analysis_periods, population_data=statpop_ha_gdf_light, output_folder=mstdbscan_folder)

# Access results
clusters_first_period = all_period_results['first_period']['clusters']
survival_first_period = all_period_results['first_period']['survival_data']
intercluster_times_first_period = all_period_results['first_period']['intercluster_times']

clusters_second_period = all_period_results['second_period']['clusters']
survival_second_period = all_period_results['second_period']['survival_data']
intercluster_times_second_period = all_period_results['second_period']['intercluster_times']

clusters_third_period = all_period_results['third_period']['clusters']
survival_third_period = all_period_results['third_period']['survival_data']
intercluster_times_third_period = all_period_results['third_period']['intercluster_times']

clusters_fourth_period = all_period_results['fourth_period']['clusters']
survival_fourth_period = all_period_results['fourth_period']['survival_data']
intercluster_times_fourth_period = all_period_results['fourth_period']['intercluster_times']

clusters_fifth_period = all_period_results['fifth_period']['clusters']
survival_fifth_period = all_period_results['fifth_period']['survival_data']
intercluster_times_fifth_period = all_period_results['fifth_period']['intercluster_times']

### Lausanne urban area

In [None]:
# Define period information
analysis_periods_lausanne = [
    {
        'name': 'first_period',
        'cases_data': cases_first_period_lausanne,
        'start_date': '2020/03/07-00:00:00',
        'end_date': '2020/06/29-00:00:00'
    },
    {
        'name': 'second_period',
        'cases_data': cases_second_period_lausanne,
        'start_date': '2020/06/30-00:00:00',
        'end_date': '2020/12/15-00:00:00'
    },
    {
        'name': 'third_period',
        'cases_data': cases_third_period_lausanne,
        'start_date': '2020/12/16-00:00:00',
        'end_date': '2021/05/06-00:00:00'
    },
    {
        'name': 'fourth_period',
        'cases_data': cases_fourth_period_lausanne,
        'start_date': '2021/05/07-00:00:00',
        'end_date': '2021/11/27-00:00:00'
    },
    {
        'name': 'fifth_period',
        'cases_data': cases_fifth_period_lausanne,
        'start_date': '2021/11/28-00:00:00',
        'end_date': '2022/04/15-00:00:00'
    }
]

In [None]:
# Analyze all periods
all_period_results_lausanne = analyze_all_periods(analysis_periods_lausanne, population_data=statpop_ha_gdf_light, output_folder=mstdbscan_folder)

# Access results
clusters_first_period_lausanne = all_period_results_lausanne['first_period']['clusters']
survival_first_period_lausanne = all_period_results_lausanne['first_period']['survival_data']
intercluster_times_first_period_lausanne = all_period_results_lausanne['first_period']['intercluster_times']

clusters_second_period_lausanne = all_period_results_lausanne['second_period']['clusters']
survival_second_period_lausanne = all_period_results_lausanne['second_period']['survival_data']
intercluster_times_second_period_lausanne = all_period_results_lausanne['second_period']['intercluster_times']

clusters_third_period_lausanne = all_period_results_lausanne['third_period']['clusters']
survival_third_period_lausanne = all_period_results_lausanne['third_period']['survival_data']
intercluster_times_third_period_lausanne = all_period_results_lausanne['third_period']['intercluster_times']

clusters_fourth_period_lausanne = all_period_results_lausanne['fourth_period']['clusters']
survival_fourth_period_lausanne = all_period_results_lausanne['fourth_period']['survival_data']
intercluster_times_fourth_period_lausanne = all_period_results_lausanne['fourth_period']['intercluster_times']

clusters_fifth_period_lausanne = all_period_results_lausanne['fifth_period']['clusters']
survival_fifth_period_lausanne = all_period_results_lausanne['fifth_period']['survival_data']
intercluster_times_fifth_period_lausanne = all_period_results_lausanne['fifth_period']['intercluster_times']

### Combine the five periods

In [None]:
def save_combined_data(dataframes: List[pd.DataFrame], filename: str, res_folder: Path, file_format: str = 'parquet'):
    """
    Combine multiple dataframes and save the result in the specified format.

    Args:
        dataframes (List[pd.DataFrame]): List of dataframes to combine.
        filename (str): Name of the output file (without extension).
        res_folder (Path): Path to the results folder.
        file_format (str): Format to save the file in ('parquet', 'csv', or 'pickle'). Defaults to 'parquet'.

    Raises:
        ValueError: If an unsupported file format is specified.
    """
    combined_df = pd.concat(dataframes, ignore_index=True)
    output_path = res_folder / f'{filename}.{file_format}'

    try:
        if file_format == 'parquet':
            combined_df.to_parquet(output_path, index=False)
        elif file_format == 'csv':
            combined_df.to_csv(output_path, index=False)
        elif file_format == 'pickle':
            combined_df.to_pickle(output_path, protocol=4)  # Using protocol 4 for compatibility
        else:
            raise ValueError(f"Unsupported file format: {file_format}")
        
        print(f"Successfully saved {filename}.{file_format}")
    except Exception as e:
        print(f"Error saving {filename}.{file_format}: {e}")

In [None]:
save_combined_data([clusters_first_period, clusters_second_period, clusters_third_period, clusters_fourth_period, clusters_fifth_period], 
                   'combined_clusters', mstdbscan_folder)

save_combined_data([survival_first_period, survival_second_period, survival_third_period, survival_fourth_period, survival_fifth_period], 
                   'combined_survival', mstdbscan_folder)

save_combined_data([intercluster_times_first_period, intercluster_times_second_period, intercluster_times_third_period, intercluster_times_fourth_period, intercluster_times_fifth_period], 
                   'combined_time_bt', mstdbscan_folder)

In [None]:

save_combined_data([clusters_first_period_lausanne, clusters_second_period_lausanne, clusters_third_period_lausanne, clusters_fourth_period_lausanne, clusters_fifth_period_lausanne], 
                   'combined_clusters_lausanne', mstdbscan_folder)

save_combined_data([survival_first_period_lausanne, survival_second_period_lausanne, survival_third_period_lausanne, survival_fourth_period_lausanne, survival_fifth_period_lausanne], 
                   'combined_survival_lausanne', mstdbscan_folder)

save_combined_data([intercluster_times_first_period_lausanne, intercluster_times_second_period_lausanne, intercluster_times_third_period_lausanne, intercluster_times_fourth_period_lausanne, intercluster_times_fifth_period_lausanne], 
                   'combined_time_bt_lausanne', mstdbscan_folder)

## Clusters over the whole period

In [None]:
results_whole_period = analyze_period(
    cases,
    'Whole period',
    '2020/03/02-00:00:00',
    '2022/04/15-00:00:00',
    statpop_ha_gdf,
    mstdbscan_folder,
)

clusterwholeperiod = results_whole_period['clusters']
whole_survival = results_whole_period['survival']
whole_time_bt_clusters = results_whole_period['time_bt_clusters']

In [None]:
whole_survival.to_pickle(mstdbscan_folder/'whole_survival - 200m.pkl')

In [None]:
results_whole_period_lausanne = analyze_period(
    cases_lausanne,
    'Whole period',
    '2020/03/04-00:00:00',
    '2022/04/15-00:00:00',
    statpop_ha_gdf,
    mstdbscan_folder
)

clusterwholeperiod_lausanne = results_whole_period_lausanne['clusters']
whole_survival_lausanne = results_whole_period_lausanne['survival']
whole_time_bt_clusters_lausanne = results_whole_period_lausanne['time_bt_clusters']

In [None]:
whole_survival_lausanne.to_pickle(mstdbscan_folder/'whole_survival_lausanne - 200m.pkl')

## Final data preparation

In [None]:
gdf_covariates = statpop_ha_gdf[['RELI','B20BTOT','E_KOORD','N_KOORD','geometry']].copy()

In [None]:
gdf_covariates = gpd.GeoDataFrame(
    gdf_covariates, geometry=gdf_covariates.geometry)

### SES determinants

In [None]:
gdf_ses = pd.read_parquet(data_folder/'Socioeconomic determinants'/'df_ses_index.parquet')

In [None]:
gdf_covariates = pd.merge(gdf_covariates, gdf_ses[['RELI','index_socio_class','index_socio_class_q3','index_socio_stand','index_socio_class_q3_inv']], on = 'RELI')

### Environmental determinants

In [None]:
ns_carnight = pd.read_pickle(data_folder/'Environmental determinants'/'ns_car_night.pkl')
ns_carday = pd.read_pickle(data_folder/'Environmental determinants'/'ns_car_day.pkl')
pm10 = pd.read_pickle(data_folder/'Environmental determinants'/'pm10_2020.pkl')
pm25 = pd.read_pickle(data_folder/'Environmental determinants'/'pm25_2020.pkl')
no2 = pd.read_pickle(data_folder/'Environmental determinants'/'no2_2020.pkl')
lst = pd.read_pickle(data_folder/'Environmental determinants'/'gdf_lst.pkl')
ndvi = pd.read_pickle(data_folder/'Environmental determinants'/'gdf_ndvi.pkl')

In [None]:
mean_columns = {
    'LST': lst['mean_lst'],
    'NDVI': ndvi['mean_ndvi'],
    'pm10': pm10['mean'],
    'pm25': pm25['mean'],
    'no2': no2['mean'],
    'noise_car_day': ns_carday['mean'],
    'noise_car_night': ns_carnight['mean']
}

ch_statpop_ha = ch_statpop_ha.assign(**mean_columns)

In [None]:
gdf_covariates = pd.merge(gdf_covariates, ch_statpop_ha[['RELI','LST','NDVI','pm10','pm25','no2','noise_car_day','noise_car_night']], on ='RELI', how='left')

In [None]:
gdf_covariates['pm25'] = gdf_covariates['pm25'].div(10)
gdf_covariates['pm10'] = gdf_covariates['pm10'].div(10)
gdf_covariates['no2'] = gdf_covariates['no2'].div(10)

### Lagged population

In [None]:
wnn8 = lps.weights.KNN(cKDTree(get_points_array(gdf_covariates.geometry)), 8)
wnn24 = lps.weights.KNN(cKDTree(get_points_array(gdf_covariates.geometry)),24)
wd_200 = lps.weights.DistanceBand(cKDTree(get_points_array(gdf_covariates.geometry)), 200)

wnn8.transform = 'r'
wnn24.transform = 'r'
wd_200.transform = 'r'

gdf_covariates['B20BTOT_lag8'] = lps.weights.lag_spatial(wnn8, gdf_covariates['B20BTOT'])
gdf_covariates['B20BTOT_lag24'] = lps.weights.lag_spatial(wnn24, gdf_covariates['B20BTOT'])
gdf_covariates['B20BTOT_lag200m'] = lps.weights.lag_spatial(wd_200, gdf_covariates['B20BTOT'])

### Testing rates

In [None]:
gdf_covariates = pd.merge(gdf_covariates, df_testing_rates[df_testing_rates.period == 'full'][['RELI','testing_rate']], on = 'RELI', how = 'left')

In [None]:
gdf_covariates['testing_rate'] = gdf_covariates['testing_rate'].fillna(0)

## Cox PH modelling

In [None]:
import seaborn as sns

from lifelines import (
    CoxPHFitter,
    ExponentialFitter,
    KaplanMeierFitter,
    LogLogisticFitter,
    LogLogisticAFTFitter,
    LogNormalFitter,
    LogNormalAFTFitter,
    WeibullFitter,
    WeibullAFTFitter,
)
from lifelines.plotting import qq_plot
from lifelines.utils import k_fold_cross_validation

### Whole canton

In [None]:
# Define the periods
periods = ['first_period', 'second_period', 'third_period', 'fourth_period', 'fifth_period']

# Create a list of dataframes, each with a different period
gdf_covariates_periods = [gdf_covariates.assign(period=p) for p in periods]

# Concatenate all dataframes
gdf_covariates_period = pd.concat(gdf_covariates_periods, ignore_index=True)


In [None]:
whole_survival_period = pd.read_parquet(mstdbscan_folder/'combined_survival.parquet')

In [None]:
df = whole_survival.groupby('RELI').agg({'Persistence_survival':'sum', 'status':'last'}).reset_index()
df_period = whole_survival_period.groupby(['RELI','period']).agg({'Persistence_survival':'sum', 'status':'last'}).reset_index()

In [None]:
df_covariates = pd.merge(df, gdf_covariates, on = 'RELI', how = 'right')
df_period_covariates = pd.merge(df_period, gdf_covariates_period, on = ['RELI','period'], how = 'right')

In [None]:
df_covariates['Persistence_survival'] = df_covariates['Persistence_survival'].fillna(0)
df_covariates['status'] = df_covariates['status'].fillna(True)

df_period_covariates['Persistence_survival'] = df_period_covariates['Persistence_survival'].fillna(0)
df_period_covariates['status'] = df_period_covariates['status'].fillna(True)

In [None]:
# Using Cox Proportional Hazards model
cph = CoxPHFitter()   ## Instantiate the class to create a cph object
cph.fit(df_covariates[df_covariates['index_socio_stand'].isnull()==False][['Persistence_survival','index_socio_stand','status']], robust=True, duration_col = 'Persistence_survival', event_col='status')   ## Fit the data to train the model
cph.print_summary(2)    ## HAve a look at the significance of the features

## Lausanne data

In [None]:
whole_survival_period_lausanne = pd.read_parquet(mstdbscan_folder/'combined_survival_lausanne.parquet')

In [None]:
df_lausanne = whole_survival_lausanne.groupby('RELI').agg({'Persistence_survival':'sum', 'status':'last'}).reset_index()
df_period_lausanne = whole_survival_period_lausanne.groupby(['RELI','period']).agg({'Persistence_survival':'sum', 'status':'last'}).reset_index()

In [None]:
df_lausanne_covariates = pd.merge(df_lausanne, gdf_covariates[gdf_covariates.RELI.isin(gdf_lausanne.RELI)], on = 'RELI', how = 'right')
df_period_lausanne_covariates = pd.merge(df_period_lausanne, gdf_covariates_period[gdf_covariates_period.RELI.isin(gdf_lausanne.RELI)], on = ['RELI','period'], how = 'right')

In [None]:
df_lausanne_covariates['Persistence_survival'] = df_lausanne_covariates['Persistence_survival'].fillna(0)
df_lausanne_covariates['status'] = df_lausanne_covariates['status'].fillna(True)

df_period_lausanne_covariates['Persistence_survival'] = df_period_lausanne_covariates['Persistence_survival'].fillna(0)
df_period_lausanne_covariates['status'] = df_period_lausanne_covariates['status'].fillna(True)

In [None]:
df_lausanne_covariates[df_lausanne_covariates['index_socio_stand'].isnull()==False][['Persistence_survival','index_socio_stand','status']].isna().sum()

In [None]:
# Using Cox Proportional Hazards model
cph = CoxPHFitter()   ## Instantiate the class to create a cph object
cph.fit(df_lausanne_covariates[df_lausanne_covariates['index_socio_stand'].isnull()==False][['Persistence_survival','index_socio_stand','status']], robust=True, duration_col = 'Persistence_survival', event_col='status')   ## Fit the data to train the model
cph.print_summary(2)    ## HAve a look at the significance of the features

## Save processed datasets

In [None]:
df_covariates = gpd.GeoDataFrame(df_covariates, crs=2056)
df_period_covariates = gpd.GeoDataFrame(df_period_covariates, crs=2056)
df_lausanne_covariates = gpd.GeoDataFrame(df_lausanne_covariates, crs=2056)
df_period_lausanne_covariates = gpd.GeoDataFrame(df_period_lausanne_covariates, crs=2056)

In [None]:
df_covariates.to_parquet(data_folder/'df_survival_covariates.geoparquet')
df_period_covariates.to_parquet(data_folder/'df_survival_period_covariates.geoparquet')

df_lausanne_covariates.to_parquet(data_folder/'df_survival_covariates_lausanne.geoparquet')
df_period_lausanne_covariates.to_parquet(data_folder/'df_survival_period_covariates_lausanne.geoparquet')