# Finding STk vs. Tk sequence determinants

In this example, we'll tackle a problem of MSA sequences' classification.

Protein kinases, or phosphate transferases, are regulatory enzymes ubiquitous across all
major known taxae.
They catalyze transferring of phosphate moieties from an ATP molecule to the carboxylic
group of an amino acid residue.
This changes physico-chemical properties of a target protein altering it's overall
behaviour and functionality.
Thus, they act as molecular controllers, deeply embedded into regulating mechanisms
governing complex cellular machinery.

Over the course of evolution, along with substrate specificity, PKs developed
a preference towards certain amino acid residues.
The principal group of PKs that likely appeared initially were PKs transferring
phosphate to serine and threonine residues (STk).
Later, another group emerged, specializing towards the tyrosine residue (Tk).

| ![Phosphorylation](../fig/Phosphorylation.pdf "Phosphorylation reaction") |
|:--:|
| <b> Fig. 1. Phosphorylation of Ser/Thr and Tyr. Notice the dissimilarities between these two groups </b>|

What could be the sequence differences enabling such a specialization?
Could we find them using features selection machinery?

In 1994, Juswinder performed a conservation-based analysis of STk and Tk
sequences.
He discovered several MSA positions that, as structural evidence suggests, expliot
physico-chemical differences (e.g., volume) between Ser/Thr and Tyr.
(pic with their aln)


## 1. Install dependencies

In [1]:
# ! pip install lXtractor eBoruta ipywidgets xgboost optuna

## 2. Prepare data

In [2]:
% load_ext autoreload
% autoreload 2

In [3]:
import gzip
from collections import Counter
from itertools import chain
from io import StringIO, BytesIO

import pandas as pd
from lXtractor.core.chain import ChainSequence, ChainList
from lXtractor.ext.hmm import PyHMMer
from lXtractor.variables.calculator import GenericCalculator
from lXtractor.variables.sequential import PFP
from lXtractor.variables.manager import Manager
from lXtractor.util.io import fetch_text
from more_itertools import consume

### 2.1 Download and parse initial sequences and profile

The link is generated from the search API.
We are taking all reviewed proteins (Swiss-Prot database) belonging to the PK superfamily.

In [4]:
LINK_UNIPROT = "https://rest.uniprot.org/uniprotkb/stream?fields=accession%2Cid%2Csequence%2Cft_domain%2Cprotein_families&format=tsv&query=%28%28family%3A%22protein%20kinase%20superfamily%22%29%29%20AND%20%28reviewed%3Atrue%29"
LINK_PFAM = 'https://www.ebi.ac.uk/interpro/wwwapi//entry/pfam/PF00069?annotation=hmm'

In [5]:
df_up = pd.read_csv(StringIO(fetch_text(LINK_UNIPROT)), sep='\t')

In [6]:
df_up.head()

Unnamed: 0,Entry,Entry Name,Sequence,Domain [FT],Protein families
0,A0A075F7E9,LERK1_ORYSI,MVALLLFPMLLQLLSPTCAQTQKNITLGSTLAPQGPASSWLSPSGD...,"DOMAIN 22..149; /note=""Bulb-type lectin""; /evi...","Protein kinase superfamily, Ser/Thr protein ki..."
1,A0A078CGE6,M3KE1_BRANA,MARQMTSSQFHKSKTLDNKYMLGDEIGKGAYGRVYIGLDLENGDFV...,"DOMAIN 20..274; /note=""Protein kinase""; /evide...","Protein kinase superfamily, Ser/Thr protein ki..."
2,A0A0B4J2F2,SIK1B_HUMAN,MVIMSEFSADPAGQGQGQQKPLRVGFYDIERTLGKGNFAVVKLARH...,"DOMAIN 27..278; /note=""Protein kinase""; /evide...","Protein kinase superfamily, CAMK Ser/Thr prote..."
3,A0A0K3AV08,MLK1_CAEEL,MEQASVPSYVNIPPIAKTRSTSHLAPTPEHHRSVSYEDTTTASTST...,"DOMAIN 69..130; /note=""SH3""; /evidence=""ECO:00...","Protein kinase superfamily, STE Ser/Thr protei..."
4,A0A0P0VIP0,LRSK7_ORYSJ,MPPRCRRLPLLFILLLAVRPLSAAAASSIAAAPASSYRRISWASNL...,"DOMAIN 389..661; /note=""Protein kinase""; /evid...",Leguminous lectin family; Protein kinase super...


