In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Load cleaned dataset
df = pd.read_csv('../data/processed/mental_health_tech_cleaned.csv')

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set(font_scale=1.2)
os.makedirs('../reports/figures', exist_ok=True)

# 1. Trend Analysis - Mental health benefits over time
plt.figure(figsize=(12, 6))
benefits_by_year = pd.crosstab(df['year'], df['mental_health_benefits'])
benefits_by_year_pct = benefits_by_year.div(benefits_by_year.sum(axis=1), axis=0) * 100

benefits_by_year_pct['Yes'].plot(marker='o', linestyle='-', linewidth=2)
plt.title('Percentage of Companies Offering Mental Health Benefits (2014-2019)')
plt.xlabel('Year')
plt.ylabel('Percentage of Companies')
plt.grid(True)
plt.savefig('../reports/figures/benefits_trend.png')
plt.show()

# 2. Company Size vs Mental Health Benefits
plt.figure(figsize=(14, 7))
benefits_by_size = pd.crosstab(df['number_of_employees'], df['mental_health_benefits'])
benefits_by_size_pct = benefits_by_size.div(benefits_by_size.sum(axis=1), axis=0) * 100
benefits_by_size_pct['Yes'].sort_index().plot(kind='bar')
plt.title('Mental Health Benefits Availability by Company Size')
plt.xlabel('Company Size')
plt.ylabel('Percentage with Benefits')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig('../reports/figures/benefits_by_company_size.png')
plt.show()

# 3. Tech vs Non-Tech Companies Comparison
plt.figure(figsize=(12, 8))
tech_benefits = pd.crosstab(df['tech_company'], df['mental_health_benefits'])
tech_benefits_pct = tech_benefits.div(tech_benefits.sum(axis=1), axis=0) * 100

tech_benefits_pct.plot(kind='bar', stacked=False)
plt.title('Mental Health Benefits: Tech vs Non-Tech Companies')
plt.xlabel('Tech Company')
plt.ylabel('Percentage')
plt.legend(title='Benefits Offered')
plt.savefig('../reports/figures/tech_vs_nontech_benefits.png')
plt.show()

# 4. Comfort discussing mental health
comfort_columns = [
    'mental_health_discussion_comfort_coworkers',
    'mental_health_discussion_comfort_supervisor'
]

plt.figure(figsize=(14, 7))
for i, col in enumerate(comfort_columns):
    plt.subplot(1, 2, i+1)
    comfort_counts = df[col].value_counts().sort_index()
    sns.barplot(x=comfort_counts.index, y=comfort_counts.values)
    plt.title(f'Comfort Level: {col.replace("mental_health_discussion_comfort_", "").title()}')
    plt.xlabel('Comfort Level')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig('../reports/figures/comfort_levels.png')
plt.show()

# 5. Correlation Heatmap
# Create binary variables for correlation analysis
binary_columns = []
for col in benefit_columns + comfort_columns:
    # Create binary versions (Yes=1, rest=0)
    binary_col = f'{col}_binary'
    df[binary_col] = (df[col] == 'Yes').astype(int)
    binary_columns.append(binary_col)

# Correlation heatmap
plt.figure(figsize=(12, 10))
corr_matrix = df[binary_columns].corr()
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
sns.heatmap(corr_matrix, mask=mask, annot=True, cmap='coolwarm', fmt='.2f', 
            linewidths=0.5, vmin=-1, vmax=1)
plt.title('Correlation Between Mental Health Factors')
plt.tight_layout()
plt.savefig('../reports/figures/correlation_heatmap.png')
plt.show()

# 6. Year-over-year changes
# Create a summary dataframe for key metrics by year
yearly_stats = pd.DataFrame(index=sorted(df['year'].unique()))

# Calculate percentage of companies offering benefits by year
yearly_stats['pct_benefits_offered'] = df.groupby('year')['mental_health_benefits_binary'].mean() * 100

# Calculate percentage of employees aware of benefits by year
yearly_stats['pct_benefits_awareness'] = df.groupby('year')['mental_health_benefits_awareness_binary'].mean() * 100

# Calculate percentage comfortable discussing with supervisor
yearly_stats['pct_comfortable_supervisor'] = df.groupby('year')['mental_health_discussion_comfort_supervisor_binary'].mean() * 100

# Plot the yearly trends
plt.figure(figsize=(14, 8))
yearly_stats.plot(marker='o', linewidth=2)
plt.title('Mental Health Metrics Trend (2014-2019)')
plt.xlabel('Year')
plt.ylabel('Percentage')
plt.grid(True)
plt.legend(title='Metrics')
plt.savefig('../reports/figures/yearly_metrics_trend.png')
plt.show()

# Save processed data with calculated metrics
df.to_csv('../data/processed/mental_health_tech_analyzed.csv', index=False)

print("Exploratory data analysis complete!")