# Worldwide data

Read the worldwide data, and then see if we can predict anything about them!

Note, cells 3 & 4 get a list of all the samples.

These lines in cell 4 are key:

```
# so far, PRJEB20836 has the best
sample_id = worldwide_samples[3]
print(f"Sample is {sample_id}", file=sys.stderr)
wwdf, wwmd = cf_analysis_lib.read_worldwide_data(sample_id)
wwdf
```

You just change the index of the `worldwide_sample_number` to reflect the one you are interested in, and restart the kernel and run all cells.

**_NOTE:_** This notebook is explicitly for PRJNA510441 because it has the SHAP waterfall plots at the end!

**_NOTE2:_** This notebook uses the MinION model

In [1]:
worldwide_sample_number = 15

In [2]:
import os
import sys
from socket import gethostname

hostname = gethostname()

if hostname.startswith('hpc-node'):
    IN_DEEPTHOUGHT = True
    sys.path.append('..')
else:
    IN_DEEPTHOUGHT = False
from cf_analysis_lib.load_libraries import *
import cf_analysis_lib


# Read our data and models

This is our normal models we use for predicting everything

In [3]:
sequence_type = "MinION"
datadir = '..'

sslevel='subsystems'
ss_normalisation='norm_ss'
taxa = "family"
# all_taxa: True = everything, False = bacteria
all_taxa = False 

our_df, our_metadata = cf_analysis_lib.read_the_data(sequence_type, datadir, all_taxa=all_taxa, sslevel=f"{sslevel}_{ss_normalisation}.tsv.gz", taxa=taxa)

#encoder_models = 'cluster_gbrfs_eukaryotes'
encoder_models = 'cluster_gbrfs_minion'
if not os.path.exists(os.path.join(encoder_models, 'clusters.json')):
    print("Please run the autoencoder code before trying to load the models.", file=sys.stderr)
    exit(1)

if not os.path.exists(os.path.join(encoder_models, 'pc_df.tsv')):
    print("Please create and save the PCA before trying to load the models.", file=sys.stderr)
    exit(1)

with open(os.path.join(encoder_models, 'clusters.json'), 'r') as file:
    data = json.load(file)
tmpjsondf = pd.DataFrame(list(data.items()), columns=['Cluster', 'Feature'])
cluster_assignments = tmpjsondf.explode('Feature').reset_index(drop=True)
cluster_assignments['Cluster'] = cluster_assignments['Cluster'].astype(int)

pc_df = pd.read_csv(os.path.join(encoder_models, 'pc_df.tsv'), sep="\t", index_col=0)
pc_df.shape

Using ../MinION/FunctionalAnalysis/subsystems/subsystems_norm_ss.tsv.gz for the subsystems


(61, 150)

# Read one of the data sets

Samples: 
['PRJNA1081394', 'PRJNA516442', 'PRJEB51171', 'PRJNA71831', 'PRJEB20836', 'PRJNA516870', 'PRJNA1126024', 'PRJEB14440', 'PRJEB32062', 'PRJEB54014', 'PRJNA1055940', 'PRJNA1091195', 'PRJNA1101448', 'PRJNA316056', 'PRJNA316588', 'PRJNA510441', 'PRJNA615628', 'PRJNA644285', 'PRJNA825831', 'PRJNA839435', 'PRJNA846291', 'PRJNA931830']

Initially, lets see what's there?

In [4]:
worldwide_samples = cf_analysis_lib.worldwide_samples()
print(f"There are {len(worldwide_samples)}")
print(worldwide_samples)

There are 24
['PRJNA1081394', 'PRJNA516442', 'PRJEB51171', 'PRJNA71831', 'PRJEB20836', 'PRJNA516870', 'PRJNA1126024', 'PRJEB14440', 'PRJEB32062', 'PRJEB54014', 'PRJNA1055940', 'PRJNA1091195', 'PRJNA1101448', 'PRJNA316056', 'PRJNA316588', 'PRJNA510441', 'PRJNA615628', 'PRJNA644285', 'PRJNA825831', 'PRJNA839435', 'PRJNA846291', 'PRJNA931830', 'not_analysed', 'papers']


In [5]:
# find one sample
worldwide_samples.index('PRJNA510441')

15

