In [26]:
pip install pyreadstat requests scikit-learn numpy

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [1]:
import numpy as np
import pandas as pd
import pyreadstat
import requests
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.linear_model import Lasso, LinearRegression
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score#, confusion_matrix, ConfusionMatrixDisplay


In [2]:
# Dataframe for the demographics data

url_demo ='https://wwwn.cdc.gov/Nchs/Data/Nhanes/Public/2021/DataFiles/DEMO_L.xpt'

file_name_demo = "DEMO_L.xpt"

response = requests.get(url_demo)
if response.status_code == 200:
    with open(file_name_demo, 'wb') as f:
        f.write(response.content)
    print("File downloaded successfully.")
else:
    print("Failed to download file.")

demo, meta_demo = pyreadstat.read_xport(file_name_demo, encoding='latin1')
print(demo.info())
print(demo.head())

File downloaded successfully.
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 11933 entries, 0 to 11932
Data columns (total 27 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   SEQN      11933 non-null  float64
 1   SDDSRVYR  11933 non-null  float64
 2   RIDSTATR  11933 non-null  float64
 3   RIAGENDR  11933 non-null  float64
 4   RIDAGEYR  11933 non-null  float64
 5   RIDAGEMN  377 non-null    float64
 6   RIDRETH1  11933 non-null  float64
 7   RIDRETH3  11933 non-null  float64
 8   RIDEXMON  8860 non-null   float64
 9   RIDEXAGM  2787 non-null   float64
 10  DMQMILIZ  8301 non-null   float64
 11  DMDBORN4  11914 non-null  float64
 12  DMDYRUSR  1875 non-null   float64
 13  DMDEDUC2  7794 non-null   float64
 14  DMDMARTZ  7792 non-null   float64
 15  RIDEXPRG  1503 non-null   float64
 16  DMDHHSIZ  11933 non-null  float64
 17  DMDHRGND  4115 non-null   float64
 18  DMDHRAGZ  4124 non-null   float64
 19  DMDHREDZ  3746 non-null   float64
 20

In [3]:
demo_cols = {
'SEQN':'id',
'SDDSRVYR' : 'Data release cycle',
'RIDSTATR' : 'Interview/Examination status',
'RIAGENDR' : 'Gender',
'RIDAGEYR' : 'Age in years at screening',
'RIDAGEMN' : 'Age in months at screening - 0 to 24 mos',
'RIDRETH1' : 'Race/Hispanic origin',
'RIDRETH3' : 'Race/Hispanic origin w/ NH Asian',
'RIDEXMON' : 'Six-month time period',
'RIDEXAGM' : 'Age in months at exam - 0 to 19 years',
'DMQMILIZ' : 'Served active duty in US Armed Forces',
'DMDBORN4' : 'Country of birth',
'DMDYRUSR' : 'Length of time in US',
'DMDEDUC2' : 'Education level - Adults 20+',
'DMDMARTZ' : 'Marital status',
'RIDEXPRG' : 'Pregnancy status at exam',
'DMDHHSIZ' : 'Total number of people in the Household',
'DMDHRGND' : 'HH ref persons gender',
'DMDHRAGZ' : 'HH ref persons age in years',
'DMDHREDZ': 'HH ref persons education level',
'DMDHRMAZ' : 'HH ref persons marital status',
'DMDHSEDZ' : 'HH ref persons spouses education level',
'WTINT2YR' : 'Full sample 2-year interview weight',
'WTMEC2YR' : 'Full sample 2-year MEC exam weight',
'SDMVSTRA' : 'Masked variance pseudo-stratum',
'SDMVPSU' : 'Masked variance pseudo-PSU',
'INDFMPIR' : 'Ratio of family income to poverty'
}

demo = demo.rename(columns=demo_cols)
demo.head()

Unnamed: 0,id,Data release cycle,Interview/Examination status,Gender,Age in years at screening,Age in months at screening - 0 to 24 mos,Race/Hispanic origin,Race/Hispanic origin w/ NH Asian,Six-month time period,Age in months at exam - 0 to 19 years,...,HH ref persons gender,HH ref persons age in years,HH ref persons education level,HH ref persons marital status,HH ref persons spouses education level,Full sample 2-year interview weight,Full sample 2-year MEC exam weight,Masked variance pseudo-stratum,Masked variance pseudo-PSU,Ratio of family income to poverty
0,130378.0,12.0,2.0,1.0,43.0,,5.0,6.0,2.0,,...,,,,,,50055.450807,54374.463898,173.0,2.0,5.0
1,130379.0,12.0,2.0,1.0,66.0,,3.0,3.0,2.0,,...,,,,,,29087.450605,34084.721548,173.0,2.0,5.0
2,130380.0,12.0,2.0,2.0,44.0,,2.0,2.0,1.0,,...,,,,,,80062.674301,81196.277992,174.0,1.0,1.41
3,130381.0,12.0,2.0,2.0,5.0,,5.0,7.0,1.0,71.0,...,2.0,2.0,2.0,3.0,,38807.268902,55698.607106,182.0,2.0,1.53
4,130382.0,12.0,2.0,1.0,2.0,,3.0,3.0,2.0,34.0,...,2.0,2.0,3.0,1.0,2.0,30607.519774,36434.146346,182.0,2.0,3.6


In [4]:
demo['id'].nunique()

11933

In [5]:
# Blood pressure

url_bp = 'https://wwwn.cdc.gov/Nchs/Data/Nhanes/Public/2021/DataFiles/BPXO_L.xpt'

file_name_bp = "BPXO_L.xpt"

response = requests.get(url_bp)
if response.status_code == 200:
    with open(file_name_bp, 'wb') as f:
        f.write(response.content)
    print("File downloaded successfully.")
else:
    print("Failed to download file.")

blood_pressure, meta_bp = pyreadstat.read_xport(file_name_bp)
blood_pressure.info()

File downloaded successfully.
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7801 entries, 0 to 7800
Data columns (total 12 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   SEQN      7801 non-null   float64
 1   BPAOARM   7801 non-null   object 
 2   BPAOCSZ   7611 non-null   float64
 3   BPXOSY1   7517 non-null   float64
 4   BPXODI1   7517 non-null   float64
 5   BPXOSY2   7505 non-null   float64
 6   BPXODI2   7505 non-null   float64
 7   BPXOSY3   7480 non-null   float64
 8   BPXODI3   7480 non-null   float64
 9   BPXOPLS1  7517 non-null   float64
 10  BPXOPLS2  7505 non-null   float64
 11  BPXOPLS3  7480 non-null   float64
dtypes: float64(11), object(1)
memory usage: 731.5+ KB


In [6]:
bp_cols = {
    'SEQN' : 'id',
    'BPAOARM' : 'Arm selected - oscillometric',
    'BPAOCSZ' : 'Coded cuff size - oscillometric',
    'BPXOSY1' : 'Systolic - 1st oscillometric reading',
    'BPXODI1' : 'Diastolic - 1st oscillometric reading',
    'BPXOSY2' : 'Systolic - 2nd oscillometric reading',
    'BPXODI2' : 'Diastolic - 2nd oscillometric reading',
    'BPXOSY3' : 'Systolic - 3rd oscillometric reading',
    'BPXODI3' : 'Diastolic - 3rd oscillometric reading',
    'BPXOPLS1' : 'Pulse - 1st oscillometric reading',
    'BPXOPLS2' : 'Pulse - 2nd oscillometric reading',
    'BPXOPLS3' : 'Pulse - 3rd oscillometric reading',
}

blood_pressure = blood_pressure.rename(columns=bp_cols)

blood_pressure.head()

Unnamed: 0,id,Arm selected - oscillometric,Coded cuff size - oscillometric,Systolic - 1st oscillometric reading,Diastolic - 1st oscillometric reading,Systolic - 2nd oscillometric reading,Diastolic - 2nd oscillometric reading,Systolic - 3rd oscillometric reading,Diastolic - 3rd oscillometric reading,Pulse - 1st oscillometric reading,Pulse - 2nd oscillometric reading,Pulse - 3rd oscillometric reading
0,130378.0,R,4.0,135.0,98.0,131.0,96.0,132.0,94.0,82.0,79.0,82.0
1,130379.0,R,4.0,121.0,84.0,117.0,76.0,113.0,76.0,72.0,71.0,73.0
2,130380.0,R,4.0,111.0,79.0,112.0,80.0,104.0,76.0,84.0,83.0,77.0
3,130386.0,R,4.0,110.0,72.0,120.0,74.0,115.0,75.0,59.0,64.0,64.0
4,130387.0,R,4.0,143.0,76.0,136.0,74.0,145.0,78.0,80.0,80.0,77.0


In [7]:
blood_pressure['avg_pulse'] = blood_pressure[['Pulse - 1st oscillometric reading', 'Pulse - 2nd oscillometric reading', 'Pulse - 3rd oscillometric reading']].mean(axis=1)

In [8]:
blood_pressure['avg_systolic'] = blood_pressure[['Systolic - 1st oscillometric reading','Systolic - 2nd oscillometric reading','Systolic - 3rd oscillometric reading']].mean(axis=1)

In [9]:
blood_pressure['avg_diastolic'] = blood_pressure[['Diastolic - 1st oscillometric reading','Diastolic - 2nd oscillometric reading','Diastolic - 3rd oscillometric reading']].mean(axis=1)

In [10]:
blood_pressure = blood_pressure[['id','avg_systolic','avg_diastolic','avg_pulse']]

In [11]:
blood_pressure.describe()

Unnamed: 0,id,avg_systolic,avg_diastolic,avg_pulse
count,7801.0,7518.0,7518.0,7518.0
mean,136349.487117,119.094418,72.20728,73.041789
std,3449.490842,18.151729,11.471177,12.564442
min,130378.0,70.0,34.0,34.0
25%,133335.0,106.333333,64.0,64.333333
50%,136382.0,116.333333,71.666667,72.166667
75%,139325.0,129.0,79.333333,81.0
max,142310.0,232.333333,139.0,151.0


In [12]:
# Total nutrient intake, Day 1

url_tn ='https://wwwn.cdc.gov/Nchs/Data/Nhanes/Public/2021/DataFiles/DR1TOT_L.xpt'

file_name_tn = "DR1IFF_L.xpt"

response = requests.get(url_tn)
if response.status_code == 200:
    with open(file_name_tn, 'wb') as f:
        f.write(response.content)
    print("File downloaded successfully.")
else:
    print("Failed to download file.")

total_nutrients_1, meta_tn = pyreadstat.read_xport(file_name_tn)

File downloaded successfully.


In [13]:
tn_cols = {
    "SEQN": "id",
    "WTDRD1": "Dietary day one sample weight",
    "WTDR2D": "Dietary two-day sample weight",
    "DR1DRSTZ": "Dietary recall status",
    "DR1EXMER": "Interviewer ID code",
    "DRABF": "Breast-fed infant (either day)",
    "DRDINT": "Number of days of intake",
    "DR1DBIH": "# of days b/w intake and HH interview",
    "DR1DAY": "Intake day of the week",
    "DR1LANG": "Language respondent used mostly",
    "DR1MRESP": "Main respondent for this interview",
    "DR1HELP": "Helped in responding for this interview",
    "DBQ095Z": "Type of table salt used",
    "DBD100": "How often add salt to food at table",
    "DRQSPREP": "Salt used in preparation?",
    "DR1STY": "Salt used at table yesterday?",
    "DR1SKY": "Type of salt used yesterday",
    "DRQSDIET": "On special diet?",
    "DRQSDT1": "Weight loss/Low calorie diet",
    "DRQSDT2": "Low fat/Low cholesterol diet",
    "DRQSDT3": "Low salt/Low sodium diet",
    "DRQSDT4": "Sugar free/Low sugar diet",
    "DRQSDT5": "Low fiber diet",
    "DRQSDT6": "High fiber diet",
    "DRQSDT7": "Diabetic diet",
    "DRQSDT8": "Weight gain/Muscle building diet",
    "DRQSDT9": "Low carbohydrate diet",
    "DRQSDT10": "High protein diet",
    "DRQSDT11": "Gluten-free/Celiac diet",
    "DRQSDT12": "Renal/Kidney diet",
    "DRQSDT91": "Other special diet",
    "DR1TNUMF": "Number of foods/beverages reported",
    "DR1TKCAL": "Energy (kcal)",
    "DR1TPROT": "Protein (gm)",
    "DR1TCARB": "Carbohydrate (gm)",
    "DR1TSUGR": "Total sugars (gm)",
    "DR1TFIBE": "Dietary fiber (gm)",
    "DR1TTFAT": "Total fat (gm)",
    "DR1TSFAT": "Total saturated fatty acids (gm)",
    "DR1TMFAT": "Total monounsaturated fatty acids (gm)",
    "DR1TPFAT": "Total polyunsaturated fatty acids (gm)",
    "DR1TCHOL": "Cholesterol (mg)",
    "DR1TATOC": "Vitamin E as alpha-tocopherol (mg)",
    "DR1TATOA": "Added alpha-tocopherol (Vitamin E) (mg)",
    "DR1TRET": "Retinol (mcg)",
    "DR1TVARA": "Vitamin A, RAE (mcg)",
    "DR1TACAR": "Alpha-carotene (mcg)",
    "DR1TBCAR": "Beta-carotene (mcg)",
    "DR1TCRYP": "Beta-cryptoxanthin (mcg)",
    "DR1TLYCO": "Lycopene (mcg)",
    "DR1TLZ": "Lutein + zeaxanthin (mcg)",
    "DR1TVB1": "Thiamin (Vitamin B1) (mg)",
    "DR1TVB2": "Riboflavin (Vitamin B2) (mg)",
    "DR1TNIAC": "Niacin (mg)",
    "DR1TVB6": "Vitamin B6 (mg)",
    "DR1TFOLA": "Total folate (mcg)",
    "DR1TFA": "Folic acid (mcg)",
    "DR1TFF": "Food folate (mcg)",
    "DR1TFDFE": "Folate, DFE (mcg)",
    "DR1TCHL": "Total choline (mg)",
    "DR1TVB12": "Vitamin B12 (mcg)",
    "DR1TB12A": "Added vitamin B12 (mcg)",
    "DR1TVC": "Vitamin C (mg)",
    "DR1TVD": "Vitamin D (D2 + D3) (mcg)",
    "DR1TVK": "Vitamin K (mcg)",
    "DR1TCALC": "Calcium (mg)",
    "DR1TPHOS": "Phosphorus (mg)",
    "DR1TMAGN": "Magnesium (mg)",
    "DR1TIRON": "Iron (mg)",
    "DR1TZINC": "Zinc (mg)",
    "DR1TCOPP": "Copper (mg)",
    "DR1TSODI": "Sodium (mg)",
    "DR1TPOTA": "Potassium (mg)",
    "DR1TSELE": "Selenium (mcg)",
    "DR1TCAFF": "Caffeine (mg)",
    "DR1TTHEO": "Theobromine (mg)",
    "DR1TALCO": "Alcohol (gm)",
    "DR1TMOIS": "Moisture (gm)",
    "DR1TS040": "SFA 4:0 (Butanoic) (gm)",
    "DR1TS060": "SFA 6:0 (Hexanoic) (gm)",
    "DR1TS080": "SFA 8:0 (Octanoic) (gm)",
    "DR1TS100": "SFA 10:0 (Decanoic) (gm)",
    "DR1TS120": "SFA 12:0 (Dodecanoic) (gm)",
    "DR1TS140": "SFA 14:0 (Tetradecanoic) (gm)",
    "DR1TS160": "SFA 16:0 (Hexadecanoic) (gm)",
    "DR1TS180": "SFA 18:0 (Octadecanoic) (gm)",
    "DR1TM161": "MFA 16:1 (Hexadecenoic) (gm)",
    "DR1TM181": "MFA 18:1 (Octadecenoic) (gm)",
    "DR1TM201": "MFA 20:1 (Eicosenoic) (gm)",
    "DR1TM221": "MFA 22:1 (Docosenoic) (gm)",
    "DR1TP182": "PFA 18:2 (Octadecadienoic) (gm)",
    "DR1TP183": "PFA 18:3 (Octadecatrienoic) (gm)",
    "DR1TP184": "PFA 18:4 (Octadecatetraenoic) (gm)",
    "DR1TP204": "PFA 20:4 (Eicosatetraenoic) (gm)",
    "DR1TP205": "PFA 20:5 (Eicosapentaenoic) (gm)",
    "DR1TP225": "PFA 22:5 (Docosapentaenoic) (gm)",
    "DR1TP226": "PFA 22:6 (Docosahexaenoic) (gm)",
    "DR1_300": "Compare food consumed yesterday to usual",
    "DR1_320Z": "Total plain water drank yesterday (gm)",
    "DR1_330Z": "Total tap water drank yesterday (gm)",
    "DR1BWATZ": "Total bottled water drank yesterday (gm)",
    "DR1TWSZ": "Tap water source",
    "DRD340": "Shellfish eaten during past 30 days",
    "DRD350A": "Clams eaten during past 30 days",
    "DRD350AQ": "# of times clams eaten in past 30 days",
    "DRD350B": "Crabs eaten during past 30 days",
    "DRD350BQ": "# of times crabs eaten in past 30 days",
    "DRD350C": "Crayfish eaten during past 30 days",
    "DRD350CQ": "# of times crayfish eaten past 30 days",
    "DRD350D": "Lobsters eaten during past 30 days",
    "DRD350DQ": "# of times lobsters eaten past 30 days",
    "DRD350E": "Mussels eaten during past 30 days",
    "DRD350EQ": "# of times mussels eaten in past 30 days",
    "DRD350F": "Oysters eaten during past 30 days",
    "DRD350FQ": "# of times oysters eaten in past 30 days",
    "DRD350G": "Scallops eaten during past 30 days",
    "DRD350GQ": "# of times scallops eaten past 30 days",
    "DRD350H": "Shrimp eaten during past 30 days",
    "DRD350HQ": "# of times shrimp eaten in past 30 days",
    "DRD350I": "Other shellfish eaten past 30 days",
    "DRD350IQ": "# of times other shellfish eaten",
    "DRD350J": "Other unknown shellfish eaten past 30 days",
    "DRD350JQ": "# of times other unknown shellfish eaten",
    "DRD350K": "Refused on shellfish eaten past 30 days",
    "DRD360": "Fish eaten during past 30 days",
    "DRD370A": "Breaded fish products eaten past 30 days",
    "DRD370AQ": "# of times breaded fish products eaten",
    "DRD370B": "Tuna eaten during past 30 days",
    "DRD370BQ": "# of times tuna eaten in past 30 days",
    "DRD370C": "Bass eaten during past 30 days",
    "DRD370CQ": "# of times bass eaten in past 30 days",
    "DRD370D": "Catfish eaten during past 30 days",
    "DRD370DQ": "# of times catfish eaten in past 30 days",
    "DRD370E": "Cod eaten during past 30 days",
    "DRD370EQ": "# of times cod eaten in past 30 days",
    "DRD370F": "Flatfish eaten during past 30 days",
    "DRD370FQ": "# of times flatfish eaten past 30 days",
    "DRD370G": "Haddock eaten during past 30 days",
    "DRD370GQ": "# of times haddock eaten in past 30 days",
    "DRD370H": "Mackerel eaten during past 30 days",
    "DRD370HQ": "# of times mackerel eaten past 30 days",
    "DRD370I": "Perch eaten during past 30 days",
    "DRD370IQ": "# of times perch eaten in past 30 days",
    "DRD370J": "Pike eaten during past 30 days",
    "DRD370JQ": "# of times pike eaten in past 30 days",
    "DRD370K": "Pollock eaten during past 30 days",
    "DRD370KQ": "# of times pollock eaten in past 30 days",
    "DRD370L": "Porgy eaten during past 30 days",
    "DRD370LQ": "# of times porgy eaten in past 30 days",
    "DRD370M": "Salmon eaten during past 30 days",
    "DRD370MQ": "# of times salmon eaten in past 30 days",
    "DRD370N": "Sardines eaten during past 30 days",
    "DRD370NQ": "# of times sardines eaten past 30 days",
    "DRD370O": "Sea bass eaten during past 30 days",
    "DRD370OQ": "# of times sea bass eaten past 30 days",
    "DRD370P": "Shark eaten during past 30 days",
    "DRD370PQ": "# of times shark eaten in past 30 days",
    "DRD370Q": "Swordfish eaten during past 30 days",
    "DRD370QQ": "# of times swordfish eaten past 30 days",
    "DRD370R": "Trout eaten during past 30 days",
    "DRD370RQ": "# of times trout eaten in past 30 days",
    "DRD370S": "Walleye eaten during past 30 days",
    "DRD370SQ": "# of times walleye eaten in past 30 days",
    "DRD370T": "Other fish eaten during past 30 days",
    "DRD370TQ": "# of times other fish eaten past 30 days",
    "DRD370U": "Other unknown fish eaten in past 30 days",
    "DRD370UQ": "# of times other unknown fish eaten",
    "DRD370V": "Refused on fish eaten past 30 days"
}

total_nutrients_1 = total_nutrients_1.rename(columns=tn_cols)

total_nutrients_1.head()

Unnamed: 0,id,Dietary day one sample weight,Dietary two-day sample weight,Dietary recall status,Interviewer ID code,Breast-fed infant (either day),Number of days of intake,# of days b/w intake and HH interview,Intake day of the week,Language respondent used mostly,...,# of times swordfish eaten past 30 days,Trout eaten during past 30 days,# of times trout eaten in past 30 days,Walleye eaten during past 30 days,# of times walleye eaten in past 30 days,Other fish eaten during past 30 days,# of times other fish eaten past 30 days,Other unknown fish eaten in past 30 days,# of times other unknown fish eaten,Refused on fish eaten past 30 days
0,130378.0,61366.555827,70554.222162,1.0,73.0,2.0,2.0,40.0,4.0,1.0,...,,2.0,,2.0,,2.0,,2.0,,2.0
1,130379.0,34638.05648,36505.468348,1.0,73.0,2.0,2.0,19.0,4.0,1.0,...,1.0,2.0,,2.0,,2.0,,2.0,,2.0
2,130380.0,84728.26156,103979.190677,1.0,73.0,2.0,2.0,16.0,4.0,1.0,...,,2.0,,2.0,,1.0,4.0,2.0,,2.0
3,130381.0,61737.133446,75009.220819,1.0,91.0,2.0,2.0,23.0,5.0,1.0,...,,2.0,,2.0,,2.0,,2.0,,2.0
4,130382.0,75846.746917,172361.851828,1.0,73.0,2.0,2.0,27.0,6.0,1.0,...,,2.0,,2.0,,2.0,,2.0,,2.0


In [14]:
total_nutrients_1.id.nunique()

8860

In [15]:
# Medical conditions questionaire data

url_mc ='https://wwwn.cdc.gov/Nchs/Data/Nhanes/Public/2021/DataFiles/MCQ_L.xpt'

file_name_mc = "MCQ_L.xpt"

response = requests.get(url_mc)
if response.status_code == 200:
    with open(file_name_mc, 'wb') as f:
        f.write(response.content)
    print("File downloaded successfully.")
else:
    print("Failed to download file.")

med_conditions, meta_mc = pyreadstat.read_xport(file_name_mc)

File downloaded successfully.


In [16]:
med_conditions_cols = {
    "SEQN": "id",
    "MCQ010": "Ever been told you have asthma",
    "MCQ035": "Still have asthma",
    "MCQ040": "Had asthma attack in past year",
    "MCQ050": "Emergency care visit for asthma/past yr",
    "AGQ030": "Did SP have episode of hay fever/past yr",
    "MCQ053": "Taking treatment for anemia/past 3 mos",
    "MCQ145": "CHECK ITEM",
    "MCQ149": "Menstrual periods started yet?",
    "MCQ157": "CHECK ITEM",
    "MCQ160A": "Doctor ever said you had arthritis",
    "MCQ195": "Which type of arthritis was it?",
    "MCQ160B": "Ever told had congestive heart failure",
    "MCQ160C": "Ever told you had coronary heart disease",
    "MCQ160D": "Ever told you had angina/angina pectoris",
    "MCQ160E": "Ever told you had heart attack",
    "MCQ160F": "Ever told you had a stroke",
    "MCQ160M": "Ever told you had thyroid problem",
    "MCQ170M": "Do you still have thyroid problem",
    "MCQ160P": "Ever told you had COPD, emphysema, ChB",
    "MCQ160L": "Ever told you had any liver condition",
    "MCQ170L": "Do you still have a liver condition",
    "MCQ500": "Ever told you had any liver condition",
    "MCQ510A": "Liver condition: Fatty liver",
    "MCQ510B": "Liver condition: Liver fibrosis",
    "MCQ510C": "Liver condition: Liver cirrhosis",
    "MCQ510D": "Liver condition: Viral hepatitis",
    "MCQ510E": "Liver condition: Autoimmune hepatitis",
    "MCQ510F": "Liver condition: Other liver disease",
    "MCQ515": "CHECK ITEM",
    "MCQ550": "Has DR ever said you have gallstones",
    "MCQ560": "Ever had gallbladder surgery?",
    "MCQ220": "Ever told you had cancer or malignancy",
    "MCQ230A": "1st cancer - what kind was it?",
    "MCQ230B": "2nd cancer - what kind was it?",
    "MCQ230C": "3rd cancer - what kind was it?",
    "MCQ230D": "More than 3 kinds of cancer",
    "OSQ230": "Any metal objects inside your body?"
}

med_conditions = med_conditions.rename(columns=med_conditions_cols)
med_conditions.head()

Unnamed: 0,id,Ever been told you have asthma,Still have asthma,Had asthma attack in past year,Emergency care visit for asthma/past yr,Did SP have episode of hay fever/past yr,Taking treatment for anemia/past 3 mos,Menstrual periods started yet?,Doctor ever said you had arthritis,Which type of arthritis was it?,...,Liver condition: Autoimmune hepatitis,Liver condition: Other liver disease,Has DR ever said you have gallstones,Ever had gallbladder surgery?,Ever told you had cancer or malignancy,1st cancer - what kind was it?,2nd cancer - what kind was it?,3rd cancer - what kind was it?,More than 3 kinds of cancer,Any metal objects inside your body?
0,130378.0,2.0,,,,2.0,2.0,,1.0,2.0,...,,,2.0,2.0,2.0,,,,,2.0
1,130379.0,2.0,,,,2.0,2.0,,2.0,,...,,,2.0,2.0,1.0,30.0,,,,1.0
2,130380.0,2.0,,,,2.0,2.0,,2.0,,...,,,2.0,2.0,2.0,,,,,1.0
3,130381.0,2.0,,,,1.0,2.0,,,,...,,,,,,,,,,
4,130382.0,2.0,,,,2.0,2.0,,,,...,,,,,,,,,,


In [17]:
# Income

url_inc ='https://wwwn.cdc.gov/Nchs/Data/Nhanes/Public/2021/DataFiles/INQ_L.xpt'

file_name_inc = "INQ_L.xpt"

response = requests.get(url_inc)
if response.status_code == 200:
    with open(file_name_inc, 'wb') as f:
        f.write(response.content)
    print("File downloaded successfully.")
else:
    print("Failed to download file.")

income, meta_inc = pyreadstat.read_xport(file_name_inc)

inc_cols = {
    "SEQN": "id",
    "INDFMMPI": "Family monthly poverty level index",
    "INDFMMPC": "Family monthly poverty level category",
    "INQ300": "Family has savings more than $20,000",
    "IND310": "Total savings/cash assets for the family"
}

income = income.rename(columns=inc_cols)

income.head()

File downloaded successfully.


Unnamed: 0,id,Family monthly poverty level index,Family monthly poverty level category,"Family has savings more than $20,000",Total savings/cash assets for the family
0,130378.0,5.0,3.0,1.0,
1,130379.0,5.0,3.0,1.0,
2,130380.0,1.4,2.0,2.0,1.0
3,130381.0,0.33,1.0,2.0,1.0
4,130382.0,4.32,3.0,1.0,


In [18]:
income.isna()

Unnamed: 0,id,Family monthly poverty level index,Family monthly poverty level category,"Family has savings more than $20,000",Total savings/cash assets for the family
0,False,False,False,False,True
1,False,False,False,False,True
2,False,False,False,False,False
3,False,False,False,False,False
4,False,False,False,False,True
...,...,...,...,...,...
11928,False,False,False,False,False
11929,False,True,False,False,True
11930,False,False,False,False,True
11931,False,False,False,False,False


In [19]:
income = income.drop(['Total savings/cash assets for the family'], axis=1)

In [20]:
# Diabetes
"""
SEQN - Respondent sequence number
DIQ010 - Doctor told you have diabetes
DID040 - Age when first told you had diabetes
DIQ159 - CHECK ITEM
DIQ160 - Ever told you have prediabetes
DIQ180 - Had blood tested past three years
DIQ050 - Taking insulin now
DID060 - How long taking insulin
DIQ060U - Unit of measure (month/year)
DIQ065 - CHECK ITEM
DIQ070 - Take diabetic pills to lower blood sugar

"""

url_diabetes = 'https://wwwn.cdc.gov/Nchs/Data/Nhanes/Public/2021/DataFiles/DIQ_L.xpt'

file_name_diabetes = 'DIQ_L.xpt'

response = requests.get(url_diabetes)
if response.status_code == 200:
    with open(file_name_diabetes, 'wb') as f:
        f.write(response.content)
    print("File downloaded successfully.")
else:
    print("Failed to download file.")

diabetes, meta_inc = pyreadstat.read_xport(file_name_diabetes)


File downloaded successfully.


In [21]:
diabetes_cols = {
    'SEQN':'id',
    'DIQ010':'pos_diabetes'
}



In [22]:
diabetes = diabetes.rename(columns=diabetes_cols)

In [23]:
diabetes = diabetes[['id','pos_diabetes']]

In [24]:
diabetes.head()

Unnamed: 0,id,pos_diabetes
0,130378.0,2.0
1,130379.0,2.0
2,130380.0,1.0
3,130381.0,2.0
4,130382.0,2.0


In [25]:
# Body measurements

url_body = 'https://wwwn.cdc.gov/Nchs/Data/Nhanes/Public/2021/DataFiles/BMX_L.xpt'

file_name_body = 'BMX_L.xpt'

response = requests.get(url_body)
if response.status_code == 200:
    with open(file_name_body, 'wb') as f:
        f.write(response.content)
    print("File downloaded successfully.")
else:
    print("Failed to download file.")

body, meta_body = pyreadstat.read_xport(file_name_body)

body_cols = {
    "SEQN": "id",
    "BMDSTATS": "Body Measures Component Status Code",
    "BMXWT": "Weight (kg)",
    "BMIWT": "Weight Comment",
    "BMXRECUM": "Recumbent Length (cm)",
    "BMIRECUM": "Recumbent Length Comment",
    "BMXHEAD": "Head Circumference (cm)",
    "BMIHEAD": "Head Circumference Comment",
    "BMXHT": "Standing Height (cm)",
    "BMIHT": "Standing Height Comment",
    "BMXBMI": "Body Mass Index (kg/m²)",
    "BMDBMIC": "BMI Category - Children/Youth",
    "BMXLEG": "Upper Leg Length (cm)",
    "BMILEG": "Upper Leg Length Comment",
    "BMXARML": "Upper Arm Length (cm)",
    "BMIARML": "Upper Arm Length Comment",
    "BMXARMC": "Arm Circumference (cm)",
    "BMIARMC": "Arm Circumference Comment",
    "BMXWAIST": "Waist Circumference (cm)",
    "BMIWAIST": "Waist Circumference Comment",
    "BMXHIP": "Hip Circumference (cm)",
    "BMIHIP": "Hip Circumference Comment"
}

body = body.rename(columns=body_cols)

body.head()

File downloaded successfully.


Unnamed: 0,id,Body Measures Component Status Code,Weight (kg),Weight Comment,Recumbent Length (cm),Recumbent Length Comment,Head Circumference (cm),Head Circumference Comment,Standing Height (cm),Standing Height Comment,...,Upper Leg Length (cm),Upper Leg Length Comment,Upper Arm Length (cm),Upper Arm Length Comment,Arm Circumference (cm),Arm Circumference Comment,Waist Circumference (cm),Waist Circumference Comment,Hip Circumference (cm),Hip Circumference Comment
0,130378.0,1.0,86.9,,,,,,179.5,,...,42.8,,42.0,,35.7,,98.3,,102.9,
1,130379.0,1.0,101.8,,,,,,174.2,,...,38.5,,38.7,,33.7,,114.7,,112.4,
2,130380.0,1.0,69.4,,,,,,152.9,,...,38.5,,35.5,,36.3,,93.5,,98.0,
3,130381.0,1.0,34.3,,,,,,120.1,,...,,,25.4,,23.4,,70.4,,,
4,130382.0,3.0,13.6,,,1.0,,,,1.0,...,,,,1.0,,1.0,,1.0,,


In [33]:
# Insulin

url_insulin = 'https://wwwn.cdc.gov/Nchs/Data/Nhanes/Public/2021/DataFiles/INS_L.xpt'

file_name_insulin = 'INS_L.xpt'

response = requests.get(url_insulin)
if response.status_code == 200:
    with open(file_name_insulin, 'wb') as f:
        f.write(response.content)
    print("File downloaded successfully.")
else:
    print("Failed to download file.")

insulin, meta_insulin = pyreadstat.read_xport(file_name_insulin)

insulin_cols = {
    'SEQN':'id',
    'LBXIN': 'Insulin (uU/mL)',
    'LBDINSI':'Insulin (pmol/L)'
}

insulin = insulin.rename(columns=insulin_cols)

insulin = insulin[['id','Insulin (uU/mL)','Insulin (pmol/L)']]

insulin.head()



File downloaded successfully.


Unnamed: 0,id,Insulin (uU/mL),Insulin (pmol/L)
0,130378.0,15.53,93.18
1,130379.0,19.91,119.46
2,130380.0,16.33,97.98
3,130386.0,11.38,68.28
4,130394.0,7.2,43.2


In [34]:
total_data = pd.merge(blood_pressure, med_conditions, on='id', how='left')

total_data.shape

(7801, 38)

In [35]:
total_data = pd.merge(total_data, total_nutrients_1, on='id', how='left')
total_data.shape

(7801, 205)

In [36]:
total_data = pd.merge(total_data, demo, on='id', how='left')
total_data.shape

(7801, 231)

In [37]:
total_data = pd.merge(total_data, income, on='id', how='left')

total_data.shape

(7801, 234)

In [38]:
total_data = pd.merge(total_data, diabetes, on='id', how='left')

total_data.shape

(7801, 235)

In [39]:
total_data = pd.merge(total_data, body, on='id', how='left')

total_data.shape

(7801, 256)

In [40]:
total_data = pd.merge(total_data, insulin, on='id', how='left')

total_data.shape

(7801, 258)

In [41]:
all_cols = total_data.columns.tolist()

all_cols

['id',
 'avg_systolic',
 'avg_diastolic',
 'avg_pulse',
 'Ever been told you have asthma',
 'Still have asthma',
 'Had asthma attack in past year',
 'Emergency care visit for asthma/past yr',
 'Did SP have episode of hay fever/past yr',
 'Taking treatment for anemia/past 3 mos',
 'Menstrual periods started yet?',
 'Doctor ever said you had arthritis',
 'Which type of arthritis was it?',
 'Ever told had congestive heart failure',
 'Ever told you had coronary heart disease',
 'Ever told you had angina/angina pectoris',
 'Ever told you had heart attack',
 'Ever told you had a stroke',
 'Ever told you had thyroid problem',
 'Do you still have thyroid problem',
 'Ever told you had COPD, emphysema, ChB',
 'Ever told you had any liver condition',
 'Do you still have a liver condition',
 'Ever told you had any liver condition',
 'Liver condition: Fatty liver',
 'Liver condition: Liver fibrosis',
 'Liver condition: Liver cirrhosis',
 'Liver condition: Viral hepatitis',
 'Liver condition: Autoim

In [62]:
sel_cols = ['id',
#  'avg_systolic',
#  'avg_diastolic',
#  'avg_pulse',
 'Energy (kcal)',
 'Protein (gm)',
 'Carbohydrate (gm)',
 'Total sugars (gm)',
 'Dietary fiber (gm)',
 'Total fat (gm)',
 'Total saturated fatty acids (gm)',
 'Total monounsaturated fatty acids (gm)',
 'Total polyunsaturated fatty acids (gm)',
 'Cholesterol (mg)',
 'Vitamin E as alpha-tocopherol (mg)',
 'Added alpha-tocopherol (Vitamin E) (mg)',
 'Retinol (mcg)',
 'Vitamin A, RAE (mcg)',
 'Alpha-carotene (mcg)',
 'Beta-carotene (mcg)',
 'Beta-cryptoxanthin (mcg)',
 'Lycopene (mcg)',
 'Lutein + zeaxanthin (mcg)',
 'Thiamin (Vitamin B1) (mg)',
 'Riboflavin (Vitamin B2) (mg)',
 'Niacin (mg)',
 'Vitamin B6 (mg)',
 'Total folate (mcg)',
 'Folic acid (mcg)',
 'Food folate (mcg)',
 'Folate, DFE (mcg)',
 'Total choline (mg)',
 'Vitamin B12 (mcg)',
 'Added vitamin B12 (mcg)',
 'Vitamin C (mg)',
 'Vitamin D (D2 + D3) (mcg)',
 'Vitamin K (mcg)',
 'Calcium (mg)',
 'Phosphorus (mg)',
 'Magnesium (mg)',
 'Iron (mg)',
 'Zinc (mg)',
 'Copper (mg)',
 'Sodium (mg)',
 'Potassium (mg)',
 'Selenium (mcg)',
 'Caffeine (mg)',
 'Theobromine (mg)',
 'Alcohol (gm)',
 'Moisture (gm)',
 'Total plain water drank yesterday (gm)',
#  'Gender',
#  'Age in years at screening',
#  'Education level - Adults 20+',
#  'Ratio of family income to poverty',
#  'Family monthly poverty level index',
#  'Family has savings more than $20,000',
 'pos_diabetes',
 'Insulin (uU/mL)',
 'Insulin (pmol/L)']
#  'Body Mass Index (kg/m²)']

len(sel_cols)

51

In [43]:
# sel_cols = ['id',
#  'avg_systolic',
#  'avg_diastolic',
#  'avg_pulse',
#  'Ever told you had cancer or malignancy',
#  'Number of foods/beverages reported',
#  'Energy (kcal)',
#  'Protein (gm)',
#  'Carbohydrate (gm)',
#  'Total sugars (gm)',
#  'Dietary fiber (gm)',
#  'Total fat (gm)',
#  'Total saturated fatty acids (gm)',
#  'Total monounsaturated fatty acids (gm)',
#  'Total polyunsaturated fatty acids (gm)',
#  'Cholesterol (mg)',
#  'Vitamin E as alpha-tocopherol (mg)',
#  'Added alpha-tocopherol (Vitamin E) (mg)',
#  'Retinol (mcg)',
#  'Vitamin A, RAE (mcg)',
#  'Alpha-carotene (mcg)',
#  'Beta-carotene (mcg)',
#  'Beta-cryptoxanthin (mcg)',
#  'Lycopene (mcg)',
#  'Lutein + zeaxanthin (mcg)',
#  'Thiamin (Vitamin B1) (mg)',
#  'Riboflavin (Vitamin B2) (mg)',
#  'Niacin (mg)',
#  'Vitamin B6 (mg)',
#  'Total folate (mcg)',
#  'Folic acid (mcg)',
#  'Food folate (mcg)',
#  'Folate, DFE (mcg)',
#  'Total choline (mg)',
#  'Vitamin B12 (mcg)',
#  'Added vitamin B12 (mcg)',
#  'Vitamin C (mg)',
#  'Vitamin D (D2 + D3) (mcg)',
#  'Vitamin K (mcg)',
#  'Calcium (mg)',
#  'Phosphorus (mg)',
#  'Magnesium (mg)',
#  'Iron (mg)',
#  'Zinc (mg)',
#  'Copper (mg)',
#  'Sodium (mg)',
#  'Potassium (mg)',
#  'Selenium (mcg)',
#  'Caffeine (mg)',
#  'Theobromine (mg)',
#  'Alcohol (gm)',
#  'Moisture (gm)',
#  'Compare food consumed yesterday to usual',
#  'Total plain water drank yesterday (gm)',
#  'Gender',
#  'Age in years at screening',
#  'Education level - Adults 20+',
#  'Marital status', 
#  'Ratio of family income to poverty',
#  'Family monthly poverty level index',
#  'Family monthly poverty level category',
#  'Family has savings more than $20,000',
#  'pos_diabetes',
#  'Body Mass Index (kg/m²)']

In [63]:
sel_data = total_data[sel_cols]

sel_data.head()

Unnamed: 0,id,Energy (kcal),Protein (gm),Carbohydrate (gm),Total sugars (gm),Dietary fiber (gm),Total fat (gm),Total saturated fatty acids (gm),Total monounsaturated fatty acids (gm),Total polyunsaturated fatty acids (gm),...,Potassium (mg),Selenium (mcg),Caffeine (mg),Theobromine (mg),Alcohol (gm),Moisture (gm),Total plain water drank yesterday (gm),pos_diabetes,Insulin (uU/mL),Insulin (pmol/L)
0,130378.0,1740.0,80.46,169.66,43.71,10.1,55.07,21.0,17.04,11.401,...,1917.0,93.5,242.0,38.0,34.2,3306.98,1080.0,2.0,15.53,93.18
1,130379.0,2741.0,86.45,314.86,113.58,29.6,67.18,14.544,19.386,25.622,...,5470.0,89.5,10.0,41.0,80.3,4115.06,0.0,2.0,19.91,119.46
2,130380.0,1995.0,69.86,281.67,114.66,21.1,65.38,26.936,18.638,11.687,...,2116.0,98.4,51.0,0.0,0.0,4103.91,2535.0,1.0,16.33,97.98
3,130386.0,2422.0,132.41,207.53,37.89,25.1,116.54,33.935,46.022,28.569,...,2680.0,161.4,192.0,0.0,0.0,3113.61,1920.0,2.0,11.38,68.28
4,130387.0,3849.0,68.93,376.54,227.99,22.0,234.39,49.849,63.889,104.428,...,3800.0,63.1,100.0,0.0,0.0,2881.37,434.7,2.0,,


In [64]:
sel_data.shape

(7801, 51)

In [65]:
sel_data.isna().any()

id                                         False
Energy (kcal)                               True
Protein (gm)                                True
Carbohydrate (gm)                           True
Total sugars (gm)                           True
Dietary fiber (gm)                          True
Total fat (gm)                              True
Total saturated fatty acids (gm)            True
Total monounsaturated fatty acids (gm)      True
Total polyunsaturated fatty acids (gm)      True
Cholesterol (mg)                            True
Vitamin E as alpha-tocopherol (mg)          True
Added alpha-tocopherol (Vitamin E) (mg)     True
Retinol (mcg)                               True
Vitamin A, RAE (mcg)                        True
Alpha-carotene (mcg)                        True
Beta-carotene (mcg)                         True
Beta-cryptoxanthin (mcg)                    True
Lycopene (mcg)                              True
Lutein + zeaxanthin (mcg)                   True
Thiamin (Vitamin B1)

In [47]:
sel_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7801 entries, 0 to 7800
Data columns (total 49 columns):
 #   Column                                   Non-Null Count  Dtype  
---  ------                                   --------------  -----  
 0   id                                       7801 non-null   float64
 1   Energy (kcal)                            5971 non-null   float64
 2   Protein (gm)                             5971 non-null   float64
 3   Carbohydrate (gm)                        5971 non-null   float64
 4   Total sugars (gm)                        5971 non-null   float64
 5   Dietary fiber (gm)                       5971 non-null   float64
 6   Total fat (gm)                           5971 non-null   float64
 7   Total saturated fatty acids (gm)         5971 non-null   float64
 8   Total monounsaturated fatty acids (gm)   5971 non-null   float64
 9   Total polyunsaturated fatty acids (gm)   5971 non-null   float64
 10  Cholesterol (mg)                         5971 no

In [67]:
sel_data = sel_data.dropna()
sel_data.shape

(2773, 51)

In [68]:
sel_data.pos_diabetes.value_counts()

pos_diabetes
2.0    2336
1.0     350
3.0      87
Name: count, dtype: int64

In [69]:
y = sel_data['pos_diabetes'].astype(int)
X = sel_data.drop(['pos_diabetes'], axis=1)


In [70]:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42, train_size=0.8 )


In [71]:
model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)

y_pred = model.predict(X_test) 
accuracy_score(y_test, y_pred)

0.7603603603603604

In [72]:
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)

