In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os
import pickle

from collections import defaultdict
from itertools import combinations
from pathlib import Path

import torch

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from IPython.display import display

from procyon.data.data_utils import DATA_DIR

In [3]:
from procyon.data.constants import CANONICAL_SPLITS, RETRIEVAL_SUBSETS
from procyon.data.data_utils import DATA_DIR

data_path = os.path.join(
    DATA_DIR,
    "integrated_data",
    "v1",
)

all_datasets = [
    ("domain", "go"),
    ("domain", "pfam"),
    ("protein", "disgenet"),
    ("protein", "drugbank"),
    ("protein", "ec"),
    ("protein", "go"),
    ("protein", "gtop"),
    ("protein", "omim"),
    ("protein", "reactome"),
    ("protein", "uniprot"),
]

test_datasets = [
    x for x in all_datasets if x[1] not in ["uniprot", "gtop"]
]

def load_rels(aaseq_type: str, name: str) -> pd.DataFrame:
    """Get phenotype-protein relations for a given dataset"""
    split_name = CANONICAL_SPLITS[name]
    rels_path = os.path.join(
        data_path,
        f"{aaseq_type}_{name}",
        split_name,
        f"{aaseq_type}_{name}_relations.unified.csv"
    )
    return pd.read_csv(rels_path)


def get_train_relations():
    """Get training relations for all datasets

    Returns a DF with:
     - dataset_idx - numeric idx within the dataset
     - dataset - the dataset name
     - unique_id - concatenation of `dataset` and `dataset_idx`, unique identifier
                   across all datasets
     - seq_id - set containing the protein IDs related to this phenotype
     - count - number of proteins related to this phenotype, i.e. length of `seq_id`
    """
    all_rels = []
    for dset in all_datasets:
        all_rels.append((
            load_rels(*dset)
            .query("split == 'CL_train'")
            .groupby("text_id")
            .seq_id
            .apply(set)
            .reset_index()
            .assign(dataset=dset, count=lambda x: x.apply(lambda y: len(y.seq_id), axis=1))
            .rename(columns={"text_id": "dataset_idx"})
        ))
    return pd.concat(all_rels).assign(unique_id=lambda x: x.dataset + "_" + x.dataset_idx.astype(str))

# Exact term mappings

In [4]:
mapper_dir = Path(DATA_DIR) / "experimental_data" / "db_mappers"

In [5]:
def load_mapper(db_name: str) -> pd.DataFrame:
    mapper = (
        pd.read_csv(
            mapper_dir / f"{db_name}2go.txt",
            sep=" > ",
            comment="!",
            names=[f"{db_name}_id", "oth"],
        )
        .assign(
            go_id=lambda x: x.oth.str.split(";").str[-1].str.strip(),
        )
        .drop(columns=["oth"])
    )
    mapper[f"{db_name}_id"] = (
        mapper[f"{db_name}_id"]
        .str.split()
        .str[0]
        .str.split(":")
        .str[1]
    )
    return mapper