### Read the ww data

In [6]:
sample_id = worldwide_samples[worldwide_sample_number]
cf_analysis_lib.show_green(title="Analysis Report", message=f"Analysing: {sample_id}")
os.makedirs(os.path.join("img", "worldwide", sample_id, "img"), exist_ok=True)
print(f"Sample is {sample_id}", file=sys.stderr)

wwdf, wwmd = cf_analysis_lib.read_worldwide_data(sample_id, sslevel=sslevel, ss_normalisation=ss_normalisation,
                        taxonomy=taxa, all_taxa=all_taxa, raw_taxa=False, drop_amplicon=True,
                        drop_suspected_amplicon=True, verbose=True)

# now we add any columns from our_df that are not in wwdf. This make subsequent computations easier (and correct!)
missing = [c for c in our_df.columns if c not in wwdf.columns]
missing_df = pd.DataFrame(0, index=wwdf.index, columns=missing)
wwdf = pd.concat([wwdf, missing_df], axis=1).copy()

print(f"We added {len(missing)} columns to ww df that were missing", file=sys.stderr)
wwdf

Sample is PRJNA510441
Reading subsystems from ../WorldWideDataAnalysis/Atavide/PRJNA510441/subsystems/PRJNA510441_subsystems_norm_ss.tsv.gz
Before dropping suspects, shape is (615, 14) and after is (615, 14)
Read 14 samples and 615 subsystems
Reading taxonomy from ../WorldWideDataAnalysis/Atavide/PRJNA510441/taxonomy_summary/PRJNA510441_family.norm.tsv.gz
Read 14 samples and 0 family
Read 14 samples and 7 metadata columns
We added 433 columns to ww df that were missing


Unnamed: 0,"2,3-diacetamido-2,3-dideoxy-d-mannuronic acid",2-O-alpha-mannosyl-D-glycerate utilization,2-aminophenol Metabolism,2-ketoacid oxidoreductases disambiguation,2-oxoglutarate dehydrogenase,2-phosphoglycolate salvage,3-amino-5-hydroxybenzoic Acid Synthesis,5-methylaminomethyl-2-thiouridine,A Hypothetical Protein Related to Proline Metabolism,A new toxin - antitoxin system,...,Thermodesulfobiaceae,Petrotogaceae,Thermotogaceae,Unnamed: 15,Methylacidiphilaceae,Unnamed: 17,Akkermansiaceae,Verrucomicrobiaceae,Unnamed: 20,Unnamed: 21
SRR8334087,49.467681,1879.771864,445.209126,676.058302,123.669202,321.539924,354.518378,4996.235745,700.792142,0.0,...,0,0,0,0,0,0,0,0,0,0
SRR8334088,0.0,1556.179701,552.192797,363.945253,192.430823,363.945253,332.570662,3714.751544,811.55608,0.0,...,0,0,0,0,0,0,0,0,0,0
SRR8334089,77.604605,1034.728061,724.309642,120.718274,564.789066,659.639139,482.873095,2819.633965,379.400289,51.736403,...,0,0,0,0,0,0,0,0,0,0
SRR8334090,27.080198,1299.849486,649.924743,45.133663,344.821183,352.042569,333.989104,3709.987074,704.085138,162.481186,...,0,0,0,0,0,0,0,0,0,0
SRR8334091,0.0,5327.887049,5327.887049,0.0,0.0,0.0,0.0,0.0,1331.971762,0.0,...,0,0,0,0,0,0,0,0,0,0
SRR8334092,21.559847,2845.899764,344.957547,172.478774,141.576327,711.474941,463.536704,4311.969339,596.489092,0.0,...,0,0,0,0,0,0,0,0,0,0
SRR8334093,364.81134,0.0,625.390869,26.057953,656.660413,638.419846,191.091655,1511.361267,425.61323,0.0,...,0,0,0,0,0,0,0,0,0,0
SRR8334094,0.0,517.950212,517.950212,712.181542,258.975106,485.578324,453.206436,4532.064355,625.856506,0.0,...,0,0,0,0,0,0,0,0,0,0
SRR8334095,0.0,0.0,5035.488203,1258.872051,419.624017,629.436025,0.0,1258.872051,1258.872051,0.0,...,0,0,0,0,0,0,0,0,0,0
SRR8334096,0.0,0.0,0.0,1246.59166,415.530553,207.765277,207.765277,831.061107,1246.59166,0.0,...,0,0,0,0,0,0,0,0,0,0


