# Student Mental Health Data Analysis

This notebook demonstrates how to:
1. **Load and Clean** the student mental health dataset using PySpark.
2. **Perform Feature Engineering** (e.g., stress index, sleep category, normalization).
3. **Run Various Analyses** (distribution, correlation, aggregations, risk analysis).
4. **Visualize** the results in a more interactive/"pretty" manner using pandas, matplotlib, and seaborn.

We’ll largely reuse the logic from your `main.py` code, but adapt it for an interactive notebook environment.

In [None]:
import sys
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, lit, avg, udf
from pyspark.sql.types import FloatType, IntegerType
from pyspark.sql import functions as F
from pyspark.sql.types import (
    StringType, StructType, StructField, DoubleType, IntegerType
)

sns.set_style("whitegrid")  

## 1. Creating the Spark Session

In [None]:
def create_spark_session(app_name="StudentMentalHealth"):
    return (
        SparkSession.builder
        .appName(app_name)
        .config("spark.sql.parquet.compression.codec", "snappy")
        .getOrCreate()
    )

spark = create_spark_session()

## 2. Define Schema & Helper Functions

We’ll define the schema and helper functions (mostly copied over from your `main.py`).

In [None]:

schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("Gender", StringType(), True),
    StructField("Age", DoubleType(), True),
    StructField("City", StringType(), True),
    StructField("Profession", StringType(), True),
    StructField("Academic Pressure", DoubleType(), True),
    StructField("Work Pressure", DoubleType(), True),
    StructField("CGPA", DoubleType(), True),
    StructField("Study Satisfaction", DoubleType(), True),
    StructField("Job Satisfaction", DoubleType(), True),
    StructField("Sleep Duration", StringType(), True),  
    StructField("Dietary Habits", StringType(), True),
    StructField("Degree", StringType(), True),
    StructField("Have you ever had suicidal thoughts ?", StringType(), True),
    StructField("Work/Study Hours", DoubleType(), True),
    StructField("Financial Stress", DoubleType(), True),
    StructField("Family History of Mental Illness", StringType(), True),
    StructField("Depression", IntegerType(), True),
])

def load_data(spark, input_path: str):
    try:
        df = (
            spark.read
            .option("header", "true")
            .schema(schema)
            .csv(input_path)
        )
        return df
    except Exception as e:
        print(f"Error loading data: {e}")
        sys.exit(1)

def clean_data(df):
    """
    Handle missing values, remove inconsistent or out-of-range values,
    convert Sleep Duration to numeric, etc.
    Returns cleaned DataFrame.
    """
    
    df = df.dropna(subset=["Sleep Duration", "Age"])

    
    df = df.withColumn("Sleep Duration", col("Sleep Duration").cast(FloatType()))

    
    df = df.filter(col("Sleep Duration") <= 24).filter(col("Sleep Duration") >= 0)
    df = df.filter(col("Age") > 0)
    return df

def report_data_quality(df):
    """
    Print out data quality metrics: null counts, basic statistics.
    """
    
    for c in df.columns:
        null_count = df.filter(col(c).isNull()).count()
        print(f"{c}: {null_count} nulls")

    
    numeric_cols = [f.name for f in df.schema.fields if f.dataType in [FloatType(), IntegerType()]]
    df.select(numeric_cols).describe().show()

def feature_engineering(df):
    """
    Create new columns, e.g. stress_index, sleep_category, age_group,
    normalized columns, and dummy variables.
    """
    
    df = df.withColumn(
        "stress_index",
        (col("Academic Pressure") + col("Work Pressure") + col("Financial Stress")) / lit(3.0)
    )

    
    df = df.withColumn(
        "sleep_category",
        when(col("Sleep Duration") < 6, lit("Low"))
        .when((col("Sleep Duration") >= 6) & (col("Sleep Duration") <= 8), lit("Normal"))
        .otherwise(lit("High"))
    )

   
    df = df.withColumn(
        "age_group",
        when((col("Age") >= 18) & (col("Age") <= 21), lit("18-21"))
        .when((col("Age") >= 22) & (col("Age") <= 25), lit("22-25"))
        .when((col("Age") >= 26) & (col("Age") <= 30), lit("26-30"))
        .otherwise(lit(">30"))
    )

    
    numeric_cols = ["CGPA", "Depression", "stress_index"]
    stats = df.select(
        *[F.min(c).alias(f"{c}_min") for c in numeric_cols],
        *[F.max(c).alias(f"{c}_max") for c in numeric_cols]
    ).collect()[0]

    for c in numeric_cols:
        min_val = stats[f"{c}_min"]
        max_val = stats[f"{c}_max"]
        if min_val != max_val:
            df = df.withColumn(f"{c}_normalized", (col(c) - lit(min_val)) / (lit(max_val - min_val)))
        else:
            df = df.withColumn(f"{c}_normalized", lit(0.0))

    
    categories = [row[0] for row in df.select("Gender").distinct().collect()]
    for cat in categories:
        df = df.withColumn(
            f"Gender_{cat}",
            when(col("Gender") == cat, lit(1)).otherwise(lit(0))
        )

    return df

