# <span><h1 style = "font-family: montserrat; font-size: 50px; font-style: normal; letter-spcaing: 3px; background-color: #f1faee; color :#1d3557; border-radius: 10px 10px; text-align:center"> **Model training** <br> </span> <span style = "font-family: montserrat; font-size: 35px"> for exploring microbiome data </h1> <span>

In [4]:
%reset
import pandas as pd

otu_table = pd.read_csv("/data/namlhs/omics-data-learners/data/metsim"
                        "/01_raw/clinical_data/formatted/OTUS.txt", 
                        delim_whitespace=True,
                        index_col=0)

tax_table = pd.read_csv("/data/namlhs/omics-data-learners/data/metsim"
                        "/01_raw/clinical_data/formatted/TAXTABLE.txt", 
                        delim_whitespace=True,
                        index_col=0)

df = pd.read_csv('/data/namlhs/omics-data-learners/data/metsim/01_raw/clinical_data/formatted/FINAL_MICROBIOME_DATASET.csv', 
                 index_col=0)

  otu_table = pd.read_csv("/data/namlhs/omics-data-learners/data/metsim"
  tax_table = pd.read_csv("/data/namlhs/omics-data-learners/data/metsim"


In [6]:
print(otu_table.head())

             ASV1    ASV2    ASV3    ASV4    ASV5   ASV6  ASV7  ASV8   ASV9  \
Sample_ID                                                                     
MET_0001   279118  288094    1175       0   15601  11118  1119   238   1346   
MET_0002   195256  103634  344333   29106   27040   4575   488   328      0   
MET_0003    12473   28944   61544  196210  226003  67630   355   379  44977   
MET_0005    54922   18382    6742    8003   39117  47958   299   304   9609   
MET_0006    72576   76530   54979    4992   27167   9310   232   166  33013   

           ASV10  ...  ASV23910  ASV23911  ASV23912  ASV23913  ASV23914  \
Sample_ID         ...                                                     
MET_0001     524  ...         0         0         0         0         0   
MET_0002   60996  ...         0         0         0         0         0   
MET_0003    8108  ...         0         0         0         0         0   
MET_0005   13679  ...         0         0         0         0         0

In [5]:
#relative abundance
otu_rel_table = (otu_table.T/otu_table.sum(axis=1)).T
otu_rel_table.sum(axis=1)

# pick first 50 ASVs only
otu_fil = otu_rel_table.iloc[: , :51]

otu_fil.index = otu_fil.index.str.replace('_', '.')

In [None]:
#check the data sparsity
sparse = df.isnull().sum()/len(df)
display(sparse)

sparse_filtered = sparse[sparse < 0.2]

# Display the filtered Series
print(sparse_filtered)

#keep only column in sparse_filtered
df_filtered = df.loc[:, sparse_filtered.index]
df_filtered = df_filtered.set_index('SampleID')
display(df_filtered)

In [None]:
alt_df = df_filtered[['dm', 'METSIM_ID', 'Time_Point',
                      'Age', 'DMType', 'WHR',
                      'fmass', 'diastbp', 'systbp',
                      'BMI', 'Freq_veg', 'Freq_fruit',
                      'Freq_leanfish', 'Freq_fattyfish', 'Freq_shellfish',
                      'Freq_strongwine', 'Freq_blend',
                      'Freq_wine', 'Freq_alclt3', 'Freq_alclt6', 'Freq_alcge6', 'Freq_liqueur',
                      'Milk', 'Milk_quantity', 'Dairy_other',
                      'Spread_sat', 'Spread_no', 'Spread_marg',
                      'Cookfat_sat', 'Cookfat_no', 'Cookfat_marg', 'Cookfat_oils',
                      'Redmeat_gwk',
                      'Cheese_freq', 'Cheese_g', 'Cheese_gvko', 'Cheese_other',
                      'Cereal_24_serv_wholegrain', 'Cereal_24_serv_wheat',
                      'Cereal_24_serv_pastry']].copy()

#merge first 50 ASVs relative abundance
match_df = pd.merge(alt_df, otu_fil, left_index=True, right_index=True)
meta_df = match_df.drop(columns=['dm', 'METSIM_ID', 'Time_Point'])
df_cor = meta_df.corr(method='kendall')
df_pairs = df_cor.unstack()

# print(df_pairs)
sorted_pairs = df_pairs.sort_values(kind='quicksort')
remove_pairs = sorted_pairs[(abs(sorted_pairs) >= 0.5) & (sorted_pairs != 1)]

display(remove_pairs)

#Check the NaN values
print(meta_df.isnull().sum())

# consider remove 'Spread_marg', 'Milk_quantity', 'Cheese_g'

chosen_df = meta_df.drop(columns = ['Spread_marg', 'Milk_quantity', 'Cheese_g'])
alt_data = chosen_df.values

In [None]:
from sklearn.impute import KNNImputer

imputer = KNNImputer(n_neighbors=2, weights="distance", metric='nan_euclidean')

array_imputed = imputer.fit_transform(alt_data)
#print(alt_df.columns)
df_imputed = pd.DataFrame(array_imputed, columns=chosen_df.columns)

#Check the NaN values
print(df_imputed.isnull().sum())

df_imputed['DMType'].loc[(df_imputed['DMType'] > 0)] = 1
display(df_imputed)
df_imputed['DMType'].value_counts()
df_imputed.dtypes

Logistic regression

In [None]:
# split X and y into training and testing sets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.utils import class_weight

X = df_imputed.drop(columns ='DMType')

y = df_imputed.DMType
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=777)

In [9]:
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Scale
scaler = StandardScaler()
scaler.fit(X_train)

X_scale = scaler.transform(X)

X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)