In [7]:
# print run counts
run_counts = {"PRJEB14440": "- The entire data set is 5 runs, and 1,330,576,074 bp", "PRJEB20836": "- The entire data set is 1396 runs, and 1,007,979,034,098 bp", "PRJEB32062": "- The entire data set is 27 runs, and 113,773,728,053 bp", "PRJEB51171": "- The entire data set is 64 runs, and 29,874,922,519 bp", "PRJEB54014": "- The entire data set is 80 runs, and 447,488,092,200 bp", "PRJNA1055940": "- The entire data set is 61 runs, and 386,942,064,540 bp", "PRJNA1081394": "- The entire data set is 549 runs, and 510,156,291,748 bp", "PRJNA1091195": "- The entire data set is 44 runs, and 146,735,286,366 bp", "PRJNA1101448": "- The entire data set is 323 runs, and 89,493,823,284 bp", "PRJNA1126024": "- The entire data set is 2 runs, and 1,071,033,642 bp", "PRJNA316056": "- The entire data set is 12 runs, and 36,012,000,000 bp", "PRJNA316588": "- The entire data set is 18 runs, and 121,257,187,412 bp", "PRJNA510441": "- The entire data set is 14 runs, and 906,511,907 bp", "PRJNA516442": "- The entire data set is 93 runs, and 34,371,719,753 bp", "PRJNA516870": "- The entire data set is 79 runs, and 39,170,766,613 bp", "PRJNA615628": "- The entire data set is 71 runs, and 95,889,827,074 bp", "PRJNA644285": "- The entire data set is 12 runs, and 29,679,938,510 bp", "PRJNA71831": "- The entire data set is 38 runs, and 1,179,409,307 bp", "PRJNA825831": "- The entire data set is 117 runs, and 58,764,823,681 bp", "PRJNA839435": "- The entire data set is 12 runs, and 8,336,976,960 bp", "PRJNA846291": "- The entire data set is 98 runs, and 87,154,850,753 bp", "PRJNA931830": "- The entire data set is 260 runs, and 740,690,138,100 bp"}
cf_analysis_lib.show_green(title="Analysis Report", message=f"{run_counts[sample_id]} (including amplicon sequences)")

In [8]:
cf_analysis_lib.show_green(title="Analysis Report", message=f"- We analysed {len(wwdf.index)} metagenomic sequence runs.")

## Create the PCA

Now we need to create the PCA df using our original data, and use those axis to fit our new data. We need the cluster assignments and the raw data!

In [9]:
grouped = cluster_assignments.groupby("Cluster")

# pc_df = pd.DataFrame(index=our_df.index, columns=[f"Cluster {x+1}" for x in range(len(cluster_assignments.groupby("Cluster").size()))])
# ww_pc_df = pd.DataFrame(index=wwdf.index, columns=[f"Cluster {x+1}" for x in range(len(cluster_assignments.groupby("Cluster").size()))])
fit_pca = {}
pc_df_data = {}
ww_pc_data = {}
for cluster_id, group in grouped:
    df_clust = our_df[cluster_assignments.loc[cluster_assignments["Cluster"] == cluster_id, "Feature"]]
    ww_clust = wwdf[cluster_assignments.loc[cluster_assignments["Cluster"] == cluster_id, "Feature"]]
    pca = PCA(n_components=1)
    pca.fit(df_clust)
    fit_pca[cluster_id] = pca
    pc_df_data[f"Cluster {cluster_id}"] = pca.transform(df_clust)[:,0]
    ww_pc_data[f"Cluster {cluster_id}"] = pca.transform(ww_clust)[:,0]

pc_df = pd.DataFrame(pc_df_data, index=our_df.index)
ww_pc_df = pd.DataFrame(ww_pc_data, index=wwdf.index)

KeyError: "['Haloferacaceae'] not in index"

In [None]:
ww_pc_df