def distribution_analysis(df):
    """
    1) Depression scores by age group and profession
    2) CGPA stats by sleep category
    Return DataFrames for saving or further analysis.
    """
    dep_by_demo = df.groupBy("age_group", "Profession") \
                    .agg(F.avg("Depression").alias("avg_depression"))

    cgpa_by_sleep = df.groupBy("sleep_category") \
                      .agg(F.avg("CGPA").alias("avg_cgpa"),
                           F.stddev("CGPA").alias("stddev_cgpa"))

    return dep_by_demo, cgpa_by_sleep

def correlation_analysis(df):
    """
    - Compute correlation among numeric columns
    - Top 5 factors correlated with Depression
    """
    numeric_cols = [f.name for f in df.schema.fields if f.dataType in [FloatType(), IntegerType()]]
    corr_rows = []
    for i in range(len(numeric_cols)):
        for j in range(i+1, len(numeric_cols)):
            col1 = numeric_cols[i]
            col2 = numeric_cols[j]
            corr_val = df.stat.corr(col1, col2)
            corr_rows.append((col1, col2, corr_val))

    correlation_df = df.sparkSession.createDataFrame(
        corr_rows, ["column1", "column2", "correlation"]
    )

    
    depression_corr = correlation_df \
        .filter((F.col("column1") == "Depression") | (F.col("column2") == "Depression")) \
        .withColumn("abs_corr", F.abs(F.col("correlation"))) \
        .orderBy(F.col("abs_corr").desc())

    top5_dep_corr = depression_corr.limit(5)

    return correlation_df, top5_dep_corr

def aggregations(df):
    """
    - Depression scores aggregated by city and degree
    - Stress index by age group and gender
    - Academic performance by sleep category
    """
    city_degree_stats = df.groupBy("City", "Degree") \
                          .agg(F.avg("Depression").alias("avg_depression"),
                               F.count("*").alias("count_students"))

    demographic_stress = df.groupBy("age_group", "Gender") \
                           .agg(F.avg("stress_index").alias("avg_stress_index"))

    sleep_performance = df.groupBy("sleep_category") \
                          .agg(F.avg("CGPA").alias("avg_cgpa"),
                               F.avg("Academic Pressure").alias("avg_academic_score"))

    return city_degree_stats, demographic_stress, sleep_performance

def risk_analysis(df):
    """
    Identify high-risk students based on thresholds
    Returns a DataFrame of flagged students.
    """
    high_risk_df = df.filter(
        (col("stress_index_normalized") > 0.7) |
        (col("Sleep Duration") < 5) |
        (col("Financial Stress") > 7)
    ).withColumn("risk_reason", lit("Stress or Poor Sleep or High Financial Stress"))
    return high_risk_df


@udf(returnType=StringType())
def my_custom_udf(value):
    return "transformed_value"


## 3. Load & Clean the Data

Change the `INPUT_PATH` to the location of your CSV.

In [None]:
INPUT_PATH = "src/data/StudentDepressionDataset.csv"  


raw_df = load_data(spark, INPUT_PATH)


cleaned_df = clean_data(raw_df)


report_data_quality(cleaned_df)

print("\nNumber of rows (cleaned):", cleaned_df.count())

## 4. Feature Engineering

In [None]:
feat_df = feature_engineering(cleaned_df)
print("Feature engineering done.")
print("Number of rows (featured):", feat_df.count())


## 5. Distribution Analysis

1. **Depression scores by age group and profession**
2. **CGPA stats by sleep category**

We’ll collect these small results into Pandas and visualize.

In [None]:
dep_by_demo_df, cgpa_by_sleep_df = distribution_analysis(feat_df)


dep_by_demo_pd = dep_by_demo_df.toPandas()
cgpa_by_sleep_pd = cgpa_by_sleep_df.toPandas()

dep_by_demo_pd.head()

### 5.1 Bar Plot: Depression by Age Group & Profession

We can create a grouped bar chart showing average depression by profession within each age group.

In [None]:


