## Extracting kmers from sequence

In [1]:
#Not all are necessarily needed for the commands I included
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sbn
from Bio import SeqIO
from scipy.spatial import distance
from scipy.cluster.hierarchy import linkage, dendrogram, cut_tree
from scipy import stats
import glob
import os
import csv

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
from sklearn.feature_selection import VarianceThreshold, RFECV
from sklearn.model_selection import train_test_split, GroupKFold, GroupShuffleSplit
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import Lasso, LassoCV, LinearRegression
from sklearn.svm import SVR
from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import KFold

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import pickle

import shap

In [11]:
#Function for extracting kmers from sequence and create matrix

import pandas as pd
import numpy as np
from collections import defaultdict

#from ch_util.sequences import extract_proteins_from_seq_record #our own package, but not needed here

_complement_table = str.maketrans({"A": "T", "T": "A", "C": "G", "G": "C"})


def extract_kmer_features(genomes, k, mode="nucl", dtype="uint8"):
    """
    Convert genomes into a feature-table based on the presence or absence of different k-mers.

    Parameters:
    -----------
    genomes: dict
        Dictionary of genomes. Keys should be genome (strain) names, values are lists of contigs, with each contig being a SeqIO.SeqRecord
        Alternatively, if 'mode' is set to 'prot', the values should be SeqRecords of protein sequences.
    k: int
        The size of the k-mers to find
    mode: str {'nucl', 'prot', 'cds'}
        Determines how kmers will be extracted.
        If 'nucl', nucleotide kmers from the provided contigs will be counted.
        If 'prot', amino acid kmers from provided protein SeqRecords will be counted.
        If 'cds', protein sequences will be extracted from the provided contigs, and amino acid kmers will be
        counted in these. The nucleotide SeqRecords must already be annotated with CDS features
    dtype: str
        The dtype used to represent the counts. Must be a valid numpy dtype. The default uint8 allows counts up to
        255. If you expect some kmers with higher counts, select uint16 instead.
    """
    if mode not in ("nucl", "prot", "cds"):
        raise ValueError(f"Invalid mode: '{mode}'. Must be 'nucl', 'prot', or 'cds'")

    n_genomes = len(genomes)
    genome_order = list(genomes)

    if mode == "cds":
        genomes = {name: _proteins_from_contigs(contigs) for name, contigs in genomes.items()}

    # Default function for kmer counting dictionaries
    def new_array():
        return np.zeros(n_genomes, dtype=dtype)

    if mode == "nucl":
        kmer_dict = DnaKmerDict(new_array)  # Use special reverse complementarity-aware dictionary
    else:
        kmer_dict = defaultdict(new_array)  # Use regular default-dictionary

    # Loop over genomes and count kmers in each
    for i, chcc in enumerate(genome_order):
        contigs = genomes[chcc]
        for kmer in _generate_kmers(contigs, k, mode):
            kmer_dict[kmer][i] += 1

    if mode == "nucl":
        # Remove redundant kmers leaving only those that are alphabetically smaller than their reverse complement
        kmer_dict.clear_redundant_entries()

    return pd.DataFrame(kmer_dict, index=genome_order)


def _generate_kmers(seq_records, k, mode):
    """
    Extracts k-mers from an iterable of SeqRecord objects

    Parameters:
    -----------
    seq_records: iterable
        An iterable, e.g. list, containing SeqRecord objects
    k: int
        The size of the k-mers to find
    mode: str
        Kmer counting model. Refer to 'extract_kmer_features' documentation for description.

    Returns:
    --------
    kmers: generator
        Generator of all the kmers found in the input SeqRecords
    """
    if mode == "nucl":
        for record in seq_records:
            # record = str(record.seq)  # Convert to string for faster substring retrieval
            for i in range(len(record) - k + 1):
                kmer = record[i: i + k]
                if "N" not in kmer:
                    yield kmer
    else:
        for record in seq_records:
            record = str(record.seq)  # Convert to string for faster substring retrieval
            for i in range(len(record) - k + 1):
                # TODO: Exclude ambiguous amino acids? Do they frequently occur?
                yield record[i: i + k]


def _reverse_complement_kmer(kmer):
    return kmer.translate(_complement_table)[::-1]


def _proteins_from_contigs(contigs):
    """
    Helper function to convert a list of nucleotide SeqRecords (contigs) into a corresponding list of protein SeqRecords
    """
    proteins = []
    for contig in contigs:
        proteins.extend(extract_proteins_from_seq_record(contig))
    return proteins


