In [37]:
from scrapy import Selector
import requests

import re
from typing import List

from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, FloatType, LongType, StringType, DoubleType
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import when, col
from pyspark.ml import Pipeline, Transformer
from pyspark.ml.feature import StringIndexer, VectorAssembler, Imputer
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import BinaryClassificationEvaluator
import pyspark.sql.functions as F
from itertools import combinations
import os

## Check Python Path

In [20]:
import sys
sys.executable

'/tmp/demos/bin/python3'

In [21]:
DATA_FOLDER = "../data"

NUMBER_OF_FOLDS = 3
SPLIT_SEED = 7576
TRAIN_TEST_SPLIT = 0.9

## Function for data reading

In [22]:

def read_data(spark: SparkSession) -> DataFrame:
    """
    read data; since the data has the header we let spark guess the schema
    """
    
    # Read the CSV data into a DataFrame
    data = spark.read \
        .format("csv") \
        .option("header", "true") \
        .option("inferSchema", "true") \
        .load(os.path.join(DATA_FOLDER,"heart_disease.csv"))

    return data

## Cleaning

In [28]:
def retain_cols(data: DataFrame) -> DataFrame:
    columns_to_retain = ['age', 'sex', 'painloc', 'painexer', 'cp', 'trestbps', 'smoke', 
                         'fbs', 'prop', 'nitr', 'pro', 'diuretic', 'thaldur', 'thalach', 
                         'exang', 'oldpeak', 'slope', 'target']
    
    filtered_data = data.select(columns_to_retain)
    return filtered_data
    
def replace_out_of_range(data: DataFrame) -> DataFrame:
    data = data.withColumn('painloc', when(col('painloc') < 0, 0).when(col('painloc') > 1, 1).otherwise(col('painloc')))
    data = data.withColumn('painexer', when(col('painexer') < 0, 0).when(col('painexer') > 1, 1).otherwise(col('painexer')))
    data = data.withColumn('trestbps', when(col('trestbps') < 100, 100).otherwise(col('trestbps')))
    data = data.withColumn('oldpeak', when(col('oldpeak') < 0, 0).when(col('oldpeak') > 4, 4).otherwise(col('oldpeak')))
    data = data.withColumn('fbs', when(col('fbs') < 0, 0).when(col('fbs') > 1, 1).otherwise(col('fbs')))
    data = data.withColumn('prop', when(col('prop') < 0, 0).when(col('prop') > 1, 1).otherwise(col('prop')))
    data = data.withColumn('nitr', when(col('nitr') < 0, 0).when(col('nitr') > 1, 1).otherwise(col('nitr')))
    data = data.withColumn('pro', when(col('pro') < 0, 0).when(col('pro') > 1, 1).otherwise(col('pro')))
    data = data.withColumn('diuretic', when(col('diuretic') < 0, 0).when(col('diuretic') > 1, 1).otherwise(col('diuretic')))
    data = data.withColumn('exang', when(col('exang') < 0, 0).when(col('exang') > 1, 1).otherwise(col('exang')))
    data = data.withColumn('slope', when(col('slope') < 1, None).when(col('slope') > 3, None).otherwise(col('slope')))
    return data