In [None]:
print(f"Our data: min: {min(pc_df['Cluster 56']):.2f} max: {max(pc_df['Cluster 56']):.2f}")
print(f"Sample {sample_id}: min: {min(ww_pc_df['Cluster 56']):.2f} max: {max(ww_pc_df['Cluster 56']):.2f}")

# Scale the data frames. 

We use a robust scaler since all the PCAs have different units. This normalises all the data.

In [None]:
# scaler = StandardScaler() # - not good because too many values at 0
scaler = RobustScaler() 
# scaler = MinMaxScaler() 

# train the scaler on pc_df and then apply the _same scalar_ to ww_pc_df

trained_scaler = scaler.fit(pc_df)

pc_scaled = pd.DataFrame(
    trained_scaler.transform(pc_df),
    index=pc_df.index,
    columns=pc_df.columns
)

ww_pc_scaled = pd.DataFrame(
    trained_scaler.transform(ww_pc_df),
    index=ww_pc_df.index,
    columns=ww_pc_df.columns
)

In [None]:
print(f"Our data: min: {min(pc_scaled['Cluster 56']):.2f} max: {max(pc_scaled['Cluster 56']):.2f}")
print(f"Sample {sample_id}: min: {min(ww_pc_scaled['Cluster 56']):.2f} max: {max(ww_pc_scaled['Cluster 56']):.2f}")

# Build the GBRF with our data first

This is from our separate model analyses

In [None]:
#intcol = 'Pseudomonas Culture' # note these two columns have the same information, but 'CS_Pseudomonas aeruginosa' is a category, while 'Pseudomonas culture' is a float!
intcol = 'CS_Pseudomonas aeruginosa'

intcol_title = replace_index.sub('', intcol).replace('_', ' ')
merged_df = pc_scaled.join(our_metadata[[intcol]]).dropna(subset=intcol)

categorical_data, custom_labels = cf_analysis_lib.create_custom_labels(our_metadata, intcol, merged_df)

X = merged_df.drop(intcol, axis=1)
y = merged_df[intcol]

if categorical_data:
  model, mse, feature_importances_sorted = cf_analysis_lib.gb_classifier_model(X, y, n_estimators=1000, n_iter_no_change=20)
  met = 'classifier'
else:
  model, mse, feature_importances_sorted = cf_analysis_lib.gb_regressor_model(X, y)
  met = 'regressor'

print(f"We used {model.n_estimators_} estimators for the random forest {met}", file=sys.stderr)

## Apply our model to the new data

We use the model from the GBRF to predict the outcomes based on the clustered data.


In [None]:
predictions = model.predict(ww_pc_scaled)
perc = predictions.sum()/len(predictions) * 100
cf_analysis_lib.show_green(title="Analysis Report", message=f"- We predicted {predictions.sum()} samples out of {len(predictions)} ({perc:0.1f}%) have _Pseudomonas aeruginosa_")
predictions

In [None]:
# how confident are we in our predictions?
def confidence_label(prob, threshold_low=0.6, threshold_high=0.8):
    if prob < threshold_low:
        return "Low"
    elif prob < threshold_high:
        return "Medium"
    else:
        return "High"

In [None]:
probs = model.predict_proba(ww_pc_scaled)
print(probs[:10])

In [None]:
pred_class = np.argmax(probs, axis=1)
pred_conf = probs[np.arange(len(probs)), pred_class]
label_map = {0: "Negative", 1: "Positive"}

# Confidence labels
confidence = [confidence_label(p) for p in pred_conf]

output = ["Sample | Pseudomonas Prediction | Confidence | Certainty\n --- | --- | --- | ---"]

for i, (cls, conf, label) in enumerate(zip(pred_class, pred_conf, confidence)):
    output.append(f"{ww_pc_scaled.index[i]} | {label_map[cls]} | {conf:.2f} | {label}")
cf_analysis_lib.show_green(title="Analysis Report", message="\n".join(output))

In [None]:
interesting_cluster = 56
fig, axes = plt.subplots(figsize=(24, 12), nrows=1, ncols=2, sharex=False, sharey=True)

df_clust = our_df[cluster_assignments.loc[cluster_assignments["Cluster"] == interesting_cluster, "Feature"]]
df_clust_corr = df_clust.corr()
g = sns.heatmap(df_clust_corr, ax=axes[0], annot=False, cmap='coolwarm')
g.set_title(f'Cluster {interesting_cluster} Correlation Matrix for our data')

