In [1]:
from pyspark.sql import SparkSession

In [2]:
# Creating a Spark session
# This initializes a Spark session which is the entry point to using Spark functionality.
# The app name is set to "HearDiseaseAnalysis" for identifying the job in Spark UI.

spark = SparkSession.builder.appName("HearDiseaseAnalysis").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/18 22:21:20 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Define the file path for the dataset and output path for saving the result
file_path = "data/heart.csv"  # Path to the heart disease dataset in CSV format
output_path = "output_parquet"  # Path where output will be saved in Parquet format

In [4]:
try:
    # Load the heart disease dataset
    # Reads the CSV file into a DataFrame, with the header set to True to treat the first row as column names
    # and inferSchema set to True to automatically detect the data types of each column.
    df = spark.read.csv(file_path, header = True, inferSchema = True)

    # Showing the first 5 rows of the dataset
    # This gives us a preview of the loaded data to understand its structure and content.
    df.show(5)
    
except Exception as e:
    # Error handling to capture any issues while loading the dataset
    # If an error occurs, it will print a message with the exception details.
    print(f"Error loading data: {e}")


+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+------+
|age|sex| cp|trestbps|chol|fbs|restecg|thalach|exang|oldpeak|slope| ca|thal|target|
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+------+
| 52|  1|  0|     125| 212|  0|      1|    168|    0|    1.0|    2|  2|   3|     0|
| 53|  1|  0|     140| 203|  1|      0|    155|    1|    3.1|    0|  0|   3|     0|
| 70|  1|  0|     145| 174|  0|      1|    125|    1|    2.6|    0|  0|   3|     0|
| 61|  1|  0|     148| 203|  0|      1|    161|    0|    0.0|    2|  1|   3|     0|
| 62|  0|  0|     138| 294|  1|      1|    106|    0|    1.9|    1|  3|   2|     0|
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+------+
only showing top 5 rows



In [5]:
from pyspark.sql.functions import sum, col, when, count

# Function for check and remove missing values
def check_and_remove_missing_values(df):
    try:
        missing_counts = df.select(
            [sum(col(c).isNull().cast("int")).alias(c) for c in df.columns]
        )

        print("Missing values by column")
        missing_counts.show()

        # If there are any missing values, remove them
        if missing_counts.count() > 1:
            df_cleaned = df.na.drop()
            print("Column with missing values are removed")
            return df_cleaned  # Return the cleaned DataFrame
        else:
            print("There are no missing values.")
            return df # Return the original DataFrame if there are no missing values
    except Exception as e:
        print(f"An error occurred while checking for missing values: {e}")
    
df = check_and_remove_missing_values(df) # Call the function on the DataFrame 'df'


Missing values by column
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+------+
|age|sex| cp|trestbps|chol|fbs|restecg|thalach|exang|oldpeak|slope| ca|thal|target|
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+------+
|  0|  0|  0|       0|   0|  0|      0|      0|    0|      0|    0|  0|   0|     0|
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+------+

There are no missing values.


In [6]:
df.printSchema() #  Function in PySpark is used to print the schema of a DataFrame. It displays the column names, data types, and whether or not the column allows null values

root
 |-- age: integer (nullable = true)
 |-- sex: integer (nullable = true)
 |-- cp: integer (nullable = true)
 |-- trestbps: integer (nullable = true)
 |-- chol: integer (nullable = true)
 |-- fbs: integer (nullable = true)
 |-- restecg: integer (nullable = true)
 |-- thalach: integer (nullable = true)
 |-- exang: integer (nullable = true)
 |-- oldpeak: double (nullable = true)
 |-- slope: integer (nullable = true)
 |-- ca: integer (nullable = true)
 |-- thal: integer (nullable = true)
 |-- target: integer (nullable = true)



In [7]:
# Function to check the values in a specific column and remove rows with invalid values
def check_column_values(df, column_name, min_values, max_values):
    try:
        print(f"Checking the values in column {column_name}:")  # Inform the user which column is being checked

        # Filtering rows where values are not within the specified range (min_values to max_values)
        invalid_values = df.filter(~col(column_name).between(min_values, max_values))

        # If there are invalid values, remove them from the DataFrame and display them
        if invalid_values.count() > 0:
            df = df.subtract(invalid_values)  # Remove rows with invalid values
            print(f"Inconsistent values in the column '{column_name}'")
            invalid_values.show()  # Show the rows with invalid values
        else:
            # If no invalid values are found, inform the user
            print(f"There are no inconsistent values in the column '{column_name}'")

        # Return the cleaned DataFrame (with invalid rows removed)
        return df
    except Exception as e:
        # Catch and print any errors that occur during the validation process
        print(f"An error occurred while validating the column: {e}")

