In [None]:
# Functions for 2_visualize_analyze_dataipynb

def plot_sum_timeline(df, variables, period, xy_labels, plot_title, label_colours_dict, save_name):
    '''
    Function plotting one or more variables (column names) over time, taking the sum per month or year
    
    - variables: col name (str) or list of col names (str)
    - label_colours_dict: dict variable: colour -> label will be first matched to variable (col name), then taken from last part of name, eg if variable
        is 'number_of_cyclist_injured', label will be 'injured'
    - save_name: name of figure. Figure will be saved in fig_dir, as save_fig_as
    '''
    
    if period == 'month':
        period = ['year', 'month']
    
    # Create sum of variables per year and month
    period_sum = (
        df_accidents.groupby(period)[variables]
        .sum()
        .reset_index()
    )

    #print(monthly_sum.head()) # check 
    period_sum['date'] = pd.to_datetime(period_sum[period].assign(day=1))

    # Plot
    plt.close()
    fig, axes = plt.subplots(len(variables),1, figsize=(12, 4*len(variables)), sharex=False); axes = axes.flatten();
    plt.suptitle(plot_title)
    for i, var in enumerate(variables):
        label = var.split('_')[-1]
        axes[i].plot(period_sum['date'], period_sum[var], marker='o', color=label_colours_dict[var], label=label)
        axes[i].set_xlabel(xy_labels[0][i])
        axes[i].set_ylabel(xy_labels[1][i])
        axes[i].tick_params(axis='x', rotation=45)
        axes[i].grid()
        axes[i].legend()
        plt.tight_layout()

    fig.savefig(os.path.join(results_dir, save_name), format=save_fig_as) # save figure
    plt.show()
    
def plot_combined_bike_metrics(df):
    
    df_bike_collisions = df[df['bike_involved'] == 'bike'] # filter rows where a bike is involved
    # new column to identify whether a cyclist was injured or killed
    df_bike_collisions['cyclist_hurt'] = df_bike_collisions['number_of_cyclists_injured'] + df_bike_collisions['number_of_cyclists_killed']
    
    # group by time period (year or month) and calculate the total number of collisions
    total_collisions = (df.groupby(['year', 'month']).size().reset_index(name='total_collisions'))

    # Calculate the total number of cyclists injured or killed for each month
    cyclist_hurt_by_month = (
        df_bike_collisions.groupby(['year', 'month'])['cyclist_hurt']
        .sum()
        .reset_index(name='total_cyclist_hurt')
    )

    merged_data = pd.merge(total_collisions, cyclist_hurt_by_month, on=['year', 'month'], how='left').fillna(0) # merge both dfs on year and month
    merged_data['cyclist_hurt_percentage'] = (merged_data['total_cyclist_hurt'] / merged_data['total_collisions']) * 100 # % cyclists injured or killed out of total collisions
    merged_data['date'] = pd.to_datetime(merged_data[['year', 'month']].assign(day=1))  # convert year and month to a datetime column for plotting

    bike_collision_filter = df[['bike_involved']].eq('bike').any(axis=1) # filter bike collisions
    df['is_bike_collision'] = bike_collision_filter

    # count bike collisions
    df['crash_month'] = pd.to_datetime(df[['year', 'month']].assign(day=1))
    monthly_totals = df.groupby('crash_month').size().reset_index(name='total_collisions')
    bike_collisions = df[df['is_bike_collision']].groupby('crash_month').size().reset_index(name='bike_collisions')

    bike_collision_data = pd.merge(monthly_totals, bike_collisions, on='crash_month', how='left').fillna(0)

    # % of bike collisions
    bike_collision_data['bike_collision_percentage'] = (bike_collision_data['bike_collisions'] / bike_collision_data['total_collisions']) * 100

    plt.close()
    fig = plt.figure(figsize=(12, 6))

    # % collisions involving bike
    plt.plot(bike_collision_data['crash_month'], bike_collision_data['bike_collision_percentage'], marker='o', color='blue', 
             label='Percentage of bike collisions (% of total collisions)')
    # % cyclist injuries/fatalities from total collisions
    plt.plot(merged_data['date'], merged_data['cyclist_hurt_percentage'], marker='o', color='orange', 
             label='Cyclist injury/fatality (% of total collisions)')

    plt.title('Bike collisions')
    plt.xlabel('Date (year, month)')
    plt.ylabel('Percentage (%)')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()
    fig.savefig(os.path.join(results_dir, 'bike_collisions.'+save_fig_as), format=save_fig_as) # save figure


