In [305]:
import os
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np

import torch
from torch.nn import Linear

import ipywidgets as widgets
from IPython.display import display, clear_output

In [306]:
print(os.listdir("Data/Input"))

['Life Expectancy Data.csv', 'Test', 'Train']


# **1 &nbsp;&nbsp;&nbsp; Load Data and Model**

In [307]:
# Features tensor
train_df = pd.read_csv("Data/Input/Train/train.csv")
test_df = pd.read_csv("Data/Input/Test/test.csv")
features = torch.tensor(train_df.drop(columns=["Country", "Year", "Life expectancy "]).values, dtype=torch.float32)

In [308]:
# Target tensor
target = torch.tensor(train_df["Life expectancy "].values, dtype=torch.float32)

In [309]:
# Application model
app_model = Linear(in_features=train_df.drop(columns=["Country", "Year", "Life expectancy "]).shape[1], out_features=1)
app_model.load_state_dict(torch.load("Data/Output/AllFeaturesModel/all_features_model.pth"))

<All keys matched successfully>

# **2 &nbsp;&nbsp;&nbsp; Widgets**

In [310]:
# Sample Input
status_value = "Developing"
[adult_mort_value, infant_death_value, alcohol_value, exp_value, hep_value, measles_value, bmi_value, 
 under5_value, polio_value, total_exp_value, diph_value, hiv_value, gdp_value, pop_value, thin1_value, 
 thin5_value, income_value, school_value] = train_df.iloc[0, 4:]

## **2.1 &nbsp;&nbsp;&nbsp; Input Data**

In [311]:
status_list = ['Developed', 'Developing']
status_input = widgets.Dropdown(
    options=status_list,
    value=status_value,
    description='Status:',
    layout=widgets.Layout(width='400px')
)

In [312]:
adult_mort_input = widgets.FloatText(
    value=adult_mort_value,
    description='Adult Mortality:',
    layout=widgets.Layout(width='400px')
)

In [313]:
infant_death_input = widgets.IntText(
    value=infant_death_value,
    description='Infant Deaths:',
    layout=widgets.Layout(width='400px')
)

In [314]:
alcohol_input = widgets.FloatText(
    value=alcohol_value,
    description='Alcohol:',
    layout=widgets.Layout(width='400px')
)

In [315]:
exp_input = widgets.FloatText(
    value=exp_value,
    description='Percentage Expenditure:',
    layout=widgets.Layout(width='400px')
)

In [316]:
hep_input = widgets.FloatText(
    value=hep_value,
    description='Hepatitis B:',
    layout=widgets.Layout(width='400px')
)

In [317]:
measles_input = widgets.IntText(
    value=measles_value,
    description='Measles:',
    layout=widgets.Layout(width='400px')
)

In [318]:
bmi_input = widgets.FloatText(
    value=bmi_value,
    description='BMI:',
    layout=widgets.Layout(width='400px')
)

In [319]:
under5_input = widgets.IntText(
    value=under5_value,
    description='Under-five Deaths:',
    layout=widgets.Layout(width='400px')
)

In [320]:
polio_input = widgets.FloatText(
    value=polio_value,
    description='Polio:',
    layout=widgets.Layout(width='400px')
)

In [321]:
total_exp_input = widgets.FloatText(
    value=total_exp_value,
    description='Total Expenditure:',
    layout=widgets.Layout(width='400px')
)

In [322]:
diph_input = widgets.FloatText(
    value=diph_value,
    description='Diphtheria:',
    layout=widgets.Layout(width='400px')
)

In [323]:
hiv_input = widgets.FloatText(
    value=hiv_value,
    description='HIV/AIDS:',
    layout=widgets.Layout(width='400px')
)

In [324]:
gdp_input = widgets.FloatText(
    value=gdp_value,
    description='GDP:',
    layout=widgets.Layout(width='400px')
)

In [325]:
pop_input = widgets.FloatText(
    value=pop_value,
    description='Population:',
    layout=widgets.Layout(width='400px')
)

In [326]:
thin1_input = widgets.FloatText(
    value=thin1_value,
    description='Thinness 1-19 years:',
    layout=widgets.Layout(width='400px')
)

In [327]:
thin5_input = widgets.FloatText(
    value=thin5_value,
    description='Thinness 5-9 years:',
    layout=widgets.Layout(width='400px')
)

In [328]:
income_input = widgets.FloatText(
    value=income_value,
    description='Income Composition of Resources:',
    layout=widgets.Layout(width='400px') 
)

In [329]:
school_input = widgets.FloatText(
    value=school_value,
    description='Schooling:',
    layout=widgets.Layout(width='400px')
)

## **2.2 &nbsp;&nbsp;&nbsp; Prediction Button**

In [None]:
predict_button = widgets.Button(description="Predict Life Expectancy", button_style='success',layout=widgets.Layout(width='400px'))
output = widgets.Output(width='700px', height='500px')

input_box = widgets.VBox(
    [status_input, adult_mort_input, infant_death_input, alcohol_input,
    exp_input, hep_input, measles_input, bmi_input,
    under5_input, polio_input, total_exp_input, diph_input,
    hiv_input, gdp_input, pop_input, thin1_input,
    thin5_input, income_input, school_input,
    predict_button, output],
    layout=widgets.Layout(align_items='flex-start')
)

In [331]:
def predict_life_expectancy(x):
    with torch.no_grad():
       pred = app_model(x).item()
        
    return pred