logreg = LogisticRegression(random_state=777, 
                            max_iter= 100000, 
                            class_weight= 'balanced',
                            penalty="elasticnet", 
                            solver="saga",
                            C=1000,
                            l1_ratio=0.15,
                            )

lr_train = logreg.fit(X_train, y_train)
y_pred = lr_train.predict(X_test)

NameError: name 'X_train' is not defined

In [None]:
import plotnine as p9

# get the feature coefficients and feature names
feature_coef = lr_train.coef_[0]
feature_names = X.columns.tolist() # assuming your input data is a pandas DataFrame

coef_df = pd.DataFrame({'factors':feature_names, 'coef':feature_coef}).sort_values(ascending=False, by="coef")
display(coef_df)

#sorted factor
factor_list = coef_df['factors']

(
    p9.ggplot(coef_df, p9.aes(x = 'factors', y = 'coef')) +
    p9.geom_col() +
    p9.scale_x_discrete(limits = factor_list) +
    p9.coord_flip()
)

In [None]:
# import the metrics class
from sklearn import metrics

cnf_matrix = metrics.confusion_matrix(y_test, y_pred)
cnf_matrix

# Visualize
# import required modules
import numpy as np
import seaborn as sns

class_names=[0,1] # name  of classes
fig, ax = plt.subplots()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names)
plt.yticks(tick_marks, class_names)
# create heatmap
sns.heatmap(pd.DataFrame(cnf_matrix), annot=True, cmap="YlGnBu" ,fmt='g')
ax.xaxis.set_label_position("top")
plt.tight_layout()
plt.title('Confusion matrix', y=1.1)
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.rcParams['figure.facecolor'] = '#f2f2f2'

In [None]:
from sklearn.metrics import classification_report
target_names = ['without diabetes', 'with diabetes']
print(classification_report(y_test, y_pred, target_names=target_names))

In [None]:
y_pred_proba = logreg.predict_proba(X_test)[::,1]
fpr, tpr, _ = metrics.roc_curve(y_test,  y_pred_proba)
auc = metrics.roc_auc_score(y_test, y_pred_proba)
plt.plot(fpr,tpr,label="data 1, auc="+str(auc))
plt.legend(loc=4)
plt.show()

Classify longitudinal patients

Probabilities check

In [None]:
proba = logreg.predict_proba(X_scale)
proba_pos = proba[:, 1]

match_df.loc[:, 'Proba'] = proba_pos

In [None]:
match_df[(match_df['DMType'] == 0) & (match_df['Proba'] >= 0.5)].shape

Plot probabilities

In [None]:
match_df['DMType'] = match_df['DMType'].fillna(0)
match_df['DMType'] = match_df['DMType'].astype('category')

In [None]:
from plotly.tools import mpl_to_plotly

proba_plot = (p9.ggplot(data = match_df, 
                        mapping = p9.aes(x = 'Time_Point',
                                         y = 'Proba')) +
                  p9.geom_line(p9.aes(group = 'METSIM_ID'),
                                                      alpha = 0.3) +
                  p9.geom_point(p9.aes(color = 'DMType')) +
                  p9.scale_color_discrete(labels = ['No', 'Yes']) + 
                  p9.labs(color = 'Diabetes',
                          x = 'Time Point',
                          y = 'Diabetes Probability') +
                  p9.ylim(0,1) +
                  p9.theme(figure_size=(5, 5))
            )

proba_plot

In [None]:
cases_df = match_df[(match_df['dm'] == 1) | (match_df['DMType'] == 2)]
patient_id = cases_df['METSIM_ID'].unique()

patient_df = match_df[match_df['METSIM_ID'].isin(patient_id)]

patient_df.to_csv(r'/data/namlhs/visualization/t2d_probs.csv', 
                  columns=['METSIM_ID', 'Time_Point', 'Proba', 'DMType'],
                  sep = '\t')

In [None]:
diab_plot = (p9.ggplot(data = patient_df, 
                        mapping = p9.aes(x = 'Time_Point',
                                         y = 'Proba')) +
                  p9.geom_line(p9.aes(group = 'METSIM_ID'),
                                                      alpha = 0.3) +
                  p9.geom_point(p9.aes(color = 'DMType')) +
                  p9.scale_color_discrete(labels = ['No', 'Yes']) + 
                  p9.labs(color = 'Diabetes',
                          x = 'Time Point',
                          y = 'Diabetes Probability') +
                  p9.ylim(0,1) +
                  p9.theme(figure_size=(7, 7))
            )

diab_plot