valid_features = [f for f in cluster_assignments.loc[cluster_assignments["Cluster"] == interesting_cluster, "Feature"] if f in wwdf.columns]
df_clust = wwdf[valid_features]
df_clust_corr = df_clust.corr()

sns.heatmap(df_clust_corr, ax=axes[1], annot=False, cmap='coolwarm')
plt.title(f'Cluster {interesting_cluster} Correlation Matrix for sample {sample_id}')

plt.show()

# Plot a t-SNE of our data and the worldwide data

In [None]:
X_combined = np.vstack([pc_df, ww_pc_df])
sample_labels = (["Adelaide"] * len(our_df)) + ([sample_id] * len(wwdf))
group_labels = list(our_metadata[intcol]) + list(predictions)
sample_ids = list(pc_df.index) + list(ww_pc_df.index)

tsne = TSNE(n_components=2, perplexity=30, random_state=42)
tsne_result = tsne.fit_transform(X_combined)

df_tsne = pd.DataFrame(tsne_result, columns=["TSNE1", "TSNE2"])
df_tsne["SampleType"] = sample_labels  # Or use group_labels
df_tsne["Pseudomonas Predictions"] = group_labels  # Or use group_labels
df_tsne['Pseudomonas Predictions'] = df_tsne['Pseudomonas Predictions'].map({0: 'Negative', 1: 'Positive'})
df_tsne['Sample IDs'] = sample_ids

# Basic comparison: New vs Original
g = sns.scatterplot(data=df_tsne, x="TSNE1", y="TSNE2", palette='coolwarm', style="SampleType", hue="Pseudomonas Predictions")
g.set_xlabel("t-SNE1")
g.set_ylabel("t-SNE2")
#plt.title(f"t-SNE of Adelaide and {sample_id} Samples")
plt.title("")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))

imgout = os.path.join("img", "worldwide", sample_id, "img", f"{sample_id}_Pseudomonas_tSNE.png")
plt.savefig(imgout)
cf_analysis_lib.show_green(title="Analysis Report", message=f"## t-SNE\n![Comparison of Adelaide and {sample_id} samples by t-SNE](img/{sample_id}_Pseudomonas_tSNE.png 'Fig. t-SNE of all the analysed sequence data coloured by whether Pseudomonas is predicted')")

plt.show()

# Create a PCA of all the data

In [None]:
pca_model = PCA(n_components=2)
pca_combined_result = pca_model.fit_transform(X_combined)

# Put into DataFrame for plotting
pca_combined_df = pd.DataFrame(pca_combined_result, columns=["PC1", "PC2"])
pca_combined_df["Source"] = (["Adelaide"] * len(our_df)) + ([sample_id] * len(wwdf))
pca_combined_df["Predictions"] = list(our_metadata[intcol]) + list(predictions)
pca_combined_df["Predictions"] = pca_combined_df["Predictions"].map({0: 'Negative', 1: 'Positive'})

pca_combined_df['IDS'] = list(our_metadata.index) + list(wwdf.index)

custom_labels = {0.0: 'No', 1.0: 'Yes'}

# Get the loadings (contributions of each original variable to the PCs)
loadings = pd.DataFrame(pca_model.components_.T, 
                        index=pc_df.columns, 
                        columns=["PC1", "PC2"])

# Scale arrows for better visibility
arrow_scale = 450000

plt.figure(figsize=(8, 6))
ax = sns.scatterplot(data=pca_combined_df, x="PC1", palette='coolwarm',
                y="PC2", style="Source", hue='Predictions', alpha=0.8, s=60)
plt.title(f"PCA Comparison of Adelaide and {sample_id} Samples")
plt.xlabel(f"PC1 ({pca_model.explained_variance_ratio_[0]:.2%}%)")
plt.ylabel(f"PC2 ({pca_model.explained_variance_ratio_[1]:.2%}%)")
# Plot top N most important loadings (e.g., based on magnitude of PC1^2 + PC2^2)
top_features = loadings.pow(2).sum(axis=1).sort_values(ascending=False).head(10).index

