# Example of running pairwise feature attribution with Integrated Gradients

### Indices:
* Zn near-centroid (4L9P:B:ZN:601): 11332
* Mg near-centroid (4OKE:A:MG:202): 7591
* Ca near-centroid (3A09:A:CA:601): 14673

### Two comparisons:
* Mg <-> Zn
* Ca <-> Zn

In [2]:
#!pip install captum # required for feature attribution calculation

In [3]:
import numpy as np
import os
import pandas as pd
import json

import torch

import mic
from mic.data.luna_utils import process_entry
from mic.data.data_loading import sym_expand
from mic.models.MIC import MIC

from luna.mol.entry import Entry

### feature attribution
from typing import Type, Union, Optional
from collections import OrderedDict
import torch
import captum
from captum.attr import IntegratedGradients

Module 'simplejson' not available. Built-in module 'json' will be imported.
Module 'simplejson' not available. Built-in module 'json' will be imported.


In [4]:
pd.set_option('display.max_rows', None)

### Utility functions
* Generating bit-wise information about each structure, feature attribution, formatting results

In [19]:
import logging
logging.getLogger().setLevel("WARNING")

In [20]:
def run_entry(entry_str, include_bitinfo = True):
    pdb, chain, res, resi = entry_str.split(':')
    sym_expand(f'{pdb}.pdb')
    
    test_entry = Entry(pdb + '.symexp', chain, res, int(resi), is_hetatm=True, sep=':')
    results = process_entry(test_entry, prune = True, ifp_type = 'eifp', 
                  ifp_num_levels = 18,
                  ifp_radius_step = 0.25,
                  ifp_length = 4096,
                  ifp_count = True,
                  pdb_path = "/srv/ds/set-1/user/lshub/skiniotis/raw_pdbs/",
                  blind=True,
                  normfunc = np.log1p,
                  return_atm_mngr = False,
                  return_bitinfo = include_bitinfo)
    
    return results

In [21]:
class EncoderPairwiseAttributionHead():
    
    def __init__(
            self,
            encoder):
        """Defines a wrapper for `encoder` that computes the distance
        between to embeddings, to be used for attribution.
        """
        self.encoder = encoder

    def __call__(
            self,
            fp_A: Type[torch.Tensor],
            fp_B: Type[torch.Tensor]):
        # make sure grads are being tracked
        if not fp_A.requires_grad:
            raise ValueError(
                "Tensor for fingerprint A must have gradient tracking on")
        if not fp_B.requires_grad:
            raise ValueError(
                "Tensor for fingerprint B must have gradient tracking on")
        # compute embeddings
        self.encoder.eval()
        emb_A = self.encoder(fp_A)
        emb_B = self.encoder(fp_B)
        # compute distance from target in embedding space
        dist = torch.sum(
            (emb_A - emb_B)**2,
            dim=-1)
        # attribution objective is the embeddings' distance itself
        return dist

In [22]:
def pairwise_attribution(fp_A, fp_B, encoder):
    
    epah = EncoderPairwiseAttributionHead(encoder)
    ig = IntegratedGradients(
            epah,
            multiply_by_inputs=True)
    
    fp_ABs = [(fp_A, fp_B)]
    
    baseline_fp = torch.zeros(fp_A.shape)
    
    attrs = []
    for fpa, fpb in fp_ABs:
        fpa = torch.Tensor(fpa)
        fpb = torch.Tensor(fpb)

        fpa.requires_grad_()
        fpb.requires_grad_()

        attr_A, attr_B = ig.attribute(
            inputs=(torch.unsqueeze(fpa, 0), torch.unsqueeze(fpb, 0)),
            baselines=(torch.unsqueeze(baseline_fp, 0), torch.unsqueeze(baseline_fp, 0)),
            n_steps=50,
            method='gausslegendre')
        attrs.append((
            attr_A.detach().numpy(),
            attr_B.detach().numpy(),
        ))
        
    return attrs

In [23]:
def features_to_csv(feature_indices, feature_list, out_csv = None):
    rows = []
    for idx, features in zip(feature_indices, feature_list):
        for feature in features:
            #print(feature)
            for shell in feature[1]:
                atom_str = ';'.join([atom.full_atom_name for atom in shell.central_atm_grp.atoms])
                rows.append((idx[0], idx[1], feature[0], shell.level, shell.radius, atom_str, len(shell.interactions)))
    result = pd.DataFrame(rows, 
                 columns = ['index', 'attribution', 'bit', 'level', 'radius', 'central_atoms', 'num_interactions'])
    if out_csv is not None:
        result.to_csv(out_csv)
    return result