class DnaKmerDict(dict):
    """
    Special dictionary for holding DNA kmers. Kmers that are equivalent in terms of reverse complementarity will share
    an entry in the dictionary.

    A default-function must be provided, which is used to initialize entries, similar to a collections.defaultdict.
    Values in the kmer-dictionary should not be set explicitly, but should only be instantiated through the
    default-function. Values should be mutable and can be modified in place. Numpy arrays are recommended for counting
    the kmers.

    Parameters:
    default_factory: callable
        The function that is used to make a new value in a dictionary, when a non-existing entry is retrieved
    """
    def __init__(self, default_factory):
        self.default_factory = default_factory

    @staticmethod
    def _reverse_complement(kmer):
        """
        Generate the reverse complement of a nucleotide string
        """
        return kmer.translate(_complement_table)[::-1]

    def __missing__(self, key):
        """
        Implements the special behaviour when accessing a key that does not exist.
        If the key is alphabetically larger than its reverse complement, the key will map to the reverse complement's
        value
        """
        rev_comp = self._reverse_complement(key)
        if key <= rev_comp:
            val = self[key] = self.default_factory()
        else:
            val = self[key] = self[rev_comp]
        return val

    def clear_redundant_entries(self):
        """
        Remove all redundant keys, i.e. all the kmers that are alphabetically larger than their reverse complement.
        """
        for key in list(self.keys()):
            if key > self._reverse_complement(key):
                del self[key]

## Hierarchical clustering and dendrogram

In [None]:
#distances
link = linkage(matrix, metric="euclidean", method="ward") #

In [None]:
#Plot dendrogram with phenotype data as heatmap
fig = plt.figure(figsize=[5, 20])
ax1 = fig.add_axes([0, 0, 0.8, 1])
dendrogram(link, orientation="left", labels=plot_df.index, leaf_font_size=1)
leaf_order = [t.get_text() for t in ax1.get_yticklabels()]

ax2 = fig.add_axes([0.82, 0, 0.1, 1])

ax3 = fig.add_axes([0.01, 0.89, 0.05, 0.1])

heatmap_data = phenotype_values.to_frame()

heatmap_data = heatmap_data.reindex(leaf_order[::-1])
sbn.heatmap(heatmap_data, yticklabels=True, ax=ax2, cbar_ax=ax3)
ax2.yaxis.tick_right()
ax2.set_yticklabels(ax2.get_yticklabels(), rotation=0, fontsize=2)


plt.savefig("dendrogram.pdf", bbox_inches="tight")

None

## Machine learning procedure

In [None]:
#Splitting
x_train, x_test, y_train, y_test =  train_test_split(X, Y)

#Define and fit model
model = GradientBoostingRegressor()
model = xgboost.XGBRegressor()
model.fit(x_train, y_train)

#More is possible here in terms of parameters, but you will have to consult the help and other sources.

In [None]:
#Compare predictions and measurements, by plot and correlation coefficient

plt.figure(figsize=[10, 5])

plt.subplot(1, 2, 1)
plt.plot([y_train.min(), y_train.max()], [y_train.min(), y_train.max()], "--")
plt.plot(model.predict(x_train), y_train, ".")


plt.subplot(1, 2, 2)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], "--")
Z = model.predict(x_test)
plt.plot(Z, y_test, ".")
print(stats.spearmanr(Z, y_test))

In [None]:
#Quantity of relevant features
#A good idea to reduce the number of features to avoid overfitting
#When predictions do not improve by adding features, this would correspond to a good number of features to include.
train_performances = []
for i in range(70):
    model_train = xgboost.XGBRegressor()
    X_train_subset = x_train[rankings.sort_values()[:i+1].index]# [:, ranking <= i+1]
    model_train.fit(X_train_subset, y_train)
    z_train = model_train.predict(X_train_subset)
    train_performances.append(stats.pearsonr(y_train, z_train)[0])
    
plt.plot(range(1, 71), train_performances)

In [None]:
#Recursive feature elimination

x_train, x_test, y_train, y_test =  train_test_split(X, Y)

#initial model
model_1 = xgboost.XGBRegressor()
model_1.fit(x_train, y_train)

#identify features with very low feature importances
#These can be removed and increase speed
reduced_features_1 = pd.Series(model_1.feature_importances_, index=x_train.columns)
reduced_features_1 = reduced_features_1[reduced_features_1 >= 0.001].index
#NB, modelx.feature_importances_ is where the basic feature importances can be found. 
#They can be useful for identifying features that give high expression, and also the shap values below.

#Recursive feature elimination
rfe = RFECV(xgboost.XGBRegressor())

#Train with reduced features
rfe.fit(x_train[reduced_features_1], y_train)

