# Revealing TCR3b sequence motifs for a specific epitope using attention weights

- Attention 분석을 통해 특정 epitope을 인지하는 TCRb CDR3 서열 모티프를 확인한다

## Global configurations

In [11]:
!pwd

/Users/hym/projects/TCRBert/notebook


In [4]:
import logging
import logging.config
import os
import sys
import warnings
from enum import auto
import pandas as pd
import numpy as np
from IPython.core.display import display

from tcrbert.exp import Experiment
from tcrbert.predlistener import PredResultRecoder

rootdir = '/Users/hym/projects/TCRBert'
workdir = '%s/notebook' % rootdir
datadir = '%s/data' % rootdir
srcdir = '%s/tcrbert' % rootdir
outdir = '%s/output' % rootdir

os.chdir(workdir)

sys.path.append(rootdir)
sys.path.append(srcdir)

display(sys.path)
    
# Display
pd.set_option('display.max.rows', 999)
pd.set_option('display.max.columns', 999)

# Logger
warnings.filterwarnings('ignore')
logging.config.fileConfig('../config/logging.conf')
logger = logging.getLogger('tcrbert')
logger.setLevel(logging.INFO)

# Target experiment
exp_key = 'testexp'
exp = Experiment.from_key(exp_key)

exp_conf = exp.exp_conf
eval_conf = exp_conf['eval']

display(exp_conf)

# Final finetuned model
model = exp.load_eval_model()
display(model)

# Eval result recoder
eval_recoder = PredResultRecoder(output_attentions=True)
model.add_pred_listener(eval_recoder)

['/Users/hym/projects/TCRBert/notebook',
 '/Users/hym/projects/TCRBert',
 '/Users/hym/projects/DeepTCR',
 '/Users/hym/projects/epidab',
 '/Users/hym/projects/TCRGP',
 '/Users/hym/projects/epidab/epidab',
 '/Users/hym/opt/anaconda3/lib/python37.zip',
 '/Users/hym/opt/anaconda3/lib/python3.7',
 '/Users/hym/opt/anaconda3/lib/python3.7/lib-dynload',
 '',
 '/Users/hym/opt/anaconda3/lib/python3.7/site-packages',
 '/Users/hym/opt/anaconda3/lib/python3.7/site-packages/aeosa',
 '/Users/hym/opt/anaconda3/lib/python3.7/site-packages/IPython/extensions',
 '/Users/hym/.ipython',
 '/Users/hym/projects/TCRBert',
 '/Users/hym/projects/TCRBert/tcrbert',
 '/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev',
 '/Applications/PyCharm.app/Contents/plugins/python/helpers-pro/jupyter_debug',
 '/Users/hym/projects/TCRBert',
 '/Users/hym/projects/TCRBert/tcrbert']

2021-09-10 14:53:46 [INFO]: Loaded exp_conf: {'title': 'testexp', 'description': 'testexp', 'paper': 'testexp', 'model_config': '../config/bert-base/', 'train': {'pretrained_model': {'type': 'tape', 'location': '../config/bert-base/'}, 'data_parallel': False, 'rounds': [{'data': 'test', 'test_size': 0.2, 'batch_size': 32, 'n_epochs': 3, 'n_workers': 2, 'metrics': ['accuracy'], 'optimizer': {'name': 'adam', 'lr': 1e-05}, 'train_bert_encoders': [-4, None], 'early_stopper': {'monitor': 'accuracy', 'patience': 1}, 'model_checkpoint': {'chk': '../output/testexp/train.0.model_{epoch}.chk', 'monitor': 'accuracy', 'save_best_only': True, 'period': 1}, 'result': '../output/testexp/train.0.result.json'}]}, 'eval': {'data_parallel': False, 'batch_size': 32, 'n_workers': 2, 'metrics': ['accuracy', 'f1', 'roc_auc'], 'output_attentions': False, 'tests': [{'data': 'test', 'result': '../output/testexp/eval.test.result.json'}]}}


{'title': 'testexp',
 'description': 'testexp',
 'paper': 'testexp',
 'model_config': '../config/bert-base/',
 'train': {'pretrained_model': {'type': 'tape',
   'location': '../config/bert-base/'},
  'data_parallel': False,
  'rounds': [{'data': 'test',
    'test_size': 0.2,
    'batch_size': 32,
    'n_epochs': 3,
    'n_workers': 2,
    'metrics': ['accuracy'],
    'optimizer': {'name': 'adam', 'lr': 1e-05},
    'train_bert_encoders': [-4, None],
    'early_stopper': {'monitor': 'accuracy', 'patience': 1},
    'model_checkpoint': {'chk': '../output/testexp/train.0.model_{epoch}.chk',
     'monitor': 'accuracy',
     'save_best_only': True,
     'period': 1},
    'result': '../output/testexp/train.0.result.json'}]},
 'eval': {'data_parallel': False,
  'batch_size': 32,
  'n_workers': 2,
  'metrics': ['accuracy', 'f1', 'roc_auc'],
  'output_attentions': False,
  'tests': [{'data': 'test',
    'result': '../output/testexp/eval.test.result.json'}]}}

2021-09-10 14:53:46 [INFO]: Create TAPE model using config: ../config/bert-base/
2021-09-10 14:53:47 [INFO]: Loading the eval model from ../output/testexp/train.0.model_0.chk


