In [None]:
print("Pandemic response patterns")

In [4]:
def extract_country_features(country_name):
    country_data = df[df['location'] == country_name].copy()
    
    if len(country_data) < 30:
        return None
    
   
    max_cases_per_million = country_data['cases_per_million'].max() if 'cases_per_million' in country_data.columns else np.nan
    max_deaths_per_million = country_data['deaths_per_million'].max() if 'deaths_per_million' in country_data.columns else np.nan
    

    cfr = (country_data['total_deaths'].max() / country_data['total_cases'].max() * 100) if country_data['total_cases'].max() > 0 else np.nan
    

    avg_stringency = country_data['stringency_index'].mean() if 'stringency_index' in country_data.columns else np.nan
    
   
    max_vax_rate = country_data['people_fully_vaccinated_per_hundred'].max() if 'people_fully_vaccinated_per_hundred' in country_data.columns else np.nan
    

    pop_density = country_data['population_density'].iloc[0] if 'population_density' in country_data.columns else np.nan
    median_age = country_data['median_age'].iloc[0] if 'median_age' in country_data.columns else np.nan
    

    peak_intensity = country_data['new_cases_per_million'].max() if 'new_cases_per_million' in country_data.columns else np.nan
    
   
    wave_count = np.nan
    if 'cases_7day_avg' in country_data.columns:
        smooth_cases = country_data['cases_7day_avg'].fillna(0).values
        if len(smooth_cases) > 0 and np.max(smooth_cases) > 0:
            peaks, _ = find_peaks(
                smooth_cases, 
                prominence=np.max(smooth_cases) * 0.1,
                distance=14
            )
            wave_count = len(peaks)
    
    return {
        'country': country_name,
        'cases_per_million': max_cases_per_million,
        'deaths_per_million': max_deaths_per_million,
        'case_fatality_rate': cfr,
        'stringency_index': avg_stringency,
        'vaccination_rate': max_vax_rate,
        'population_density': pop_density,
        'median_age': median_age,
        'peak_intensity': peak_intensity,
        'wave_count': wave_count
    }


print("Extracting features for clustering analysis...")
countries = df['location'].unique()
features_list = []

for country in countries:
    if country in ['World', 'International', 'Europe', 'Asia', 'North America', 'South America', 'Africa', 'Oceania']:
        continue
    
    features = extract_country_features(country)
    if features is not None:
        features_list.append(features)

features_df = pd.DataFrame(features_list)

