### Summary

Calculate CCMpred mutation scores

In [47]:
import concurrent.futures
import itertools
import os
import re
import socket
import subprocess
import sys
import tempfile
from pathlib import Path
import math

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm.notebook import tqdm

#### Parameters

In [2]:
NOTEBOOK_DIR = Path("31_run_ccmpred").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

PosixPath('/scratch/mjslee/notebooks/run_ccmpred')

In [3]:
if (slurm_tmpdir := os.getenv("SLURM_TMPDIR")) is not None:
    os.environ["TMPDIR"] = slurm_tmpdir

print(tempfile.gettempdir())

/localscratch/mjslee.932477.0


In [4]:
if "scinet" in socket.gethostname():
    CPU_COUNT = 40
else:
    CPU_COUNT = max(1, len(os.sched_getaffinity(0)))

CPU_COUNT = max(1, CPU_COUNT // 2)

CPU_COUNT

4

In [10]:
DATASET_NAME = os.getenv("DATASET_NAME")
DATASET_PATH = os.getenv("DATASET_PATH")
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

DATASET_NAME, DATASET_PATH, TASK_ID, TASK_COUNT

(None, None, None, None)

In [11]:
DEBUG = TASK_ID is None

if DEBUG:
    DATASET_NAME = "cagi6-sherloc"
    DATASET_PATH = str(
        NOTEBOOK_DIR.parent.joinpath("31_run_ccmpred", "input-data-gby-protein.parquet")
    )
    TASK_ID = 500
    TASK_COUNT = 4182 # 4182
else:
    assert DATASET_NAME is not None
    assert DATASET_PATH is not None
    DATASET_PATH = Path(DATASET_PATH).expanduser().resolve()
    assert TASK_COUNT is not None

DATASET_NAME, DATASET_PATH, TASK_ID, TASK_COUNT

('cagi6-sherloc',
 '/scratch/mjslee/notebooks/run_ccmpred/input-data-gby-protein.parquet',
 500,
 4182)

In [12]:
output_file = NOTEBOOK_DIR.joinpath(DATASET_NAME, f"result-{TASK_ID}-of-{TASK_COUNT}.parquet")
output_file.parent.mkdir(exist_ok=True)

output_file

PosixPath('/scratch/mjslee/notebooks/run_ccmpred/cagi6-sherloc/result-500-of-4182.parquet')

In [None]:
if output_file.is_file():
    raise Exception("Already finished!")

#### Load Data

In [14]:
pfile = pq.ParquetFile(DATASET_PATH)

pfile.num_row_groups

4182

In [15]:
rows_per_chunk = np.ceil(pfile.num_row_groups / TASK_COUNT).astype(int)

rows_per_chunk

1

In [16]:
start = (TASK_ID - 1) * rows_per_chunk
stop = min([pfile.num_row_groups + 1, TASK_ID * rows_per_chunk])

start, stop

(499, 500)

In [17]:
input_df = pfile.read_row_group(start).to_pandas()

In [18]:
input_df

Unnamed: 0,protein_id,mutation_id,mutation,effect,sequence,structure,alignment
0,Q9NVV0,"[NM_018112.2:c.748C>T, NM_018112.2:c.799G>A, N...","[P250S, V267I, G268E, K138E, H197D, I221T, A27...","[Likely benign, Benign, Likely benign, None, N...",MDSPWDELALAFSRTSMFPFFDIAHYLVSVMAVKRQPGAAALAWKN...,HEADER ...,"[>101\n, MDSPWDELALAFSRTSMFPFFDIAHYLVSVMAVKRQP..."


#### Calculate CCMpred scores

Raw probability matrix from ccmpred outputs one 20xL matrix followed by comb(LxL)x21x21 matrix, corresponding to amino acid probabilities per pairwise residue contacts in the alignment

CCMpred learns a generative Markov Random Field model using vertices with single-residue emission potentials $\varepsilon_i(a)$ and edges with pairwise emission potentials $\varepsilon_{i,j}(a,b)$

arr1 = $\varepsilon_i(a)$ where $i$ = MSA column index and $a$ = amino acid index <br>
arr2 = $\varepsilon_{i,j}(a,b)$ where $i,j$ = MSA column indices and $a,b$ = amino acid indices 

##### Find amino acid index positions of CCMpred

In [19]:
# # CCMpred run to see which indices correspond to which amino acid
# testFile = str(Path(aln_path).parent.joinpath('AAtest.aln'))
# outFile = str(Path(aln_path).parent.joinpath('AAtest.mat'))
# rawFile = str(Path(aln_path).parent.joinpath('AAtest.raw'))

# aa_list = 'ARNDCEQGHILKMFPSTWYV'

# with open(testFile,'w') as f:
#     f.write(aa_list)

# bashCommand = [ccmpred_path,'-r',rawFile,testFile,outFile]
# process = subprocess.run(bashCommand,capture_output=True)

# arr1, _, _ = parse_raw_prob(rawFile)

# aa_true_idx = [np.argmax(i) for i in arr1]
# aa_true_list = ''.join([aa_list[i] for i in aa_true_idx])

# # add gap character to front
# aa_true_list = '-' + aa_true_list

# aa_true_list

##### Average log probabilities of all pair-wise mutations

In [20]:
protein_id_column = None

for col in ["protein_id", "uniprot_id"]:
    if col in input_df:
        protein_id_column = col
        
assert protein_id_column is not None
protein_id_column

'protein_id'

In [21]:
tup = next(input_df.itertuples(index=False))

iterable_fields = []
for field in tup._fields:
    if field in [protein_id_column]:
        continue
    try:
        if len(getattr(tup, field)) == len(tup.mutation):
            iterable_fields.append(field)
    except TypeError:
        pass

iterable_fields

['mutation_id', 'mutation', 'effect']

In [22]:
# Run CCMpred on alignment
# writes alignment, output, and raw probability files to temp folder (/tmp)
# Potential issue/warning: if more than one alignment in row group, the alignment/output/raw files are overwritten
# since they are written to the same filename '{}/{}_ccmpred_task_{}_of_{}.aln'.format(outDir,outFile,TASK_ID,TASK_COUNT)'

def run_ccmpred(alignment, NOTEBOOK_DIR, DATASET_NAME, TASK_ID, TASK_COUNT):
    
    def write_to_aln(alignment, outPath, outFile, TASK_ID, TASK_COUNT):
        outDir = Path(outPath).resolve()
        outDir.mkdir(exist_ok=True)
        outFile = '{}/{}_ccmpred_task_{}_of_{}.aln'.format(outDir,outFile,TASK_ID,TASK_COUNT)
        with open(outFile, 'w') as fout:
            for line in alignment:
                if line == '' or line[0] == '>':
                    continue
                else:
                    # remove insertions (lower-case letters in .a3m format)
                    seq = ''.join(x for x in line if not x.islower())
                    fout.write(seq)
        return outFile
    
    ccmpred_path = str(NOTEBOOK_DIR.joinpath('ccmpred/bin/ccmpred'))
    aln_path = write_to_aln(alignment, tempfile.gettempdir(), DATASET_NAME, TASK_ID, TASK_COUNT)
    outFile = str(Path(aln_path).parent.joinpath(Path(aln_path).stem + '.mat'))
    rawFile = str(Path(aln_path).parent.joinpath(Path(aln_path).stem + '.raw'))
    
    # With raw probability matrix
    bashCommand = [ccmpred_path,'-r',rawFile,aln_path,outFile]
    process = subprocess.run(bashCommand,capture_output=True)
    
    return rawFile

In [23]:
# Create a probability matrix of pairwise amino acid probabilities
# Output: Pandas DataFrame of LxL where df[i,j] = 21x21 aa probability matrix of MSA positions i,j in L

def get_aa_prob_matrix(rawFile):
    
    def parse_raw_prob(rawFile):

        # Read raw probability matrix
        with open(rawFile,'r') as fin:
            raw_mat = fin.readlines()

        # Parse raw prob matrix
        arr1 = []
        arr2 = []
        arr3 = []
        tmp_arr = []
        for line in raw_mat:
            line_split = line.strip().split('\t')
            length = len(line_split)
            if length == 1:
                if tmp_arr:
                    arr2.append(tmp_arr)
                arr3.append(line_split)
                tmp_arr = []
            elif length == 20:
                arr1.append(line_split)
            elif length == 21:
                tmp_arr.append(line_split)

        #add last arr
        arr2.append(tmp_arr)

        return np.float_(arr1), np.float_(arr2), arr3
    
    # TODO: very inefficient way to do this
    # probably doesn't scale well with increased L due to filling an empty Pandas DF
    
    single_prob_matrix, pairwise_prob_matrix, aa_pair_id = parse_raw_prob(rawFile)

    aa_pair = [l[0].split(' ')[1:] for l in aa_pair_id]
    
    d = {}
    for i in range(len(np.unique(aa_pair))):
        d[str(i)] = {}

    for idx in range(len(aa_pair)):
        aa_idx = aa_pair[idx]
        d[aa_idx[0]][aa_idx[1]] = pairwise_prob_matrix[idx]
        d[aa_idx[1]][aa_idx[0]] = pairwise_prob_matrix[idx]
    
    return d, single_prob_matrix

In [24]:
# Get score of specific mutation using ccmpred's amino acid probabilities

def get_mean_mut_score(mutation, d, log=False):
    
    # Obtained by running CCMpred on sample sequence of all amino acids + gap and observing arr1 in get_aa_prob_matrix
    alphabet = '-ARNDCQEGHILKMFPSTWYV'
    
    wt = mutation[0]
    mut = mutation[-1]
    pos = str(int(mutation[1:-1])-1)

    wt_idx = alphabet.index(wt)
    mut_idx = alphabet.index(mut)

    pmat = list(d[pos].values())
    
    # log doesn't really work due to negative and zero values
    # should investigate what the ccmpred raw matrix values actually are
    if log:
        pmat = [np.log10(m) for m in pmat]

    pmat = np.mean(pmat,axis=0)
    
    mut_score = pmat[wt_idx][mut_idx]
    
    return mut_score

In [50]:
def get_mean_mut_score(mutation, d, s):
    
    # Obtained by running CCMpred on sample sequence of all amino acids + gap and observing arr1 in get_aa_prob_matrix
    alphabet = '-ARNDCQEGHILKMFPSTWYV'
    
    wt = mutation[0]
    mut = mutation[-1]
    pos = str(int(mutation[1:-1])-1)

    wt_idx = alphabet.index(wt)
    mut_idx = alphabet.index(mut)

    def calculate_score(idx, pos, d ,s):
        # Note: gap character missing in single prob matrix
        eps_i = s[int(pos)][idx-1]
        eps_ij = 0
        # d[pos] contains 21x21 matrices for i in L where i != pos 
        # \sum_{j=1,i \neq j}^L \epsilon_{i,j} (x_i^n,x_j^n)
        for p in d[pos].values():
            eps_ij += p[idx][idx]
        
        return math.exp(eps_i+eps_ij)
    
    wt_score = calculate_score(wt_idx, pos, d, s)
    mut_score = calculate_score(mut_idx, pos, d, s)

    return wt_score, mut_score

In [40]:
def get_single_aa_score(mutation, s):
    
    # Note: gap character missing in single prob matrix
    alphabet = 'ARNDCQEGHILKMFPSTWYV'
    
    wt = mutation[0]
    mut = mutation[-1]
    pos = int(mutation[1:-1])-1
    
    wt_idx = alphabet.index(wt)
    mut_idx = alphabet.index(mut)
    
    wt_score = s[pos][wt_idx]
    mut_score = s[pos][mut_idx]
    
    return wt_score, mut_score

In [41]:
def validate_mutation(mutation):
    aa = "GVALICMFWPDESTYQNKRH"
    if re.search(f"^[{aa}][1-9]+[0-9]*[{aa}]$", mutation) is None:
        print(f"Skipping mutation {mutation} because it appears to be malformed.")
        return False

    if mutation[0] == mutation[-1]:
        print(
            f"Skipping mutation {mutation} because the wildtype and mutant residues are the same."
        )
        return False

    return True

In [51]:
alphabet = '-ARNDCQEGHILKMFPSTWYV'
results = []
for tup in input_df.itertuples(index=False):
    
    assert all([(len(getattr(tup, field)) == len(tup.mutation)) for field in iterable_fields])
    
    rawFile = run_ccmpred(tup.alignment, NOTEBOOK_DIR, DATASET_NAME, TASK_ID, TASK_COUNT)
    
    if not Path(rawFile).exists():
        print("ccmpred file not found - check memory issues")
        break

    d, s = get_aa_prob_matrix(rawFile)
    
    for mutation_idx, mutation in enumerate(tup.mutation):
        if not validate_mutation(mutation):
            continue
        
        wt_p_score, mut_p_score = get_mean_mut_score(mutation, d, s)
        wt_s_score, mut_s_score = get_single_aa_score(mutation, s)
        
        results.append(
            {
                "protein_id": tup.protein_id,
                "mutation": mutation,
                "single_prob_wt": wt_s_score,
                "single_prob_mut": mut_s_score,
                "pairwise_prob_wt": wt_p_score,
                "pairwise_prob_mut": mut_p_score
            }
            # | {field: getattr(tup, field)[mutation_idx] for field in iterable_fields}
            # | result
        )

results_df = pd.DataFrame(results)

In [52]:
display(results_df.head(2))
print(len(results_df))

Unnamed: 0,protein_id,mutation,single_prob_wt,single_prob_mut,pairwise_prob_wt,pairwise_prob_mut
0,Q9NVV0,P250S,-1.79462,-4.63096,0.164664,0.008814
1,Q9NVV0,V267I,-3.44336,-4.48053,0.419924,0.012767


8


In [104]:
if not DEBUG and if not results_df.empty:
    pq.write_table(pa.Table.from_pandas(results_df, preserve_index=False), output_file)