In [4]:
import plotly.graph_objects as go
import plotly.io as pio

# Define the updated base template
base = go.layout.Template(
    layout=go.Layout(
        paper_bgcolor='#FFF5CC',
        plot_bgcolor='#FFF5CC',
        height=800,
        width=800 * 1.618,
        xaxis=dict(
            anchor='y',
            showgrid=True,
            gridcolor='#888888',  # Darker grid lines
            tickfont=dict(
                size=24,
                family='Open Sans, sans-serif'  # Use Open Sans font
            ),
            titlefont=dict(
                size=26,
                family='Open Sans, sans-serif'  # Use Open Sans font
            ),
            linecolor='#333333',
            linewidth=2  # Adjust the thickness of the x-axis line
        ),
        yaxis=dict(
            anchor='x',
            showgrid=True,
            gridcolor='#888888',  # Darker grid lines
            tickfont=dict(
                size=24,
                family='Open Sans, sans-serif'
            ),
            titlefont=dict(
                size=26,
                family='Open Sans, sans-serif'
            ),
            linecolor='#333333',
            linewidth=2  # Adjust the thickness of the y-axis line
        ),
        font=dict(
            color='#333333',
            size=28,
            family='Open Sans, sans-serif'
        ),
        # Updated colorway to ensure more distinguishable colors
        colorway=["#470945", # H: Violet
                  "#E67E5A", # H: Orange (Sienna)
                  "#297FB9", # H: Blue (Steel)
                  "#163748", # D: Charcoal
                  "#4F1787", # H: Purple
                  "#EFE04E", # H: Yellow (Maize)
                  "#214F70", # D: Indigo
                  "#DF14AA", # H: Pink (Cerise)
                  "#100B1A", # D: Black
                  "#12C4CF", # H: Teal
                  "#14193D", # D: Space
                  "#CC5500"],# H: Cream
        title=go.layout.Title(
            text='',
            font=dict(
                size=34,
                color='#333333',
                family='Open Sans, sans-serif'
            ),
            x=0.05,
        )
    ),
    data=dict(
        scatter=[
            go.Scatter(
                line=dict(width=3)  # Set the line width for scatter plots
            )
        ]
    )
)

# Register the updated base template
pio.templates['base'] = base
pio.templates.default = 'base'

In [5]:
import pandas as pd

# Load the CSV data
file_name = "d_aus_suicide.csv"  # Replace with your file path
data = pd.read_csv(file_name)

# Function to group age ranges
def group_age(age_group):
    if age_group in ["15-19"]:
        return "<20"
    elif age_group in ["20-24", "25-29"]:
        return "20s"
    elif age_group in ["30-34", "35-39"]:
        return "30s"
    elif age_group in ["40-44", "45-49"]:
        return "40s"
    elif age_group in ["50-54", "55-59"]:
        return "50s"
    elif age_group in ["60-64", "65-69"]:
        return "60s"
    elif age_group in ["70-74", "75-79"]:
        return "70s"
    elif age_group in ["80-84", "85+"]:
        return "80+"
    else:
        return "Unknown"

# Add a new column for grouped age ranges
data["age_group_grouped"] = data["age_group"].apply(group_age)

# Save the updated dataset to a new CSV (optional)
output_file_name = "d_aus_suicide_grouped.csv"
data.to_csv(output_file_name, index=False)

# Display the first few rows of the updated data
print(data.head())


     sex age_group  year  suicide_rate age_group_grouped
0  Males     15-19  1907           3.3               <20
1  Males     20-24  1907           9.1               20s
2  Males     25-29  1907          18.3               20s
3  Males     30-34  1907          18.7               30s
4  Males     35-39  1907          24.2               30s


In [101]:
import pandas as pd
import plotly.graph_objects as go