for feature in top_features:
    x = loadings.loc[feature, "PC1"]
    y = loadings.loc[feature, "PC2"]
    plt.arrow(0, 0, x * arrow_scale, y * arrow_scale,
              color='black', alpha=0.7, head_width=0.02)
    plt.text(x * arrow_scale * 1.1, y * arrow_scale * 1.1, feature,
             color='black', ha='center', va='center', fontsize=10)

plt.legend(loc='center left', ncol=2, bbox_to_anchor=(0.3, -0.3))

plt.tight_layout()
plt.show()


# Plot a PCA of just the "Pseudomonas" Cluster

In [None]:
pca = PCA(n_components=2)

intcol = 'Pseudomonas Culture'
valid_features = [f for f in cluster_assignments.loc[cluster_assignments["Cluster"] == interesting_cluster, "Feature"] if f in wwdf.columns]
df_clust = wwdf[valid_features]

pca_result = pca.fit_transform(df_clust)
pca_df = pd.DataFrame(data=pca_result, index=wwdf.index, columns=['PC1', 'PC2'])

# Get loadings
loadings = pca.components_.T * np.sqrt(pca.explained_variance_)
loadings_df = pd.DataFrame(loadings, index=df_clust.columns, columns=['PC1', 'PC2'])

# Create a DataFrame for top loadings
top_loadings_df = loadings_df.loc[loadings_df['PC1'].abs().sort_values(ascending=False).index]
top_loadings_df.head()

explained_variance = pca.explained_variance_ratio_ * 100
pc1_variance = explained_variance[0]
pc2_variance = explained_variance[1]


# don't forget to change the legend
intcol_neg = 0
colours = np.where(predictions == intcol_neg, 'blue', 'red')

# Plot the PCA results
plt.figure(figsize=(10, 8))
plt.scatter(pca_df['PC1'], pca_df['PC2'], c=colours, alpha=0.2)
#plt.title(f"PCA of the 'Pseudomonas' cluster (#56) in {sample_id}")
plt.title("")
plt.xlabel(f'Principal Component 1 ({pc1_variance:.3f}%)')
plt.ylabel(f'Principal Component 2 ({pc2_variance:.3f}%)')

# add the loadings ... we only plot maxloadings here
maxloadings = 5
if len(loadings) < maxloadings:
    maxloadings = len(loadings)

plotscaler = 2
texts = []
colour_cycle = cycle(mcolors.TABLEAU_COLORS)
found_pseudomonas = False
for i in range(maxloadings):
    c = next(colour_cycle)
    xpos = top_loadings_df.iloc[i, 0]*plotscaler
    ypos = top_loadings_df.iloc[i, 1]*plotscaler
    plt.arrow(0, 0, xpos, ypos, 
              color=c, alpha=0.5, width=0.05)
    texts.append(plt.text(xpos, ypos, top_loadings_df.index[i], color=c))

adjust_text(texts)


# Add a legend
blue_patch = plt.Line2D([0], [0], marker='o', color='w', label='Predicted Pseudomonas positive', 
                         markerfacecolor='red', alpha=0.2, markersize=10)
red_patch = plt.Line2D([0], [0], marker='o', color='w', label='Predictied Pseudomonas negative', 
                        markerfacecolor='blue', alpha=0.2, markersize=10)


plt.legend(handles=[blue_patch, red_patch])

imgout = os.path.join("img", "worldwide", sample_id, "img", f"{sample_id}_Pseudomonas_PCA.png")
plt.savefig(imgout)
cf_analysis_lib.show_green(title="Analysis Report", message=f"## PCA\n![This cluster of features are most strongly associated with the presence of Pseudomonas](img/{sample_id}_Pseudomonas_PCA.png 'Fig. PCA of the cluster of features most strongly associated with Pseudomonas colonization in {sample_id}')")
# Show the plot
plt.show()

In [None]:
intcol = 'Pseudomonas Culture'
valid_features = [f for f in cluster_assignments.loc[cluster_assignments["Cluster"] == interesting_cluster, "Feature"] if f in wwdf.columns]
df_clust = pd.concat([our_df[valid_features], wwdf[valid_features]])
df_clust.shape

In [None]:
pca = PCA(n_components=2)

