In [3]:
# notebooks/multimodal_analysis.ipynb
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr
import os
from pathlib import Path

# Set up paths
current_dir = Path.cwd()
figures_dir = current_dir.parent/"figures/multimodal"
os.makedirs(figures_dir, exist_ok=True)

# Load cleaned data
df = pd.read_csv("../data/processed/cleaned_flood_data.csv")

# 1. PREPROCESSING --------------------------------
# Convert date to datetime if not already
if not pd.api.types.is_datetime64_any_dtype(df['DATE']):
    df['DATE'] = pd.to_datetime(df['DATE'])

# Select numeric columns for correlation analysis
numeric_cols = ['PERSONS_AFFECTED', 'DISPLACED_PERSONS', 'SEVERITY_RATIO']
numeric_df = df[numeric_cols]

# 2. CORRELATION MATRIX --------------------------
plt.figure(figsize=(10, 8))
corr_matrix = numeric_df.corr()
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))  # Mask upper triangle

heatmap = sns.heatmap(corr_matrix, 
                     annot=True, 
                     fmt=".2f", 
                     cmap="coolwarm", 
                     vmin=-1, 
                     vmax=1,
                     mask=mask,
                     cbar_kws={'label': 'Correlation Coefficient'})
plt.title("Flood Impact Correlation Matrix", pad=20)
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(figures_dir/"flood_impact_correlation.png", dpi=300, bbox_inches='tight')
plt.close()

# 3. TEMPORAL-SPATIAL ANALYSIS -------------------
# Monthly flood impact by state
plt.figure(figsize=(14, 8))
top_states = df.groupby('STATE')['PERSONS_AFFECTED'].sum().nlargest(5).index.tolist()
month_order = ['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec']

# Filter data for top states
temp_df = df[df['STATE'].isin(top_states)].copy()
temp_df['MONTH_NAME'] = temp_df['DATE'].dt.month_name().str[:3]

sns.lineplot(data=temp_df, 
             x='MONTH_NAME', 
             y='PERSONS_AFFECTED',
             hue='STATE',
             estimator='sum',
             err_style=None,
             marker='o',
             sort=False,
             palette='viridis')

plt.title("Monthly Flood Impact by State (2022)", pad=20)
plt.xlabel("Month")
plt.ylabel("Total Persons Affected")
plt.xticks(range(12), month_order)
plt.grid(True, alpha=0.3)
plt.legend(title='State')
plt.tight_layout()
plt.savefig(figures_dir/"monthly_state_impact.png", dpi=300, bbox_inches='tight')
plt.close()

# 4. SEVERITY ANALYSIS (Fully Corrected Version) ---------------------------
plt.figure(figsize=(12, 6))

# Create main axis first
ax = plt.gca()

# Create the scatter plot on our axis
scatter = sns.scatterplot(data=df,
                         x='PERSONS_AFFECTED',
                         y='DISPLACED_PERSONS',
                         hue='SEVERITY_RATIO',
                         size='SEVERITY_RATIO',
                         sizes=(20, 200),
                         palette='magma',
                         alpha=0.7,
                         ax=ax)  # Explicitly specify axis

plt.xscale('log')
plt.yscale('log')
plt.title("Flood Impact vs. Displacement (Log Scale)", pad=20)
plt.xlabel("Persons Affected (log)")
plt.ylabel("Persons Displaced (log)")

# Create normalization based on your data
norm = plt.Normalize(vmin=df['SEVERITY_RATIO'].min(), 
                    vmax=df['SEVERITY_RATIO'].max())

# Create a ScalarMappable for the colorbar
sm = plt.cm.ScalarMappable(cmap='magma', norm=norm)
sm.set_array(df['SEVERITY_RATIO'])

# Add colorbar to the existing axis
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label('Severity Ratio')

# Adjust legend for size encoding
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:3], labels[:3], title='Size by Severity')

plt.grid(True, which='both', alpha=0.3)
plt.tight_layout()
plt.savefig(figures_dir/"impact_vs_displacement.png", dpi=300, bbox_inches='tight')
plt.close()

# 5. STATE-LEVEL RISK PROFILE --------------------
if 'STATE' in df.columns:
    state_stats = df.groupby('STATE').agg({
        'PERSONS_AFFECTED': 'sum',
        'DISPLACED_PERSONS': 'sum',
        'SEVERITY_RATIO': 'mean'
    }).sort_values('PERSONS_AFFECTED', ascending=False).head(10)

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Total affected
    sns.barplot(ax=axes[0],
                x=state_stats['PERSONS_AFFECTED'],
                y=state_stats.index,
                palette='Blues_r')
    axes[0].set_title("Total Persons Affected")
    axes[0].set_xlabel("Count")
    
    # Total displaced
    sns.barplot(ax=axes[1],
                x=state_stats['DISPLACED_PERSONS'],
                y=state_stats.index,
                palette='Oranges_r')
    axes[1].set_title("Total Displaced Persons")
    axes[1].set_xlabel("Count")
    
    # Average severity
    sns.barplot(ax=axes[2],
                x=state_stats['SEVERITY_RATIO'],
                y=state_stats.index,
                palette='Greens_r')
    axes[2].set_title("Average Severity Ratio")
    axes[2].set_xlabel("Ratio")
    axes[2].set_xlim(0, 1)
    
    plt.suptitle("Top 10 States by Flood Impact", y=1.05)
    plt.tight_layout()
    plt.savefig(figures_dir/"state_risk_profiles.png", dpi=300, bbox_inches='tight')
    plt.close()

print(f"✅ Multimodal analysis complete. Figures saved to: {figures_dir}")


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(ax=axes[0],

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(ax=axes[1],

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(ax=axes[2],


✅ Multimodal analysis complete. Figures saved to: /Users/raheeminioluwa/Documents/Flood-EDA-Nigeria/figures/multimodal