In [24]:
### Loading trained model
mic_trained = MIC()

### Running feature attribution following IFPencoder implementation

### Indices:
* Zn near-centroid (4L9P:B:ZN:601): 0
* Mg near-centroid (4OKE:A:MG:202): 1
* Ca near-centroid (3A09:A:CA:601): 2

In [25]:
entry_strs = ['4L9P:B:ZN:601', '4OKE:A:MG:202', '3A09:A:CA:601']

In [29]:
bit_info = {}
fps = []
for entry_str in entry_strs:
    fp_info, bit_info[entry_str.split(':')[2]] = run_entry(entry_str, True)
    fps.append(fp_info[0])
fps = np.stack(fps)

 Symmetry: Found 2 symmetry operators.
 Symmetry: Found 4 symmetry operators.
 Symmetry: Found 2 symmetry operators.


In [30]:
attrs = pairwise_attribution(fps[0], fps[1], mic_trained.model)

In [31]:
# mg -> zn
attrs = pairwise_attribution(fps[0], fps[1], mic_trained.model)
zn_mg_attrs = attrs[0][0][0]
zn_mg_feats = sorted(list(enumerate(zn_mg_attrs)), key=lambda x: x[1], reverse=True)[:10]
zn_feats_recovered = [bit_info ['ZN'][i[0]] for i in zn_mg_feats]

In [32]:
features_to_csv(zn_mg_feats, 
                zn_feats_recovered)

Unnamed: 0,index,attribution,bit,level,radius,central_atoms,num_interactions
0,1369,0.008236,480707929,11,2.75,prot/0/C/HOH/102/O,0
1,1369,0.008236,480707929,11,2.75,prot/0/C/HOH/103/O,0
2,1369,0.008236,480707929,11,2.75,prot/0/B/HOH/710/O,0
3,1369,0.008236,480707929,11,2.75,prot/0/B/HOH/766/O,0
4,1369,0.008236,480707929,11,2.75,prot/0/B/HOH/780/O,0
5,1369,0.008236,480707929,11,2.75,prot/0/B/HOH/789/O,0
6,1369,0.008236,480707929,11,2.75,prot/0/B/HOH/804/O,0
7,1369,0.008236,480707929,11,2.75,prot/0/B/HOH/827/O,0
8,1885,0.007072,1480615773,12,3.0,prot/0/C/HOH/102/O,0
9,1885,0.007072,1480615773,12,3.0,prot/0/C/HOH/103/O,0


In [33]:
# ca -> zn
attrs = pairwise_attribution(fps[0], fps[2], mic_trained.model)
zn_ca_attrs = attrs[0][0][0]
zn_ca_feats = sorted(list(enumerate(zn_ca_attrs)), key=lambda x: x[1], reverse=True)[:10]
zn_feats_recovered = [bit_info['ZN'][i[0]] for i in zn_ca_feats]

In [34]:
features_to_csv(zn_ca_feats, 
                zn_feats_recovered)

Unnamed: 0,index,attribution,bit,level,radius,central_atoms,num_interactions
0,1369,0.011139,480707929,11,2.75,prot/0/C/HOH/102/O,0
1,1369,0.011139,480707929,11,2.75,prot/0/C/HOH/103/O,0
2,1369,0.011139,480707929,11,2.75,prot/0/B/HOH/710/O,0
3,1369,0.011139,480707929,11,2.75,prot/0/B/HOH/766/O,0
4,1369,0.011139,480707929,11,2.75,prot/0/B/HOH/780/O,0
5,1369,0.011139,480707929,11,2.75,prot/0/B/HOH/789/O,0
6,1369,0.011139,480707929,11,2.75,prot/0/B/HOH/804/O,0
7,1369,0.011139,480707929,11,2.75,prot/0/B/HOH/827/O,0
8,1885,0.009902,1480615773,12,3.0,prot/0/C/HOH/102/O,0
9,1885,0.009902,1480615773,12,3.0,prot/0/C/HOH/103/O,0