intcol = 'Pseudomonas Culture'
valid_features = [f for f in cluster_assignments.loc[cluster_assignments["Cluster"] == interesting_cluster, "Feature"] if f in wwdf.columns]
df_clust = pd.concat([our_df[valid_features], wwdf[valid_features]])


pca_result = pca.fit_transform(df_clust)
pca_df = pd.DataFrame(data=pca_result, index=df_clust.index, columns=['PC1', 'PC2'])

pca_df["Source"] = (["Adelaide"] * len(our_df)) + ([sample_id] * len(wwdf))
pca_df["Predictions"] = list(our_metadata[intcol]) + list(predictions)
pca_df["Predictions"] = pca_df["Predictions"].map({0: 'Negative', 1: 'Positive'})
pca_df

In [None]:
pca = PCA(n_components=2)

interesting_cluster = 70
valid_features = [f for f in cluster_assignments.loc[cluster_assignments["Cluster"] == interesting_cluster, "Feature"] if f in wwdf.columns]
df_clust = pd.concat([our_df[valid_features], wwdf[valid_features]])

pca_result = pca.fit_transform(df_clust)
pca_df = pd.DataFrame(data=pca_result, index=df_clust.index, columns=['PC1', 'PC2'])

pca_df["Source"] = (["Adelaide"] * len(our_df)) + ([sample_id] * len(wwdf))
pca_df["Pseudomonas Predictions"] = list(our_metadata[intcol]) + list(predictions)
pca_df["Pseudomonas Predictions"] = pca_df["Pseudomonas Predictions"].map({0: 'Negative', 1: 'Positive'})

# Get loadings
loadings = pca.components_.T * np.sqrt(pca.explained_variance_)
loadings_df = pd.DataFrame(loadings, index=df_clust.columns, columns=['PC1', 'PC2'])

# Create a DataFrame for top loadings
top_loadings_df = loadings_df.loc[loadings_df['PC1'].abs().sort_values(ascending=False).index]

explained_variance = pca.explained_variance_ratio_ 
pc1_variance = explained_variance[0]
pc2_variance = explained_variance[1]

# Scale arrows for better visibility
arrow_scale = 5

plt.figure(figsize=(12,12))
ax = sns.scatterplot(data=pca_df, x="PC1", palette='coolwarm',
                y="PC2", style="Source", hue='Pseudomonas Predictions', alpha=0.8, s=60)
plt.title(f"PCA Comparison of Adelaide and {sample_id} Samples cluster {interesting_cluster}")
plt.xlabel(f"PC1 ({pc1_variance:.2%})")
plt.ylabel(f"PC2 ({pc2_variance:.2%})")
# Plot top N most important loadings (e.g., based on magnitude of PC1^2 + PC2^2)

for feature in top_loadings_df.index[:5]:
    x = loadings_df.loc[feature, 'PC1']
    y = loadings_df.loc[feature, 'PC2']
    plt.arrow(0, 0, x * arrow_scale, y * arrow_scale,
              color='black', alpha=0.7, head_width=0.02)
    plt.text(x * arrow_scale * 1.1, y * arrow_scale * 1.1, feature,
             color='black', ha='center', va='center', fontsize=10)

plt.legend(loc='center left', ncol=2, bbox_to_anchor=(0.3, -0.3))

plt.tight_layout()
plt.show()

In [None]:
wwdf['Pseudomonadaceae']

## Just compare cluster 56

In [None]:
plot_df = pd.DataFrame({
    "Value": pd.concat([pc_scaled["Cluster 56"], ww_pc_scaled["Cluster 56"]], axis=0),
    "Dataset": (["Adelaide"] * len(pc_scaled)) + ([sample_id] * len(ww_pc_scaled))
})


plt.figure(figsize=(8, 6))

# Choose one:
# Boxplot + jitter
#sns.boxplot(data=plot_df, x="Dataset", y="Value", whis=1.5, palette="pastel")
#sns.stripplot(data=plot_df, x="Dataset", y="Value", jitter=True, color="k", alpha=0.6)

# Or Violin plot + jitter
sns.violinplot(data=plot_df, x="Dataset", palette='BrBG', y="Value", inner=None)
sns.stripplot(data=plot_df, x="Dataset", y="Value", jitter=True, color="k", alpha=0.6)