In [7]:
prof = PyHMMer(BytesIO(gzip.decompress(
    fetch_text(LINK_PFAM, decode=False)
)))

### 2.2 Convert sequences

In [8]:
def wrap_into_chain_seq(row):
    fam_df = row['Protein families']
    family = 'other'
    if 'protein kinase family' in fam_df:
        if 'Ser/Thr' in fam_df:
            family = 'STk'
        elif 'Tyr' in fam_df:
            family = 'Tk'

    return ChainSequence.from_string(
        row['Sequence'], name=row['Entry Name'], meta={'Family': family}
    )

In [9]:
chains = ChainList(
    wrap_into_chain_seq(r) for _, r in df_up.iterrows()
)
Counter(c.meta['Family'] for c in chains)

Counter({'STk': 3674, 'Tk': 551, 'other': 248})

In [10]:
chains = chains.filter(lambda x: x.meta['Family'] != 'other')

In [11]:
consume(prof.annotate(
    chains, new_map_name='PK', min_size=200, min_cov_hmm=0.7, min_score=30
));

In [12]:
len(chains), len(chains.collapse_children())

(4225, 3893)

In [13]:
chains = chains.filter(lambda x: len(x.children) > 0)
len(chains), Counter(c.meta['Family'] for c in chains)

(3844, Counter({'STk': 3391, 'Tk': 453}))

### 2.3 Prepare encoded dataset

In [14]:
N_COMP = 3

In [15]:
variables = list(chain.from_iterable(
    (PFP(pos, i) for i in range(1, N_COMP + 1)) for pos in range(1, prof.hmm.M + 1)
))
len(variables)

792

In [16]:
manager = Manager(verbose=True)
calculator = GenericCalculator()
domains = chains.collapse_children()
df = manager.aggregate_from_it(
    manager.calculate(domains, variables, calculator, map_name='PK')
)

Aggregating variables: 0it [00:00, ?it/s]

Staging calculations:   0%|          | 0/3893 [00:00<?, ?it/s]

Calculating variables: 0it [00:00, ?it/s]

In [17]:
cls_map = {'STk': 0, 'Tk': 1}
id2cls = {s.id: cls_map[s.parent.meta['Family']] for s in domains}
df['IsTk'] = df['ObjectID'].map(id2cls)

In [18]:
df.head()

Unnamed: 0,ObjectID,"PFP(p=1,i=1)","PFP(p=1,i=2)","PFP(p=1,i=3)","PFP(p=2,i=1)","PFP(p=2,i=2)","PFP(p=2,i=3)","PFP(p=3,i=1)","PFP(p=3,i=2)","PFP(p=3,i=3)",...,"PFP(p=262,i=1)","PFP(p=262,i=2)","PFP(p=262,i=3)","PFP(p=263,i=1)","PFP(p=263,i=2)","PFP(p=263,i=3)","PFP(p=264,i=1)","PFP(p=264,i=2)","PFP(p=264,i=3)",IsTk
0,PK_1|526-791<-(LERK1_ORYSI|1-813),,,,,,,,,,...,,,,,,,,,,0
1,PK_1|21-274<-(M3KE1_BRANA|1-1299),,,,5.11,0.19,-1.02,5.76,-1.33,-1.71,...,-3.82,-2.31,3.45,7.33,4.55,2.77,6.58,-1.73,-2.49,0
2,PK_1|27-278<-(SIK1B_HUMAN|1-783),3.14,3.59,2.45,-6.61,0.94,-3.04,6.58,-1.73,-2.49,...,-2.79,6.6,1.21,7.33,4.55,2.77,5.11,0.19,-1.02,0
3,PK_1|188-445<-(MLK1_CAEEL|1-1059),,,,,,,,,,...,,,,,,,,,,0
4,PK_1|392-592<-(LRSK7_ORYSJ|1-695),,,,,,,,,,...,,,,,,,,,,0