y_pred = model.predict(X_test)
accuracy_score(y_test, y_pred)

0.8486486486486486

In [54]:
hyperparams = {
    'n_estimators': [50, 100, 150],
    'criterion': ['gini', 'entropy'], 
    'max_depth': [5, 15, 45, None],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 3, 9]
}

# grid = GridSearchCV(model, hyperparams, scoring='accuracy', cv=5)

# grid.fit(X_train, y_train)

# print(f"Best hyperparameters: {grid.best_params_}")

In [73]:
opt_model = RandomForestClassifier(criterion='gini', max_depth=45, min_samples_leaf=1, min_samples_split=10, n_estimators=50, random_state=42)
opt_model.fit(X_train, y_train)

y_pred_opt = opt_model.predict(X_test)

accuracy_score(y_test, y_pred)

# for k, v in {'Accuracy':accuracy_score, 'Precision':precision_score, 'Recall':recall_score, 'F1':f1_score}.items():
#     print(f'{k}:{v(y_test, y_pred_opt)}')

0.8486486486486486

In [56]:
scaler = MinMaxScaler()

X_train_scaled = scaler.fit_transform(X_train)

X_test_scaled = scaler.transform(X_test)


In [57]:
lasso_model = Lasso(alpha=0.02, max_iter=7000)
lasso_model.fit(X_train_scaled, y_train)

