## Summary

---

## Imports

In [None]:
import concurrent.futures
import os
import re
import shutil
import socket
import subprocess
import sys
import tempfile
from collections import Counter
from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm.notebook import tqdm

## Parameters

In [None]:
NOTEBOOK_DIR = Path("31_run_msa_analysis").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

In [None]:
if (slurm_tmpdir := os.getenv("SLURM_TMPDIR")) is not None:
    os.environ["TMPDIR"] = slurm_tmpdir

print(tempfile.gettempdir())

In [None]:
if "scinet" in socket.gethostname():
    CPU_COUNT = 40
else:
    CPU_COUNT = max(1, len(os.sched_getaffinity(0)))

CPU_COUNT = max(1, CPU_COUNT // 2)

CPU_COUNT

In [None]:
DATASET_NAME = os.getenv("DATASET_NAME")
DATASET_PATH = os.getenv("DATASET_PATH")
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

DATASET_NAME, DATASET_PATH, TASK_ID, TASK_COUNT

In [None]:
DEBUG = TASK_ID is None

if DEBUG:
    DATASET_NAME = "cagi6-sherloc"
    DATASET_PATH = str(
        # NOTEBOOK_DIR.parent.joinpath("30_cagi6_sherloc", "input-data-gby-protein.parquet")
        NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein-waln.parquet")
    )
    TASK_ID = 1098
    TASK_COUNT = 12557 # 4182
else:
    assert DATASET_NAME is not None
    assert DATASET_PATH is not None
    DATASET_PATH = Path(DATASET_PATH).expanduser().resolve()
    assert TASK_COUNT is not None

DATASET_NAME, DATASET_PATH, TASK_ID, TASK_COUNT

In [None]:
output_file = NOTEBOOK_DIR.joinpath(DATASET_NAME, f"result-{TASK_ID}-of-{TASK_COUNT}.parquet")
output_file.parent.mkdir(exist_ok=True)

output_file

In [None]:
if output_file.is_file():
    raise Exception("Already finished!")

## Workspace

### Load data

In [None]:
pfile = pq.ParquetFile(DATASET_PATH)

pfile.num_row_groups

In [None]:
np.floor(1.4).astype(int)

In [None]:
rows_per_chunk = np.ceil(pfile.num_row_groups / TASK_COUNT).astype(int)

rows_per_chunk

In [None]:
start = (TASK_ID - 1) * rows_per_chunk
stop = min([pfile.num_row_groups + 1, TASK_ID * rows_per_chunk])

start, stop

### Other

In [None]:
AMINO_ACIDS = "ARNDCEQGHILKMFPSTWYV"

In [None]:
def sequences_to_counts(sequences):
    counts_mat = np.zeros((len(sequences[0]), len(AMINO_ACIDS)), dtype=np.float64)
    for i, msa_row in enumerate(zip(*sequences)):
        msa_row = [aa for aa in msa_row if aa in AMINO_ACIDS]
        counts = Counter(msa_row)
        counts_mat[i, :] = [counts.get(aa, 0) for aa in AMINO_ACIDS]
    return counts_mat

In [None]:
def counts_to_probas(counts_mat):
    probas_mat = np.log((counts_mat + 1) / (counts_mat.sum(axis=1, keepdims=True) + 20))
    return probas_mat

In [None]:
def run_convervation_script(sequences):
    script_name = "Conservation.jl"
    fasta_string = "".join((f">{i}\n{seq}\n" for i, seq in enumerate(sequences)))

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_path = Path(tmp_dir)
        shutil.copy(NOTEBOOK_DIR.joinpath(script_name), tmp_path.joinpath(script_name))
        with tmp_path.joinpath("aln.fasta").open("wt") as fout:
            fout.write(fasta_string)
        cmd = ["julia", script_name, "-f", "FASTA", "aln.fasta"]
        proc = subprocess.run(
            cmd, cwd=tmp_dir, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
        )
        result_df = pd.read_csv(tmp_path.joinpath("aln.fasta.conservation.csv"), comment="#")

    assert len(result_df) == len(sequences[0])
    return result_df

In [None]:
@dataclass
class Row:
    protein_id: str = None
    counts_mat: np.ndarray = None
    probas_mat: np.ndarray = None
    msa_length: int = None
    msa_proba: float = None
    conservation_df: pd.DataFrame = None

In [None]:
writer = None
previous_row = None
for task_idx in tqdm(range(start, stop)):
    input_df = pfile.read_row_group(task_idx).to_pandas(integer_object_nulls=True)
    assert len(input_df) == 1
    tup = next(input_df.itertuples())
    sequences = [
        "".join((aa for aa in line.strip() if not aa.islower()))
        for line in tup.alignment
        if line and not line.startswith(">")
    ]
    assert tup.sequence == sequences[0]
    assert all([len(tup.sequence) == len(seq) for seq in sequences])

    if previous_row is not None and previous_row.protein_id == tup.protein_id:
        row = previous_row
    else:
        row = Row()
        row.protein_id = tup.protein_id
        row.counts_mat = sequences_to_counts(sequences)
        row.probas_mat = counts_to_probas(row.counts_mat)
        row.msa_length = len(sequences)
        row.msa_proba = np.mean(
            [row.probas_mat[i, AMINO_ACIDS.index(aa)] for i, aa in enumerate(sequences[0])]
        )
        row.conservation_df = run_convervation_script(sequences)

    mutation_results = []
    for mutation_id, mutation in zip(tup.mutation_id, tup.mutation):
        aa_wt, pos, aa_mut = mutation[0], int(mutation[1:-1]), mutation[-1]
        if len(sequences[0]) < pos or sequences[0][pos - 1] != aa_wt:
            print(f"Mutation {mutation!r} does not match sequence for {task_idx=}.")
            continue

        cons = row.conservation_df.iloc[pos - 1]
        assert cons.i == pos

        idx_wt = AMINO_ACIDS.index(aa_wt)
        idx_mut = AMINO_ACIDS.index(aa_mut)
        mutation_results.append(
            {
                "protein_id": tup.protein_id,
                "mutation_id": mutation_id,
                "mutation": mutation,
                "msa_count_wt": row.counts_mat[pos - 1, idx_wt],
                "msa_count_mut": row.counts_mat[pos - 1, idx_mut],
                "msa_count_total": row.counts_mat[pos - 1].sum().item(),
                "msa_proba_wt": row.probas_mat[pos - 1, idx_wt],
                "msa_proba_mut": row.probas_mat[pos - 1, idx_mut],
                "msa_proba_total": row.probas_mat[pos - 1].sum().item(),
                "msa_length": row.msa_length,
                "msa_proba": row.msa_proba,
                "msa_H": cons.H,
                "msa_KL": cons.KL,
            }
        )
    mutation_results_df = pd.DataFrame(mutation_results)
    table = pa.Table.from_pandas(mutation_results_df, preserve_index=False)
    if writer is None:
        writer = pq.ParquetWriter(output_file, table.schema)
    writer.write_table(table)
writer.close()