plt.figure(figsize=(10, 6))
sns.barplot(
    data=dep_by_demo_pd,
    x="age_group", y="avg_depression", hue="Profession"
)
plt.title("Average Depression by Age Group & Profession")
plt.ylabel("Average Depression")
plt.xlabel("Age Group")
plt.legend(title="Profession", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

### 5.2 CGPA by Sleep Category

We can plot a bar chart or a boxplot for CGPA by sleep category.

In [None]:


plt.figure(figsize=(6, 4))
sns.barplot(
    data=cgpa_by_sleep_pd,
    x="sleep_category", y="avg_cgpa"
)
plt.title("Average CGPA by Sleep Category")
plt.xlabel("Sleep Category")
plt.ylabel("Average CGPA")
plt.show()

## 6. Correlation Analysis

- Compute correlation matrix among numeric columns.
- Identify top 5 factors correlated with depression.

In [None]:
corr_matrix_df, top5_dep_corr_df = correlation_analysis(feat_df)


corr_pdf = corr_matrix_df.toPandas()
top5_dep_corr_pdf = top5_dep_corr_df.toPandas()

print("Top 5 correlations with Depression:")
display(top5_dep_corr_pdf)

### 6.1 Heatmap of Correlations

We’ll pivot the `corr_pdf` table into a full matrix and plot a heatmap. Note that this approach works well if we have a manageable number of numeric columns. For large sets, you may need sampling or other approaches.

In [None]:

all_numeric_cols = set(list(corr_pdf['column1'].unique()) + list(corr_pdf['column2'].unique()))
all_numeric_cols = sorted(list(all_numeric_cols))

import numpy as np


n = len(all_numeric_cols)
corr_matrix = np.full((n, n), np.nan)


for idx, row in corr_pdf.iterrows():
    c1 = row['column1']
    c2 = row['column2']
    val = row['correlation']
    i = all_numeric_cols.index(c1)
    j = all_numeric_cols.index(c2)
    corr_matrix[i, j] = val
    corr_matrix[j, i] = val


for i in range(n):
    corr_matrix[i, i] = 1.0

plt.figure(figsize=(8,6))
sns.heatmap(
    corr_matrix,
    xticklabels=all_numeric_cols,
    yticklabels=all_numeric_cols,
    cmap="coolwarm",
    annot=True,
    fmt=".2f"
)
plt.title("Correlation Heatmap")
plt.tight_layout()
plt.show()

## 7. Aggregations

1. **Depression by city and degree**
2. **Stress by age group and gender**
3. **Academic performance by sleep category**


In [None]:
city_degree_df, demo_stress_df, sleep_perf_df = aggregations(feat_df)

city_degree_pd = city_degree_df.toPandas()
demo_stress_pd = demo_stress_df.toPandas()
sleep_perf_pd = sleep_perf_df.toPandas()

city_degree_pd.head()

### 7.1 Plot: Average Depression by City & Degree

In [None]:
plt.figure(figsize=(10, 6))
sns.barplot(data=city_degree_pd, x="City", y="avg_depression", hue="Degree")
plt.title("Average Depression by City & Degree")
plt.xticks(rotation=45)
plt.legend(title="Degree", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

### 7.2 Stress by Age Group and Gender

In [None]:
plt.figure(figsize=(8, 5))
sns.barplot(data=demo_stress_pd, x="age_group", y="avg_stress_index", hue="Gender")
plt.title("Average Stress Index by Age Group & Gender")
plt.xlabel("Age Group")
plt.ylabel("Average Stress Index")
plt.legend(title="Gender", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

### 7.3 Academic Performance by Sleep Category

In [None]:
plt.figure(figsize=(6,4))
sns.barplot(data=sleep_perf_pd, x="sleep_category", y="avg_cgpa")
plt.title("Average CGPA by Sleep Category")
plt.xlabel("Sleep Category")
plt.ylabel("Average CGPA")
plt.show()

plt.figure(figsize=(6,4))
sns.barplot(data=sleep_perf_pd, x="sleep_category", y="avg_academic_score")
plt.title("Average Academic Pressure by Sleep Category")
plt.xlabel("Sleep Category")
plt.ylabel("Average Academic Pressure")
plt.show()

## 8. Risk Analysis

Identifying high-risk students based on stress_index, sleep duration, and financial stress.

In [None]:
high_risk_df = risk_analysis(feat_df)
high_risk_pdf = high_risk_df.toPandas()

print(f"Number of high-risk students: {len(high_risk_pdf)}")
high_risk_pdf.head()

We could, for example, look at a distribution of `stress_index_normalized` among high-risk students.

In [None]:
plt.figure(figsize=(6,4))
sns.histplot(data=high_risk_pdf, x="stress_index_normalized", bins=20, kde=True)
plt.title("Distribution of Stress Index (Normalized) - High Risk")
plt.xlabel("Stress Index (Normalized)")
plt.ylabel("Count")
plt.show()

## 9. Cleanup

It's often good practice to **stop** your Spark session at the end of the notebook if you don't need it anymore.

In [None]:

spark.stop()
print("Spark session stopped.")