In [58]:
from sklearn.linear_model import LassoCV

# Use cross-validation to find the best alpha
lasso = LassoCV(cv=5).fit(X_train_scaled, y_train)

print(f"Optimal alpha: {lasso.alpha_}")  # Check the best regularization parameter


  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descen

Optimal alpha: 9.32619853677089e-05


  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descent_gram(
  model = cd_fast.enet_coordinate_descen

In [59]:
y_pred = lasso_model.predict(X_test_scaled)

print(f"Mean squared error: {mean_squared_error(y_test, y_pred)}")
print(f"Coefficient of determination: {r2_score(y_test, y_pred)}")

Mean squared error: 0.15365837220228074
Coefficient of determination: -0.000647542315341143


In [60]:
lasso_model.coef_


array([ 0.,  0.,  0.,  0.,  0.,  0., -0., -0., -0.,  0., -0.,  0.,  0.,
        0.,  0.,  0.,  0., -0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0., -0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -0.,
        0., -0.,  0.,  0., -0.,  0.,  0.,  0.,  0.])

In [None]:
nonzero_indices = np.nonzero(lasso_model.coef_)

cols = []
for idx in nonzero_indices[0]:
    cols.append(X_train.columns.tolist()[idx])

print(cols)

NameError: name 'numpy' is not defined

In [None]:
model = LinearRegression()
model.fit(X_train_scaled, y_train)

y_pred = model.predict(X_test_scaled)

print(f"Mean squared error: {mean_squared_error(y_test, y_pred)}")
print(f"Coefficient of determination: {r2_score(y_test, y_pred)}")

Mean squared error: 0.15115248987272845
Coefficient of determination: 0.015671158439193644