def plot_normalized_suicide_trends(file_name, min_year, max_year, age_groups=None, split_by_gender=True, split_by_age_group=True, template="base", rolling_window=None, highlight_lines=None, show_legend=True, width=1920, height=1080):
    """
    Plots normalized suicide rate trends with options to group or split by gender and age groups, apply a rolling window average, and highlight specific lines.

    Parameters:
        file_name (str): Path to the CSV file.
        min_year (int): Minimum year for filtering data.
        max_year (int): Maximum year for filtering data.
        age_groups (list or None): List of age groups to include (e.g., ["<20", "20s"]). If None, include all age groups.
        split_by_gender (bool): Whether to split trends by gender.
        split_by_age_group (bool): Whether to split trends by age groups.
        template (str): Plotly template to apply (default: "base").
        rolling_window (int or None): Size of the rolling window for smoothing. If None, no smoothing is applied.
        highlight_lines (list of str or None): List of specific combinations to highlight (e.g., ["Males_<20", "Females_20s"]). If None, no highlighting is applied.
        show_legend (bool): Whether to include the legend in the plot.
        width (int): Width of the plot (default: 1920 for widescreen).
        height (int): Height of the plot (default: 1080 for widescreen).
    """
    # Load the data
    data = pd.read_csv(file_name)

    # Filter data for relevant years and age groups (if specified)
    filtered_data = data[
        (data["year"] >= min_year) &
        (data["year"] <= max_year)
    ]
    if age_groups is not None:
        filtered_data = filtered_data[filtered_data["age_group_grouped"].isin(age_groups)]

    # Handle grouping based on user input
    group_columns = []
    if split_by_gender:
        group_columns.append("sex")
    if split_by_age_group:
        group_columns.append("age_group_grouped")

    # Pivot data
    pivot_data = filtered_data.pivot_table(
        index="year", 
        columns=group_columns, 
        values="suicide_rate", 
        aggfunc="mean"
    )

    # Normalize data to min_year = 100
    normalized_data = pivot_data.div(pivot_data.loc[min_year]).mul(100)

    # Apply rolling window if specified
    if rolling_window:
        normalized_data = normalized_data.rolling(window=rolling_window, min_periods=1).mean()

    # Flatten the column MultiIndex for easier access
    if group_columns:
        normalized_data.columns = [
            "_".join(map(str, col)) for col in normalized_data.columns
        ]
    else:
        normalized_data.columns = ["Overall"]

    # Define color mapping for each gender
    gender_colors = {
        "Males": '#297fb9',
        "Females": "#4f1787"
    }

    # Initialize Plotly figure
    fig = go.Figure()

    # Add traces for each cohort
    for column in normalized_data.columns:
        gender = "Males" if "Males" in column else "Females"
        is_highlighted = highlight_lines and column in highlight_lines

        # Add white border (only for highlighted lines)
        if is_highlighted:
            fig.add_trace(
                go.Scatter(
                    x=normalized_data.index,
                    y=normalized_data[column],
                    mode="lines",
                    name=None,  # No legend entry for the border
                    line=dict(
                        color="white",  # Border color
                        width=14         # Border width
                    ),
                    showlegend=False  # Don't show border in legend
                )
            )

        # Add the actual line
        fig.add_trace(
            go.Scatter(
                x=normalized_data.index,
                y=normalized_data[column],
                mode="lines",
                name=column.replace("_", " "),
                line=dict(
                    color=gender_colors[gender],
                    width=10 if is_highlighted else 4
                ),
                opacity=1.0 if is_highlighted else 0.15  # Set opacity based on highlight
            )
        )

    # Apply the template and customize layout
        fig.update_layout(
            title=dict(
                text="Suicide Rate Australia (Normalised)",
                font=dict(size=50)  # Adjust the size here (e.g., 28)
            ),
        xaxis_title=None,  # Remove x-axis title
        yaxis_title=None, #"Suicide Rate (Normalized)",
        template=template,
        legend_title="Group",
        showlegend=show_legend,  # Control legend visibility
        width=width,            # Set plot width
        height=height,          # Set plot height
        xaxis=dict(
            tickfont=dict(size=35)     # X-axis tick label font size
        ),
        yaxis=dict(
            # title_font=dict(size=30),  # Y-axis title font size
            tickfont=dict(size=35)     # Y-axis tick label font size
        )
    )

    # Show the plot
    fig.show()


In [97]:
plot_normalized_suicide_trends(
    file_name="d_aus_suicide_grouped.csv",
    min_year=2006,
    max_year=2020,
    # age_groups=["<20", "20s", "30s"],
    split_by_gender=True,
    split_by_age_group=True,
    template="base",
    rolling_window=5,
    # highlight_lines=["Females_<20", "Females_20s", "Females_40s"],
    show_legend=False
)


In [102]:
plot_normalized_suicide_trends(
    file_name="d_aus_suicide_grouped.csv",
    min_year=2006,
    max_year=2020,
    # age_groups=["<20", "20s", "30s"],
    split_by_gender=True,
    split_by_age_group=True,
    template="base",
    rolling_window=5,
    highlight_lines=["Males_70s", "Females_70s"],
    show_legend=False
)

In [103]:
plot_normalized_suicide_trends(
    file_name="d_aus_suicide_grouped.csv",
    min_year=2006,
    max_year=2020,
    # age_groups=["<20", "20s", "30s"],
    split_by_gender=True,
    split_by_age_group=True,
    template="base",
    rolling_window=5,
    highlight_lines=["Females_<20", "Females_20s", "Females_40s"],
    show_legend=False
)
