In [None]:
import pathlib
import os
from time import sleep
import warnings
import MDAnalysis as mda
from MDAnalysis.analysis.dihedrals import Dihedral
import multiprocessing as mp
import numpy as np
import pandas as pd
from scipy.spatial import cKDTree
from tqdm.auto import tqdm
from functools import lru_cache
warnings.filterwarnings("ignore", module="MDAnalysis")

## Detect pocket residues

In [None]:
def make_meshgrid(pocket, radius=2.5, step=1):
    """Create grid of points arranged in spheres around defined points
    
    Parameters
    ----------
    pocket : numpy.ndarray
        Coordinates of points around which the grid points will be generated
    radius : float
        Radius of each sphere
    step : float
        Resolution of the grid
    """
    minmax = np.array([c for c in zip(pocket.min(axis=0), pocket.max(axis=0))])
    minmax[:, 0] -= (radius + step)
    minmax[:, 1] += (radius + step)
    xyz = [np.arange(begin, end, step)
           for begin, end in minmax]
    x, y, z = np.meshgrid(*xyz)
    mask = False
    for x0, y0, z0 in pocket:
        r = np.sqrt((x - x0)**2 + (y - y0)**2 + (z - z0)**2)
        mask |= (r <= radius)
    xyz = np.array([[x0, y0, z0] for x0, y0, z0 in zip(x[mask], y[mask], z[mask])])
    return xyz

pocket = mda.Universe("pocket.pdb")
pocket = pocket.atoms.positions
xyz = make_meshgrid(pocket)

with open("grid.pdb", "w") as f:
    for i, (x, y, z) in enumerate(xyz, start=1):
        f.write(f"HETATM{i!s:>5}  C01 UNK{1!s:>6}{x:12.3f}{y:8.3f}{z:8.3f}{0:6.2f}{0:6.2f}{'C':>12s}\n")
    f.write("END\n")

grid = mda.Universe("grid.pdb")
grid = grid.atoms.positions

In [None]:
def get_residues_near_pocket(u, pocket, cutoff=10):
    """ResidueGroup close to pocket"""
    residues = np.array([res.atoms.positions.mean(axis=0) for res in u.residues])
    tree = cKDTree(residues)
    ix = tree.query_ball_point(pocket, cutoff)
    ix = sorted(set([i for lst in ix for i in lst]))
    return u.residues[ix]

def get_angle(a, b, c):
    """Angle between 3 points (or between vectors BA and BC)"""
    ba = a - b
    bc = c - b
    cosine_angle = np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))
    angle = np.arccos(cosine_angle)
    return np.degrees(angle)

def points_towards_pocket(res, pocket, angle_cutoff=30, dist_cutoff=8):
    """Checks if a residue's sidechain points towards the pocket"""
    ca = res.atoms.select_atoms("backbone and name CA")
    o = ca.positions[0]
    if res.resname == "GLY":
        return False
    sidechain = res.atoms.select_atoms("not backbone")
    s = sidechain.center_of_mass()
    tree = cKDTree(pocket)
    ix = tree.query_ball_point(s, dist_cutoff)    
    return any(get_angle(p, o, s) <= angle_cutoff
               for p in pocket[ix])

def get_pocket_residues(u, pocket):
    """Returns a ResidueGroup of residues oriented towards and close to the pocket"""
    residues = get_residues_near_pocket(u, pocket)
    mask = [points_towards_pocket(res, pocket) for res in residues]
    residues = residues[mask]
    return residues

## Calculate helicity

In [None]:
# read simplified TAS2R alignment file
ali = pd.read_excel("msa.xlsx", header=None)
ali.drop(index=[27], inplace=True)
index = ali[0][2:]
index = index.apply(lambda x: f"hT2R{x.split('R')[-1]}")
index.name = "TAS2R"
tm = ali.loc[0][1:]
tm.name = "TM"
bw = ali.loc[1][1:]
bw.name = "BW"
ali = ali.loc[2:, 1:]
ali.columns = pd.MultiIndex.from_arrays([tm, bw])
ali.index = index
ali