def plot_injured_killed(df, category, iax=0):
    plt.close()
    sns.barplot(data=df[df['type'] == category], x='group', y='count_sum', ax=axs[iax], 
            palette=colour_palette, alpha=0.7, order=['pedestrians','cyclists'])
    sns.stripplot(data=df[df['type'] == category], x='group', y='count_sum', ax=axs[iax], color='black', 
                  dodge=True, size=4, jitter=True, edgecolor='gray', linewidth=0.5, legend=False, order=['pedestrians','cyclists']);
    axs[iax].set_title(f'Average number of people {category}')
    axs[iax].set_ylabel('Avg. number of people/year (#)')
    # add stats to plot
    sum_ = df.loc[df['type'] == category]
    stats_ = statistical_test( {"cyclists": sum_.loc[sum_['group'] == 'cyclists', 'count_sum'], 
                                "pedestrians":  sum_.loc[sum_['group'] == 'pedestrians', 'count_sum']}, pairing='paired');
    p_color = color_p(stats_["p-value"].item(), sign, not_sign); # get color for p-value text, dependen on whether significant
    print(f'{stats_.test.item()} for {category}: p-value = {stats_["p-value"].item():.6f}')
    #print(stats_injured)
    axs[iax].annotate(f'p-value={stats_["p-value"].item():.4f}', xy=(0.5, 0.9), 
                    xycoords='axes fraction', ha='center', va='bottom', fontsize=10, xytext=(0, 10), textcoords='offset points', color=p_color)
    plt.show()
    
def ddf_plot_grouped_timeline(ddf, variable, group_by, xy_labels, plot_title, save_name):
    ddf = ddf.dropna(subset=[variable, group_by]) # drop all rows where variable is NaN 
    ddf = ddf.sort_values(by='start_datetime').compute()
    ddf['year'] = ddf['start_datetime'].dt.year.astype('int32')
    ddf['month'] = ddf['start_datetime'].dt.month.astype('int32')
    
    bike_rides_number_month = ( # Calculate the count for each month
        ddf.groupby(['year','month', group_by])[variable]
        .count().reset_index()#.compute()
    )
    bike_rides_sum_month = ( # Calculate the sum for each month
        ddf.groupby(['year','month', group_by])[variable]
        .sum().reset_index()#.compute()
    )

    bike_rides_number_month['date'] = pd.to_datetime(bike_rides_number_month[['year','month']].assign(day=1))
    bike_rides_sum_month['date'] = pd.to_datetime(bike_rides_sum_month[['year','month']].assign(day=1))
    
    formatter = ScalarFormatter(useMathText=True) # adjust power limits to enforce scientific notation
    formatter.set_scientific(True)
    formatter.set_powerlimits((0, 4))  
    
    plt.close()
    fig, axes = plt.subplots(2,1, figsize=(12, 12), sharex=True); axes = axes.flatten();
    plt.suptitle(plot_title)
    
    unique_groups = bike_rides_number_month[group_by].unique()
    for group in unique_groups:
        group_data_count = bike_rides_number_month[bike_rides_number_month[group_by] == group]
        group_data_sum = bike_rides_sum_month[bike_rides_sum_month[group_by] == group]
        # Plot the count per group
        axes[0].plot(group_data_count['date'], group_data_count[variable], marker='o', linewidth=3, markersize=5, label=f'{group}')
    
        # Plot the sum per group
        axes[1].plot(group_data_sum['date'], group_data_sum[variable], marker='o', linewidth=3, markersize=5, label=f'{group}')
    
    #axes[0].plot(bike_rides_number_month['date'], bike_rides_number_month[variable], marker='o', color=label_colours_dict[var], label='count')
    axes[0].set_xlabel(xy_labels[0][0])
    axes[0].set_ylabel(xy_labels[1][0])
    axes[0].set_title('Rental count')
    axes[0].yaxis.set_major_formatter(formatter) # format y-values as scientific for better readability
    axes[0].tick_params(axis='x', rotation=45)
    #axes[0].set_xlim([])
    axes[0].grid()
    axes[0].legend(title='Age group')
    
    #axes[1].plot(bike_rides_sum_month['date'], bike_rides_sum_month[variable], marker='o', color=label_colours_dict[var], label='count')
    axes[1].set_xlabel(xy_labels[0][1])
    axes[1].set_ylabel(xy_labels[1][1])
    axes[1].set_title('Rental duration')
    axes[1].yaxis.set_major_formatter(formatter) # format y-values as scientific for better readability
    axes[1].tick_params(axis='x', rotation=45)
    axes[1].grid()
    axes[1].legend(title='Age group')
    plt.tight_layout()
    
    fig.savefig(save_name, format='jpg') # save figure
    plt.show()

def plot_grouped_ride_duration_bars(df, variable, group_by, xy_labels, plot_title, save_name, filter_years=[]):

    if filter_years: # if select only specific years
        df = df.loc[df['start_datetime'].dt.year.between(filter_years[0], filter_years[1])]
        plot_title = plot_title + f' ({filter_years[0]}-{filter_years[1]})' # update title
        save_name = save_name +  f'_{filter_years[0]}-{filter_years[1]}' # update save name
        
    df_yearly = df.groupby(group_by).agg({
        variable: 'sum'
    }).reset_index().compute()
    
    n_groups = len(ddf[group_by].unique().compute()) # n groups

    plt.close()
    sns.set_theme(style='whitegrid')
    g = sns.catplot(height=4, aspect=0.34 * n_groups,
        data=df_yearly, kind='bar', 
        x=group_by, y=variable, color='navy', #**plotting_params
    )
    g.despine(left=True)
    g.set_axis_labels(xy_labels[0], xy_labels[1])
    plt.title(plot_title)
    
    g.savefig(save_name, format='jpg') # save figure
    plt.show()