In [8]:
# Checking and cleaning the 'age' column (valid range: 0-120)
df = check_column_values(df, 'age', 0, 120)

# Checking and cleaning the 'sex' column (valid values: 0 or 1)
df = check_column_values(df, 'sex', 0, 1)

# Checking and cleaning the 'cp' column (valid values: 0, 1, 2, 3)
df = check_column_values(df, 'cp', 0, 3)

# Checking and cleaning the 'trestbps' column (valid range: 0-300)
df = check_column_values(df, 'trestbps', 0, 300)

# Checking and cleaning the 'chol' column (valid range: 50-1000)
df = check_column_values(df, 'chol', 50, 1000)

# Checking and cleaning the 'fbs' column (valid values: 0 or 1)
df = check_column_values(df, 'fbs', 0, 1)

# Checking and cleaning the 'restecg' column (valid values: 0, 1, 2)
df = check_column_values(df, 'restecg', 0, 2)

# Checking and cleaning the 'thalach' column (valid range: 20-230)
df = check_column_values(df, 'thalach', 20, 230)

# Checking and cleaning the 'exang' column (valid values: 0 or 1)
df = check_column_values(df, 'exang', 0, 1)

# Checking and cleaning the 'oldpeak' column (valid range: -3 to 8)
df = check_column_values(df, 'oldpeak', -3, 8)

# Checking and cleaning the 'slope' column (valid values: 0, 1, 2)
df = check_column_values(df, 'slope', 0, 2)

# Checking and cleaning the 'ca' column (valid range: 0-4)
df = check_column_values(df, 'ca', 0, 4)

# Checking and cleaning the 'thal' column (valid range: 0-3)
df = check_column_values(df, 'thal', 0, 3)

# Checking and cleaning the 'target' column (valid values: 0 or 1)
df = check_column_values(df, 'target', 0, 1)

Checking the values in column age:
There are no inconsistent values in the column 'age'
Checking the values in column sex:
There are no inconsistent values in the column 'sex'
Checking the values in column cp:
There are no inconsistent values in the column 'cp'
Checking the values in column trestbps:
There are no inconsistent values in the column 'trestbps'
Checking the values in column chol:
There are no inconsistent values in the column 'chol'
Checking the values in column fbs:
There are no inconsistent values in the column 'fbs'
Checking the values in column restecg:
There are no inconsistent values in the column 'restecg'
Checking the values in column thalach:
There are no inconsistent values in the column 'thalach'
Checking the values in column exang:
There are no inconsistent values in the column 'exang'
Checking the values in column oldpeak:
There are no inconsistent values in the column 'oldpeak'
Checking the values in column slope:
There are no inconsistent values in the colum

In [9]:
from pyspark.sql.types import IntegerType, DoubleType

# When loading data, we set the parameter inferSchema = true and thus enabled 
# automatic determination of data types, but we can also do it manually

# Implementation of data type conversion
def data_type_conversion(df):
    try:
        # Converting columns to their appropriate data types (IntegerType or DoubleType)
        df = df.withColumn("age", col("age").cast(IntegerType())) \
            .withColumn("sex", col("sex").cast(IntegerType())) \
            .withColumn("cp", col("cp").cast(IntegerType())) \
            .withColumn("trestbps", col("trestbps").cast(IntegerType())) \
            .withColumn("chol", col("chol").cast(IntegerType())) \
            .withColumn("fbs", col("fbs").cast(IntegerType())) \
            .withColumn("restecg", col("restecg").cast(IntegerType())) \
            .withColumn("thalach", col("thalach").cast(IntegerType())) \
            .withColumn("exang", col("exang").cast(IntegerType())) \
            .withColumn("oldpeak", col("oldpeak").cast(DoubleType())) \
            .withColumn("slope", col("slope").cast(IntegerType())) \
            .withColumn("ca", col("ca").cast(IntegerType())) \
            .withColumn("thal", col("thal").cast(IntegerType())) \
            .withColumn("target", col("target").cast(IntegerType()))
        return df # Returning the DataFrame with updated column data types
    except Exception as e:
        # Printing error message if there's an issue during the conversion
        print(f"An error occurred while converting types: {e}")

# Applying the data type conversion function to the DataFrame
df = data_type_conversion(df)

In [10]:
# Add a new column 'high_risk_category' to categorize patients into high risk (1) or not (0)
df = df.withColumn(
    "high_risk_category",  # New column name
    when(col("age") > 60, 1)  # High risk if age is greater than 60
    .when(col("chol") > 250, 1)  # High risk if cholesterol is greater than 250
    .when(col("trestbps") > 140, 1)  # High risk if resting blood pressure is greater than 140
    .when(col("thalach") < 100, 1)  # High risk if maximum heart rate is less than 100
    .otherwise(0)  # Otherwise, assign 0 (low risk)
)