In [32]:
def smoke_1(data: DataFrame) -> DataFrame:
    url1 = 'https://www.abs.gov.au/statistics/health/health-conditions-and-risks/smoking-and-vaping/latest-release'
    response = requests.get(url1)
        
    # get the HTML file as a string
    html_content = response.content
    
    # create a selector object
    full_sel = Selector(text=html_content)
    
    # select all tables in page -> returns a SelectorList object
    tables = full_sel.xpath('//table')
    smokers_by_age = tables[1]
    # get the rows
    rows = smokers_by_age.xpath('./tbody//tr')

    def parse_row_1(row:Selector) -> List[str]:
        '''
        Parses a html row into a list of individual elements
        '''
        cells = row.xpath('.//th | .//td')
        row_data = []
        
        for i, cell in enumerate(cells):
            if i == 0 or i == 10:
                cell_text = cell.xpath('normalize-space(.)').get()
                cell_text = re.sub(r'<.*?>', ' ', cell_text)  # Remove remaining HTML tags
                # if there are br tags, there will be some binary characters
                cell_text = cell_text.replace('\xa0', '')  # Remove \xa0 characters
                row_data.append(cell_text)
        
        return row_data
    
    table_data = [parse_row_1(row) for row in rows]

    def get_rate_1(age):
        try:
            age = int(age)
            for i, row in enumerate(table_data):
                if i < len(table_data) - 1:
                    cutoff = row[0].split('–')[1]
                    if age <= int(cutoff):
                        return float(row[1])
                else:
                    return float(row[1])
        except:
            return np.nan
    
    # Register the UDF
    get_rate_1_udf = F.udf(lambda age: get_rate_1(age) / 100, DoubleType())

    data = data.withColumn('smoke_1', when(col('smoke_1').isNull(), get_rate_1_udf(col('age'))).otherwise(col('smoke_1')))

    return data

def smoke_2(data: DataFrame) -> DataFrame:
    url2 = 'https://www.cdc.gov/tobacco/data_statistics/fact_sheets/adult_data/cig_smoking/index.htm'
    response = requests.get(url2)

    # Create a scrapy Selector from the response content
    selector = Selector(text=response.content)

    ul_sel_list = selector.xpath('//ul[@class="block-list"]')
    genders = ul_sel_list[0]
    ages = ul_sel_list[1]

    def clean_gender_percents(rows):
        dict = {}
        for row in rows:
            gender = 'woman' if 'women' in row.split('(')[0] else 'man'
            percent = float(row.split('(')[1].split('%')[0])
            dict[gender] = float(percent)
        return dict

    def clean_age_percents(rows):
        for i, row in enumerate(rows):
            if i < len(rows) - 1:
                age = int(row.split('–')[1].split(' ')[0])
            else:
                age = int(row.split(' ')[7])
                
            percent = float(row.split('(')[1].split('%')[0])
            rows[i] = [age, percent]
        return rows

    def parse_row_2(row:Selector) -> List[str]:
        '''
        Parses a html row into a list of individual elements
        '''
        cells = row.xpath('./li')
        row_data = []
        
        for i, cell in enumerate(cells):
            cell_text = cell.xpath('normalize-space(.)').get()
            cell_text = re.sub(r'<.*?>', ' ', cell_text)  # Remove remaining HTML tags
            # if there are br tags, there will be some binary characters
            cell_text = cell_text.replace('\xa0', '')  # Remove \xa0 characters
            row_data.append(cell_text)
        
        return row_data

    per_by_gender = clean_gender_percents(parse_row_2(genders))
    per_by_age = clean_age_percents(parse_row_2(ages))

    def get_rate_2(sex, age):
        if sex == 0:
            try:
                age = int(age)
                for i, row in enumerate(per_by_age):
                    if i < len(per_by_age) - 1:
                        if age <= row[0]:
                            return row[1]
                    else:
                        return row[1]
            except:
                return np.nan
        else:
            try:
                age = int(age)
                for i, row in enumerate(per_by_age):
                    if i < len(per_by_age) - 1:
                        if age <= row[0]:
                            return row[1] * per_by_gender['man'] / per_by_gender['woman']
                    else:
                        return row[1] * per_by_gender['man'] / per_by_gender['woman']
            except:
                return np.nan

    # Register the UDF
    get_rate_2_udf = F.udf(lambda sex, age: get_rate_2(sex, age) / 100, DoubleType())

    data = data.withColumn('smoke_2', when(col('smoke_2').isNull(), get_rate_2_udf(col('sex'), col('age'))).otherwise(col('smoke_2')))

    return data 

def impute_smoke(data: DataFrame) -> DataFrame:
    data = data.withColumn('smoke_1', F.col('smoke'))
    data = data.withColumn('smoke_2', F.col('smoke'))

    data = smoke_1(data)
    data = smoke_2(data)

    data = data.drop('smoke')
    
    return data

