In [1]:
#First importing the required packages

!pip install lightning

!pip install torchmetrics

!pip install obonet

!pip install pyvis

!pip install bio

Collecting lightning
  Downloading lightning-2.2.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)
  Downloading lightning_utilities-0.11.0-py3-none-any.whl (25 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.2.1-py3-none-any.whl (801 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<4.0,>=1.13.0->lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import random

# Graphs & Protein sequences
import networkx as nx
import obonet
from pyvis.network import Network
from Bio import SeqIO

# Deep Learning
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MultilabelF1Score
from torchmetrics.classification import MultilabelAccuracy
import lightning as L

# Embeddings
# from transformers import AutoTokenizer, BioGptModel

# biogpt_tokenizer = AutoTokenizer.from_pretrained("microsoft/biogpt")
# biogpt_model = BioGptModel.from_pretrained("microsoft/biogpt")

# Formatting
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [5]:
from google.colab import drive
drive.mount('/content/drive/')

MessageError: Error: credential propagation was unsuccessful

In [None]:
train_root = "Train"
test_root = "Test (Targets)"

# Task Description

The objective of our model is to predict the terms (functions) of a protein sequence. One protein sequence can have many functions and can thus be classified into any number of terms. Each term is uniquely identified by a GO Term ID. Thus, our model has to predict all the GO Term IDs for a protein sequence, which means that the task at hand is a **multi-label classification** problem.

# File Descriptions

## Training Data

### About the $X$'s

#### 1. `train_sequences.fasta`

* Contains the protein sequences for the training dataset
    * Most protein sequences were extracted from the Swiss-Prot database, but a subset of proteins that are not represented in Swiss-Prot were extracted from the TrEMBL database. More info [here](https://www.uniprot.org/help/uniprotkb_sections).
* Header indicates from which database the sequence originate, either Swiss-Prot or TrEMBL, both being parts of UniProtKB.
    * For example, `sp|P9WHI7|RECN_MYCT` in the FASTA header indicates the protein...
        1. with UniProt ID `P9WHI7`
        2. with entry name `RECN_MYCT`
        3. taken from Swiss-Prot (`sp`)
            * Any sequences taken from TrEMBL will have `tr` in the header instead of sp.
    * Detailed format of the fasta header can be found [here](https://www.uniprot.org/help/fasta-headers)
* This file contains only sequences for proteins with annotations in the dataset (labeled proteins). To obtain the full set of protein sequences for unlabeled proteins, the Swiss-Prot and TrEMBL databases can be found [here](https://www.uniprot.org/help/downloads).

#### 2. `train_taxonomy.tsv`

* Contains the list of **proteins** and the **species to which they belong**, represented by a "taxonomic identifier" (taxon ID) number.
    * The first column `EntryID`: the protein UniProt accession ID
    * The second column `taxonomyID`: the taxon ID
* More information about taxonomies can he found [here](https://www.uniprot.org/help/taxonomic_identifier).

### About the $y$'s

#### 3. `train_terms.tsv`

* Contains the list of **annotated terms (ground truth)** for the proteins in `train_sequences.fasta`.
    * The first column `EntryID`: the protein's UniProt accession ID
    * The second column `term`: the GO term ID
    * The third column `aspect`: the ontology in which the term appears

#### 4. `go-basic.obo`

* The Gene Ontology (GO) data in OBO format, which indicates the **functional properties of proteins**. If a protein is labeled with a term, it means that this protein has this function validated by experimental evidence. A protein can have multiple labels. The absence of a term annotation does not necessarily mean a protein does not have this function, only that this annotation does not exist (yet) in the GO.
* GO describes our understanding of the biological domain with respect to the below three aspects, represented as subgraphs. Biologically, each subgraph represents a different aspect of the protein's function: what it does on a molecular level (MF), which biological processes it participates in (BP) and where in the cell it is located (CC).

    1. **Molecular Function (MF)**
        1. Activities that occur at the molecular level, such as “catalysis” or “transport”
        2. generally correspond to activities that can be performed by individual gene products (i.e. a protein or RNA), but some activities are performed by molecular complexes composed of multiple gene products:
            1. Examples of broad functional terms: *catalytic activity* and *ransporter activity*
            2. Examples of narrower functional terms: *adenylate cyclase activity* or *Toll-like receptor binding*
    
    2. **Cellular Component (CC)**
        1. Locations of gene products, relative to cellular compartments and structures, occupied by a macromolecular machine
        2. There are two ways to describe:
            1. *Cellular anatomical entities*, in which a gene product carries out a molecular function, including:
                1. Cellular structures, such as the plasma membrane and the cytoskeleton
                2. Membrane-enclosed cellular compartments, such as the mitochondrion
            2. the stable *Macromolecular complexes* of which they are parts
    
    3. **Biological Process (BP)**
        1. The larger processes, or ‘biological programs’ accomplished by multiple molecular activities
            1. Examples of broad biological process terms: DNA repair or signal transduction
            2. Examples of more specific terms: pyrimidine nucleobase biosynthetic process or glucose transmembrane transport  
        
* The three GO ontologies are "*is-a disjoint*", meaning that no *is-a* relations operate between terms from the different ontologies.
    * However, other relationships such as *part-of* and *regulates* do operate between the GO ontologies.
    * For example, the molecular function term ‘cyclin-dependent protein kinase activity’ is part of the biological process ‘cell cycle’.
* A term may have more than one parent term!!
* More about gene ontology [here](https://geneontology.github.io/docs/ontology-documentation/)
* This is the 2023-01-01 release of the GO graph
    * The nodes in this graph are indexed by the term name:
    > ```python
      subontology_roots = {'BPO':'GO:0008150',  # Biological Process (BP)
                           'CCO':'GO:0005575',  # Cellular Component (CC)
                           'MFO':'GO:0003674'}  # Molecular Function (MF)
      ```

#### 5. `IA.txt` - Information Accretion

* Contains the **information accretion (weights) for each GO term**.
* These weights are used to compute weighted precision and recall, as described in the Evaluation section.
* The values of this file were computed using the following code [repo](https://github.com/claradepaolis/InformationAccretion).
    * Information accretion ($ia$), introduced in [Clark and Radivojac, 2013](https://watermark.silverchair.com/bioinformatics_29_13_i53.pdf?token=AQECAHi208BE49Ooan9kkhW_Ercy7Dm3ZL_9Cf3qfKAc485ysgAAA38wggN7BgkqhkiG9w0BBwagggNsMIIDaAIBADCCA2EGCSqGSIb3DQEHATAeBglghkgBZQMEAS4wEQQM_6yBfcYBeK9gGSVMAgEQgIIDMu5V50JH8avL5_BSrU_NcCzyt7-4kGItwHrgn3SubEw_I-Jxi7uzqi-OLZOfba85vy2uD1UkBZokynBNWyxYEuWInuJnDRtbfmwsw7gTmsBiEMwd81gVpfSYeGCERK2J_79hEvsFrdvCPvOUiOmnSrJXYkta_h4TgDREr1IVMwYJcWtLv_Kcj5uwko9BTo226PiWJsV4GkTSdgYNpIqsbbrK28ffYaNIOXG99jnNVu_Z_cLbtzqdU-ZelaiYCzbpFL-dO8Rtwc90pVrxpOrFPcdHirHTkXJaZrUO0fFat-SpGThD5Vnk7If7CJWRHSoNHYaLlv5G1Tm3_vIrHO-tAWbGiU-rYuMs7ye5Dl1GzD0tucx1TveVLWNnj_hgurQVfbFhMjpgl3iuLoVIYi9b_Wv5bSpptksQERMXMTK9a-_l4L4M2w2OpXpECvp-ZLTOml784hHBHelaKfQvZMxwn1kGsEgMLH1b_jiAc7MNWCvwLj4Us3hz1z6hEz0Sv5oPXZsYd9iV-3e1k5fTUmLhPLUXtRoYwKFHtKhAa8X5XnP6b6Y-XNy0CXm-fAuvLLDrxeA2d_uz7ED6RLMrxCjqYEwmS-uiYjbCz05W6_tgd7JQwqwUFAYeDT8-YM0-MQPv5pWOuG4N1PUa54iifeVMFazWdEJDGvEUR5vSNl2RHCc1g9WQ7sCEBqapNAn4jNQDq_yqGttPJYc9EpdPHTfSq5E4tWRwzO_uV0xx5axldiN1DZMlL23hCl1ISSrOWPx60auCvI10fNX0Yldtet29fJ702lDzJ_igwUQ6dkGYWprOtUKl0vmIQ36DRAinj6Fvi7AKmFuVicW4CJPlDlpCj3quai9ApsNX-PIRfJufa7kaVOhppdBKBTh6gEa63KWsd3RQXko44RspOqr_jAVXVxFLtni_4Ferk-U9kWbnkFflFHCD4Wh11GUQaW6ZSYPXow1aebJz4QSik01KdRg86jv0A0PXwO_Ybb1uWln5XaBU2Olj9BwpLP2r5Tkr3nrfZDb-7mo9MwBqmuwRYG4BsDN4IC0OhUmaGsks3m618K2A7CYPdMhKgnX_gbx3xNC4Dpq9), is a measure of how much information is added to an ontology annotation by node $v$ given that its parents $\mathcal{P}_a(v)$ are already annotated. Specifically,
$$
ia(v) = \log_2 \frac{1}{Pr(v|\mathcal{P}_a(v))} = \log_2 \frac{Pr(\mathcal{P}_a(v))}{Pr(\mathcal{P}_a(v)|v) \cdot Pr(v)}
$$
* [A great post](https://www.kaggle.com/competitions/cafa-5-protein-function-prediction/discussion/405237) about the intuition behind IA.

## Testing Data

### About the $X$'s

#### 1. `testsuperset.fasta`

* Contains protein sequences on which the participants are asked to submit predictions
* Header contains the protein's UniProt accession ID and the Taxon ID of the species this protein belongs to
* Only a small subset of those sequences will accumulate functional annotations and will constitute the test set

#### 2. `testsuperset-taxon-list.tsv`

* A set of taxon IDs for the proteins in the test superset.

# First Look at the data

In [None]:
# Read in train data
train_sequences_fasta = list(SeqIO.parse(os.path.join(train_root, "train_sequences.fasta"), "fasta"))
train_taxonomy = pd.read_csv(os.path.join(train_root, "train_taxonomy.tsv"), sep="\t")

train_terms = pd.read_csv(os.path.join(train_root, "train_terms.tsv"), sep="\t")

with open(os.path.join(train_root, "go-basic.obo")) as obo_file:
    go_graph = obonet.read_obo(obo_file)

information_accretion = pd.read_csv("IA.txt", sep="\t", header=None, names=["GO_term", "IA"])

# Read in test data
test_sequences_fasta = list(SeqIO.parse(os.path.join(test_root, "testsuperset.fasta"), "fasta"))
test_taxonomy = pd.read_csv(os.path.join(test_root, "testsuperset-taxon-list.tsv"), sep="\t", encoding="ISO-8859-1")

In [None]:
def plot_dag(graph, term, radius=1):
    """Create a smaller subgraph with `radius`-hop neighbors given a center node `term`"""
    ng_graph = nx.ego_graph(graph, term, radius=radius)

    for node in ng_graph.nodes(data=True):
        # concatenate label of the node with its attribute
        node[1]["label"] = node[0] + " " + node[1]["name"]

    net = Network(
        directed=True,
        notebook=True,
        # cdn_resources="in_line",
        height="750px",
        width="100%",
    )
    net.from_nx(ng_graph)

    return net.show("network.html")

In [None]:
# train_sequences_fasta
print("train_sequences_fasta:")
for seq_record in train_sequences_fasta:
    print(seq_record.id)
    print(seq_record.description)
    print(repr(seq_record.seq))
    print(len(seq_record))
    break
print("=" * 60)

# train_taxonomy
print("train_taxonomy:")
train_taxonomy.head()
train_taxonomy.shape
print("=" * 60)

# train_terms
print("train_terms:")
train_terms.head()
train_terms.shape
print("=" * 60)

# go_graph
print("go_graph")
print(f"Gene Ontology: {type(go_graph)}")
print(f"Number of nodes: {go_graph.number_of_nodes()}")
print(f"Number of edges: {go_graph.number_of_edges()}")
print("=" * 60)

# information_accretion
print("information_accretion")
information_accretion.head()
information_accretion.shape
print("=" * 60)

# test_sequences_fasta
print("test_sequences_fasta")
for seq_record in test_sequences_fasta:
    print(seq_record.id)
    print(seq_record.description)
    print(repr(seq_record.seq))
    print(len(seq_record))
    break
print("=" * 60)

# test_taxonomy
print("test_taxonomy:")
test_taxonomy.head()
test_taxonomy.shape

train_sequences_fasta:
P20536
P20536 sp|P20536|UNG_VACCC Uracil-DNA glycosylase OS=Vaccinia virus (strain Copenhagen) OX=10249 GN=UNG PE=1 SV=1
Seq('MNSVTVSHAPYTITYHDDWEPVMSQLVEFYNEVASWLLRDETSPIPDKFFIQLK...FIY')
218
train_taxonomy:


Unnamed: 0,EntryID,taxonomyID
0,Q8IXT2,9606
1,Q04418,559292
2,A8DYA3,7227
3,Q9UUI3,284812
4,Q57ZS4,185431


(142246, 2)

train_terms:


Unnamed: 0,EntryID,term,aspect
0,A0A009IHW8,GO:0008152,BPO
1,A0A009IHW8,GO:0034655,BPO
2,A0A009IHW8,GO:0072523,BPO
3,A0A009IHW8,GO:0044270,BPO
4,A0A009IHW8,GO:0006753,BPO


(5363863, 3)

go_graph
Gene Ontology: <class 'networkx.classes.multidigraph.MultiDiGraph'>
Number of nodes: 43248
Number of edges: 84805
information_accretion


Unnamed: 0,GO_term,IA
0,GO:0000001,0.0
1,GO:0000002,3.103836
2,GO:0000003,3.439404
3,GO:0000011,0.056584
4,GO:0000012,6.400377


(43248, 2)

test_sequences_fasta
Q9CQV8
Q9CQV8	10090
Seq('MTMDKSELVQKAKLAEQAERYDDMAAAMKAVTEQGHELSNEERNLLSVAYKNVV...GEN')
246
test_taxonomy:


Unnamed: 0,ID,Species
0,9606,homo sapiens[All Names]
1,10090,mus musculus[All Names]
2,10116,Rattus norvegicus
3,3702,Arabidopsis thaliana[All Names]
4,83333,Escherichia coli K-12[all names]


(90, 2)

# Explore the GO graph and NetworkX (library) functionalities

#### What is in each node (GO term)?

##### Essential elements

* `term`: The unique identifier of the GO term
* `name`: a human-readable name for the GO term
* `namespace`: the aspect - which of the three sub-ontologies (CC, BP or MF) the term belongs to.
* `def`: A textual description of what the term represents, plus reference(s) to the source of the information.
* `is_a`: Relationships to other terms in the ontology.
    * All terms (other than the root terms representing each aspect) have an *is a* sub-class relationship to another term
    * For example, *glucose transmembrane transport (GO:1904659)* is a *monosaccharide transport (GO:0015749)*.
    * The Gene Ontology employs a number of other relations; the relations documentation page describes the relations used in the ontology.

##### Optional elements

* `alt_id`: Secondary IDs come about when two or more terms are identical in meaning, and are merged into a single term. All terms IDs are preserved so that no information (for example, annotations to the merged IDs) is lost.
* `comment`: Any extra information about the term and its usage.
* `synonym`: Alternative words or phrases closely related in meaning to the term name, with indication of the relationship between the name and synonym given by the synonym scope. The scopes for GO synonyms are:
    * __Exact__: an exact equivalent; interchangeable with the term name; for e.g. ornithine cycle is an exact synonym of urea cycle
    * __Broad__: the synonym is broader than the term name; for e.g. cell division is a broad synonym of cytokinesis
    * __Narrow__: the synonym is narrower or more precise than the term name; for e.g. pyrimidine-dimer repair by photolyase is a narrow synonym of photoreactive repair
    * __Related__: the terms are related in some imprecise way; for e.g. cytochrome bc1 complex is a related synonym of ubiquinol-cytochrome-c reductase activity virulence is a related synonym of pathogenesis.
    * Custom synonym types are also used in the ontology. For example, a number of synonyms are designated as systematic synonyms; synonyms of this type are exact synonyms of the term name.
* `xref`: Database cross-references, or dbxrefs, refer to identical or very similar objects in other databases.
    * For instance, the molecular function term retinal isomerase activity (GO:0004744) is cross-referenced with RHEA:24124
    * Another example, the biological process term ulfate assimilation (GO:0000103) has the InterPro cross-reference Sulphate adenylyltransferase (IPR002650).
* `subset`: Indicates that the term belongs to a designated subset of terms, e.g. one of the [GO subsets](https://geneontology.org/docs/go-subset-guide/) (also known as GO slims). It offers an overall sense of the key biological functions that are vital to an organism.
    * More about GO subsets [here](https://geneontology.org/docs/go-subset-guide/)


More information about elements in each node can be found [here](https://geneontology.org/docs/GO-term-elements)

#### Exploration

In [None]:
def process_node(node):

    # Essential elements
    name = node.get("name", np.nan)
    namespace = node.get("namespace", np.nan)
    definition = node.get("def", np.nan)
    is_a = node.get("is_a", np.nan)

    # Optional elements
    alt_id = node.get("alt_id", np.nan)
    comment = node.get("comment", np.nan)
    subset = node.get("subset", np.nan)
    xref = node.get("xref", np.nan)

    node_row = pd.DataFrame({
        "GO_name": name,
        "GO_namespace": namespace,
        "GO_definition": definition,
        "GO_is_a": [is_a],

        "GO_alt_id": [alt_id],
        "GO_comment": comment,
        "GO_xref": [xref],
        "GO_subset": [subset],
    })

    return node_row


# An example
entry_id = "A0A009IHW8"
train_terms_subset = train_terms.query("EntryID == @entry_id").reset_index(drop=True).copy()
node_info = []
for row in train_terms_subset.itertuples(index=False):
    node_row = process_node(go_graph.nodes[row.term])
    node_info.append(node_row)

node_info = pd.concat(node_info, ignore_index=True)
pd.concat([train_terms_subset, node_info], axis=1)

Unnamed: 0,EntryID,term,aspect,GO_name,GO_namespace,GO_definition,GO_is_a,GO_alt_id,GO_comment,GO_xref,GO_subset
0,A0A009IHW8,GO:0008152,BPO,metabolic process,biological_process,"""The chemical reactions and pathways, includin...",[GO:0008150],"[GO:0044236, GO:0044710]",Note that metabolic processes do not include s...,[Wikipedia:Metabolism],"[gocheck_do_not_manually_annotate, goslim_chem..."
1,A0A009IHW8,GO:0034655,BPO,nucleobase-containing compound catabolic process,biological_process,"""The chemical reactions and pathways resulting...","[GO:0006139, GO:0019439, GO:0044270, GO:004670...",,,,[goslim_chembl]
2,A0A009IHW8,GO:0072523,BPO,purine-containing compound catabolic process,biological_process,"""The chemical reactions and pathways resulting...","[GO:0019439, GO:0044270, GO:0046700, GO:007252...",,,,
3,A0A009IHW8,GO:0044270,BPO,cellular nitrogen compound catabolic process,biological_process,"""The chemical reactions and pathways resulting...","[GO:0034641, GO:0044248]",,,,
4,A0A009IHW8,GO:0006753,BPO,nucleoside phosphate metabolic process,biological_process,"""The chemical reactions and pathways involving...","[GO:0006796, GO:0019637, GO:0055086]",,,,
5,A0A009IHW8,GO:1901292,BPO,nucleoside phosphate catabolic process,biological_process,"""The chemical reactions and pathways resulting...","[GO:0006753, GO:0034655, GO:0046434]",,,,
6,A0A009IHW8,GO:0044237,BPO,cellular metabolic process,biological_process,"""The chemical reactions and pathways by which ...","[GO:0008152, GO:0009987]",,,,[gocheck_do_not_annotate]
7,A0A009IHW8,GO:1901360,BPO,organic cyclic compound metabolic process,biological_process,"""The chemical reactions and pathways involving...",[GO:0071704],,,,
8,A0A009IHW8,GO:0008150,BPO,biological_process,biological_process,"""A biological process represents a specific ob...",,"[GO:0000004, GO:0007582, GO:0044699]","Note that, in addition to forming the root of ...",[Wikipedia:Biological_process],"[goslim_candida, goslim_chembl, goslim_metagen..."
9,A0A009IHW8,GO:1901564,BPO,organonitrogen compound metabolic process,biological_process,"""The chemical reactions and pathways involving...","[GO:0006807, GO:0071704]",,,,


In [None]:
# Visualize a subgraph (with the center being a given term)
term = "GO:0034655"
display(go_graph.nodes[term])
# display(go_graph.nodes["GO:0048308"])
# display(go_graph.nodes["GO:0006996"])
# display(go_graph.nodes["GO:0016043"])
# display(go_graph.nodes["GO:0071840"])
# display(go_graph.nodes["GO:0009987"])

# plot_dag(go_graph, term, radius=5)

{'name': 'nucleobase-containing compound catabolic process',
 'namespace': 'biological_process',
 'def': '"The chemical reactions and pathways resulting in the breakdown of nucleobases, nucleosides, nucleotides and nucleic acids." [GOC:mah]',
 'subset': ['goslim_chembl'],
 'synonym': ['"nucleobase, nucleoside, nucleotide and nucleic acid breakdown" EXACT []',
  '"nucleobase, nucleoside, nucleotide and nucleic acid catabolic process" RELATED [GOC:dph, GOC:tb]',
  '"nucleobase, nucleoside, nucleotide and nucleic acid catabolism" EXACT []',
  '"nucleobase, nucleoside, nucleotide and nucleic acid degradation" EXACT []'],
 'is_a': ['GO:0006139',
  'GO:0019439',
  'GO:0044270',
  'GO:0046700',
  'GO:1901361']}

In [None]:
# Get all nodes
pd.DataFrame(go_graph.nodes, columns=["node"])

Unnamed: 0,node
0,GO:0000001
1,GO:0000002
2,GO:0000003
3,GO:0000006
4,GO:0000007
...,...
43243,GO:2001313
43244,GO:2001314
43245,GO:2001315
43246,GO:2001316


In [None]:
# Get neighbors, neighboring edges (iterator)
list(go_graph.adjacency())[10:18] # node, adjacency

# Get neighbors of a term
list(go_graph.neighbors(term))

# Get a subgraph with n-hop neighbors
nx.ego_graph(go_graph, term, radius=3)

[('GO:0000015', {'GO:1902494': {'is_a': {}}, 'GO:0005829': {'part_of': {}}}),
 ('GO:0000016', {'GO:0004553': {'is_a': {}}}),
 ('GO:0000017', {'GO:0042946': {'is_a': {}}}),
 ('GO:0000018', {'GO:0051052': {'is_a': {}}, 'GO:0006310': {'regulates': {}}}),
 ('GO:0000019', {'GO:0000018': {'is_a': {}}, 'GO:0006312': {'regulates': {}}}),
 ('GO:0000022',
  {'GO:0051231': {'is_a': {}},
   'GO:1903047': {'is_a': {}},
   'GO:0000070': {'part_of': {}},
   'GO:0007052': {'part_of': {}}}),
 ('GO:0000023', {'GO:0005984': {'is_a': {}}}),
 ('GO:0000024', {'GO:0000023': {'is_a': {}}, 'GO:0046351': {'is_a': {}}})]

['GO:0006139', 'GO:0019439', 'GO:0044270', 'GO:0046700', 'GO:1901361']

<networkx.classes.multidigraph.MultiDiGraph at 0x2c1bc3700>

In [None]:
# Get out-degrees (num of parents)
pd.DataFrame(go_graph.degree, columns=["node", "out_degrees"]) # number of parents

Unnamed: 0,node,out_degrees
0,GO:0000001,2
1,GO:0000002,2
2,GO:0000003,8
3,GO:0000006,1
4,GO:0000007,1
...,...,...
43243,GO:2001313,5
43244,GO:2001314,3
43245,GO:2001315,3
43246,GO:2001316,6


In [None]:
# Types of edges
edges = pd.DataFrame(go_graph.edges, columns=["child", "parent", "edge"])
edges.edge.value_counts(dropna=False)

edge
is_a                    69350
part_of                  6851
regulates                3157
negatively_regulates     2729
positively_regulates     2718
Name: count, dtype: int64

In [None]:
# Get shortest path between nodes
shortest_path_lengths = dict(nx.all_pairs_shortest_path_length(go_graph))
# shortest_path_lengths["GO:0000001"]["GO:0009987"] # shortest path between child: 1 and parent: 9987
# shortest_path_lengths["GO:0000001"]["GO:0000002"] # KeyError if no path exists between them

def get_shortest_path(child, parent):
    """Return the number of hops to get from `child` to `parent`"""
    try:
        return shortest_path_lengths[child][parent]  # shortest path between child and parent
    except KeyError:
        return None                                  # KeyError if no path exists between them
    except Exception as e:
        raise e

print(get_shortest_path("GO:0000001", "GO:0009987"))
print(get_shortest_path("GO:0000001", "GO:0000002"))

5
None


In [None]:
# Get subontology graphs
subontology_roots = {
    'BPO':'GO:0008150',  # Biological Process (BP)
    'CCO':'GO:0005575',  # Cellular Component (CC)
    'MFO':'GO:0003674',  # Molecular Function (MF)
}

BP_subgraph = nx.induced_subgraph(go_graph, nx.ancestors(go_graph, source=subontology_roots["BPO"]))
CC_subgraph = nx.induced_subgraph(go_graph, nx.ancestors(go_graph, source=subontology_roots["CCO"]))
MF_subgraph = nx.induced_subgraph(go_graph, nx.ancestors(go_graph, source=subontology_roots["MFO"]))

print("BP_subgraph: #nodes = {:,}, #edges = {:,}".format(BP_subgraph.number_of_nodes(), BP_subgraph.number_of_edges()))
print("CC_subgraph: #nodes = {:>6,}, #edges = {:>6,}".format(CC_subgraph.number_of_nodes(), CC_subgraph.number_of_edges()))
print("MF_subgraph: #nodes = {:,}, #edges = {:,}".format(MF_subgraph.number_of_nodes(), MF_subgraph.number_of_edges()))

# pd.DataFrame(BP_subgraph.edges, columns=["child", "parent", "edge"])
# pd.DataFrame(CC_subgraph.edges, columns=["child", "parent", "edge"])
# pd.DataFrame(MF_subgraph.edges, columns=["child", "parent", "edge"])

BP_subgraph: #nodes = 27,941, #edges = 64,536
CC_subgraph: #nodes =  4,042, #edges =  6,495
MF_subgraph: #nodes = 11,262, #edges = 13,714


In [None]:
nx.number_connected_components(go_graph.to_undirected())

3

In [None]:
for i, pairs in enumerate(nx.all_pairs_lowest_common_ancestor(go_graph)):
    print(pairs)
    if i == 10:
        break

(('GO:0000001', 'GO:0000001'), 'GO:0000001')
(('GO:0000001', 'GO:0006996'), 'GO:0000001')
(('GO:0000001', 'GO:0007005'), 'GO:0000001')
(('GO:0000001', 'GO:0008150'), 'GO:0000001')
(('GO:0000001', 'GO:0009987'), 'GO:0000001')
(('GO:0000001', 'GO:0016043'), 'GO:0000001')
(('GO:0000001', 'GO:0048308'), 'GO:0000001')
(('GO:0000001', 'GO:0048311'), 'GO:0000001')
(('GO:0000001', 'GO:0051179'), 'GO:0000001')
(('GO:0000001', 'GO:0051640'), 'GO:0000001')
(('GO:0000001', 'GO:0051646'), 'GO:0000001')


In [None]:
# Check if is DAG
nx.is_directed_acyclic_graph(go_graph)

True

In [None]:
# Topological sort of the graphs
# Kahn’s algorithm: It first finds a list of “start nodes” which have no incoming edges
BP_top_sorts = list(nx.topological_generations(BP_subgraph))
CC_top_sorts = list(nx.topological_generations(CC_subgraph))
MF_top_sorts = list(nx.topological_generations(MF_subgraph))

len(BP_top_sorts)
len(CC_top_sorts)
len(MF_top_sorts)

18

14

12

In [None]:
# Degree Centrality - % of neighbors out of all nodes
GO_degree_centrality = nx.centrality.degree_centrality(go_graph)

# Top 5 nodes with highest degree centrality
(sorted(GO_degree_centrality.items(), key=lambda item: item[1], reverse=True))[:5]

[('GO:0110165', 0.009827271255809652),
 ('GO:0016616', 0.008069923925358985),
 ('GO:0016709', 0.006775036418711125),
 ('GO:0032991', 0.006358822577288598),
 ('GO:0016758', 0.004832705158739335)]

In [None]:
# EXTREMELY SLOW!!
# Betweenness centrality - the number of times a node lies on the shortest path between other nodes, meaning it acts as a bridge
# GO_betweenness_centrality = nx.centrality.betweenness_centrality(go_graph)

# Top 5 nodes with highest betweenness centrality
# (sorted(GO_betweenness_centrality.items(), key=lambda item: item[1], reverse=True))[:5]

In [None]:
# Closeness centrality - their ‘closeness’ to all other nodes in the network
# the higher the closeness centrality of a node, the closer it is located to the center of the network.
closeness_centrality = nx.centrality.closeness_centrality(go_graph)

# Top 5 nodes with highest closeness centrality
(sorted(closeness_centrality.items(), key=lambda item: item[1], reverse=True))[:8]

[('GO:0008150', 0.13249641986781371),
 ('GO:0009987', 0.09177009024191234),
 ('GO:0050789', 0.05450615735840855),
 ('GO:0003674', 0.052467978206565916),
 ('GO:0065007', 0.05023079543531534),
 ('GO:0008152', 0.04454669778432706),
 ('GO:0003824', 0.0426376148572408),
 ('GO:0071704', 0.04236188137222803)]

In [None]:
# Eigenvector Centrality - how connected a node is to other important nodes in the network
# A high eigenvector centrality means that the node is connected to other nodes who themselves have high eigenvector centralities.
GO_eigenvector_centrality = nx.centrality.eigenvector_centrality(nx.DiGraph(go_graph))

# Top 5 nodes with highest eigenvector centrality
(sorted(GO_eigenvector_centrality.items(), key=lambda item: item[1], reverse=True))[:5]

[('GO:0008150', 0.9570036605845309),
 ('GO:0008152', 0.23355013653192808),
 ('GO:0009987', 0.13004250621401445),
 ('GO:0065007', 0.05374101319562961),
 ('GO:0032502', 0.05273139959773638)]

In [None]:
# The clustering coefficient of a node is defined as the probability that two randomly selected friends of are friends with each other.
# As a result, the average clustering coefficient is the average of clustering coefficients of all the nodes.
nx.average_clustering(nx.DiGraph(BP_subgraph))
nx.average_clustering(nx.DiGraph(CC_subgraph))
nx.average_clustering(nx.DiGraph(MF_subgraph))

0.04148416899611462

0.02700774532392186

0.0002592562235609651

In [None]:
# Bridges: Deleting the edge would cause A and B to lie in two different components
def check_bridges(graph):
    if not nx.has_bridges(graph.to_undirected()):
        print("No bridges in this graph")
        return

    bridges = nx.bridges(graph.to_undirected())
    num_bridges = len(list(bridges))
    print(f"Number of bridges: {num_bridges:,}")
    print(f"Total number of edges: {graph.number_of_edges():,}")
    print(f"% of bridges: {num_bridges / graph.number_of_edges():%}")
    print()

check_bridges(go_graph)

check_bridges(BP_subgraph)
check_bridges(CC_subgraph)
check_bridges(MF_subgraph)

Number of bridges: 12,605
Total number of edges: 84,805
% of bridges: 14.863510%

Number of bridges: 2,753
Total number of edges: 64,536
% of bridges: 4.265836%

Number of bridges: 1,484
Total number of edges: 6,495
% of bridges: 22.848345%

Number of bridges: 8,374
Total number of edges: 13,714
% of bridges: 61.061689%



In [None]:
# Assortativity
# the preference for a network’s nodes to attach to others that are "similar" in some way.
# "Similar" here means having the same degree
nx.degree_assortativity_coefficient(go_graph)

nx.degree_assortativity_coefficient(BP_subgraph)
nx.degree_assortativity_coefficient(CC_subgraph)
nx.degree_assortativity_coefficient(MF_subgraph)

-0.12418034755903223

0.11234612296704903

-0.02642072932299072

-0.1843306292754108

In [None]:
# Network community
# A community is a group of nodes, so that nodes inside the group are connected with many more edges than between groups.
communities = nx.community.label_propagation_communities(go_graph.to_undirected())

len(communities) # number of communities
np.mean([len(comm) for comm in communities]) # average number of members in each community

6590

6.562670713201821

# Preprocessing

## Preprocess Protein Sequences

In [None]:
def create_sequence_dataframe_from_fasta(fasta_sequences: list[Bio.SeqRecord.SeqRecord]) -> pd.DataFrame:
    """Create DataFrame for train/test sequences"""

    seq_record_dict = {
        "EntryID": [],
        "description": [],
        "sequence": [],
    }
    for seq_record in fasta_sequences:
        seq_record_dict["EntryID"].append(seq_record.id)
        seq_record_dict["description"].append(seq_record.description)
        seq_record_dict["sequence"].append(str(seq_record.seq))

    return pd.DataFrame(seq_record_dict)

def augment_features_from_fasta_description(sequence_dataframe: pd.DataFrame) -> pd.DataFrame:
    """Extract features from the fasta header text"""
    sequence_dataframe = sequence_dataframe.copy()

    # Extract information from fasta headers
    description_parts = sequence_dataframe.description.str.split(r"\|", n=2, expand=True)
    extracted_groups = sequence_dataframe.description.str.extract(r"([A-Z0-9]+)_([A-Z0-9]+) (.*?) OS=")

    # Create features
    sequence_dataframe = sequence_dataframe.assign(
        db=description_parts[0].str[-2:],
        description=description_parts[2],

        entry_name_prefix=extracted_groups[0],
        entry_name_suffix=extracted_groups[1],
        protein_name=extracted_groups[2],

        organism_name=train_sequences.description.str.extract(r"OS=(.+) OX="),            # The source organism's name
        organism_id=train_sequences.description.str.extract(r"OX=(.+?) ").astype(float),  # The source organism's id
        gene_name=train_sequences.description.str.extract(r"GN=(.+?) "),                  # Optional: Might be empty
        protein_existence=train_sequences.description.str.extract(r"PE=(.+?) ").astype(float),
        sequence_version=train_sequences.description.str.extract(r"SV=(.+?)$").astype(float),
    )

    # Drop unneeded column
    return sequence_dataframe.drop(columns="description")

In [None]:
# Preprocess the train sequences
train_sequences = create_sequence_dataframe_from_fasta(train_sequences_fasta)
train_sequences = augment_features_from_fasta_description(train_sequences)
train_sequences = train_sequences.merge(train_taxonomy, on="EntryID", how="left")

train_sequences.shape
train_sequences.head()

(142246, 11)

Unnamed: 0,EntryID,sequence,db,entry_name_prefix,entry_name_suffix,protein_name,organism_id,gene_name,protein_existence,sequence_version,taxonomyID
0,P20536,MNSVTVSHAPYTITYHDDWEPVMSQLVEFYNEVASWLLRDETSPIP...,sp,UNG,VACCC,Uracil-DNA glycosylase,10249.0,UNG,1.0,1.0,10249
1,O73864,MTEYRNFLLLFITSLSVIYPCTGISWLGLTINGSSVGWNQTHHCKL...,sp,WNT11,DANRE,Protein Wnt-11,7955.0,wnt11,2.0,1.0,7955
2,O95231,MRLSSSPPRGPQQLSSFGSVDWLSQSSCSGPTHTPRPADFSLGSLP...,sp,VENTX,HUMAN,Homeobox protein VENTX,9606.0,VENTX,1.0,1.0,9606
3,A0A0B4J1F4,MGGEAGADGPRGRVKSLGLVFEDESKGCYSSGETVAGHVLLEAAEP...,sp,ARRD4,MOUSE,Arrestin domain-containing protein 4,10090.0,Arrdc4,1.0,1.0,10090
4,P54366,MVETNSPPAGYTLKRSPSDLGEQQQPPRQISRSPGNTAAYHLTTAM...,sp,GSC,DROME,Homeobox protein goosecoid,7227.0,Gsc,2.0,2.0,7227


In [None]:
# Unfortunately, we have to discard most of the features because the test data does not have those information in its fasta header
train_sequences = train_sequences[["EntryID", "sequence", "organism_id", "taxonomyID"]].copy()

In [None]:
# Preprocess the test sequences
test_sequences = create_sequence_dataframe_from_fasta(test_sequences_fasta)
test_sequences = (test_sequences
    .assign(organism_id=test_sequences.description.str.split("\t", expand=True)[1])
    .drop(columns="description")
)
test_sequences = test_sequences.merge(train_taxonomy, on="EntryID", how="left")

test_sequences.shape
test_sequences.head()

(141865, 4)

Unnamed: 0,EntryID,sequence,organism_id,taxonomyID
0,Q9CQV8,MTMDKSELVQKAKLAEQAERYDDMAAAMKAVTEQGHELSNEERNLL...,10090,10090.0
1,P62259,MDDREDLVYQAKLAEQAERYDEMVESMKKVAGMDVELTVEERNLLS...,10090,10090.0
2,P68510,MGDREQLLQRARLAEQAERYDDMASAMKAVTELNEPLSNEDRNLLS...,10090,10090.0
3,P61982,MVDREQLVQKARLAEQAERYDDMAAAMKNVTELNEPLSNEERNLLS...,10090,10090.0
4,O70456,MERASLIQKAKLAEQAERYEDMAAFMKSAVEKGEELSCEERNLLSV...,10090,10090.0


In [None]:
# Note: The train_sequence and train_terms contains the same set of IDs!
# set(train_sequences.EntryID) == set(train_terms.EntryID) # True

In [None]:
# Labels of the train sequences
train_terms.shape
train_terms.head()

(5363863, 3)

Unnamed: 0,EntryID,term,aspect
0,A0A009IHW8,GO:0008152,BPO
1,A0A009IHW8,GO:0034655,BPO
2,A0A009IHW8,GO:0072523,BPO
3,A0A009IHW8,GO:0044270,BPO
4,A0A009IHW8,GO:0006753,BPO


## Preprocess GO terms

In [None]:
# Encode texts (example)
inputs = biogpt_tokenizer(["Hello, my dog is cute", "Algorithms in Structural Bioinformatics"], return_tensors="pt", padding=True)
outputs = biogpt_model(**inputs)

last_hidden_states = outputs.last_hidden_state
last_hidden_states.shape

torch.Size([2, 10, 1024])

In [None]:
go_graph.

<networkx.classes.multidigraph.MultiDiGraph at 0x16d25f040>

# Feature Extraction Components

### Naive
The score that a protein $P_j$ is associated with a GO term $G_i$ is defined as $G_i$'s relative frequency in the training data.

In [None]:
# Naive
naive_component = train_terms["term"].value_counts(dropna=False, normalize=True)
naive_component

term
GO:0005575    1.732184e-02
GO:0008150    1.719097e-02
GO:0110165    1.701870e-02
GO:0003674    1.466052e-02
GO:0005622    1.319665e-02
                  ...     
GO:0031772    1.864328e-07
GO:0042324    1.864328e-07
GO:0031771    1.864328e-07
GO:0051041    1.864328e-07
GO:0102628    1.864328e-07
Name: proportion, Length: 31466, dtype: float64

# Baseline Model

In [None]:
# Setup device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
display(DEVICE)

# Fix random seed for reproducibility
SEED = 42
def set_seed(seed: int=SEED) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    return None

set_seed()

device(type='cpu')


In [None]:
train_sequences.sequence

Unnamed: 0,EntryID,sequence,organism_id
0,P20536,MNSVTVSHAPYTITYHDDWEPVMSQLVEFYNEVASWLLRDETSPIP...,10249.0
1,O73864,MTEYRNFLLLFITSLSVIYPCTGISWLGLTINGSSVGWNQTHHCKL...,7955.0
2,O95231,MRLSSSPPRGPQQLSSFGSVDWLSQSSCSGPTHTPRPADFSLGSLP...,9606.0
3,A0A0B4J1F4,MGGEAGADGPRGRVKSLGLVFEDESKGCYSSGETVAGHVLLEAAEP...,10090.0
4,P54366,MVETNSPPAGYTLKRSPSDLGEQQQPPRQISRSPGNTAAYHLTTAM...,7227.0
...,...,...,...
142241,A0A286YAI0,METEVDDFPGKASIFSQVNPLYSNNMKLCEAERYDFQHSEPKTMKS...,7955.0
142242,A0A1D5NUC4,MSAAASAEMIETPPVLNFEEIDYKEIEVEEVVGRGAFGVVCKAKWR...,9031.0
142243,Q5RGB0,MADKGPILTSVIIFYLSIGAAIFQILEEPNLNSAVDDYKNKTNNLL...,7955.0
142244,A0A2R8QMZ5,MGRKKIQITRIMDERNRQVTFTKRKFGLMKKAYELSVLCDCEIALI...,7955.0