# Create a new column 'age_group' based on the age of the patient
df = df.withColumn(
    "age_group",  # New column name
    when(col("age") < 30, "Under 30")  # Age group 'Under 30'
    .when((col("age") >= 30) & (col("age") < 40), "30-40")  # Age group '30-40'
    .when((col("age") >= 40) & (col("age") < 50), "40-50")  # Age group '40-50'
    .when((col("age") >= 50) & (col("age") < 60), "50-60")  # Age group '50-60'
    .when(col("age") >= 60, "60+")  # Age group '60+'
    .otherwise("Unknown")  # Default value if age doesn't fit any of the categories
)

# Show the updated DataFrame with the new columns
df.show(5)

# Save the resulting DataFrame in Parquet format
try:
    # Writing the data to the specified output path in Parquet format
    df.write.parquet(output_path)  
    print(f"Data successfully saved in Parquet format at {output_path}")
    
except Exception as e:
    # Error handling if saving to Parquet fails
    print(f"An error occurred while saving to Parquet format: {e}")


+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+------+------------------+---------+
|age|sex| cp|trestbps|chol|fbs|restecg|thalach|exang|oldpeak|slope| ca|thal|target|high_risk_category|age_group|
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+------+------------------+---------+
| 52|  1|  0|     125| 212|  0|      1|    168|    0|    1.0|    2|  2|   3|     0|                 0|    50-60|
| 53|  1|  0|     140| 203|  1|      0|    155|    1|    3.1|    0|  0|   3|     0|                 0|    50-60|
| 70|  1|  0|     145| 174|  0|      1|    125|    1|    2.6|    0|  0|   3|     0|                 1|      60+|
| 61|  1|  0|     148| 203|  0|      1|    161|    0|    0.0|    2|  1|   3|     0|                 1|      60+|
| 62|  0|  0|     138| 294|  1|      1|    106|    0|    1.9|    1|  3|   2|     0|                 1|      60+|
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+------+-------------

                                                                                

Data successfully saved in Parquet format at output_parquet


In [11]:
from pyspark.sql import functions as F

# Funkcija za analizu odnosa između visoko-rizičnih osoba i targeta
def analyze_high_risk_vs_target(df):
    try:
        # Grupiranje podataka prema visokom riziku i targetu
        result = df.groupBy("high_risk_category", "target") \
            .agg(F.count("*").alias("count")) \
            .orderBy("high_risk_category", "target")

        print("Relation between high risk category and heart disease (target):")
        result.show()

        # Spremanje rezultata u Parquet format
        result.write.parquet("high_risk_vs_target_distribution.parquet")
        print("High risk vs target distribution successfully saved in Parquet format.")
    
    except Exception as e:
        print(f"An error occurred: {e}")

# Poziv funkcije za analizu
analyze_high_risk_vs_target(df)


Relation between high risk category and heart disease (target):
+------------------+------+-----+
|high_risk_category|target|count|
+------------------+------+-----+
|                 0|     0|  151|
|                 0|     1|  251|
|                 1|     0|  348|
|                 1|     1|  275|
+------------------+------+-----+

High risk vs target distribution successfully saved in Parquet format.


In [12]:
# Function to analyze the distribution of heart disease by age group and sex
def analyze_heart_disease_distribution(df):
    try:
        # Grouping data by age group, sex, and target (heart disease status), then calculating the count
        result = df.groupBy("age_group", "sex", "target") \
            .agg(F.count("*").alias("count"))  # Counting the occurrences in each group
        
        # Ordering the result by age group, sex, and target for better readability
        result = result.orderBy(
            F.when(F.col("age_group") == "Under 30", 0)  # Assign numeric values to age groups for sorting
             .when(F.col("age_group") == "30-40", 1)
             .when(F.col("age_group") == "40-50", 2)
             .when(F.col("age_group") == "50-60", 3)
             .when(F.col("age_group") == "60+", 4)
             .otherwise(5),
            "sex",  # Sort by sex
            "target"  # Sort by heart disease status (target)
        )

        # Displaying the results
        print("Heart disease distribution by age and sex:")
        result.show()

        # Saving the results in Parquet format for future use
        result.write.parquet("heart_disease_distribution_by_age_and_sex.parquet")
        print("Heart disease distribution successfully saved in Parquet format.")
    
    except Exception as e:
        # Error handling in case of any issues during the analysis
        print(f"An error occurred during the analysis of heart disease distribution: {e}")