## The ML pipeline

In [33]:
def pipeline(data: DataFrame):

    """
    every attribute that is numeric is non-categorical; this is questionable
    """

    data = retain_cols(data)
    data = replace_out_of_range(data)
    data = impute_smoke(data)

    # drop null targets
    data = data.dropna(subset=['target'])

    # make age an int
    data = data.withColumn("age", data["age"].cast(IntegerType()))


    numeric_features = [f.name for f in data.schema.fields if isinstance(f.dataType, (DoubleType, FloatType, IntegerType, LongType))]
    string_features = [f.name for f in data.schema.fields if isinstance(f.dataType, StringType)]

    print(numeric_features)
    print(string_features)

    # Fill missing values for string columns with a placeholder before indexing
    data = data.fillna({col: 'null' for col in string_features})
    
    # Index string features
    indexed_string_columns = [f"{v}Index" for v in string_features]
    indexers = [StringIndexer(inputCol=col, outputCol=indexed_col, handleInvalid='keep') for col, indexed_col in zip(string_features, indexed_string_columns)]

    # Impute missing values for indexed string columns
    imputed_columns_string = [f"Imputed{v}" for v in indexed_string_columns]
    imputer_string = Imputer(inputCols=indexed_string_columns, outputCols=imputed_columns_string, strategy="mode")

    
    # numeric columns
    imputed_columns_numeric = [f"Imputed{v}" for v in numeric_features]
    imputer_numeric = Imputer(inputCols=numeric_features, outputCols=imputed_columns_numeric, strategy = "mean")


    # Assemble feature columns into a single feature vector
    assembler = VectorAssembler(
        inputCols=imputed_columns_numeric + imputed_columns_string, 
        outputCol="features"
        )

    # Create a list of pipeline stages
    stages = indexers + [imputer_string, imputer_numeric, assembler]
    
    # Create and fit the pipeline
    pipeline = Pipeline(stages=stages)
    model = pipeline.fit(data)
    
    # Transform the data
    transformed_data = model.transform(data)
    
    return transformed_data

    
    

In [30]:
def main():
    # Create a Spark session
    spark = SparkSession.builder \
        .appName("Predict Heart Disease") \
        .getOrCreate()

    try:
        # Read data
        data = read_data(spark)
        
        # Print schema and preview the data
        data.printSchema()
        data.show(5)

        # Apply the pipeline
        transformed_data = pipeline(data)
        
        # Show the transformed data, including the imputed columns
        columns_to_show = [col for col in transformed_data.columns if col.startswith("Imputed")]
        transformed_data.select(columns_to_show).show(truncate=False)
        
    finally:
        # Stop the Spark session
        spark.stop()

In [38]:
main()

24/05/26 21:27:31 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


root
 |-- age: string (nullable = true)
 |-- sex: integer (nullable = true)
 |-- painloc: integer (nullable = true)
 |-- painexer: integer (nullable = true)
 |-- relrest: integer (nullable = true)
 |-- pncaden: string (nullable = true)
 |-- cp: integer (nullable = true)
 |-- trestbps: integer (nullable = true)
 |-- htn: integer (nullable = true)
 |-- chol: integer (nullable = true)
 |-- smoke: integer (nullable = true)
 |-- cigs: integer (nullable = true)
 |-- years: integer (nullable = true)
 |-- fbs: integer (nullable = true)
 |-- dm: integer (nullable = true)
 |-- famhist: integer (nullable = true)
 |-- restecg: integer (nullable = true)
 |-- ekgmo: integer (nullable = true)
 |-- ekgday(day: integer (nullable = true)
 |-- ekgyr: integer (nullable = true)
 |-- dig: integer (nullable = true)
 |-- prop: integer (nullable = true)
 |-- nitr: integer (nullable = true)
 |-- pro: integer (nullable = true)
 |-- diuretic: integer (nullable = true)
 |-- proto: integer (nullable = true)
 |-- th

AttributeError: __provides__