In [None]:
@lru_cache(maxsize=None)
def get_bw(receptor, resid):
    """Returns BW notation given the residue number of a receptor"""
    s = ali.loc[receptor].copy()
    seq = s[s != "-"]
    seq[:] = np.array(range(1, len(seq)+1))
    seq = seq.reset_index().set_index(receptor)
    tm, bw = seq.loc[resid]
    if np.isnan(tm):
        return None
    return f"{tm:.0f}.{bw:.0f}"

@lru_cache(maxsize=None)
def get_resid(receptor, tm, bw):
    """Returns the residue number given the BW notation of a receptor"""
    s = ali.loc[receptor].copy()
    seq = s[s != "-"]
    seq[:] = np.array(range(1, len(seq)+1))
    num = seq.xs([tm, bw], level=["TM", "BW"])
    return num[0] if num.values else f"{tm}.{bw}"

In [None]:
# TM definition from templates (BW notation)
TM_TO_INDEX = { 
    1: [31,60],
    2: [38,66],
    3: [20,55],
    4: [39,62],
    5: [38,68],
    6: [27,59],
    7: [31,53],
}
# score to be considered as an alpha helix
alpha_helix = (0.32,0.38)

In [None]:
def ramachandran_number(phi, psi, sigma=1e5, lambda_=360):
    """Return the Ramachandran number given arrays of phi and psi angles"""
    _lambda = np.around(np.sqrt(2) * lambda_ * sigma)
    rz_min = np.around(lambda_ * sigma / np.sqrt(2))
    rz_max = rz_min + _lambda**2
    _phi = np.around(
        (phi - psi + lambda_) * sigma / np.sqrt(2)
    )
    _psi = np.around(
        (phi + psi + lambda_) * sigma / np.sqrt(2)
    )
    rz = _phi + _lambda * _psi
    return (rz - rz_min) / (rz_max - rz_min)

def rama_scores(atomgroup):
    """Return the list of Ramachandran scores given an MDAnalysis AtomGroup"""
    phi_sel = atomgroup.residues[1:].phi_selections()
    psi_sel = atomgroup.residues[:-1].psi_selections()
    none_idx = [i for i, (aphi, apsi) in enumerate(zip(phi_sel, psi_sel)) if aphi is None or apsi is None]
    for i in none_idx[::-1]:
        phi_sel.pop(i)
        psi_sel.pop(i)
    phi = Dihedral(phi_sel).run().angles
    psi = Dihedral(psi_sel).run().angles
    rs = ramachandran_number(phi, psi)
    return rs.flatten()

def helix_score(scores, threshold=alpha_helix):
    """Returns the proportion of residues that are structured as a typical alpha-helix"""
    return ((scores >= threshold[0]) &
            (scores <= threshold[1])).mean()

def tm_score(u, receptor, tm):
    """Returns the score for a particular TM"""
    begin, end = [get_resid(receptor, tm, bw) for bw in TM_TO_INDEX[tm]]
    ag = u.select_atoms(f"resid {begin}-{end}")
    rscores = rama_scores(ag)
    return helix_score(rscores)

In [None]:
# TM parts that must be structured as helices perfect helices
GOOD_TM = { 
    1: [34,58],
    2: [39,62],
    3: [25,55],
    4: [43,60],
    5: [40,66],
    6: [31,55],
    7: [35,53],
}

def window_func(x, center=.35, std=.07, threshold=3):
    """Returns `True` if a particular set of Ramachandran numbers correspond to a
    missfolded helix
    """
    return ((x - center).abs() <= std).sum() > threshold