BertTCREpitopeModel(
  (bert): ProteinBertModel(
    (embeddings): ProteinBertEmbeddings(
      (word_embeddings): Embedding(30, 768, padding_idx=0)
      (position_embeddings): Embedding(8192, 768)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ProteinBertEncoder(
      (layer): ModuleList(
        (0): ProteinBertLayer(
          (attention): ProteinBertAttention(
            (self): ProteinBertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ProteinBertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm()
              (dropout): Drop

## Shomuradova et al

- The dataset containing COVID-19 S-protein269-277(YLQPRTFLL) epitope with the 352 epitope-specific TCRβs from {Shomuradova:2020}
- TCR3b sequence motifs와 attention weights와의 상관관계는?

In [17]:
from torch.utils.data import DataLoader
from tcrbert.dataset import TCREpitopeSentenceDataset

eval_ds = TCREpitopeSentenceDataset.from_key('shomuradova')
eval_data_loader = DataLoader(eval_ds, batch_size=len(eval_ds), shuffle=False, num_workers=2)

model.predict(data_loader=eval_data_loader, metrics=eval_conf['metrics'])

2021-09-10 15:04:50 [INFO]: Begin predict...
2021-09-10 15:04:50 [INFO]: use_cuda, device: False, cpu
2021-09-10 15:04:50 [INFO]: model: BertTCREpitopeModel(
  (bert): ProteinBertModel(
    (embeddings): ProteinBertEmbeddings(
      (word_embeddings): Embedding(30, 768, padding_idx=0)
      (position_embeddings): Embedding(8192, 768)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ProteinBertEncoder(
      (layer): ModuleList(
        (0): ProteinBertLayer(
          (attention): ProteinBertAttention(
            (self): ProteinBertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ProteinBertSelfOutput(
       

2021-09-10 15:04:50 [INFO]: n_data: 610
2021-09-10 15:04:50 [INFO]: batch_size: 610
2021-09-10 15:04:51 [INFO]: Begin 0/1 prediction batch
2021-09-10 15:05:58 [INFO]: End 0/1 prediction batch
2021-09-10 15:05:58 [INFO]: Done to predict...


In [38]:
import numpy as np
from tcrbert.dataset import CN

n_layers = model.config.num_hidden_layers
n_data = len(eval_ds)
n_heads = model.config.num_attention_heads
max_len = eval_ds.max_len

eval_df = eval_ds.df_enc

attentions = np.zeros((n_layers, n_data, n_heads, max_len, max_len))
print('attentions.shape: %s' % str(attentions.shape))

for li, layer_attns in enumerate(eval_recoder.result_map['attentions']):
    # (n_data, n_heads, max_len, max_len)
    attentions[li] = layer_attns.numpy()

print('eval_df.label: %s' % eval_df[CN.label])
pos_indices = np.where(eval_df[CN.label] == 1)[0]
neg_indices = np.where(eval_df[CN.label] == 0)[0]

print('pos_indices: %s(%s)' % (pos_indices, str(pos_indices.shape)))
print('neg_indices: %s(%s)' % (neg_indices, str(neg_indices.shape)))

pos_attns = attentions[:, pos_indices]
neg_attns = attentions[:, neg_indices]

print('pos_attns.shape: %s, neg_attns.shape: %s' % (str(pos_attns.shape), str(neg_attns.shape)))


attentions.shape: (12, 610, 12, 40, 40)
eval_df.label: YLQPRTFLL_CASSFQNTGELFF           1
YLQPRTFLL_CASSSVNNNEQFF           1
YLQPRTFLL_CAVGEANTGELFF           1
YLQPRTFLL_CAYQEVNTGELFF           1
YLQPRTFLL_CSARDDQAVNTGELFF        1
YLQPRTFLL_CSAGQRNTGELFF           1
YLQPRTFLL_CASSLEIEAFF             1
YLQPRTFLL_CAGDYLNTGELFF           1
YLQPRTFLL_CASSPDIACTF             1
YLQPRTFLL_CASSVDNTGELFF           1
YLQPRTFLL_CASSPDIEAFF             1
YLQPRTFLL_CAGQDLNTGELFF           1
YLQPRTFLL_CASSPDIVAFF             1
YLQPRTFLL_CAAQNLNTGELFF           1
YLQPRTFLL_CASSLDIEAFF             1
YLQPRTFLL_CSAGDRNTGELFF           1
YLQPRTFLL_CSARGGQGQNTGELFF        1
YLQPRTFLL_CASSPDIEQYF             1
YLQPRTFLL_CASTDLNTGELFF           1
YLQPRTFLL_CASSELNTGELFF           1
YLQPRTFLL_CATQDVNTGELFF           1
YLQPRTFLL_CASSDLSTGELFF           1
YLQPRTFLL_CASSDLNSGEQYF           1
YLQPRTFLL_CANQDSNTGELFF           1
YLQPRTFLL_CASGDLSSGEQYF           1
YLQPRTFLL_CASSDQNGNIQYF           1
YLQPRTFLL

pos_attns.shape: (12, 305, 12, 40, 40), neg_attns.shape: (12, 305, 12, 40, 40)


In [24]:
attentions.mean(axis=(1, 3, 4)).shape

(12, 12)