In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
            

dataset = pd.read_csv("../../data/processed/GYTS_dataset.csv")

# Convert columns to categorical
categorical_columns = ["State", "Gender", "Age", "SmokingFriends", "SeenSmokerInPublicPlace",
                       "SeenSmokerInEnclosedPlace", "SeenSmokerInHome", "AttractiveSmoker",
                       "HardQuitSmoke", "SmokerConfidentInCelebrations", "SchoolWarnings",
                       "SeenHealthWarnings", "AntiTobaccoInEvents", "HarmfulPassiveSmoke"]
dataset[categorical_columns] = dataset[categorical_columns].astype('category')

# Convert to boolean
boolean_columns = ["Smoke", "SeenSmokerInSchool", "ParentWarnings", "AntiTobaccoInMedia",
                   "BanTobaccoOutdoors", "SmokingFather", "SmokingMother", "WorkingFather",
                   "WorkingMother"]
dataset[boolean_columns] = dataset[boolean_columns].astype('bool')

classes = dataset['Smoke'].unique()

list_of_columns = ["State", "Gender", "Smoke", "SmokingFriends", 'SmokingFather', 'SmokingMother', 'WorkingFather', 'WorkingMother',
                       "SeenSmokerInSchool", "SeenSmokerInPublicPlace", "SeenSmokerInEnclosedPlace",
                       "SeenSmokerInHome", "ParentWarnings", "AttractiveSmoker", "HardQuitSmoke",
                       "SmokerConfidentInCelebrations", "SchoolWarnings", "SeenHealthWarnings",
                       "AntiTobaccoInEvents", "AntiTobaccoInMedia", "BanTobaccoOutdoors",
                       "HarmfulPassiveSmoke"
                    ]

In [None]:
# Univariate analysis 
print("Univariate analysis")

# Bar plots
for i, column in enumerate(list_of_columns):
    fig = plt.plot(figsize=(24, 16))
    ax = sns.countplot(x=column, data=dataset, palette='rainbow', zorder=10)
    ax.grid(axis='y', alpha=0.6, zorder=0)
    ax.bar_label(ax.containers[0], fmt=lambda x: f'{(x/len(dataset[column]))*100:0.1f}%')
    if column == "SeenHealthWarnings":
        ax.set_xticklabels(ax.get_xticklabels(), rotation=25, ha="right")
    # add grid to ax
   
    ax.set(title=f"{column} histogram");
    
    plt.savefig(f"../../data/processed/univariate_analysis/{column}_histogram.png", dpi=600)
    plt.show()

# Multivariate analysis
print("Multivariate analysis")

# Correlation matrix and heatmap
corr_matrix = dataset[list_of_columns].corr(numeric_only=True)
plt.figure(figsize=(12, 8))
sns.heatmap(corr_matrix, annot=True, cmap='crest', fmt=".2f")
plt.show()

# Bar plots
for i, column in enumerate(list_of_columns):
    fig = plt.plot(figsize=(24, 16))
    ax = sns.countplot(x='Smoke', hue=column, data=dataset, palette='rainbow', dodge=True, zorder=10)
    ax.grid(axis='y', alpha=0.6, zorder=0)
    for container in ax.containers:
        ax.bar_label(container, fmt=lambda x: f'{(x/len(dataset[column]))*100:0.1f}%')
    ax.set(title=f"Smoker vs {column}");
    
    plt.savefig(f"../../data/processed/multivariate_analysis/{column}_histogram.png", dpi=600)
    plt.show()

In [None]:
# from scipy.stats import chi2_contingency
from scipy.stats.contingency import association
    

association_table = pd.DataFrame(index=list_of_columns, columns=list_of_columns)
association_table = association_table.astype('float')

for col1 in list_of_columns:
    for col2 in list_of_columns:
        frequency_table = pd.crosstab(dataset[col1], dataset[col2])
        association_value = association(frequency_table) 
        association_table.loc[col1, col2] = association_value


plt.figure(figsize=(24, 18))
ax = sns.heatmap(association_table, annot=True, fmt=".2f")
plt.savefig("../../data/processed/association_table.png", dpi=600)
plt.show()


# association_table.to_csv("../../data/processed/association_table.csv")


In [None]:
# Interesting countplots with association > 0.25
for col1 in list_of_columns:
    for col2 in list_of_columns:
        if col1 != col2 and association_table.loc[col1, col2] > 0.25:
            fig = plt.plot(figsize=(24, 16))
            ax = sns.countplot(x=col1, hue=col2, data=dataset, palette='rainbow', dodge=True, zorder=10)
            ax.grid(axis='y', alpha=0.6, zorder=0)
            for container in ax.containers:
                ax.bar_label(container, fmt=lambda x: f'{(x/len(dataset[col1]))*100:0.1f}%')
            ax.set(title=f"{col1} vs {col2}");
            
            plt.savefig(f"../../data/processed/multivariate_analysis/{col1}_histogram.png", dpi=600)
            plt.show()
