# Can we predict anything about the antibiotic usage?

Lets start with '3 tobramycin_IV' because we have lots of pwCF with that!

In [39]:
import os
import sys

import re
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
import matplotlib.colors as mcolors
import matplotlib.dates as mdates
from matplotlib.colors import ListedColormap
import pandas as pd
import seaborn as sns
import json

from itertools import cycle

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder
from sklearn.inspection import permutation_importance

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics import mean_squared_error

from scipy.stats import linregress


# there is a FutureWarning in sklearn StandardScalar which is really annoying. This ignores it.
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

try:
  import google.colab
  IN_COLAB = True
  !pip install adjustText
  from google.colab import drive
  drive.mount('/content/drive')
  datadir = '/content/drive/MyDrive/Projects/CF/Adelaide/CF_Data_Analysis'
except ImportError:
  IN_COLAB = False
  datadir = '..'

from adjustText import adjust_text

import cf_analysis_lib


In [62]:
def random_forest_regression(X, y):
  """
  Run a regressor for continuous data and return the mean squared error and the feature importances
  """

  # Split the data into training and testing sets
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

  # Initialize and train a RandomForestRegressor model
  model = RandomForestRegressor(random_state=42, n_estimators = 1000) # You can adjust hyperparameters
  model.fit(X_train, y_train)

  # Make predictions on the test set
  y_pred = model.predict(X_test)

  # Evaluate the model
  mse = mean_squared_error(y_test, y_pred)

  # Feature importance
  feature_importances = pd.DataFrame(model.feature_importances_, index=X.columns, columns=['importance'])
  feature_importances_sorted = feature_importances.sort_values(by='importance', ascending=False)
  return mse, feature_importances_sorted

def random_forest_classifier(X, y):
  """
  Run a classifier for categorical data and return the mean squared error and the feature importances
  """

  # Split the data into training and testing sets
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

  # Initialize and train a RandomForestRegressor model
  model = RandomForestClassifier(random_state=42, n_estimators = 1000) # You can adjust hyperparameters
  model.fit(X_train, y_train)

  # Make predictions on the test set
  y_pred = model.predict(X_test)

  # Evaluate the model
  mse = mean_squared_error(y_test, y_pred)

  # Feature importance
  feature_importances = pd.DataFrame(model.feature_importances_, index=X.columns, columns=['importance'])
  feature_importances_sorted = feature_importances.sort_values(by='importance', ascending=False)
  return mse, feature_importances_sorted



def plot_feature_importance(ax, feature_importances_sorted, title):

  # Create dotted lines and circles for each feature
  for feature in feature_importances_sorted.index[::-1]:
      importance = feature_importances_sorted.loc[feature, 'importance']
      ax.plot([importance], [feature], linestyle='dotted', marker='o', markersize=5, c='blue')
      ax.plot([0, importance], [feature, feature], linestyle='dotted', marker='None', markersize=5, c='lightblue')

  ax.set_xlabel("Importance")
  ax.set_ylabel(f"Bacteria")
  ax.set_title(title)

def plot_feature_abundance(ax, feature_df, intcol, title):
    """
    Plot the top n important features.

    use something like this:
    top20 = list(feature_importances_sorted[:20].index)+[intcol]
    plot_feature_abundance(ax, merged_df[top20], intcol, f"Plot of normalised measures that are important to distinguish '{intcol}' usage")
    """
    
    # before we plot the data we scale the data to make the mean 0 and the variance 1.
    # you can compare the values before and after by looking at merged_df[top20].max() and  scaled_df.max()
    #scaler = StandardScaler()
    scaler = MinMaxScaler()
    scaled_df = pd.DataFrame(scaler.fit_transform(feature_df), columns=feature_df.columns)
    scaled_df[intcol] = feature_df[intcol].values
    
    melted_df = pd.melt(scaled_df, id_vars=[intcol], var_name='Feature', value_name='Value')

    sns.boxplot(data=melted_df, x='Feature', y='Value', hue=intcol, fill=False, legend=False, color='k', fliersize=0, ax=ax)
    sns.stripplot(data=melted_df, x='Feature', y='Value', hue=intcol, jitter=True, alpha=0.5, dodge=True, ax=ax)

    ax.set_title(title, fontsize=14)
    ax.set_xlabel('Features', fontsize=12)
    ax.set_ylabel('Normalised Abundance', fontsize=12)


In [7]:
sequence_type = "MGI"
datadir = '..'
#sslevel = 'level2_norm_ss.tsv.gz'
sslevel = 'subsystems_norm_ss.tsv.gz'
ss_df = cf_analysis_lib.read_subsystems(os.path.join(datadir, sequence_type, "FunctionalAnalysis", "subsystems", sslevel), sequence_type)
ss_df = ss_df.T
print(f"The subsystems df has shape: {ss_df.shape}")

