# Smoothen one pocket

Smoothen one pocket using the pre-trained model from `train.ipynb`.

To smoothen the pocket, you need the following:
- the pocket to be smoothened
- sequence
- ESM-2 embedding (`https://github.com/skrhakv/esm2-generator/blob/master/compute-esm.py`)
- distance matrix (`compute_distance_matrix.py`)

In [5]:
import torch
import numpy as np
import os

import torch.nn as nn

POSITIVE_DISTANCE_THRESHOLD = 15
NEGATIVE_DISTANCE_THRESHOLD = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def process_single_sequence(structure_name: str, chain_id: str, binding_residues: list[str], sequence: str, embedding_path: str, distance_matrix_path: str):
    id = structure_name.lower() + chain_id
    if not os.path.exists(embedding_path):
        raise FileNotFoundError(f'Embedding file for {id} not found in {embedding_path}')
    
    embedding = np.load(embedding_path)
    distance_matrix = np.load(distance_matrix_path)

    Xs = []
    Ys = []
    idx = []

    binding_residues_indices = [int(residue[1:]) for residue in binding_residues]
    
    negative_examples_indices = set()
    
    for (aa, residue_idx) in [(residue[0], int(residue[1:])) for residue in binding_residues]:
        assert sequence[residue_idx] == aa
        close_residues_indices = np.where(distance_matrix[residue_idx] < POSITIVE_DISTANCE_THRESHOLD)[0]
        close_binding_residues_indices = np.intersect1d(close_residues_indices, binding_residues_indices)
        
        concatenated_embedding = np.concatenate((embedding[residue_idx], np.mean(embedding[close_binding_residues_indices], axis=0)))
        Xs.append(concatenated_embedding)
        Ys.append(1)  # positive example
        idx.append(residue_idx)
        
        really_close_residues_indices = np.where(distance_matrix[residue_idx] < NEGATIVE_DISTANCE_THRESHOLD)[0]
        negative_examples_indices.update(set(list(really_close_residues_indices)) - set(list(binding_residues_indices)))

    for residue_idx in negative_examples_indices:
        close_residues_indices = np.where(distance_matrix[residue_idx] < POSITIVE_DISTANCE_THRESHOLD)[0]
        close_binding_residues_indices = np.intersect1d(close_residues_indices, binding_residues_indices)
        concatenated_embedding = np.concatenate((embedding[residue_idx], np.mean(embedding[close_binding_residues_indices], axis=0)))
        Xs.append(concatenated_embedding)
        Ys.append(0)
        idx.append(residue_idx)

    return np.array(Xs), np.array(Ys), np.array(idx)

def predict_single_sequence(Xs, Ys, idx, model_3):
    Xs = torch.tensor(Xs, dtype=torch.float32).to(device)
    Ys = torch.tensor(Ys, dtype=torch.int64).to(device)
    idx = torch.tensor(idx, dtype=torch.int64).to(device)

    test_logits = model_3(Xs).squeeze()
    test_pred = torch.sigmoid(test_logits)

    return {'predictions': test_pred.detach().cpu().numpy(), 'indices': idx.detach().cpu().numpy()}

DECISION_THRESHOLD = 0.8
DROPOUT = 0.3
LAYER_WIDTH = 256
ESM2_DIM  = 1280 * 2

class CryptoBenchClassifier(nn.Module):
    def __init__(self, input_dim=ESM2_DIM):
        super().__init__()
        self.layer_1 = nn.Linear(in_features=input_dim, out_features=LAYER_WIDTH)
        self.dropout1 = nn.Dropout(DROPOUT)

        self.layer_2 = nn.Linear(in_features=LAYER_WIDTH, out_features=LAYER_WIDTH)
        self.dropout2 = nn.Dropout(DROPOUT)

        self.layer_3 = nn.Linear(in_features=LAYER_WIDTH, out_features=1)

        self.relu = nn.ReLU()

    def forward(self, x):
        # Intersperse the ReLU activation function between layers
        return self.layer_3(self.dropout2(self.relu(self.layer_2(self.dropout1(self.relu(self.layer_1(x)))))))

In [None]:
MODEL_STATE_DICT_PATH = '/home/polakluk/cryptobench-data/model/cryptobench_classifier.pt'
EMBEDDING_PATH = '/home/polakluk/cryptobench-data/650M-embeddings/1arlA.npy'
DISTANCE_MATRIX_PATH = '/home/polakluk/cryptobench-data/distance-matrices/1arlA.npy'