if len(features_df) > 10:
    print(f"Collected features for {len(features_df)} countries")
    
    cluster_features = ['cases_per_million', 'deaths_per_million', 'case_fatality_rate', 
                         'stringency_index', 'peak_intensity', 'wave_count']
    
    available_features = [col for col in cluster_features if features_df[col].notna().mean() >= 0.7]
    
    if len(available_features) >= 3:  
        print(f"Using features for clustering: {', '.join(available_features)}")
        
        cluster_data = features_df.dropna(subset=available_features)
        
        if len(cluster_data) >= 30:  
            print(f"Clustering {len(cluster_data)} countries with complete data")
            
            scaler = StandardScaler()
            scaled_features = scaler.fit_transform(cluster_data[available_features])
            
            silhouette_scores = []
            max_clusters = min(8, len(cluster_data) // 5)

             for k in range(2, max_clusters + 1):
                kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
                cluster_labels = kmeans.fit_predict(scaled_features)
                silhouette_avg = silhouette_score(scaled_features, cluster_labels)
                silhouette_scores.append(silhouette_avg)
            
            optimal_k = silhouette_scores.index(max(silhouette_scores)) + 2  
            print(f"Optimal number of clusters identified: {optimal_k}")
            
            kmeans = KMeans(n_clusters=optimal_k, random_state=42, n_init=10)
            cluster_data['cluster'] = kmeans.fit_predict(scaled_features)
            
            cluster_stats = cluster_data.groupby('cluster').agg({
                'country': 'count',
                'cases_per_million': 'mean',
                'deaths_per_million': 'mean',
                'case_fatality_rate': 'mean',
                'stringency_index': 'mean' if 'stringency_index' in available_features else 'count',
                'peak_intensity': 'mean' if 'peak_intensity' in available_features else 'count',
                'wave_count': 'mean' if 'wave_count' in available_features else 'count'
            }).rename(columns={'country': 'count'})
            
            print("\nCluster analysis results:")
            print(cluster_stats.round(2))
            
            print("\nCluster interpretation:")
            for cluster_id in range(optimal_k):
                cluster_countries = cluster_data[cluster_data['cluster'] == cluster_id]['country'].tolist()
                sample_countries = ', '.join(cluster_countries[:5])
                if len(cluster_countries) > 5:
                    sample_countries += f" and {len(cluster_countries) - 5} more"
                    
                print(f"\nCluster {cluster_id+1} ({len(cluster_countries)} countries, including {sample_countries}):")
                
                characteristics = []
                
                cases = cluster_stats.loc[cluster_id, 'cases_per_million']
                deaths = cluster_stats.loc[cluster_id, 'deaths_per_million']
                cfr = cluster_stats.loc[cluster_id, 'case_fatality_rate']
                
                if cases < cluster_stats['cases_per_million'].quantile(0.33):
                    characteristics.append("LOW case rates")
                elif cases > cluster_stats['cases_per_million'].quantile(0.67):
                    characteristics.append("HIGH case rates")
                else:
                    characteristics.append("MODERATE case rates")
                    
                if deaths < cluster_stats['deaths_per_million'].quantile(0.33):
                    characteristics.append("LOW death rates")
                elif deaths > cluster_stats['deaths_per_million'].quantile(0.67):
                    characteristics.append("HIGH death rates")
                else:
                    characteristics.append("MODERATE death rates")
                    if cfr < cluster_stats['case_fatality_rate'].quantile(0.33):
                    characteristics.append("LOW case fatality")
                elif cfr > cluster_stats['case_fatality_rate'].quantile(0.67):
                    characteristics.append("HIGH case fatality")
                else:
                    characteristics.append("MODERATE case fatality")
                
                if 'stringency_index' in available_features:
                    stringency = cluster_stats.loc[cluster_id, 'stringency_index']
                    if stringency < cluster_stats['stringency_index'].quantile(0.33):
                        characteristics.append("LENIENT policies")
                    elif stringency > cluster_stats['stringency_index'].quantile(0.67):
                        characteristics.append("STRICT policies")
                    else:
                        characteristics.append("MODERATE policies")
                
                if 'wave_count' in available_features:
                    waves = cluster_stats.loc[cluster_id, 'wave_count']
                    if waves < cluster_stats['wave_count'].quantile(0.33):
                        characteristics.append("FEW waves")
                    elif waves > cluster_stats['wave_count'].quantile(0.67):
                        characteristics.append("MANY waves")
                    else:
                        characteristics.append("AVERAGE number of waves")
                
                print(f"  Characteristics: {', '.join(characteristics)}")
            
            from sklearn.decomposition import PCA
            
            pca = PCA(n_components=2)
            principal_components = pca.fit_transform(scaled_features)
            
            pca_df = pd.DataFrame(data=principal_components, columns=['PC1', 'PC2'])
            pca_df['cluster'] = cluster_data['cluster']
            pca_df['country'] = cluster_data['country']
            
            fig_clusters = px.scatter(
                pca_df, 
                x='PC1', 
                y='PC2', 
                color='cluster',
                hover_name='country',
                title='Country Clustering Based on COVID-19 Response Patterns',
                labels={'PC1': 'Principal Component 1', 'PC2': 'Principal Component 2'},
                color_continuous_scale=px.colors.qualitative.G10
            )
            fig_clusters.show()
            
            if len(available_features) >= 3:
                
                selected_features = available_features[:3]
                
                fig_3d = px.scatter_3d(
                    cluster_data, 
                    x=selected_features[0],
                    y=selected_features[1],
                    z=selected_features[2],
                    color='cluster',
                    hover_name='country',
                    title=f'Country Clustering in 3D Feature Space',
                    labels={
                        selected_features[0]: selected_features[0].replace('_', ' ').title(),
                        selected_features[1]: selected_features[1].replace('_', ' ').title(),
                        selected_features[2]: selected_features[2].replace('_', ' ').title()
                    },
                    color_continuous_scale=px.colors.qualitative.G10
                )
                
                fig_3d.update_layout(
                    scene=dict(
                        xaxis_title=selected_features[0].replace('_', ' ').title(),
                        yaxis_title=selected_features[1].replace('_', ' ').title(),
                        zaxis_title=selected_features[2].replace('_', ' ').title()
                    )
                )
                
                fig_3d.show()

                 else:
            print(f"Not enough countries ({len(cluster_data)}) with complete data for clustering")
    else:
        print(f"Not enough features available for clustering (need at least 3, found {len(available_features)})")
else:
    print("Not enough countries with sufficient data for clustering analysis")


print("\n=== DEMOGRAPHIC IMPACT ANALYSIS ===")


demographic_columns = ['median_age', 'aged_65_older', 'aged_70_older', 'gdp_per_capita', 
                        'extreme_poverty', 'cardiovasc_death_rate', 'diabetes_prevalence',
                        'life_expectancy', 'human_development_index']

available_demographics = [col for col in demographic_columns if col in df.columns]

if available_demographics:
    print(f"Available demographic indicators: {', '.join(available_demographics)}")
    
    latest_data = df.sort_values('date').groupby('location').last().reset_index()
    
    demo_analysis = latest_data[latest_data['total_cases'] >= 10000].copy()
    
    if len(demo_analysis) >= 30:  # Need reasonable sample size
        print(f"Analyzing demographic factors for {len(demo_analysis)} countries with significant outbreaks")
        
        if 'deaths_per_million' in demo_analysis.columns:
            correlation_results = []
            
            for demo_col in available_demographics:
                if demo_col in demo_analysis.columns:
                   
                    corr_deaths = demo_analysis[demo_col].corr(demo_analysis['deaths_per_million'])
                    
                    if 'case_fatality_rate' in demo_analysis.columns:
                        corr_cfr = demo_analysis[demo_col].corr(demo_analysis['case_fatality_rate'])
                    else:
                        corr_cfr = np.nan
                    
                    correlation_results.append({
                        'Demographic Factor': demo_col.replace('_', ' ').title(),
                        'Correlation with Deaths/Million': corr_deaths,
                        'Correlation with CFR': corr_cfr
                    })
            
            corr_df = pd.DataFrame(correlation_results)
            corr_df = corr_df.sort_values('Correlation with Deaths/Million', key=abs, ascending=False)
            
            print("\nCorrelation of demographic factors with COVID-19 impact:")
            print(corr_df.round(3))
            
            top_demos = corr_df.iloc[:5]['Demographic Factor'].tolist()
            fig_corr = go.Figure()
            
            fig_corr.add_trace(go.Bar(
                x=corr_df['Demographic Factor'],
                y=corr_df['Correlation with Deaths/Million'],
                name='Correlation with Deaths/Million',
                marker_color='darkred'
            ))
            
          
            if not corr_df['Correlation with CFR'].isna().all():
                fig_corr.add_trace(go.Bar(
                    x=corr_df['Demographic Factor'],
                    y=corr_df['Correlation with CFR'],
                    name='Correlation with Case Fatality Rate',
                    marker_color='darkblue'
                ))
            
            fig_corr.update_layout(
                title='Correlation of Demographic Factors with COVID-19 Severity',
                xaxis_title='Demographic Factor',
                yaxis_title='Correlation Coefficient',
                barmode='group',
                xaxis={'categoryorder':'total descending'}
            )
            
            fig_corr.show()
            
            for i, demo_factor in enumerate(top_demos[:3]):  
                demo_col = demo_factor.replace(' ', '_').lower()
                if demo_col in demo_analysis.columns:

import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import warnings
from scipy import stats
import statsmodels.api as sm
from statsmodels.tsa.seasonal import seasonal_decompose
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score

             
warnings.filterwarnings('ignore')

plt.style.use('ggplot')
sns.set_palette("Set2")
print("Setting up environment and loading packages...")


try:
    from google.colab import output
    IN_COLAB = True
    print("Running in Google Colab environment")
    from plotly.offline import init_notebook_mode
    init_notebook_mode(connected=True)
except ImportError:
    IN_COLAB = False
    print("Not running in Google Colab")

print("\n=== Loading and Preparing COVID-19 Data ===")
print("Loading COVID-19 data from Our World in Data...")

df = pd.read_csv('https://covid.ourworldindata.org/data/owid-covid-data.csv')

df['date'] = pd.to_datetime(df['date'])

print(f"Dataset shape: {df.shape}")
print(f"Date range: {df['date'].min().strftime('%Y-%m-%d')} to {df['date'].max().strftime('%Y-%m-%d')}")
print(f"Number of countries/regions: {df['location'].nunique()}")
print(f"Number of columns: {df.shape[1]}")

metric_categories = {
    'Cases': [col for col in df.columns if 'case' in col.lower()],
    'Deaths': [col for col in df.columns if 'death' in col.lower()],
    'Tests': [col for col in df.columns if 'test' in col.lower()],
    'Vaccinations': [col for col in df.columns if 'vaccine' in col.lower() or 'vaccination' in col.lower()],
    'Hospital & ICU': [col for col in df.columns if 'hosp' in col.lower() or 'icu' in col.lower()],
    'Policy': [col for col in df.columns if 'stringency' in col.lower() or 'policy' in col.lower()],
    'Demographics': ['population', 'population_density', 'median_age', 'aged_65_older', 'aged_70_older', 'gdp_per_capita', 'life_expectancy']
}

print("\n=== Available Data Categories ===")
for category, metrics in metric_categories.items():
    if metrics: 
        print(f"{category}: {len(metrics)} metrics")
        print(f"  Example metrics: {', '.join(metrics[:3])}" + (" and more..." if len(metrics) > 3 else ""))

print("\n=== Data Cleaning and Preparation ===")
print("Handling missing values...")

key_metrics = ['new_cases', 'new_deaths', 'total_cases', 'total_deaths']
missing_stats = pd.DataFrame({
    'Missing (%)': {col: df[col].isna().mean() * 100 for col in key_metrics}
})
print("Missing values in key metrics:")
print(missing_stats)


for col in key_metrics:
    if col in df.columns:
        df[col] = df[col].fillna(0)


print("\nCreating advanced metrics...")

df['cases_7day_avg'] = df.groupby('location')['new_cases'].transform(
    lambda x: x.rolling(window=7, min_periods=1).mean())
df['deaths_7day_avg'] = df.groupby('location')['new_deaths'].transform(
    lambda x: x.rolling(window=7, min_periods=1).mean())

df['case_growth_rate'] = df.groupby('location')['total_cases'].pct_change() * 100
df['death_growth_rate'] = df.groupby('location')['total_deaths'].pct_change() * 100

df['case_doubling_days'] = np.log(2) / (np.log(1 + df['case_growth_rate']/100))
df['death_doubling_days'] = np.log(2) / (np.log(1 + df['death_growth_rate']/100))

for col in ['case_doubling_days', 'death_doubling_days']:

    df[col] = df[col].replace([np.inf, -np.inf], np.nan)
    upper_limit = df[col].quantile(0.95)  # 95th percentile
    df[col] = df[col].clip(upper=upper_limit)


df['case_fatality_rate'] = np.where(
    df['total_cases'] > 0,
    (df['total_deaths'] / df['total_cases']) * 100,
    np.nan
)


if 'population' in df.columns:
    print("Calculating per capita metrics...")
    per_capita_metrics = {
        'cases_per_million': 'total_cases',
        'deaths_per_million': 'total_deaths',
        'daily_cases_per_million': 'new_cases',
        'daily_deaths_per_million': 'new_deaths'
    }
    
    for new_col, base_col in per_capita_metrics.items():
        if base_col in df.columns:
            df[new_col] = df[base_col] * 1000000 / df['population']

df['cases_7day_sum'] = df.groupby('location')['new_cases'].transform(
    lambda x: x.rolling(window=7).sum())
df['prev_cases_7day_sum'] = df.groupby('location')['cases_7day_sum'].shift(7)

df['approx_rt'] = np.where(
    df['prev_cases_7day_sum'] > 10,
    df['cases_7day_sum'] / df['prev_cases_7day_sum'],
    np.nan
)

df['approx_rt'] = df['approx_rt'].clip(0, 10) 

print("Data preparation complete.\n")

print("\n=== GLOBAL PANDEMIC OVERVIEW ===")


global_df = df.groupby('date')[['new_cases', 'new_deaths', 'total_cases', 'total_deaths']].sum().reset_index()


global_df['cases_7day_avg'] = global_df['new_cases'].rolling(window=7).mean()
global_df['deaths_7day_avg'] = global_df['new_deaths'].rolling(window=7).mean()


fig = make_subplots(
    rows=2, cols=1,
    subplot_titles=("Global Daily COVID-19 Cases", "Global Daily COVID-19 Deaths"),
    vertical_spacing=0.15,
    shared_xaxes=True
)


fig.add_trace(
    go.Bar(
        x=global_df['date'], 
        y=global_df['new_cases'],
        name="Daily Cases",
        marker_color='rgba(55, 128, 191, 0.5)'
    ),
    row=1, col=1
)

fig.add_trace(
    go.Scatter(
        x=global_df['date'], 
        y=global_df['cases_7day_avg'],
        name="7-Day Average (Cases)",
        line=dict(color='rgb(40, 61, 163)', width=2)
    ),
    row=1, col=1
)


fig.add_trace(
    go.Bar(
        x=global_df['date'], 
        y=global_df['new_deaths'],
        name="Daily Deaths",
        marker_color='rgba(219, 64, 82, 0.5)'
    ),
    row=2, col=1
)

fig.add_trace(
    go.Scatter(
        x=global_df['date'], 
        y=global_df['deaths_7day_avg'],
        name="7-Day Average (Deaths)",
        line=dict(color='rgb(168, 42, 42)', width=2)
    ),
    row=2, col=1
)


fig.update_layout(
    height=800,
    title_text="Global COVID-19 Pandemic Timeline",
    hovermode="x unified",
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
)


fig.update_yaxes(title_text="Number of Cases", row=1, col=1)
fig.update_yaxes(title_text="Number of Deaths", row=2, col=1)
fig.update_xaxes(title_text="Date", row=2, col=1)


fig.show()


print("\nAnalyzing global pandemic waves...")


from scipy.signal import find_peaks

peaks, _ = find_peaks(global_df['cases_7day_avg'], prominence=global_df['cases_7day_avg'].max() * 0.05)
wave_dates = global_df.iloc[peaks]['date']
wave_heights = global_df.iloc[peaks]['cases_7day_avg']

print(f"Detected {len(peaks)} major global COVID-19 waves:")
for i, (date, height) in enumerate(zip(wave_dates, wave_heights)):
    print(f"  Wave {i+1}: Peak on {date.strftime('%Y-%m-%d')} with {height:.0f} average daily cases")


print("\n=== REGIONAL ANALYSIS ===")


if 'continent' in df.columns:

    continent_data = df.groupby(['continent', 'date'])[['new_cases', 'new_deaths']].sum().reset_index()
    

    continent_data['cases_7day_avg'] = continent_data.groupby('continent')['new_cases'].transform(
        lambda x: x.rolling(window=7, min_periods=1).mean())
    

    fig_continent = px.line(
        continent_data,
        x='date',
        y='cases_7day_avg',
        color='continent',
        title="COVID-19 Cases by Continent (7-day Average)",
        labels={'cases_7day_avg': '7-day Average Cases', 'date': 'Date', 'continent': 'Continent'}
    )
    
    fig_continent.update_layout(
        xaxis_title="Date",
        yaxis_title="7-day Average Cases",
        legend_title="Continent",
        hovermode="x unified"
    )
    
    fig_continent.show()
    

    pandemic_timing = continent_data.groupby('continent').apply(
        lambda x: x.loc[x['cases_7day_avg'] > 1000, 'date'].min() if any(x['cases_7day_avg'] > 1000) else pd.NaT
    ).reset_index()
    pandemic_timing.columns = ['continent', 'date_reached_1000_cases']
    
    if not pandemic_timing['date_reached_1000_cases'].isna().all():

        pandemic_timing = pandemic_timing.sort_values('date_reached_1000_cases')
        
        print("\nPandemic progression across continents:")
        print("When each continent first reached 1,000 daily cases (7-day avg):")
        
        for _, row in pandemic_timing.iterrows():
            if pd.notna(row['date_reached_1000_cases']):
                print(f"  {row['continent']}: {row['date_reached_1000_cases'].strftime('%Y-%m-%d')}")
            else:
                print(f"  {row['continent']}: Never reached 1,000 daily cases")
else:
    print("Continent data not available in the dataset.")


print("\n=== COUNTRY COMPARISON ANALYSIS ===")


latest_date = df['date'].max()
latest = df[df['date'] == latest_date]


top_countries = latest.sort_values('total_cases', ascending=False).head(15)['location'].tolist()
print(f"Top countries by total cases: {', '.join(top_countries[:5])}, and more...")


top_countries_data = df[df['location'].isin(top_countries)].copy()


if 'population' in top_countries_data.columns:
  
    if 'cases_per_million' not in top_countries_data.columns:
        top_countries_data['cases_per_million'] = top_countries_data['total_cases'] * 1000000 / top_countries_data['population']
    if 'deaths_per_million' not in top_countries_data.columns:
        top_countries_data['deaths_per_million'] = top_countries_data['total_deaths'] * 1000000 / top_countries_data['population']

    
    fig_cases_per_million = px.line(
        top_countries_data,
        x='date',
        y='cases_per_million',
        color='location',
        title="COVID-19 Cases per Million Population (Top 15 Countries)",
        labels={'cases_per_million': 'Cases per Million', 'date': 'Date', 'location': 'Country'}
    )
    
    fig_cases_per_million.update_layout(
        xaxis_title="Date",
        yaxis_title="Cases per Million",
        legend_title="Country",
        hovermode="x unified"
    )
    
    fig_cases_per_million.show()
    
    
    fig_deaths_per_million = px.line(
        top_countries_data,
        x='date',
        y='deaths_per_million',
        color='location',
        title="COVID-19 Deaths per Million Population (Top 15 Countries)",
        labels={'deaths_per_million': 'Deaths per Million', 'date': 'Date', 'location': 'Country'}
    )
    
    fig_deaths_per_million.update_layout(
        xaxis_title="Date",
        yaxis_title="Deaths per Million",
        legend_title="Country",
        hovermode="x unified"
    )
    
    fig_deaths_per_million.show()


if all(col in latest.columns for col in ['cases_per_million', 'deaths_per_million']):
   
    latest['cases_score'] = 100 * latest['cases_per_million'] / latest['cases_per_million'].max()
    latest['deaths_score'] = 100 * latest['deaths_per_million'] / latest['deaths_per_million'].max()
    
   
    latest['severity_index'] = (0.4 * latest['cases_score'] + 0.6 * latest['deaths_score'])
    
   
    severity_data = latest[latest['total_cases'] > 10000].sort_values('severity_index', ascending=False)
    
   
    top_severity = severity_data.head(20)
    fig_severity = px.bar(
        top_severity,
        x='location',
        y='severity_index',
        title="COVID-19 Pandemic Severity Index (Top 20 Countries)",
        color='severity_index',
        labels={'severity_index': 'Severity Index', 'location': 'Country'},
        color_continuous_scale=px.colors.sequential.Plasma
    )
    
    fig_severity.update_layout(
        xaxis_title="Country",
        yaxis_title="Severity Index (0-100)",
        xaxis={'categoryorder':'total descending'}
    )
    
    fig_severity.show()
    
    print("\nTop 10 countries by pandemic severity index:")
    for i, (_, row) in enumerate(top_severity[:10].iterrows()):
        print(f"  {i+1}. {row['location']}: {row['severity_index']:.1f} (Cases per million: {row['cases_per_million']:.0f}, Deaths per million: {row['deaths_per_million']:.0f})")

print("\n=== CASE FATALITY RATE (CFR) ANALYSIS ===")


cfr_data = latest[(latest['total_cases'] >= 10000) & (~latest['case_fatality_rate'].isna())]

if not cfr_data.empty:

    cfr_data_sorted = cfr_data.sort_values('case_fatality_rate', ascending=False)
    

    top_cfr = cfr_data_sorted.head(10)
    bottom_cfr = cfr_data_sorted.tail(10)
    
   
    fig_cfr = make_subplots(
        rows=1, cols=2,
        subplot_titles=("Highest Case Fatality Rates", "Lowest Case Fatality Rates"),
        specs=[[{"type": "bar"}, {"type": "bar"}]],
        horizontal_spacing=0.1
    )
    
    
    fig_cfr.add_trace(
        go.Bar(
            x=top_cfr['location'],
            y=top_cfr['case_fatality_rate'],
            marker_color='darkred',
            name="Highest CFR"
        ),
        row=1, col=1
    )
    
    
    fig_cfr.add_trace(
        go.Bar(
            x=bottom_cfr['location'],
            y=bottom_cfr['case_fatality_rate'],
            marker_color='darkgreen',
            name="Lowest CFR"
        ),
        row=1, col=2
    )
    
    fig_cfr.update_layout(
        title_text="Countries with Highest and Lowest Case Fatality Rates (minimum 10,000 cases)",
        height=500,
        showlegend=False
    )
    

    fig_cfr.update_yaxes(title_text="Case Fatality Rate (%)", row=1, col=1)
    fig_cfr.update_yaxes(title_text="Case Fatality Rate (%)", row=1, col=2)
    

    fig_cfr.show()
    
    global_cfr = np.sum(cfr_data['total_deaths']) / np.sum(cfr_data['total_cases']) * 100
    median_cfr = cfr_data['case_fatality_rate'].median()
    
    print(f"\nGlobal case fatality rate: {global_cfr:.2f}%")
    print(f"Median country case fatality rate: {median_cfr:.2f}%")
    
    if all(col in cfr_data.columns for col in ['median_age', 'case_fatality_rate']):
        corr = cfr_data['median_age'].corr(cfr_data['case_fatality_rate'])
        print(f"\nCorrelation between median age and CFR: {corr:.2f}")
        
        # Create a scatter plot
        fig_age_cfr = px.scatter(
            cfr_data,
            x='median_age',
            y='case_fatality_rate',
            hover_name='location',
            title="Relationship Between Country Median Age and COVID-19 Case Fatality Rate",
            trendline="ols",
            labels={'median_age': 'Median Age', 'case_fatality_rate': 'Case Fatality Rate (%)'}
        )
        
        fig_age_cfr.update_layout(
            xaxis_title="Median Age",
            yaxis_title="Case Fatality Rate (%)"
        )
        
        fig_age_cfr.show()

print("\n=== VACCINATION ANALYSIS ===")


vax_columns = [col for col in df.columns if 'vaccine' in col.lower() or 'vaccination' in col.lower()]

if vax_columns:
    print(f"Vaccination data available: {', '.join(vax_columns[:5])}" + (" and more..." if len(vax_columns) > 5 else ""))
    
    
    if 'people_fully_vaccinated_per_hundred' in df.columns:
       
        vax_data = df.dropna(subset=['people_fully_vaccinated_per_hundred'])
        if not vax_data.empty:
           
            latest_vax = vax_data.sort_values('date').groupby('location').last().reset_index()
            
         
            top_vax = latest_vax.sort_values('people_fully_vaccinated_per_hundred', ascending=False).head(20)
            fig_vax = px.bar(
                top_vax,
                x='location',
                y='people_fully_vaccinated_per_hundred',
                title="Top 20 Countries by Full Vaccination Rate",
                color='people_fully_vaccinated_per_hundred',
                labels={'people_fully_vaccinated_per_hundred': 'Population Fully Vaccinated (%)', 'location': 'Country'},
                color_continuous_scale=px.colors.sequential.Viridis
            )
            
            fig_vax.update_layout(
                xaxis_title="Country",
                yaxis_title="Population Fully Vaccinated (%)",
                xaxis={'categoryorder':'total descending'}
            )
            
            fig_vax.show()
            
           
            major_countries = ['United States', 'United Kingdom', 'Israel', 'India', 'Brazil', 'South Africa']
            major_vax_data = df[df['location'].isin(major_countries)].dropna(subset=['people_fully_vaccinated_per_hundred'])
            
            if not major_vax_data.empty:
                fig_vax_time = px.line(
                    major_vax_data,
                    x='date',
                    y='people_fully_vaccinated_per_hundred',
                    color='location',
                    title="Vaccination Progress Over Time (Selected Countries)",
                    labels={'people_fully_vaccinated_per_hundred': 'Population Fully Vaccinated (%)', 'date': 'Date', 'location': 'Country'}
                )
                
                fig_vax_time.update_layout(
                    xaxis_title="Date",
                    yaxis_title="Population Fully Vaccinated (%)",
                    legend_title="Country",
                    hovermode="x unified"
                )
                
                fig_vax_time.show()
            
            
            if 'new_cases_per_million' in latest_vax.columns:
                
                vax_case_corr = latest_vax[['people_fully_vaccinated_per_hundred', 'new_cases_per_million']].corr().iloc[0, 1]
                
                print(f"\nCorrelation between vaccination rates and recent cases: {vax_case_corr:.2f}")
                
              
                fig_vax_cases = px.scatter(
                    latest_vax,
                    x='people_fully_vaccinated_per_hundred',
                    y='new_cases_per_million',
                    hover_name='location',
                    title="Relationship Between Vaccination Rate and Recent COVID-19 Cases",
                    trendline="ols",
                    labels={'people_fully_vaccinated_per_hundred': 'Population Fully Vaccinated (%)', 
                           'new_cases_per_million': 'New Cases per Million'}
                )
                
                fig_vax_cases.update_layout(
                    xaxis_title="Population Fully Vaccinated (%)",
                    yaxis_title="New Cases per Million"
                )
                
                fig_vax_cases.show()
        else:
            print("Not enough data points for vaccination analysis.")
    else:
        print("Full vaccination data not available in this dataset.")
else:
    print("Vaccination data not available in this dataset.")

print("\n=== POLICY RESPONSE ANALYSIS ===")

if 'stringency_index' in df.columns:
   
    stringency_data = df[df['location'].isin(top_countries)].copy()
    
    fig_stringency = px.line(
        stringency_data,
        x='date',
        y='stringency_index',
        color='location',
        title="COVID-19 Policy Stringency Index Over Time (Top Countries)",
        labels={'stringency_index': 'Stringency Index (0-100)', 'date': 'Date', 'location': 'Country'}
    )
    
    fig_stringency.update_layout(
        xaxis_title="Date",
        yaxis_title="Stringency Index (0-100)",
        legend_title="Country",
        hovermode="x unified"
    )
    
    fig_stringency.show()
    
   
    three_months_ago = latest_date - pd.Timedelta(days=90)
    recent_data = df[df['date'] >= three_months_ago].copy()
    
   
    policy_effectiveness = recent_data.groupby('location').agg({
        'stringency_index': 'mean',
        'new_cases_per_million': 'mean'
    }).reset_index()
    
   
    policy_effectiveness = policy_effectiveness.dropna()
    
    if not policy_effectiveness.empty:
       
        fig_policy = px.scatter(
            policy_effectiveness,
            x='stringency_index',
            y='new_cases_per_million',
            hover_name='location',
            title="Relationship Between Policy Stringency and COVID-19 Cases (Last 3 Months)",
            trendline="ols",
            labels={'stringency_index': 'Average Stringency Index (0-100)', 
                   'new_cases_per_million': 'Average Daily Cases per Million'}
        )
        
        fig_policy.update_layout(
            xaxis_title="Average Stringency Index (0-100)",
            yaxis_title="Average Daily Cases per Million"
        )
        
        fig_policy.show()
        
        
        policy_corr = policy_effectiveness['stringency_index'].corr(policy_effectiveness['new_cases_per_million'])
        print(f"Correlation between policy stringency and recent cases: {policy_corr:.2f}")
        
      
        if policy_corr < -0.3:
            print("There appears to be a negative correlation between policy stringency and cases,")
            print("suggesting stricter policies may help reduce case counts.")
        elif policy_corr > 0.3:
            print("There appears to be a positive correlation between policy stringency and cases,")
            print("which might indicate that countries implement stricter policies in response to rising cases.")
        else:
            print("There's no strong correlation between policy stringency and cases,")
            print("suggesting complex relationships between policies and outcomes.")
else:
    print("Policy stringency data not available in this dataset.")


print("\n=== ADVANCED STATISTICAL ANALYSIS ===")


sample_countries = ['United States', 'Germany', 'Brazil', 'India', 'South Korea']
countries_present = [country for country in sample_countries if country in df['location'].unique()]

if countries_present:
    print(f"Performing detailed analysis for: {', '.join(countries_present)}")
    
    for country in countries_present:
        print(f"\nDetailed analysis for {country}:")
        
        country_data = df[df['location'] == country].sort_values('date')
        
        if len(country_data) < 30:
            print(f"  Not enough data points for {country}")
            continue
        
     
        total_cases = country_data['total_cases'].max()
        total_deaths = country_data['total_deaths'].max()
        cfr = (total_deaths / total_cases * 100) if total_cases > 0 else 0
        
        print(f"  Total Cases: {total_cases:,.0f}")
        print(f"  Total Deaths: {total_deaths:,.0f}")
        print(f"  Case Fatality Rate: {cfr:.2f}%")
        
      
        if len(country_data) >= 90:  # Need enough data for decomposition
          
            ts_data = country_data['new_cases'].fillna(0)
            
            if np.sum(ts_data) > 0:  # Ensure we have non-zero data
               
                ts_data_smoothed = ts_data.rolling(window=7, min_periods=1).mean()
                
             
                try:
                    result = seasonal_decompose(ts_data_smoothed, model='additive', period=7)
                    
                  
                    fig_decomp = make_subplots(
                        rows=4, cols=1,
                        subplot_titles=("Observed", "Trend", "Seasonal", "Residual"),
                        vertical_spacing=0.1,
                        shared_xaxes=True
                    )
                    
                 
                    fig_decomp.add_trace(
                        go.Scatter(
                            x=country_data['date'], 
                            y=result.observed,
                            name="Observed",
                            line=dict(color='blue')
                        ),
                        row=1, col=1
                    )
                    
                  
                    fig_decomp.add_trace(
                        go.Scatter(
                            x=country_data['date'], 
                            y=result.trend,
                            name="Trend",
                            line=dict(color='red')
                        ),
                        row=2, col=1
                    )
                    
                   
                    fig_decomp.add_trace(
                        go.Scatter(
                            x=country_data['date'], 
                            y=result.seasonal,
                            name="Seasonal",
                            line=dict(color='green')
                        ),
                        row=3, col=1
                    )
                    
                    
                    fig_decomp.add_trace(
                        go.Scatter(
                            x=country_data['date'], 
                            y=result.resid,
                            name="Residual",
                            line=dict(color='purple')
                        ),
                        row=4, col=1
                    )
                    
                  
                    fig_decomp.update_layout(
                        height=800,
                        title_text=f"Time Series Decomposition of COVID-19 Cases in {country}",
                        showlegend=False
                    )
                    
                    fig_decomp.show()
                    

                    day_of_week_effect = np.abs(result.seasonal[:7].mean())
                    print(f"  Day-of-week reporting effect magnitude: {day_of_week_effect:.2f} cases")
                    
                  
                    recent_trend = result.trend[-30:].dropna()
                    if len(recent_trend) > 0:
                        trend_direction = recent_trend.iloc[-1] - recent_trend.iloc[0]
                        if trend_direction > 0:
                            print(f"  Recent trend: INCREASING by {trend_direction:.2f} cases over the last 30 days")
                        else:
                            print(f"  Recent trend: DECREASING by {abs(trend_direction):.2f} cases over the last 30 days")
                
                except Exception as e:
                    print(f"  Could not perform time series decomposition: {e}")
        
      
        if 'approx_rt' in country_data.columns:
            valid_rt = country_data.dropna(subset=['approx_rt'])
            
            if len(valid_rt) > 0:
             
                current_rt = valid_rt['approx_rt'].iloc[-1]
                mean_rt = valid_rt['approx_rt'].mean()
                max_rt = valid_rt['approx_rt'].max()
                
                print(f"  Current effective reproduction number (Rt): {current_rt:.2f}")
                print(f"  Average Rt throughout pandemic: {mean_rt:.2f}")
                print(f"  Maximum Rt recorded: {max_rt:.2f}")
                
              
                if current_rt < 1:
                    control_status = "CONTROLLED (Rt < 1)"
                else:
                    control_status = "GROWING (Rt > 1)"
                
                print(f"  Current pandemic status: {control_status}")
                
              
                fig_rt = px.line(
                    valid_rt,
                    x='date',
                    y='approx_rt',
                    title=f"Effective Reproduction Number (Rt) Over Time in {country}",
                    labels={'approx_rt': 'Effective Reproduction Number (Rt)', 'date': 'Date'}
                )
                
                
                fig_rt.add_hline(
                    y=1, 
                    line_dash="dash", 
                    line_color="red",
                    annotation_text="Rt = 1 (Control Threshold)",
                    annotation_position="bottom right"
                )
                
                fig_rt.update_layout(
                    xaxis_title="Date",
                    yaxis_title="Effective Reproduction Number (Rt)"
                )
                
                fig_rt.show()
        
        if 'cases_7day_avg' in country_data.columns:
           
            smooth_cases = country_data['cases_7day_avg'].fillna(0).values
           
            peaks, properties = find_peaks(
                smooth_cases, 
                prominence=np.max(smooth_cases) * 0.1,
                distance=14  
            )
            
            if len(peaks) > 0:
                print(f"  Detected {len(peaks)} major COVID-19 waves:")
                
                wave_data = []
                for i, peak_idx in enumerate(peaks):
                    peak_date = country_data.iloc[peak_idx]['date']
                    peak_cases = smooth_cases[peak_idx]
                    prominence = properties['prominences'][i]
                    
                    wave_data.append({
                        'Wave': i+1,
                        'Date': peak_date,
                        'Cases': peak_cases,
                        'Prominence': prominence
                    })
                    
                    print(f"    Wave {i+1}: Peak on {peak_date.strftime('%Y-%m-%d')} with {peak_cases:.0f} avg daily cases (prominence: {prominence:.0f})")
                
                fig_waves = go.Figure()
                
                fig_waves.add_trace(
                    go.Scatter(
                        x=country_data['date'],
                        y=country_data['cases_7day_avg'],
                        name="7-day Avg Cases",
                        line=dict(color='blue', width=1)
                    )
                )
                
                fig_waves.add_trace(
                    go.Scatter(
                        x=[country_data.iloc[peak_idx]['date'] for peak_idx in peaks],
                        y=[smooth_cases[peak_idx] for peak_idx in peaks],
                        mode='markers+text',
                        marker=dict(color='red', size=10),
                        text=[f"Wave {i+1}" for i in range(len(peaks))],
                        textposition="top center",
                        name="Wave Peaks"
                    )
                )
                
                fig_waves.update_layout(
                    title=f"COVID-19 Waves in {country}",
                    xaxis_title="Date",
                    yaxis_title="7-day Average Cases",
                    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
                )
                
                fig_waves.show()
                
                if len(wave_data) > 1:
                    wave_intervals = []
                    for i in range(1, len(wave_data)):
                        interval = (wave_data[i]['Date'] - wave_data[i-1]['Date']).days
                        wave_intervals.append(interval)
                    
                    avg_interval = np.mean(wave_intervals)
                    print(f"  Average time between waves: {avg_interval:.0f} days")
                    
                    
                    first_wave_size = wave_data[0]['Cases']
                    last_wave_size = wave_data[-1]['Cases']
                    
                    if last_wave_size > first_wave_size:
                        print(f"  Wave progression: INCREASING in severity (Last/First ratio: {last_wave_size/first_wave_size:.2f}x)")
                    else:
                        print(f"  Wave progression: DECREASING in severity (Last/First ratio: {last_wave_size/first_wave_size:.2f}x)")
        
    
        if all(col in country_data.columns for col in ['new_tests', 'new_cases']):
            valid_test_data = country_data.dropna(subset=['new_tests', 'new_cases'])
            
            if len(valid_test_data) > 30:  
               
                valid_test_data['positivity_rate'] = (valid_test_data['new_cases'] / valid_test_data['new_tests']) * 100
                
              
                fig_pos = px.line(
                    valid_test_data,
                    x='date',
                    y='positivity_rate',
                    title=f"COVID-19 Test Positivity Rate in {country}",
                    labels={'positivity_rate': 'Positivity Rate (%)', 'date': 'Date'}
                )
                
               
                fig_pos.add_hline(
                    y=5, 
                    line_dash="dash", 
                    line_color="red",
                    annotation_text="WHO 5% Threshold",
                    annotation_position="bottom right"
                )
                
                fig_pos.update_layout(
                    xaxis_title="Date",
                    yaxis_title="Test Positivity Rate (%)"
                )
                
                fig_pos.show()

                current_pos = valid_test_data['positivity_rate'].iloc[-1]
                avg_pos = valid_test_data['positivity_rate'].mean()
                max_pos = valid_test_data['positivity_rate'].max()
                
                print(f"  Current test positivity rate: {current_pos:.2f}%")
                print(f"  Average positivity rate: {avg_pos:.2f}%")
                print(f"  Maximum positivity rate: {max_pos:.2f}%")
              
                if current_pos < 5:
                    print("  Testing adequacy: SUFFICIENT (below WHO 5% threshold)")
                else:
                    print("  Testing adequacy: INSUFFICIENT (above WHO 5% threshold)")
        
        print("\n  --- End of detailed analysis ---")

        
        