def plot_grouped_ride_count_bars(df, variable, group_by, xy_labels, plot_title, save_name, filter_years=[]):

    if filter_years: # if select only specific years
        df = df.loc[df['start_datetime'].dt.year.between(filter_years[0], filter_years[1])]
        plot_title = plot_title + f' ({filter_years[0]}-{filter_years[1]})' # update title
        save_name = save_name +  f'_{filter_years[0]}-{filter_years[1]}' # update save name
    
    df_yearly = df.groupby(group_by).size().reset_index()  # get size as a Series
    df_yearly.columns = [group_by, 'count']  # rename columns
    df_yearly = df_yearly.compute() 

    n_groups = len(ddf[group_by].unique().compute()) # n groups

    plt.close()
    sns.set_theme(style='whitegrid')
    g = sns.catplot(height=4, aspect=0.37 * n_groups,
        data=df_yearly, kind='bar', 
        x=group_by, y='count', color='lightblue', #**plotting_params
    )
    g.despine(left=True)
    g.set_axis_labels(xy_labels[0], xy_labels[1])
    plt.title(plot_title)

    g.savefig(save_name, format='jpg') # save figure
    plt.show()
    
def annotate_pvals(plot, p_values, ax=None):
    if ax is None:
        ax = plot.ax
    for i, p_value in enumerate(p_values):
        color = color_p(p_value);
        ax.annotate(f'p={p_value:.3f}', (i, 0), xytext=(0, 5), textcoords='offset points',
                    ha='center', va='bottom', fontsize=10, color=color)
        
def color_p(pval, sign="crimson", not_sign="slategrey"): # determine colour of p value depicted in figure
    if pval < 0.05: 
        p_color = sign; 
    else:
        p_color = not_sign;
    return p_color

def get_ddf_sum_and_count(df, variable, group_by):
    
    df_sum = df.groupby(group_by).agg({ # trip_duration in hours
        variable: 'sum',
    }).reset_index()
    df_sum=df_sum.sort_values(by='hour').compute()

    df_count = ddf.groupby('hour').agg({ # count of trips
    variable: 'count',
    }).reset_index()
    
    df_count=df_count.sort_values(by='hour').compute()

    return df_sum, df_count


def get_df_count(df, variable, group_by):
    df_count = df.groupby(group_by).agg({
        variable: 'count',
    }).reset_index()
    
    df_count = df_count.sort_values(by=group_by)  # No need for .compute() in Pandas
    
    return df_count

def assign_season(df, time_datetime):

    df = df.dropna(subset=[time_datetime]) # remove row when 'start_datetime' has nans
    
    # Assign seasons based on month (rough estimation)
    #df['season'] = 'winter' # assign winter as defaults
    # df.loc[df['start_datetime'].dt.month.between(4, 6), 'season'] = 'spring'
    # df.loc[df['start_datetime'].dt.month.between(7, 9), 'season'] = 'summer'
    # df.loc[df['start_datetime'].dt.month.between(10, 11), 'season'] = 'fall'

    df['season'] = np.nan
    
    # Assigning seasons based on specific date ranges
    df.loc[(df[time_datetime].dt.month == 12) & (df[time_datetime].dt.day >= 21), 'season'] = 'winter'
    df.loc[(df[time_datetime].dt.month == 1) | (df[time_datetime].dt.month == 2), 'season'] = 'winter'
    df.loc[(df[time_datetime].dt.month == 3) & (df[time_datetime].dt.day <= 20), 'season'] = 'winter'
    
    df.loc[(df[time_datetime].dt.month == 3) & (df[time_datetime].dt.day >= 21), 'season'] = 'spring'
    df.loc[(df[time_datetime].dt.month == 4) | (df[time_datetime].dt.month == 5), 'season'] = 'spring'
    df.loc[(df[time_datetime].dt.month == 6) & (df[time_datetime].dt.day <= 20), 'season'] = 'spring'
    
    df.loc[(df[time_datetime].dt.month == 6) & (df[time_datetime].dt.day >= 21), 'season'] = 'summer'
    df.loc[(df[time_datetime].dt.month == 7) | (df[time_datetime].dt.month == 8), 'season'] = 'summer'
    df.loc[(df[time_datetime].dt.month == 9) & (df[time_datetime].dt.day <= 22), 'season'] = 'summer'
    
    df.loc[(df[time_datetime].dt.month == 9) & (df[time_datetime].dt.day >= 23), 'season'] = 'autumn'
    df.loc[(df[time_datetime].dt.month == 10) | (df[time_datetime].dt.month == 11), 'season'] = 'autumn'
    df.loc[(df[time_datetime].dt.month == 12) & (df[time_datetime].dt.day <= 20), 'season'] = 'autumn'

    return df