model = CryptoBenchClassifier().to(device)
model.load_state_dict(torch.load(MODEL_STATE_DICT_PATH, map_location=device), strict=True)

# 1arl;A;UNKNOWN;H68 E71 R126 N143 R144 H195 S196 L202 I242 I246 Y247 A249 G252 S253 I254 T267 E269;ARSTNTFNYATYHTLDEIYDFMDLLVAEHPQLVSKLQIGRSYEGRPIYVLKFSTGGSNRPAIWIDLGIHSREWITQATGVWFAKKFTEDYGQDPSFTAILDSMDIFLEIVTNPDGFAFTHSQNRLWRKTRSVTSSSLCVGVDANRNWDAGFGKAGASSSPCSETYHGKYANSEVEVKSIVDFVKDHGNFKAFLSIHSYSQLLLYPYGYTTQSIPDKTELNQVAKSAVAALKSLYGTSYKYGSIITTIYQASGGSIDWSYNQGIKYSFTFELRDTGRYGFLLPASQIIPTAQETWLGVLTIMEHTVNN

single_for_prediction = process_single_sequence('1arl', 'A', ['H68', 'E71', 'R126', 'N143', 'R144', 'H195', 'S196', 'L202', 'I242', 'I246', 'Y247', 'A249', 'G252', 'S253', 'I254', 'T267', 'E269'],
                        'ARSTNTFNYATYHTLDEIYDFMDLLVAEHPQLVSKLQIGRSYEGRPIYVLKFSTGGSNRPAIWIDLGIHSREWITQATGVWFAKKFTEDYGQDPSFTAILDSMDIFLEIVTNPDGFAFTHSQNRLWRKTRSVTSSSLCVGVDANRNWDAGFGKAGASSSPCSETYHGKYANSEVEVKSIVDFVKDHGNFKAFLSIHSYSQLLLYPYGYTTQSIPDKTELNQVAKSAVAALKSLYGTSYKYGSIITTIYQASGGSIDWSYNQGIKYSFTFELRDTGRYGFLLPASQIIPTAQETWLGVLTIMEHTVNN', 
                        EMBEDDING_PATH, DISTANCE_MATRIX_PATH)

smoothened_prediction = predict_single_sequence(*single_for_prediction, model_3=model)

print(smoothened_prediction)

SMOOTHENED_THRESHOLD = 0.7 # this is defined by the training data - best F1 score was achieved with this threshold 

selected_indices = np.where(smoothened_prediction['predictions'] > SMOOTHENED_THRESHOLD)[0]
selected_indices_mapped = smoothened_prediction['indices'][selected_indices]

print(f'Selected binding residues: {selected_indices_mapped}')

{'predictions': array([9.36800182e-01, 9.32939410e-01, 9.56092358e-01, 8.79242003e-01,
       8.41791093e-01, 9.73467410e-01, 8.17485511e-01, 2.16683656e-01,
       7.85170346e-02, 9.20639873e-01, 9.75117385e-01, 6.62114799e-01,
       1.65152818e-01, 1.66030377e-01, 4.09017861e-01, 8.67445320e-02,
       9.09414828e-01, 2.30749901e-02, 8.31042416e-03, 7.10558966e-02,
       4.93570417e-02, 4.23515541e-03, 9.00857002e-02, 2.18222085e-02,
       1.01847798e-02, 3.59809175e-02, 3.57238087e-03, 4.00694273e-03,
       2.75908887e-01, 2.66343355e-01, 6.97652623e-02, 3.76036926e-03,
       4.13294742e-03, 7.80426383e-01, 3.50182247e-03, 4.37276810e-03,
       7.94239994e-03, 1.18232474e-01, 1.34831788e-02, 9.84929875e-03,
       6.17025137e-01, 2.05906741e-02, 8.01444411e-01, 7.97918797e-01,
       6.55539036e-01, 1.16083324e-01, 1.27658453e-02, 1.67769776e-03,
       3.84957209e-04, 8.73115100e-03, 6.25527918e-01, 3.54148680e-03,
       4.59358888e-03, 1.95668135e-02, 1.25064538e-03, 5.7633