def get_input_data():
    input_data = {
        "Status": [0 if status_input.value == "Developed" else 1],
        "Adult Mortality": [adult_mort_input.value], 
        "Infant Deaths": [infant_death_input.value], 
        "Alcohol": [alcohol_input.value],
        "Percentage Expenditure": [exp_input.value], 
        "Hepatitis B": [hep_input.value], 
        "Measles": [measles_input.value], 
        "BMI": [bmi_input.value],
        "Under-five Deaths": [under5_input.value], 
        "Polio": [polio_input.value], 
        "Total expenditure": [total_exp_input.value], 
        "Diphtheria": [diph_input.value],
        "HIV/AIDS": [hiv_input.value], 
        "GDP": [gdp_input.value], 
        "Population": [pop_input.value], 
        "Thinness 1-19 Years": [thin1_input.value],
        "Thinness 5-9 Years": [thin5_input.value], 
        "Income Composition of Resources": [income_input.value],
        "Schooling": [school_input.value]
    }
    
    return pd.DataFrame(data=input_data)
    
def on_predict_click(b):
    input_data = get_input_data()
    
    input_tensor = torch.tensor(input_data.values, dtype=torch.float32)
    input_tensor = (input_tensor - features.mean(dim=0)) / features.std(dim=0)
    
    with output:
        clear_output()
        le = predict_life_expectancy(input_tensor)        
        print(f"Life Expectancy: {le * target.std() + target.mean():.2f} years")

## **2.3 &nbsp;&nbsp;&nbsp; Draw Plot**

In [332]:
draw_button = widgets.Button(description="Draw Plot", button_style='success', layout=widgets.Layout(width='400px'))
features_list = [
    'Status', 'Adult Mortality', 'Infant Deaths', 'Alcohol', 'Percentage Expenditure',
    'Hepatitis B', 'Measles', 'BMI', 'Under-five Deaths', 'Polio', 'Total Expenditure', 'Diphtheria', 
    'HIV/AIDS', 'GDP', 'Population', 'Thinness 1-19 years', 'Thinness 5-9 years', 'Income Composition of Resources', 'Schooling'
]
features_input = widgets.Dropdown(
    options=features_list,
    description='Feature:',
    layout=widgets.Layout(width='400px')
)
out_plot = widgets.Output(layout=widgets.Layout(border='1px solid gray', width='700px', height='590px'))

out_box = widgets.VBox([features_input, out_plot, draw_button], layout=widgets.Layout(align_items='flex-end'))

In [333]:
df_cols_dict = {
    'Status': 'Status', 
    'Adult Mortality': 'Adult Mortality', 
    'Infant Deaths': 'infant deaths', 
    'Alcohol': 'Alcohol', 
    'Percentage Expenditure': 'percentage expenditure',
    'Hepatitis B': 'Hepatitis B', 
    'Measles': 'Measles ', 
    'BMI': ' BMI ', 
    'Under-five Deaths': 'under-five deaths ', 
    'Polio': 'Polio', 
    'Total Expenditure': 'Total expenditure', 
    'Diphtheria': 'Diphtheria ', 
    'HIV/AIDS': ' HIV/AIDS', 
    'GDP': 'GDP', 
    'Population': 'Population', 
    'Thinness 1-19 years': ' thinness  1-19 years', 
    'Thinness 5-9 years': ' thinness 5-9 years', 
    'Income Composition of Resources': 'Income composition of resources', 
    'Schooling': 'Schooling'
}

df = pd.concat([train_df, test_df], axis=0)
df.drop(columns=["Country", "Year"], inplace=True)

In [334]:
def draw_plot(b):
    features_tensor = torch.tensor(df.drop(columns="Life expectancy ").values, dtype=torch.float32)
    features_scale = (features_tensor - features_tensor.mean(dim=0)) / features_tensor.std(dim=0)
    target_tensor = torch.tensor(df["Life expectancy "].values, dtype=torch.float32)
    target_scale = (target_tensor - target_tensor.mean(dim=0)) / target_tensor.std(dim=0)
    
    features_df = get_input_data()
    input_tensor = torch.tensor(features_df.values, dtype=torch.float32)
    input_scale = (input_tensor - features.mean(dim=0)) / features.std(dim=0)
    predict_scale = predict_life_expectancy(input_scale)
    
    out_plot.layout.border = '0px'
    with out_plot:
        clear_output()
        
        plt.figure(figsize=(10, 7))
        
        feature_index = df.drop(columns="Life expectancy ").columns.get_loc(df_cols_dict[features_input.value])
        
        # Data Points
        plt.scatter(
            x=features_scale[:, feature_index], 
            y=target_scale, 
            s=20, alpha=0.6, facecolors="royalblue", 
            label="Data Points"
        )
        
        # Model Trend
        x_line = torch.linspace(features_scale[:, feature_index].min(), features_scale[:, feature_index].max(), 200)
        y_line = app_model.weight[:, feature_index].item() * x_line + app_model.bias.item()
        plt.plot(x_line, y_line, linewidth=3, color="crimson", label="Model Trend")
        
        # Input Point
        plt.scatter(
            x=input_scale[:, feature_index], 
            y=predict_scale, 
            s=150, alpha=1, facecolors="yellow", edgecolors='black', linewidth=2,
            label="Input Point"
        )
        
        # Lable
        plt.xlabel(f"{features_input.value}")
        plt.ylabel("Life Expectancy")
        plt.title(f"Life Expectancy vs {features_input.value}")
        plt.legend()

        plt.grid(alpha=0.2)
        plt.show()

In [335]:
feature_index = df.drop(columns="Life expectancy ").columns.get_loc(df_cols_dict[features_input.value])
print(feature_index)

0


# **3 &nbsp;&nbsp;&nbsp; Display UI**

In [336]:
display(widgets.HBox([input_box, out_box]))
predict_button.on_click(on_predict_click)
draw_button.on_click(draw_plot)

HBox(children=(VBox(children=(Dropdown(description='Status:', index=1, layout=Layout(width='400px'), options=(â€¦