plt.title("Comparison of Cluster 56 Values")
plt.ylabel("Standardized PCA Value")
plt.xlabel("")
plt.tight_layout()
plt.show()

In [None]:
for c in wwdf.columns:
    if "pseudom" in c.lower():
        print(c)

In [None]:
wwdf.loc['SRR8334087', 'Pseudomonadaceae']

In [None]:
ww_pc_df.loc['SRR8334087', 'Cluster 56']

## The problem.

Sample SRR8334087 is predicted as not having _Pseudomonas_:

SRR8334087 | Negative | 0.69 | Medium

However, 

```
ww_pc_df.loc['SRR8334087', 'Cluster 56'] = 8644.405853116295
and
wwdf.loc['SRR8334087', 'Pseudomonadaceae'] = 31044.084329222947
```

so both of these are strong signals!

In [None]:
import shap

# 1. Create a SHAP explainer for your trained model
explainer = shap.TreeExplainer(model)

# 2. Pick the sample you want to explain
run_id = "SRR8334087"
sample_index = ww_pc_df.index.get_loc(run_id)

# 3. Get shap values (no extra class axis)
shap_values = explainer.shap_values(ww_pc_scaled)  # shape: (n_samples, n_features)

# 4. Base value is a single scalar
base_value = explainer.expected_value

# 5. Extract the row for our sample
sample_shap_values = shap_values[sample_index]
sample_features = ww_pc_scaled.iloc[sample_index]

# 6. Waterfall plot
shap.waterfall_plot(
    shap.Explanation(
        values=sample_shap_values,
        base_values=base_value,
        data=sample_features,
        feature_names=ww_pc_scaled.columns
    )
)

In [None]:
fig, axes = plt.subplots(figsize=(24, 24), nrows=2, ncols=2)
# SRR8334087 is wrong
# SRR8334093 is correct
sample_ids = ['SRR8334087', 'SRR8334093', '1128691_20171206_S', '1128691_20171218_S']

explainer = shap.TreeExplainer(model)
shap_values = [explainer.shap_values(ww_pc_scaled), explainer.shap_values(pc_scaled)]
base_value = explainer.expected_value

for i, (ax, run_id) in enumerate(zip(axes.ravel(), sample_ids)):
    print(i)
    plt.sca(ax)  # set current axes
    if i < 2:
        sample_index = ww_pc_scaled.index.get_loc(run_id)
        shapdata = ww_pc_scaled.iloc[sample_index]
        shapnames = ww_pc_scaled.columns
        shapvals = shap_values[0][sample_index]
    else:
        sample_index = pc_scaled.index.get_loc(run_id)
        shapdata = pc_scaled.iloc[sample_index]
        shapnames = pc_scaled.columns
        shapvals = shap_values[1][sample_index]
        
    shap.waterfall_plot(
        shap.Explanation(
            values=shapvals,
            base_values=base_value,
            data=shapdata,
            feature_names=shapnames
        ),
        max_display=10,
        show=False
    )
    ax.set_title(run_id)

plt.tight_layout()
plt.show()

In [None]:
cluster_assignments[cluster_assignments['Cluster'] == 70]

# Read the raw data, not the normalised data.

In [None]:
ssnorm = 'raw'
rawt = True
wwraw, wwrawmd = cf_analysis_lib.read_worldwide_data(sample_id, sslevel=sslevel, ss_normalisation=ssnorm,
                        taxonomy=taxa, all_taxa=all_taxa, raw_taxa=rawt, drop_amplicon=True,
                        drop_suspected_amplicon=True, verbose=True)

wwraw

In [None]:
for c in wwraw:
    if 'pseudo' in c.lower():
        print(c)

In [None]:
wwraw['Pseudomonadaceae']

In [None]:
ssnorm = 'raw'
rawt = True
taxa="species"
wwraw, wwrawmd = cf_analysis_lib.read_worldwide_data(sample_id, sslevel=sslevel, ss_normalisation=ssnorm,
                        taxonomy=taxa, all_taxa=True, raw_taxa=rawt, drop_amplicon=True,
                        drop_suspected_amplicon=True, verbose=True)

In [None]:
for c in wwraw:
    if 'pseudo' in c.lower():
        print(c)

In [None]:
wwraw['Pseudomonas aeruginosa']