def is_good_helix(rscores, window=6, skip=[.08, -.08], **kwargs):
    """Given an array of Ramachandran numbers for a full helix, checks if it is
    missfolded or not
    """
    for i, val in enumerate(skip):
        if isinstance(val, float):
            skip[i] = round(val * len(rscores))
    begin, end = skip
    x = pd.Series(rscores[begin:end]).rolling(window).apply(lambda r: window_func(r, **kwargs))
    x.dropna(inplace=True)
    return x.all()

def keep_tm_score(u, rec, tm, **kwargs):
    """Keep the given TM or discard it if it is missfolded"""
    begin, end = [get_resid(rec, tm, bw) for bw in GOOD_TM[tm]]
    ag = u.select_atoms(f"resid {begin}-{end}")
    rscores = rama_scores(ag)
    return is_good_helix(rscores, **kwargs)

## Run Analysis

In [None]:
# folders
dirs = (str(d).split("/") for d in pathlib.Path(".").glob("aligned_models_*/*/*"))
dd = pd.DataFrame(dirs, columns=["Protocol", "Receptor", "Template"])
dd["Protocol"] = dd["Protocol"].apply(lambda x: x.replace("aligned_models_", ""))
dd = dd.loc[(~dd["Protocol"].isin(['beta2', 'beta2F']))]
dd.reset_index(inplace=True, drop=True)
dd

In [None]:
# score models based on helicity and pocket residues
warnings.filterwarnings("ignore", message="Reader has no dt information")

def job(args):
    proto, rec, template, pdb = args
    item = {"Protocol": proto,
            "Receptor": rec,
            "Template": template,
            "Model": pdb.name}
    u = mda.Universe(pdb)
    residues = get_pocket_residues(u, grid)
    for resid in residues.resids:
        bw = get_bw(rec, resid)
        if bw:
            item[bw] = True
    for tm in TM_TO_INDEX.keys():
        score = tm_score(u, rec, tm)
        item[f"helix_TM{tm}"] = score
    item["Helicity"] = np.array([v for k,v in item.items()
                                 if k.startswith("helix_TM")]).mean()
    for tm in GOOD_TM.keys():
        score = keep_tm_score(u, rec, tm)
        item[f"keep_TM{tm}"] = score
    return item

data = []
with mp.Pool(20) as pool, \
     tqdm(desc="PDB models", total=1000, position=1, leave=True) as pbar:
    for i, (proto, rec, template) in tqdm(dd.iterrows(),
                                          total=len(dd), desc="Experiments", position=0):
        pdbs = sorted(pathlib.Path(f"aligned_models_{proto}/{rec}/{template}")
                             .glob("*.pdb"))
        args = ((proto, rec, template, pdb) for pdb in pdbs)
        pbar.reset(total=len(pdbs))
        for item in pool.imap_unordered(job, args):
            pbar.update()
            data.append(item)
df = pd.DataFrame(data)
del data
df.fillna(False, inplace=True)
index = ["Protocol", "Receptor", "Template", "Model"]
df.set_index(index, inplace=True)
df.sort_index(axis=1, inplace=True)
df.to_pickle("all_scores_models_raw.pkl")
df

In [None]:
# add modeller scores

modeller = pd.DataFrame()
for i, (proto, rec, template) in tqdm(dd.iterrows(), total=len(dd)):
    if proto == "Gomodo":
        template = template.lower()
    elif proto in ["gpcrDB", "bitterDB"]:
        continue
    path = next(pathlib.Path("..").glob(f"models_{proto}/*/{rec}/{template}/modeller_scores.out"))
    temp = pd.read_csv(path, sep="\t", header=0)[["Filename", "DOPE score"]]
    temp["Protocol"] = proto
    temp["Receptor"] = rec
    temp["Template"] = template.upper()
    temp.rename(columns={"Filename": "Model"}, inplace=True)
    temp.set_index(["Protocol", "Receptor", "Template", "Model"], inplace=True)
    modeller = modeller.append(temp)
    
modeller