taxa = "genus"
genus_otu = cf_analysis_lib.read_taxonomy(datadir, sequence_type, taxa)
genus_otu = genus_otu.T
print(f"The taxonomy df has shape: {genus_otu.shape}")
metadata = cf_analysis_lib.read_metadata(datadir, sequence_type)
print(f"The metadata df has shape: {metadata.shape}")

In [9]:
df = ss_df.merge(genus_otu, left_index=True, right_index=True, how='inner')
print(df.shape)
df.head(5)

In [10]:
# this approach uses an inner join to ensure that we have the same 'interesting column' (intcol) 
# data in the same order as the data frame
intcol = '3 tobramycin_IV'
merged_df = df.join(metadata[[intcol]])

# this data set models all bacteria
X = merged_df.drop(intcol, axis=1)
y = merged_df[intcol]

y = y.fillna(y.mean())

if metadata[intcol].dtype == 'object':
  mse, feature_importances_sorted = random_forest_classifier(X, y)
  met = 'classification'
else:
  mse, feature_importances_sorted = random_forest_regression(X, y)
  met = 'regression'


print(f"Mean Squared Error for all bacteria: {mse}")

fig, axes = plt.subplots(figsize=(10,6), nrows=1, ncols=1)

plot_feature_importance(axes, feature_importances_sorted[:20], f"Top 20 Subsystems that predict {intcol}")


plt.tight_layout()
plt.show()

In [80]:
n=5
topN = list(feature_importances_sorted[:n].index)+[intcol]

# before we plot the data we scale the data to make the mean 0 and the variance 1.
# you can compare the values before and after by looking at merged_df[top20].max() and  scaled_df.max()
#scaler = StandardScaler()
scaler = MinMaxScaler()
scaled_df = pd.DataFrame(scaler.fit_transform(merged_df[topN]), columns=topN)
scaled_df[intcol] = merged_df[intcol].values

melted_df = pd.melt(scaled_df, id_vars=[intcol], var_name='Feature', value_name='Value')

# Step 2: Plot the strip plot
plt.figure(figsize=(12, 6))
#sns.violinplot(data=melted_df, x='Feature', y='Value', hue=intcol, split=False, alpha=0.5)
sns.boxplot(data=melted_df, x='Feature', y='Value', hue=intcol, fill=False, legend=False, color='k', fliersize=0)
sns.stripplot(data=melted_df, x='Feature', y='Value', hue=intcol, jitter=True, alpha=0.5, dodge=True)
# Note: Swarm plot doesn't work because too many 0's. Maybe with unnormalised numbers but then you can't see them
#sns.swarmplot(data=melted_df, x='Feature', y='Value', hue=intcol, alpha=0.5, dodge=True)


# Step 3: Customize the plot
plt.title(f"Plot of normalised measures that are important to distinguish '{intcol}' usage", fontsize=14)
plt.xlabel('Features', fontsize=12)
plt.ylabel('Normalised Abundance', fontsize=12)
plt.xticks(rotation=90)
plt.legend(title=intcol, bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()

In [68]:
n = 5
topN = list(feature_importances_sorted[:n].index)+[intcol]
fig, axes = plt.subplots(figsize=(10,6), nrows=1, ncols=1)
title=f"Plot of top {n} normalised measures that are important to distinguish '{intcol}' usage"
plot_feature_abundance(axes, merged_df[topN], intcol, title)

plt.xticks(rotation=90)
plt.legend(title=intcol, bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()

In [73]:
antibiotics = ['1 Cephalexin_PO', '1 Flucloaxcillin_PO', '1 Itraconazole (Lozenoc)_PO', '1 Sulfamethoxazole_trimethoprim (Bactrim)_PO', '2 Amikacin_INH', '2 Amoxicillin & Potassium clavulanate (Aug Duo)_PO', '2 Amphotericin B (Ambisome)_INH', '2 Azithromycin_PO', '2 Ceftazidime_INH', '2 Ciprofloxacin_PO', '2 Clarithromycin_PO', '2 Clofazimine PO', '2 Colistin_IHN', '2 prednisolone_PO', '2 tobramycin_INH', '3 Azithromycin_IV', '3 Aztreonam_IV', '3 Cefopime_IV', '3 Ceftazidime_IV', '3 Imipenem', '3 Ivacaftor (Kalydeco)', '3 Meropenem_IV', '3 Methylpredinosolone_IV', '3 Omalizumab_SC', '3 piperacillin sodium, tazobactam sodium (Tazocin)_IV', '3 tobramycin_IV', '4 Amikacin_IV', '4 Cefoxitin_IV', '4 Colistin_IV']

In [75]:
len(antibiotics)

In [95]:
metadata[antibiotics].sum(axis=0)