## 3. Run feature selection

In [60]:
from dataclasses import dataclass

import numpy as np
import optuna
import shap
from eBoruta import eBoruta
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedShuffleSplit
from xgboost import XGBClassifier

In [45]:
@dataclass
class Dataset:
    df: pd.DataFrame
    x_names: list[str]
    y_name: str

    @property
    def x(self) -> pd.DataFrame:
        return self.df[self.x_names]

    @property
    def y(self) -> pd.Series:
        return self.df[self.y_name]


def sel_by_idx(ds, idx):
    return Dataset(ds.df.iloc[idx], ds.x_names, ds.y_name)


def score(ds, model, score_fn=f1_score):
    return score_fn(ds.y.values, model.predict(ds.x))


def cv(ds, model, n=10, score_fn=f1_score, agg_fn=np.mean):
    scores = []
    splitter = StratifiedShuffleSplit(n_splits=n)
    for train_idx, test_idx in splitter.split(ds.x, ds.y):
        ds_train = sel_by_idx(ds, train_idx)
        ds_test = sel_by_idx(ds, test_idx)
        _model = model.__class__(**model.get_params())
        _model.fit(ds_train.x, ds_train.y)
        scores.append(score(ds_test, _model, score_fn))
    return agg_fn(scores)

# def plot_imp_history(df_history: pd.DataFrame):
#     sns.lineplot(x='Step', y='Importance', hue='Feature', data=df_history)
#     sns.lineplot(x='Step', y='Threshold', data=df_history, linestyle='--', linewidth=4)

In [55]:
dataset = Dataset(df, [c for c in df.columns if 'PFP' in c], 'IsTk')
classifier = XGBClassifier(n_jobs=-1)

### 3.2 Cross-validate the initial model

In [54]:
cv(dataset, XGBClassifier(n_jobs=-1))

0.9796576971175099

### 3.3 Select Features

In [20]:
boruta = eBoruta(n_iter=100)
boruta.fit(dataset.x, dataset.y, model=classifier)


Traceback (most recent call last):
  File "/Users/ivanreveguk/Projects/eBoruta/eBoruta/dataprep.py", line 127, in has_missing
    assert isinstance(res, bool)
AssertionError
Failed to check input for missing values due to 


Boruta trials:   0%|          | 0/100 [00:00<?, ?it/s]

ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated, use `iteration_range` or model slicing instead.
ntree_limit is deprecated

In [58]:
accepted = boruta.features_.accepted
ds_sel = Dataset(
    df[list(accepted) + [dataset.y_name]], list(accepted), dataset.y_name
)

In [59]:
cv(ds_sel, XGBClassifier(n_jobs=-1))

0.9892923297770253

In [64]:
model = boruta.model_
model.fit(ds_sel.x, ds_sel.y)

In [70]:
explainer = shap.TreeExplainer(model)
shap_vs = explainer.shap_values(ds_sel.x, ds_sel.y)
imp_vs = np.mean(np.abs(shap_vs), axis=0)

ntree_limit is deprecated, use `iteration_range` or model slicing instead.


In [71]:
imp_vs