orig_mappers = {
    "ec": load_mapper("ec"),
    "pfam": load_mapper("pfam"),
    "reactome": (
        pd.concat([
            pd.read_table(mapper_dir / "Pathways2GoTerms_human.txt"),
            pd.read_table(mapper_dir / "Reactions2GoTerms_human.txt"),
        ])
        .rename(columns={"Identifier":"reactome_id", "GO_Term":"go_id"})
        .drop(columns=["Name"])
        .reset_index(drop=True)
    )
}

  pd.read_csv(
  pd.read_csv(


In [6]:
orig_dest_dbs = list(orig_mappers.keys())

# Mappers is a dict of `src` to {dest_x: df, dest_y: df,...} mapping a DB name
# to a dict of {dest: df} where `df`` has columns `src_id` and `dest_id`
mappers = defaultdict(dict)
for db in orig_dest_dbs:
    mappers["go"][db] = orig_mappers[db]
    mappers[db]["go"] = orig_mappers[db]
    print(f"{db} <-> GO: {len(orig_mappers[db])} exact mappings")

for db_a, db_b in combinations(orig_dest_dbs, r=2):
    mapper_a = orig_mappers[db_a]
    mapper_b = orig_mappers[db_b]
    trans_mapper = mapper_a.merge(mapper_b, on="go_id").drop(columns="go_id")
    mappers[db_a][db_b] = trans_mapper
    mappers[db_b][db_a] = trans_mapper
    print(f"{db_a} <-> {db_b}: {len(trans_mapper)} exact mappings")


ec <-> GO: 5236 exact mappings
pfam <-> GO: 10095 exact mappings
reactome <-> GO: 7363 exact mappings
ec <-> pfam: 3829 exact mappings
ec <-> reactome: 16233 exact mappings
pfam <-> reactome: 36090 exact mappings


In [7]:
mappers["pfam"]["reactome"].pfam_id.value_counts().value_counts().sort_index()

count
1      921
2      460
3      289
4      156
5       70
      ... 
310      4
385      1
509     10
510      7
532      1
Name: count, Length: 72, dtype: int64

In [8]:
mappers["pfam"]["reactome"].pfam_id.value_counts()

pfam_id
PF09036    532
PF09202    510
PF08826    510
PF15785    510
PF00433    510
          ... 
PF19052      1
PF19088      1
PF00023      1
PF05937      1
PF05955      1
Name: count, Length: 2552, dtype: int64

In [9]:
mappers["pfam"]["reactome"].query("pfam_id == 'PF09036'")

Unnamed: 0,pfam_id,reactome_id
24764,PF09036,R-HSA-109702
24765,PF09036,R-HSA-109822
24766,PF09036,R-HSA-109823
24767,PF09036,R-HSA-109860
24768,PF09036,R-HSA-109862
...,...,...
25291,PF09036,R-HSA-9018745
25292,PF09036,R-HSA-9018806
25293,PF09036,R-HSA-9624893
25294,PF09036,R-HSA-9693282


In [10]:
eval_sims_path = Path(DATA_DIR) / "experimental_data" / "phenotype_embeddings" / "eval_phenotype_sims.tsv.gz"
eval_sims = pd.read_table(eval_sims_path)
eval_sims.head()

Unnamed: 0,query_id,query_db_id,query_text_col,max_train_sim,max_train_sim_count,num_matched_seqs,matched_text_count,dataset
0,25,GO:0000152,go_def,0.624915,1.0,0,0,go
1,360,GO:0004143,go_def,0.627025,1.6,0,0,go
2,361,GO:0004175,go_def,0.611923,1.0,0,0,go
3,363,GO:0004180,go_def,0.646375,1.0,0,0,go
4,365,GO:0004190,go_def,0.612427,15.2,0,0,go


In [11]:
# Some phenotypes are associated with multiple texts (e.g. drugbank MoA and indication)
eval_sims.query_db_id.value_counts().value_counts()

count
1    1531
4    1224
2    1058
3      75
Name: count, dtype: int64

In [12]:
def unique_and_check(x):
    uniq = np.unique(x)
    assert len(uniq) == 1, f"got: {uniq}"
    return uniq[0]

eval_sims = (
    eval_sims
    .drop(columns=["query_id", "query_text_col", "matched_text_count", "max_train_sim_count"])
    .groupby("query_db_id")
    .agg({
        "max_train_sim": "mean",
        "num_matched_seqs": "max",
        "dataset": unique_and_check,
    })
    .reset_index()
)
eval_sims.head()

Unnamed: 0,query_db_id,max_train_sim,num_matched_seqs,dataset
0,1.1.1.104,0.588074,1,ec
1,1.1.1.153,0.654462,1,ec
2,1.1.1.159,0.672711,1,ec
3,1.1.1.170,0.72878,0,ec
4,1.1.1.189,0.688617,1,ec


In [13]:
len(eval_sims)

3888

In [14]:
aaseq_map = {
    "pfam": ["domain"],
    "reactome": ["protein"],
    "ec": ["protein"],
    "go": ["domain", "protein"],
}

In [23]:
def flag_from_exact_matches(
    rels: pd.DataFrame,
    dataset: str,
) -> pd.DataFrame:
    if dataset not in mappers:
        return rels

    for dest_db, map_df in mappers[dataset].items():
        for oth_aaseq in aaseq_map[dest_db]:
            oth_rels = (
                load_rels(oth_aaseq, dest_db)
                .assign(**{
                    f"{dest_db}_id": lambda x: x.text_id,
                })
            )
            has_train_rel = (
                rels
                .merge(map_df, on=f"{dataset}_id")
                .merge(
                    oth_rels.rename(columns={"split": "oth_split"})[[f"{dest_db}_id", "oth_split"]],
                    on=f"{dest_db}_id",
                )
                .groupby(f"{dataset}_id")
                .oth_split
                .apply(lambda x: "CL_train" in np.unique(x))
                [lambda x: x]
                .reset_index()
                [f"{dataset}_id"]
            )
            rels.loc[lambda x: x[f"{dataset}_id"].isin(has_train_rel), "reason"] += f",{dest_db}_train"
    return rels

def flag_from_sims(
    rels: pd.DataFrame,
    dataset: str,
    max_sim: float = 0.8,
    max_sim_w_match: float = 0.6,
) -> pd.DataFrame:
    zero_shot_texts = (
        rels
        .query("split == 'eval_zero_shot'")
        .drop_duplicates(f"{dataset}_id")
    )
    want_sims = (
        eval_sims
        .loc[lambda x: x.dataset == dataset]
        .rename(columns={"query_db_id": f"{dataset}_id"})
    )
    merged = zero_shot_texts.merge(want_sims, on=f"{dataset}_id")

    # All zero-shot texts should be in here
    assert len(merged) == len(zero_shot_texts)
    flagged = (
        merged
        .loc[lambda x: (x.max_train_sim >= max_sim) | ((x.num_matched_seqs != 0) & (x.max_train_sim >= max_sim_w_match))]
        [f"{dataset}_id"]
    )
    rels.loc[lambda x: x[f"{dataset}_id"].isin(flagged), "reason"] += f",train_sim"
    return rels

def annotate_relations(
    aaseq: str,
    dataset: str,
) -> pd.DataFrame:
    all_rels = (
        load_rels(aaseq, dataset)
        .assign(**{
            f"{dataset}_id": lambda x: x.text_id.astype(str),
            "reason": "",
        })
        .pipe(flag_from_exact_matches, dataset=dataset)
        .pipe(flag_from_sims, dataset=dataset)
        .assign(
            reason=lambda x: x.reason.str.strip(",")
        )
    )

    zero_shot_rels = (
        all_rels
        # Include eval_zero_shot_hard here just so we can recreate the numbers
        # after the first run.
        .loc[lambda x: x.split.isin(["eval_zero_shot", "eval_zero_shot_hard"])]
        .assign(flagged=lambda x: x.reason != "")
    )
    print(f"{aaseq}-{dataset}: {len(zero_shot_rels.query('flagged'))} / {len(zero_shot_rels)} ({len(zero_shot_rels.query('flagged')) / len(zero_shot_rels): .1%}) zero-shot relations are flagged")

    zero_shot_texts = (
        zero_shot_rels
        .groupby(f"{dataset}_id")
        .agg({
            "flagged": "any",
        })
    )
    print(f"{aaseq}-{dataset}: {len(zero_shot_texts.query('flagged'))} / {len(zero_shot_texts)} ({len(zero_shot_texts.query('flagged')) / len(zero_shot_texts):.1%}) zero-shot texts are flagged")
    return all_rels.drop(columns=f"{dataset}_id")

In [24]:
all_annotated = []
for aaseq, dataset in test_datasets:
    split_name  = CANONICAL_SPLITS[dataset]
    rels_path = os.path.join(
        data_path,
        f"{aaseq}_{dataset}",
        split_name,
        f"{aaseq}_{dataset}_relations.unified.csv"
    )

    rels_indexed_path = os.path.join(
        data_path,
        f"{aaseq}_{dataset}",
        split_name,
        f"{aaseq}_{dataset}_relations_indexed.unified.csv"
    )
    rels_indexed = pd.read_csv(rels_indexed_path)

    text_df = pd.read_pickle(os.path.join(
        data_path,
        dataset,
        f"{dataset}_info_filtered_composed.pkl"
    ))

    seq_to_index = {}
    for index, text_id in text_df[["index", f"{dataset}_id"]].itertuples(index=False):
        seq_to_index[text_id] = index

    annotated = (
        annotate_relations(aaseq, dataset)
        .assign(text_idx=lambda x: x.text_id.map(seq_to_index))
    )

    new_rels = (
        annotated
        .assign(
            split=lambda x: np.where(
                x.split == "eval_zero_shot",
                np.where(
                    x.reason != "",
                    "eval_zero_shot",
                    "eval_zero_shot_hard",
                ),
                x.split,
            )
        )
    )

    old_len = len(rels_indexed)
    new_rels_indexed = (
        rels_indexed
        .merge(new_rels[["reason", "text_idx"]].rename(columns={"text_idx": "text_id"}).drop_duplicates(), on="text_id")
        .assign(
            split=lambda x: np.where(
                x.split == "eval_zero_shot",
                np.where(
                    x.reason != "",
                    "eval_zero_shot",
                    "eval_zero_shot_hard",
                ),
                x.split,
            )
        )
    )
    assert len(new_rels_indexed) == old_len

    (
        new_rels
        .drop(columns=["reason", "text_idx"])
        .to_csv(rels_path, index=False)
    )

    (
        new_rels_indexed
        .drop(columns=["reason"])
        .to_csv(rels_indexed_path, index=False)
    )
    all_annotated.append(annotated)

all_annotated = pd.concat(all_annotated)

domain-go: 1861 / 3086 ( 60.3%) zero-shot relations are flagged
domain-go: 33 / 70 (47.1%) zero-shot texts are flagged
domain-pfam: 303 / 1036 ( 29.2%) zero-shot relations are flagged
domain-pfam: 183 / 638 (28.7%) zero-shot texts are flagged
protein-disgenet: 208 / 406 ( 51.2%) zero-shot relations are flagged
protein-disgenet: 69 / 179 (38.5%) zero-shot texts are flagged
protein-drugbank: 202 / 1589 ( 12.7%) zero-shot relations are flagged
protein-drugbank: 27 / 269 (10.0%) zero-shot texts are flagged
protein-ec: 107 / 145 ( 73.8%) zero-shot relations are flagged
protein-ec: 46 / 70 (65.7%) zero-shot texts are flagged
protein-go: 5956 / 12073 ( 49.3%) zero-shot relations are flagged
protein-go: 153 / 274 (55.8%) zero-shot texts are flagged
protein-omim: 306 / 645 ( 47.4%) zero-shot relations are flagged
protein-omim: 294 / 623 (47.2%) zero-shot texts are flagged
protein-reactome: 355 / 506 ( 70.2%) zero-shot relations are flagged
protein-reactome: 182 / 258 (70.5%) zero-shot texts are

In [28]:
zero_shot_rels = (
    all_annotated
    .loc[lambda x: x.split.isin(["eval_zero_shot", "eval_zero_shot_hard"])]
    .assign(
        flagged=lambda x: x.reason != "",
        flagged_sim=lambda x: x.reason.str.contains("train_sim") & ~x.reason.str.contains("_train"),
        flagged_map=lambda x: ~x.reason.str.contains("train_sim") & x.reason.str.contains("_train"),
    )
)
print(f"total: {len(zero_shot_rels.query('flagged'))} / {len(zero_shot_rels)} ({len(zero_shot_rels.query('flagged')) / len(zero_shot_rels): .1%}) zero-shot relations are flagged")

zero_shot_texts = (
    zero_shot_rels
    .assign(unique_id=lambda x: x.text_type.astype(str)+"_"+x.text_id.astype(str))
    .groupby("unique_id")
    .agg({
        "flagged": "any",
        "flagged_sim": "any",
        "flagged_map": "any",
    })
)
print(f"total: {len(zero_shot_texts.query('flagged'))} / {len(zero_shot_texts)} ({len(zero_shot_texts.query('flagged')) / len(zero_shot_texts):.1%}) zero-shot texts are flagged")

total: 9298 / 19486 ( 47.7%) zero-shot relations are flagged
total: 987 / 2381 (41.5%) zero-shot texts are flagged


In [29]:
zero_shot_texts.flagged_map.value_counts()

flagged_map
False    2145
True      236
Name: count, dtype: int64

In [30]:
zero_shot_texts.flagged_sim.value_counts()

flagged_sim
False    1698
True      683
Name: count, dtype: int64

In [31]:
(zero_shot_texts.flagged_sim & ~zero_shot_texts.flagged_map).value_counts()

False    1698
True      683
Name: count, dtype: int64

In [32]:
zero_shot_texts.flagged.value_counts()

flagged
False    1394
True      987
Name: count, dtype: int64