#Get ranking of features
therankings = pd.Series(rfe.ranking_, index=reduced_features_1)
#Get features with best ranking
# [rankings[rankings <= 30]
reduced_features_2 = therankings[therankings == 1].index

#Train model on reduced features
model_2 = xgboost.XGBRegressor()
model_2.fit(x_train[reduced_features_2], y_train)

In [None]:
#Save the model
pickle.dump(model, open('models/model.sav', 'wb'))

In [None]:
#Basic SHAP commands
shap.initjs()

explainer = shap.TreeExplainer(model, x_test)
shap_values = explainer.shap_values(x_test)

shap.summary_plot(shap_values, x_test)

shap_importance_df = global_shap_importance(modelX, x_test)
shap_importance_df[0:20]


fig = plt.figure()
ax = fig.add_axes([0, 0, 0.8, 0.8])

shap.summary_plot(shap_values, x_test)
shap.summary_plot(shap_values, x_test, show=False)
plt.savefig("plots/shap_summary_plot3.pdf")
plt.close()

In [None]:
#or dedicated SHAP function

def shap_importances(X,
                     trained_model,
                     n_features=30,
                     plot=[],
                     n_non_genetic_features=0,
                     tax_dict=None):
    """
    ONLY for tree models!
    Gets the most important features, as identified by Shap and makes Shap
    feature importance plots.
    Learn more about Shap: https://github.com/slundberg/shap
    If the plot list contains 'summary', a Shap summary plot will be made.
    If the plot list contains any strain IDs (from the index of X), the
    function also makes individual Shap plots for each prediction for those
    strain IDs. All plots are returned in a dictionary.

    Input:
    -------
    X: Pandas DataFrame.
        Strain IDs in index, features in the columns.
    trained_model: Trained scikit learn tree model.
    n_features: Integer (default=30).
        Number of highest importance features to plot and return.
    plot: List (default=[]).
        Can contain the values 'summary' and IDs from X.index.
        If the list is empty, no plots are made.
        If it contains 'summary' it makes Shap feature importance plot.
        If it contains IDs from X.index it makes Shap plots for each sample.
    n_non_genetic_features: Integer (default=0).
        Number of non-genetic features. This is used for names and labels in
        the individual sample plots.
    tax_dict: Dictionary.
        Strain IDs as keys and taxonomy values. This is used for labels in
        the individual sample plots.

    Returns:
    --------
    top_features: List of strings.
        List of the highest importance features (corresponding to the plot).
        Only returns the n_features top features.
    figures: Dictionary.
        Dictionary of Shap plots. Plot names as keys, plots as values.
        The plot names are either 'summary' or they start with a strain ID
        and contain information on the contition values of the sample.
        Print the dictionary keys to see: print(figures.keys())
    """

    for p in plot:
        if p not in ['summary'] + list(X.index):
            raise Exception('"plot" contains invalid values.')

    # Explain the model's predictions with SHAP values:
    explainer = shap.TreeExplainer(trained_model)
    shap_values = explainer.shap_values(X)

    figures = {}

    if 'summary' in plot:
        fig = plt.figure()
        # Plot the effects of the features:
        shap.summary_plot(shap_values, X, max_display=n_features, show=False)
        figures['summary'] = fig

    # Get numeric index of which rows (samples) of X to plot:
    rows = [i for i in range(len(X)) if X.index[i] in plot]

    # Plot the effects of the features for each individual prediction:
    for i in rows:
        fig = shap.force_plot(explainer.expected_value,
                              shap_values[i, :],
                              X.iloc[i, :],
                              matplotlib=True,
                              show=False)  # If show=True fig is empty
        if n_non_genetic_features > 0:
            # Put label on x-axis to know which sample it is:
            # X.index[i] is the strain ID
            # tax_dict[X.index[i]] is the taxonomy
            # X.columns is the feature list
            plt.xlabel(X.index[i] + ', ' + tax_dict[X.index[i]] + ', ' +
                       ', '.join([X.columns[j] + '=' + str(X.iloc[i, j])
                                  for j in range(n_non_genetic_features)]))
            name = X.index[i] + '__'
            name += '__'.join([X.columns[j] + '_' + str(X.iloc[i, j])
                               for j in range(n_non_genetic_features)])
            figures[name] = fig
        else:
            # Put label on x-axis to know which sample it is:
            plt.xlabel(X.index[i] + ', ' + tax_dict[X.index[i]])
            figures[X.index[i]] = fig

    top_features = X.columns[np.argsort(np.abs(shap_values).mean(0))[::-1]]
    top_features = list(top_features)[:n_features]

    return top_features, figures