In [None]:
# merge
df2 = df.join(modeller, how="left")
df2.to_pickle("all_scores_models_raw_with_dope.pkl")
df2

In [None]:
# filter models

df2 = pd.read_pickle("all_scores_models_raw_with_dope.pkl")
n_rec = df2.index.get_level_values("Receptor").nunique()

# discard "refined loops" models
df2 = df2.loc[~df2.index.get_level_values("Model").str.contains(".BL")]

# remove models with bad helicity
df2 = df2.loc[df2["Helicity"] >= .789]
df2 = df2.query(" and ".join(f"keep_TM{i}" for i in range(1, 8)))
df2.drop(columns=[f"keep_TM{i}" for i in range(1, 8)], inplace=True)

# reset index
df2.sort_index(inplace=True)
df2.to_pickle("all_scores_models_filtered.pkl")
df2

In [None]:
# Positions (BW notation) corresponding to residues that should be
# oriented towards the binding pocket in the TAS2R models according
# to mutagenesis data. 
bws_t2r = {
    "common": ["3.29", "3.33",   "3.34", "3.38",
               "5.46",
               "6.44", "6.47", "6.48",
               "7.35", "7.39", "7.42", "7.43"],
    "hT2R1":  ["2.61", "3.38"],
    "hT2R4":  ["6.55"],
    "hT2R7":  ["2.60", "3.29", "3.34", "5.38", "7.32"],
    "hT2R9":  ["5.47"],
    "hT2R10": ["3.30", "5.40", "5.43", "7.42"],
    "hT2R14": ["3.33", "3.34",
               "5.46", "5.47",
               "6.50",
               "7.36", "7.39", "7.42"],
    "hT2R16": ["2.53",
               "3.30", "3.34", "3.35", "3.37", "3.38",
               "5.43", "5.46", "5.47",
               "6.48",
               "7.34", "7.35", "7.38", "7.39"],
    "hT2R20": ["2.60", "3.29", "7.39"],
    "hT2R31": ["7.39", "7.42"],
    "hT2R38": ["3.33", "5.46", "6.52"],
    "hT2R46": ["2.61", "2.66",
               "3.33",
               "5.39",
               "6.47",
               "7.35", "7.39", "7.42", "7.43"],
}

In [None]:
# calculate pocket score

df3 = pd.read_pickle("all_scores_models_filtered.pkl")

bw_cols = set([i for i in df3.columns if "." in i])
tms = set([int(x.split(".")[0]) for x in bw_cols])
other_cols = df3.columns.drop(bw_cols).tolist()
final_cols = sorted(set([i for l in bws_t2r.values() for i in l]))

def set_pocket_scores(args):
    ix, s = args
    proto, rec, template, model = s.name
    bw = list(set(bws_t2r.get(rec, []) + bws_t2r["common"]))
    tms = set([int(x.split(".")[0]) for x in bw])
    bw_per_tm = {i: [x for x in bw if x.startswith(str(i))] for i in tms}
    for tm in tms:
        score = s.loc[bw_per_tm[tm]].astype(int).mean()
        s[f"pocket_TM{tm}"] = score
    s["Pocket"] = s.loc[bw].astype(int).mean()
    s["Score"] = s[["Helicity", "Pocket"]].mean()
    return s

s = []
with mp.Pool(24) as pool:
    for row in tqdm(pool.imap_unordered(set_pocket_scores, df3.iterrows()),
                    total=len(df3)):
        s.append(row)
df3 = pd.DataFrame(s, index=pd.MultiIndex.from_tuples([x.name for x in s],
                                                      names=["Protocol", "Receptor", "Template", "Model"]))
df3 = df3[["Score", "Helicity", "Pocket", "DOPE score"] +
          [f"helix_TM{i}" for i in range(1,8)] +
          [f"pocket_TM{i}" for i in [2, 3, 5, 6, 7]] +
          final_cols ]
df3.to_pickle("all_scores_models_final.pkl")
df3