# Calling the function to analyze the distribution
analyze_heart_disease_distribution(df)


Heart disease distribution by age and sex:
+---------+---+------+-----+
|age_group|sex|target|count|
+---------+---+------+-----+
| Under 30|  1|     1|    4|
|    30-40|  0|     1|   17|
|    30-40|  1|     0|   15|
|    30-40|  1|     1|   21|
|    40-50|  0|     0|    4|
|    40-50|  0|     1|   55|
|    40-50|  1|     0|   76|
|    40-50|  1|     1|  102|
|    50-60|  0|     0|   35|
|    50-60|  0|     1|   74|
|    50-60|  1|     0|  181|
|    50-60|  1|     1|  132|
|      60+|  0|     0|   47|
|      60+|  0|     1|   80|
|      60+|  1|     0|  141|
|      60+|  1|     1|   41|
+---------+---+------+-----+

Heart disease distribution successfully saved in Parquet format.


In [13]:
df.printSchema()

root
 |-- age: integer (nullable = true)
 |-- sex: integer (nullable = true)
 |-- cp: integer (nullable = true)
 |-- trestbps: integer (nullable = true)
 |-- chol: integer (nullable = true)
 |-- fbs: integer (nullable = true)
 |-- restecg: integer (nullable = true)
 |-- thalach: integer (nullable = true)
 |-- exang: integer (nullable = true)
 |-- oldpeak: double (nullable = true)
 |-- slope: integer (nullable = true)
 |-- ca: integer (nullable = true)
 |-- thal: integer (nullable = true)
 |-- target: integer (nullable = true)
 |-- high_risk_category: integer (nullable = false)
 |-- age_group: string (nullable = false)



In [14]:
# Function to calculate and rank correlations of all numerical features with the target variable
def calculate_all_correlations(df):
    try:
        # Extracting all numerical columns from the DataFrame, excluding 'target'
        numerical_columns = [field.name for field in df.schema.fields if isinstance(field.dataType, (IntegerType, DoubleType)) and field.name != "target"]

        # Initializing a dictionary to store correlations
        correlations = {}

        # Computing the correlation for each numerical variable with the target
        for column in numerical_columns:
            correlation_value = df.stat.corr(column, "target")
            correlations[column] = abs(correlation_value)  # Absolute correlation value

        # Sorting correlations in descending order
        sorted_correlations = sorted(correlations.items(), key=lambda item: item[1], reverse=True)

        # Saving all correlations to Parquet format
        result_df = spark.createDataFrame(sorted_correlations, ["Factor", "Correlation"])
        result_df.write.parquet("all_correlations.parquet")
        print("All variable correlations successfully saved in Parquet format.")

        # Displaying the top 3 most significant factors
        print("Top 3 most significant factors correlated with heart disease:")
        for i in range(3):
            factor, correlation = sorted_correlations[i]
            print(f"{i+1}. {factor} - Correlation: {correlation:.4f}")

        # Saving the top 3 most significant factors to Parquet format
        top_3_df = spark.createDataFrame(sorted_correlations[:3], ["Factor", "Correlation"])
        top_3_df.write.parquet("top_3_significant_factors.parquet")
        print("Top 3 most significant factors successfully saved in Parquet format.")

    except Exception as e:
        # Handling errors during correlation analysis
        print(f"An error occurred during correlation analysis: {e}")

# Calling the function to analyze all correlations
calculate_all_correlations(df)


25/01/18 22:25:47 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

All variable correlations successfully saved in Parquet format.
Top 3 most significant factors correlated with heart disease:
1. oldpeak - Correlation: 0.4384
2. exang - Correlation: 0.4380
3. cp - Correlation: 0.4349
Top 3 most significant factors successfully saved in Parquet format.