array([2.78003458e-02, 3.30474880e-03, 4.22273064e-03, 4.26997477e-03,
       2.08926313e-02, 0.00000000e+00, 0.00000000e+00, 1.63884573e-02,
       1.02229714e-02, 5.04431594e-03, 8.13580584e-03, 0.00000000e+00,
       4.60334448e-03, 1.84719786e-01, 1.74306203e-02, 1.75376218e-02,
       2.19746336e-01, 1.48278996e-01, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 1.66387344e-03, 5.89598827e-02, 8.05656835e-02,
       3.17157656e-02, 7.69123137e-02, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 7.03409165e-02, 6.27455637e-02, 0.00000000e+00,
       0.00000000e+00, 4.71823402e-02, 7.57577568e-02, 4.75610457e-02,
       0.00000000e+00, 0.00000000e+00, 2.09255260e-03, 3.00644222e-03,
       5.50829172e-02, 0.00000000e+00, 6.35143667e-02, 1.10375546e-02,
       6.34095259e-03, 5.48221581e-02, 1.76545568e-02, 0.00000000e+00,
       2.56536007e+00, 3.35658669e-01, 8.97402018e-02, 1.53745651e-01,
       6.39048517e-02, 0.00000000e+00, 4.11237702e-02, 7.89283216e-02,
      

In [None]:
boruta

In [62]:
explanations.

<shap.explainers._tree.Tree at 0x1712f9300>

### (Optional) 3.1 Finalize the model

In [53]:
def objective(trial, ds, model) -> float:
    params = {
        'learning_rate': trial.suggest_float('learning_rate', 0, 1),
        'max_depth': trial.suggest_int('max_depth', 4, 16),
        'gamma': trial.suggest_float('gamma', 0, 10.0),
        'reg_lambda': trial.suggest_float('reg_lambda', 0, 10.0),
        'reg_alpha': trial.suggest_float('reg_alpha', 0, 10.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.4, 1.0),
        'colsample_bylevel': trial.suggest_float('colsample_bylevel', 0.4, 1.0),
        'scale_pos_weight': trial.suggest_float('scale_pos_weight', 0.0, 10.0),
    }
    # callback = optuna.integration.XGBoostPruningCallback(trial, 'validation_logloss')
    params = {**model.get_params(), **params}
    model = model.__class__(**params)
    return cv(ds, model)

In [39]:
hist

Unnamed: 0,Feature,Step,Importance,Hit,Decision,Threshold
1708,"PFP(p=6,i=3)",9,0.003305,1,Accepted,0.0
3008,"PFP(p=11,i=1)",9,0.00427,1,Accepted,0.0
3108,"PFP(p=11,i=2)",9,0.019231,1,Accepted,0.0
3908,"PFP(p=14,i=1)",9,0.017288,1,Accepted,0.0
4012,"PFP(p=14,i=2)",13,0.120567,1,Accepted,0.025793
4708,"PFP(p=16,i=3)",9,0.008629,1,Accepted,0.0
7008,"PFP(p=24,i=2)",9,0.004603,1,Accepted,0.0
7208,"PFP(p=25,i=1)",9,0.15893,1,Accepted,0.0
7912,"PFP(p=27,i=2)",13,0.143174,1,Accepted,0.025793
9708,"PFP(p=33,i=2)",9,0.056687,1,Accepted,0.0


In [75]:
class A:
    def __getitem__(self, item):
        print(item)
        print(type(item))

[autoreload of eBoruta.containers failed: Traceback (most recent call last):
  File "/Users/ivanreveguk/conda/miniconda3/envs/boruta/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 261, in check
    superreload(m, reload, self.old_objects)
  File "/Users/ivanreveguk/conda/miniconda3/envs/boruta/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 459, in superreload
    module = reload(module)
  File "/Users/ivanreveguk/conda/miniconda3/envs/boruta/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 879, in exec_module
  File "<frozen importlib._bootstrap_external>", line 1017, in get_code
  File "<frozen importlib._bootstrap_external>", line 947, in source_to_code
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/Users/ivanreveguk/Projects/eBoruta/eBoruta/con

In [87]:
dec = features.dec_history.iloc[5].values

In [88]:
dec

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [95]:
y = np.arange(10).reshape((5, 2))
pd.DataFrame(
    y, columns=(
        ['Y']
        if len(y.shape) == 1
        else [f'Y_{i}' for i in range(1, y.shape[1] + 1)]
    )
)

Unnamed: 0,Y_1,Y_2
0,0,1
1,2,3
2,4,5
3,6,7
4,8,9