In [15]:
# Function to create and count risk profiles based on age and other features
def create_and_count_risk_profiles(df):
    try:
        # Create new risk profiles based on age and other health features
        df = df.withColumn(
            "risk_profile", 
            F.when((F.col("age_group") == "Under 30") & (F.col("trestbps") > 130) & (F.col("chol") > 240), "High Risk")  # High risk condition for age < 30
            .when((F.col("age_group") == "30-40") & (F.col("trestbps") > 140) & (F.col("thalach") < 120), "High Risk")  # High risk condition for age between 30-40
            .when((F.col("age_group") == "40-50") & (F.col("chol") > 250) & (F.col("thalach") < 130), "High Risk")  # High risk condition for age between 40-50
            .when((F.col("age_group") == "50-60") & (F.col("trestbps") > 150) & (F.col("chol") > 260), "High Risk")  # High risk condition for age between 50-60
            .when((F.col("age_group") == "60+") & (F.col("trestbps") > 160) & (F.col("thalach") < 100), "High Risk")  # High risk condition for age > 60
            .when((F.col("age_group") == "Under 30") & (F.col("trestbps") <= 130) & (F.col("chol") <= 240), "Low Risk")  # Low risk condition for age < 30
            .when((F.col("age_group") == "30-40") & (F.col("trestbps") <= 140) & (F.col("thalach") >= 120), "Low Risk")  # Low risk condition for age between 30-40
            .when((F.col("age_group") == "40-50") & (F.col("chol") <= 250) & (F.col("thalach") >= 130), "Low Risk")  # Low risk condition for age between 40-50
            .when((F.col("age_group") == "50-60") & (F.col("trestbps") <= 150) & (F.col("chol") <= 260), "Low Risk")  # Low risk condition for age between 50-60
            .when((F.col("age_group") == "60+") & (F.col("trestbps") <= 160) & (F.col("thalach") >= 100), "Low Risk")  # Low risk condition for age > 60
            .otherwise("Moderate Risk")  # If none of the above conditions are met, assign "Moderate Risk"
        )

        # Group data by age group and risk profile, then count the number of occurrences in each group
        risk_profile_counts = df.groupBy("age_group", "risk_profile").count()

        # Print the risk profile counts for each category
        print("Risk profile counts by category:")
        risk_profile_counts.show()

        # Save the result in Parquet format
        risk_profile_counts.write.parquet("risk_profile_counts.parquet")
        print("Risk profile counts successfully saved in Parquet format.")

    except Exception as e:
        # Error handling in case of any issues during the risk profile creation or counting
        print(f"Error occurred while creating and counting risk profiles: {e}")

# Call the function to create and count risk profiles
create_and_count_risk_profiles(df)


Risk profile counts by category:
+---------+-------------+-----+
|age_group| risk_profile|count|
+---------+-------------+-----+
|    50-60|     Low Risk|  245|
|    50-60|    High Risk|   38|
|    40-50|     Low Risk|  144|
|    40-50|Moderate Risk|   78|
| Under 30|     Low Risk|    4|
|      60+|     Low Risk|  277|
|    50-60|Moderate Risk|  139|
|    40-50|    High Risk|   15|
|      60+|Moderate Risk|   32|
|    30-40|     Low Risk|   53|
+---------+-------------+-----+

Risk profile counts successfully saved in Parquet format.


In [16]:
from pyspark.sql.window import Window  # Import Window for windowed operations

# Function to analyze the relationship between chest pain type and heart disease
def analyze_chest_pain_and_heart_disease(df):
    try:
        # Group the data by chest pain type ('cp') and presence of heart disease ('target')
        result = (
            df.groupBy("cp", "target")
            .agg(F.count("*").alias("count"))  # Count the number of occurrences in each group
            .withColumn(
                "percentage",  # Calculate the percentage of each target within each chest pain type
                F.round(
                    F.col("count") / F.sum("count").over(Window.partitionBy("cp")) * 100, 
                    2
                )
            )
        )

        
        # The 'Window.partitionBy("cp")' creates a window partitioned by chest pain type ('cp'),
        # meaning that for each unique chest pain type, we are calculating the sum of counts for each group 
        # to get the percentage. This ensures that the percentage calculation is done within each chest pain type.
        # The 'over(Window.partitionBy("cp"))' ensures that the sum of counts is calculated for each partition (chest pain type)
        # rather than for the entire dataset.

        # Sort the result for better readability
        result = result.orderBy("cp", "target")

        # Print the result
        print("Relationship between chest pain type and presence of heart disease:")
        result.show(truncate=False)

        # Save the result in Parquet format
        result.write.parquet("chest_pain_heart_disease_analysis.parquet")
        print("Results successfully saved in Parquet format.")

    except Exception as e:
        # Error handling in case of any issues during the analysis
        print(f"Error occurred while analyzing the relationship between chest pain and heart disease: {e}")

# Call the function to analyze the relationship
analyze_chest_pain_and_heart_disease(df)


Relationship between chest pain type and presence of heart disease:
+---+------+-----+----------+
|cp |target|count|percentage|
+---+------+-----+----------+
|0  |0     |375  |75.45     |
|0  |1     |122  |24.55     |
|1  |0     |33   |19.76     |
|1  |1     |134  |80.24     |
|2  |0     |65   |22.89     |
|2  |1     |219  |77.11     |
|3  |0     |26   |33.77     |
|3  |1     |51   |66.23     |
+---+------+-----+----------+

Results successfully saved in Parquet format.
