In [1]:
import numpy as np
import pandas as pd
import os 
import networkx as nx
import obonet
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from Bio import SeqIO
from collections import defaultdict
from sklearn.model_selection import ShuffleSplit
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm
from collections import Counter
import subprocess
import requests, sys
import json

In [2]:
np.random.seed(42)

# Read in the raw datasets

In [3]:
import os

# without this mmseqs isn't recognized as a module

os.environ["PATH"] = (
    "/shared/ifbstor1/projects/deepmar/conda_environments/conda_envs/PlasmoFP/bin:"
    + os.environ["PATH"]
)

In [5]:
TSV_path = "../uniprot_taxonomy_bacteria_20_11_2025.tsv"

#read in the datasets
uniprot = pd.read_csv(TSV_path, sep="\t")
print("Uniprot Data:", uniprot.shape)

Uniprot Data: (329022, 9)


In [6]:
def create_seq_records(data):
    seq_records = []
    for index, row in data.iterrows():
        seq_record = SeqRecord(Seq(row['Sequence']), id=row['Entry'], description=row['Organism'])
        seq_records.append(seq_record)
    assert len(seq_records) == data.shape[0]
    return seq_records

In [7]:
uniprot["GOAssertion"] = "AA"

In [8]:
total_bacteria_function = uniprot[uniprot["Gene Ontology (molecular function)"].notnull()]

print(total_bacteria_function.shape)

(301300, 10)


In [9]:
total_bacteria_records = create_seq_records(uniprot)
print(len(total_bacteria_records))

329022


In [10]:
SeqIO.write(total_bacteria_records, "total_bacteria.fasta", "fasta")

329022

In [11]:
seq_records_bacteria_function = create_seq_records(total_bacteria_function)

assert len(seq_records_bacteria_function) == total_bacteria_function.shape[0]

In [12]:
SeqIO.write(seq_records_bacteria_function, "total_bacteria_function_for_reduce.fasta", "fasta")

301300

In [13]:
import shutil
shutil.which("mmseqs")

'/shared/ifbstor1/projects/deepmar/conda_environments/conda_envs/PlasmoFP/bin/mmseqs'

In [13]:
! mkdir reduced_90

In [None]:
subprocess.run([
    "./mmseq_cluster.sh",
    "/shared/projects/deepmar/data/bacteria_sequences/function_only/reduced_90",
     "/shared/projects/deepmar/data/bacteria_sequences/function_only/total_bacteria_function_for_reduce.fasta",
    "/shared/projects/deepmar/data/bacteria_sequences/function_only/empty.fasta",    
    "/shared/projects/deepmar/data/bacteria_sequences/function_only/empty.fasta",
     "0.9"]
, check=True)

Creating databases...
function/function_redundant_db exists and will be overwritten
createdb /shared/projects/deepmar/data/bacteria_sequences/function_only/total_bacteria_function_for_reduce.fasta function/function_redundant_db 

MMseqs Version:                    	18.8cc5c
Database type                      	0
Shuffle input database             	true
Createdb mode                      	0
Write lookup file                  	1
Offset of numeric ids              	0
Threads                            	256
Compressed                         	0
Mask residues                      	0
Mask residues probability          	0.9
Mask lower case residues           	0
Mask lower letter repeating N times	0
Use GPU                            	0
Verbosity                          	3

Converting sequences
Time for merging to function_redundant_db_h: 0h 0m 0s 113ms
Time for merging to function_redundant_db: 0h 0m 0s 221ms
Database type: Aminoacid
Time for processing: 0h 0m 1s 135ms
component/component_red

CompletedProcess(args=['/shared/projects/deepmar/PlasmoFP_public/scripts/mmseq_cluster.sh', '/shared/projects/deepmar/data/bacteria_sequences/function_only/reduced_90', '/shared/projects/deepmar/data/bacteria_sequences/function_only/total_bacteria_function_for_reduce.fasta', '/shared/projects/deepmar/data/bacteria_sequences/function_only/empty.fasta', '/shared/projects/deepmar/data/bacteria_sequences/function_only/empty.fasta', '0.9'], returncode=0)

# Pick cluster representatives from reduced dataset (0.9)

In [14]:
bacteria_function_clusters = pd.read_csv('reduced_90/function/function_clusters_redundant.tsv', sep='\t', header=None)
bacteria_function_clusters.columns = ["Cluster", "Entry"]

print(bacteria_function_clusters.shape)

assert bacteria_function_clusters["Entry"].isin(total_bacteria_function["Entry"]).all()

(301300, 2)


In [15]:
total_bacteria_function_merged = pd.merge(total_bacteria_function, bacteria_function_clusters, on="Entry") 
total_bacteria_function_merged_grouped = total_bacteria_function_merged.groupby("Cluster")

total_bacteria_function_GO = total_bacteria_function_merged_grouped.apply(lambda x: x.loc[x['Gene Ontology (molecular function)'].str.findall(r"GO:\d+").str.len().idxmax()])

print(total_bacteria_function_GO.shape)

total_bacteria_function_GO_terms = total_bacteria_function_GO["Gene Ontology (molecular function)"].str.findall(r"GO:\d+")

(150860, 11)


  total_bacteria_function_GO = total_bacteria_function_merged_grouped.apply(lambda x: x.loc[x['Gene Ontology (molecular function)'].str.findall(r"GO:\d+").str.len().idxmax()])


In [16]:
print(total_bacteria_function_GO.shape)

(150860, 11)


In [17]:
seq_records_bacteria_function_GO = create_seq_records(total_bacteria_function_GO)

assert len(seq_records_bacteria_function_GO) == total_bacteria_function_GO.shape[0]

function_path = "reduced_90_function.fasta"

SeqIO.write(seq_records_bacteria_function_GO, function_path, "fasta")

total_bacteria_function_GO.to_csv("reduced_90_function.tsv", sep="\t", index=False)

with open('reduced_90_expected_shapes.txt', 'w') as f:
    print("Function:", total_bacteria_function_GO.shape, file=f)

# Propagate

In [18]:
total_bacteria_function_GO_raw_terms = total_bacteria_function_GO['Gene Ontology (molecular function)'].str.findall(r"GO:\d+")

total_bacteria_function_GO['Raw GO terms'] = total_bacteria_function_GO_raw_terms

In [19]:
def get_unique_go_terms(go_terms):
    return set([term for terms in go_terms for term in terms])

unique_go_terms_function = get_unique_go_terms(total_bacteria_function_GO_raw_terms)

print("Unique GO terms in function:", len(unique_go_terms_function))

#conver to lists 
unique_go_terms_function = list(unique_go_terms_function)

Unique GO terms in function: 3643


In [20]:
obo_file_path = "go-basic.obo"
graph = obonet.read_obo(obo_file_path)


remove_edges = [(i, j, k) for i, j, k in graph.edges if not(k=="is_a" or k=="part_of")]

graph.remove_edges_from(remove_edges)

    
# There should not be any cross-ontology edges, but we verify here
crossont_edges = [(i, j, k) for i, j, k in graph.edges if
                      graph.nodes[i]['namespace']!= graph.nodes[j]['namespace']]
if len(crossont_edges)>0:
        graph.remove_edges_from(crossont_edges)

        
graph_rev = graph.reverse(copy=True) # so edges point towards children
all_go_terms_function = []

def get_ancestors_from_obo(go_terms_batch, graph):
    """
    Returns ancestors for a batch of GO terms from the OBO graph.
    """
    ancestors_dict_batch = {}
    json_dict_batch = {}
    all_go_terms_batch = set(go_terms_batch)

    for term in go_terms_batch:
        if term not in graph:
            ancestors_dict_batch[term] = None
            json_dict_batch[term] = {"is_obsolete":True} # just assume this is the case
            continue

        ancestors = nx.ancestors(graph, term)
        ancestors_dict_batch[term] = list(ancestors) if ancestors else None
        
        node_data = graph.nodes[term]
        is_obsolete = node_data.get("is_obsolete") == "true"

        json_dict_batch[term] = {
            "id": term,
            "is_obsolete": is_obsolete,
            "ancestors": [{"id": a, "distance": nx.shortest_path_length(graph, source=a, target=term)} for a in ancestors]
        }

        all_go_terms_batch.update(ancestors)

    return json_dict_batch, ancestors_dict_batch, all_go_terms_batch

batch_size = 50  

def process_go_terms_in_batches_obo(go_terms, graph, batch_size=20):
    json_dict = {}
    ancestors_dict = {}
    all_go_terms = set()

    for i in tqdm(range(0, len(go_terms), batch_size)):
        batch = go_terms[i:i + batch_size]
        json_data_batch, ancestors_dict_batch, all_go_terms_batch = get_ancestors_from_obo(batch, graph)

        if json_data_batch and ancestors_dict_batch:
            json_dict.update(json_data_batch)
            ancestors_dict.update(ancestors_dict_batch)
            all_go_terms.update(all_go_terms_batch)

    return json_dict, ancestors_dict, all_go_terms


function_json_dict, function_ancestors_dict, function_all_go_terms = process_go_terms_in_batches_obo(unique_go_terms_function, graph_rev, batch_size=batch_size)

with open('function_data_ancestors_dict.json', 'w') as f:
    json.dump(function_ancestors_dict, f, indent=4)
with open('function_data_json_dict.json', 'w') as f:
    json.dump(function_json_dict, f, indent=4)



assert len(function_json_dict) == len(unique_go_terms_function)
assert len(function_ancestors_dict) == len(unique_go_terms_function)

100%|██████████| 73/73 [00:00<00:00, 241.76it/s]


In [21]:
function_OBSELETE = [term for term in function_json_dict if function_json_dict[term]["is_obsolete"]]

print("Function Obsolete:", len(function_OBSELETE))

Function Obsolete: 22


In [22]:
def add_ancestors(go_terms_list, ancestors_dict):
    all_ancestors = []
    
    for term_list in tqdm(go_terms_list, desc="Adding ancestors"):
        terms_return = set() 
        for t in term_list:
            # Check if the term exists in the dictionary and has valid ancestors
            if t in ancestors_dict and ancestors_dict[t] is not None:
                # Add ancestors, excluding the original term itself
                terms_return.update(a for a in ancestors_dict[t] if a != t)
                
            # I (Helena) don't think it's correct to exclude the original term
            # I couldn't see a place where they were added back in
            terms_return.update([t])
        
        # Append the filtered ancestors found for this protein (as a list)
        all_ancestors.append(list(terms_return))

    return all_ancestors

total_bacteria_function_GO_raw_terms_propogated = add_ancestors(list(total_bacteria_function_GO_raw_terms), function_ancestors_dict)

Adding ancestors: 100%|██████████| 150860/150860 [00:00<00:00, 227813.67it/s]


In [23]:
non_propogated_function_terms = [i for i, x in enumerate(total_bacteria_function_GO_raw_terms_propogated) if len(x) <= 1]

print("Non-propogated function terms:", len(non_propogated_function_terms))
# these are obsolete go terms.
total_bacteria_function_GO.iloc[non_propogated_function_terms]['Gene Ontology (molecular function)'].unique()

Non-propogated function terms: 74


array(['citrate (Si)-synthase activity [GO:0004108]',
       'aldo-keto reductase (NADPH) activity [GO:0004033]',
       'electron transporter, transferring electrons within cytochrome b6/f complex of photosystem II activity [GO:0045158]',
       'citrate (Re)-synthase activity [GO:0050450]',
       'androsterone dehydrogenase (B-specific) activity [GO:0047042]',
       'organic anion transmembrane transporter activity [GO:0008514]',
       'electron transporter, transferring electrons within the cyclic electron transport pathway of photosynthesis activity [GO:0045156]'],
      dtype=object)

In [24]:
total_bacteria_function_GO['Raw propagated GO terms'] = total_bacteria_function_GO_raw_terms_propogated

In [25]:
total_bacteria_function_GO = total_bacteria_function_GO[total_bacteria_function_GO['Raw propagated GO terms'].str.len() > 0]

In [26]:
seq_records_bacteria_function_GO_propagated = create_seq_records(total_bacteria_function_GO)

assert len(seq_records_bacteria_function_GO_propagated) == total_bacteria_function_GO.shape[0]

function_path = "reduced_90_function_propagated.fasta"

SeqIO.write(seq_records_bacteria_function_GO_propagated, function_path, "fasta")

total_bacteria_function_GO.to_csv("reduced_90_function_propagated.tsv", sep="\t", index=False)

with open('reduced_90_expected_shapes_propagated.txt', 'w') as f:
    print("Function:", total_bacteria_function_GO.shape, file=f)

Now we can preform the sequence similarity splits for the train, dev, and test sets

In [None]:
! mkdir reduced_90_similarity_30

In [None]:
subprocess.run(["./mmseq_cluster.sh",
    "/shared/projects/deepmar/data/bacteria_sequences/function_only/reduced_90_similarity_30",
    "/shared/projects/deepmar/data/bacteria_sequences/function_only/reduced_90_function_propagated.fasta",
    "/shared/projects/deepmar/data/bacteria_sequences/function_only/empty.fasta",
    "/shared/projects/deepmar/data/bacteria_sequences/function_only/empty.fasta",
    "0.3"]
)

Creating databases...
createdb /shared/projects/deepmar/data/bacteria_sequences/function_only/reduced_90_function_propagated.fasta function/function_redundant_db 

MMseqs Version:                    	18.8cc5c
Database type                      	0
Shuffle input database             	true
Createdb mode                      	0
Write lookup file                  	1
Offset of numeric ids              	0
Threads                            	256
Compressed                         	0
Mask residues                      	0
Mask residues probability          	0.9
Mask lower case residues           	0
Mask lower letter repeating N times	0
Use GPU                            	0
Verbosity                          	3

Converting sequences
Time for merging to function_redundant_db_h: 0h 0m 0s 112ms
Time for merging to function_redundant_db: 0h 0m 0s 146ms
Database type: Aminoacid
Time for processing: 0h 0m 0s 955ms
createdb /shared/projects/deepmar/data/bacteria_sequences/function_only/empty.fasta compo

CompletedProcess(args=['/shared/projects/deepmar/PlasmoFP_public/scripts/mmseq_cluster.sh', '/shared/projects/deepmar/data/bacteria_sequences/function_only/reduced_90_similarity_30', '/shared/projects/deepmar/data/bacteria_sequences/function_only/reduced_90_function_propagated.fasta', '/shared/projects/deepmar/data/bacteria_sequences/function_only/empty.fasta', '/shared/projects/deepmar/data/bacteria_sequences/function_only/empty.fasta', '0.3'], returncode=0)

# Train, val, test split

In [27]:
function_mmseq_out = pd.read_csv("reduced_90_similarity_30/function/function_clusters_redundant.tsv", sep='\t', header=None)
function_mmseq_out.columns = ['Cluster_2', 'Entry']


print("There are {} clusters in the function dataset.".format(function_mmseq_out['Cluster_2'].nunique()))

assert len(function_mmseq_out) == len(total_bacteria_function_GO)

There are 9826 clusters in the function dataset.


In [28]:
#merge with the original datasets
function_merged = pd.merge(function_mmseq_out, total_bacteria_function_GO, on='Entry')

assert len(function_merged) == len(function_mmseq_out)

In [30]:
def prepare_data_and_split(go_terms_df, clusters_col='NA', terms_col='NA'):
    cluster_to_terms = defaultdict(list)
    term_to_clusters = defaultdict(set)

    for _, row in tqdm(go_terms_df.iterrows(), total=go_terms_df.shape[0], desc="Processing rows"):
        cluster, terms = row[clusters_col], row[terms_col]
        cluster_to_terms[cluster].extend(terms)
        for term in terms:
            term_to_clusters[term].add(cluster)

    print("Identifying unique and multi-occurrence clusters")

    unique_clusters = {cluster for cluster, terms in cluster_to_terms.items()
                       if all(len(term_to_clusters[term]) == 1 for term in terms)}
    multi_clusters = set(cluster_to_terms.keys()) - unique_clusters

    mlb = MultiLabelBinarizer()
    all_terms = [terms for terms in cluster_to_terms.values()]
    mlb.fit(all_terms)

    print("ShuffleSplit for training, validation, and test sets from multi occurrence clusters")
    ss = ShuffleSplit(n_splits=1, test_size=0.1, random_state=38)
    multi_clusters_list = list(multi_clusters)
    train_val_idx, test_idx = next(ss.split(multi_clusters_list))

    train_val_clusters = [multi_clusters_list[i] for i in train_val_idx]
    ss_train_val = ShuffleSplit(n_splits=1, test_size=0.1, random_state=38)
    train_idx, val_idx = next(ss_train_val.split(train_val_clusters))

    train_clusters = [train_val_clusters[i] for i in train_idx] + list(unique_clusters)
    val_clusters = [train_val_clusters[i] for i in val_idx]
    test_clusters = [multi_clusters_list[i] for i in test_idx]

    print("Extracting rows for each set")
    train_df = go_terms_df[go_terms_df[clusters_col].isin(train_clusters)]
    val_df = go_terms_df[go_terms_df[clusters_col].isin(val_clusters)]
    test_df = go_terms_df[go_terms_df[clusters_col].isin(test_clusters)]

    train_terms = set(sum(train_df[terms_col].tolist(), []))
    val_terms = set(sum(val_df[terms_col].tolist(), []))
    test_terms = set(sum(test_df[terms_col].tolist(), []))

    missing_val_terms = val_terms - train_terms
    missing_test_terms = test_terms - train_terms

    print("Moving clusters with missing terms from validation/test to training")
    if missing_val_terms:
        val_clusters_to_move = [cluster for cluster in val_clusters if any(term in missing_val_terms for term in cluster_to_terms[cluster])]
        train_clusters.extend(val_clusters_to_move)
        val_clusters = list(set(val_clusters) - set(val_clusters_to_move))

    if missing_test_terms:
        test_clusters_to_move = [cluster for cluster in test_clusters if any(term in missing_test_terms for term in cluster_to_terms[cluster])]
        train_clusters.extend(test_clusters_to_move)
        test_clusters = list(set(test_clusters) - set(test_clusters_to_move))

    train_df = go_terms_df[go_terms_df[clusters_col].isin(train_clusters)]
    val_df = go_terms_df[go_terms_df[clusters_col].isin(val_clusters)]
    test_df = go_terms_df[go_terms_df[clusters_col].isin(test_clusters)]

    train_terms = set(sum(train_df[terms_col].tolist(), []))
    val_terms = set(sum(val_df[terms_col].tolist(), []))
    test_terms = set(sum(test_df[terms_col].tolist(), []))

    missing_val_terms = val_terms - train_terms
    missing_test_terms = test_terms - train_terms

    print("Validation terms not in training:", missing_val_terms)
    print("Test terms not in training:", missing_test_terms)

    propagated_terms_col = 'Raw propagated GO terms'  

    min_training_instances = 50 
    train_terms_counter = Counter([term for terms in train_df[propagated_terms_col].tolist() for term in terms])
    terms_below_threshold = {term for term, count in train_terms_counter.items() if count < min_training_instances}

    print(f"Terms below threshold ({min_training_instances} occurrences):", terms_below_threshold)

    if terms_below_threshold:
        print(f"Filtering out {len(terms_below_threshold)} terms from the training set.")
        train_df[propagated_terms_col] = train_df[propagated_terms_col].apply(lambda terms: [term for term in terms if term not in terms_below_threshold])
        train_df = train_df[train_df[propagated_terms_col].apply(len) > 0]  # Remove rows with no terms after filtering

    print("Filtering validation and test sets to remove terms below training threshold.")
    val_df[propagated_terms_col] = val_df[propagated_terms_col].apply(lambda terms: [term for term in terms if term not in terms_below_threshold])
    test_df[propagated_terms_col] = test_df[propagated_terms_col].apply(lambda terms: [term for term in terms if term not in terms_below_threshold])

    val_df = val_df[val_df[propagated_terms_col].apply(len) > 0]
    test_df = test_df[test_df[propagated_terms_col].apply(len) > 0]

    filtered_train_counts = Counter([term for terms in train_df[propagated_terms_col].tolist() for term in terms])
    filtered_train_counter_terms_less_than_X = {term: count for term, count in filtered_train_counts.items() if count < min_training_instances}
    print("Filtered function training set terms with less than {} occurrences:".format(min_training_instances), filtered_train_counter_terms_less_than_X)
    print("Filtered function training set total terms:", len(filtered_train_counts))

    return train_df, val_df, test_df, mlb


function_train_df, function_val_df, function_test_df, function_mlb = prepare_data_and_split(function_merged, clusters_col='Cluster_2', terms_col='Raw GO terms')
print("Function training set shape:", function_train_df.shape)
print("Function validation set shape:", function_val_df.shape)
print("Function test set shape:", function_test_df.shape)
print("----")

Processing rows: 100%|██████████| 150860/150860 [00:07<00:00, 19078.82it/s]


Identifying unique and multi-occurrence clusters
ShuffleSplit for training, validation, and test sets from multi occurrence clusters
Extracting rows for each set
Moving clusters with missing terms from validation/test to training
Validation terms not in training: set()
Test terms not in training: set()
Terms below threshold (50 occurrences): {'GO:0008761', 'GO:0018549', 'GO:0016815', 'GO:0047492', 'GO:0047922', 'GO:0015478', 'GO:1990829', 'GO:0016409', 'GO:0050540', 'GO:0052823', 'GO:0102682', 'GO:0036041', 'GO:0052855', 'GO:0015527', 'GO:0015643', 'GO:0047355', 'GO:0016647', 'GO:0003681', 'GO:0046625', 'GO:0043802', 'GO:0043837', 'GO:0051425', 'GO:0018493', 'GO:0018459', 'GO:1901359', 'GO:0047727', 'GO:0050833', 'GO:0004366', 'GO:0052880', 'GO:0070026', 'GO:0004574', 'GO:0051020', 'GO:0160105', 'GO:0047120', 'GO:0033940', 'GO:0051903', 'GO:0015117', 'GO:0015519', 'GO:0047971', 'GO:0051009', 'GO:0008719', 'GO:0050231', 'GO:0102933', 'GO:0000822', 'GO:0016430', 'GO:0016672', 'GO:0018620

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_df[propagated_terms_col] = train_df[propagated_terms_col].apply(lambda terms: [term for term in terms if term not in terms_below_threshold])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_df[propagated_terms_col] = val_df[propagated_terms_col].apply(lambda terms: [term for term in terms if term not in terms_below_threshold])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pyd

Filtering validation and test sets to remove terms below training threshold.
Filtered function training set terms with less than 50 occurrences: {}
Filtered function training set total terms: 1098
Function training set shape: (129887, 14)
Function validation set shape: (10804, 14)
Function test set shape: (10130, 14)
----


In [31]:
temp_terms = function_train_df['Raw propagated GO terms'].tolist()
temp_terms= [item for sublist in temp_terms for item in sublist]
print("Total number of GO terms in the function training set:", len(temp_terms))
print("Total number of unique GO terms in the function training set:", len(set(temp_terms)))

Total number of GO terms in the function training set: 1885828
Total number of unique GO terms in the function training set: 1098


In [32]:
#are there any terms that are not in the training set but are in the validation set and test set

temp_terms_train = function_train_df['Raw propagated GO terms'].tolist()
temp_terms_train = [item for sublist in temp_terms_train for item in sublist]
temp_terms_val = function_val_df['Raw propagated GO terms'].tolist()
temp_terms_val = [item for sublist in temp_terms_val for item in sublist]
temp_terms_test = function_test_df['Raw propagated GO terms'].tolist()
temp_terms_test = [item for sublist in temp_terms_test for item in sublist]

print("Terms in validation set but not in training set:", set(temp_terms_val) - set(temp_terms_train))
print("Terms in test set but not in training set:", set(temp_terms_test) - set(temp_terms_train))

Terms in validation set but not in training set: set()
Terms in test set but not in training set: set()


In [103]:
! mkdir processed_data_90_30

mkdir: cannot create directory ‘processed_data_90_30’: File exists


In [33]:
os.chdir("processed_data_90_30")
function_train_df.to_csv("function_train.tsv", sep="\t", index=False)
function_val_df.to_csv("function_val.tsv", sep="\t", index=False)
function_test_df.to_csv("function_test.tsv", sep="\t", index=False)

#write file that has contains how many unique GO terms are in the training set
function_train_terms = function_train_df['Raw propagated GO terms'].tolist()
function_train_terms = [item for sublist in function_train_terms for item in sublist]
function_train_terms_unique = len(set(function_train_terms))

with open("number_of_terms_trained_on.txt", "w") as f:
    print("Function:", len(set(function_train_terms)), file=f)

In [34]:
with open("df_shapes.txt", "w") as f:
    print("Function train:", function_train_df.shape, file=f)
    print("Function val:", function_val_df.shape, file=f)
    print("Function test:", function_test_df.shape, file=f)

In [35]:
function_merged.to_csv("presplit_function_merged.tsv", sep="\t", index=False)

In [37]:
assert len(function_train_df) + len(function_val_df) + len(function_test_df) == len(function_merged)

AssertionError: 

In [38]:
#create fasta files 
seq_records_function_train = create_seq_records(function_train_df)
seq_records_function_val = create_seq_records(function_val_df)
seq_records_function_test = create_seq_records(function_test_df)


assert len(seq_records_function_train) == function_train_df.shape[0]
assert len(seq_records_function_val) == function_val_df.shape[0]
assert len(seq_records_function_test) == function_test_df.shape[0]

In [39]:
#write out the fasta files
SeqIO.write(seq_records_function_train, "function_train.fasta", "fasta")
SeqIO.write(seq_records_function_val, "function_val.fasta", "fasta")
SeqIO.write(seq_records_function_test, "function_test.fasta", "fasta")

10130

# Get embeddings

In [None]:
%run ../generate_embeddings/generate_embeddings.py \
    --input ./reduced_90_function.fasta \
    --output ./embeddings/ \
    --output_format npz --tm_vec_model /shared/projects/deepmar/data/tmvec_model_weights/tm_vec_cath_model.ckpt --device cuda --tm_vec_config /shared/projects/deepmar/data/tmvec_model_weights/tm_vec_cath_model_params.json

  from .autonotebook import tqdm as notebook_tqdm


Embedding generation started!
Output directory: embeddings
Log file: embeddings/embedding_generation.log
Verbose mode: OFF
Using device: cuda
2026-01-16 12:06:56,566 - INFO - Using device: cuda
Found 1 FASTA file(s) to process:
   1. reduced_90_function.fasta

2026-01-16 12:06:56,568 - INFO - Found 1 FASTA file(s) to process
Loading ProtT5 tokenizer and model...
2026-01-16 12:06:56,570 - INFO - Loading ProtT5 tokenizer and model...


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading TM-Vec model...
2026-01-16 12:07:51,249 - INFO - Loading TM-Vec model...


Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.6.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../../../../projects/deepmar/data/tmvec_model_weights/tm_vec_cath_model.ckpt`


Models loaded successfully on device: cuda
2026-01-16 12:07:53,046 - INFO - Models loaded successfully on device: cuda
Starting processing...
[1/1] Processing: reduced_90_function.fasta
Processing file: reduced_90_function.fasta
2026-01-16 12:07:53,046 - INFO - Processing file: reduced_90_function.fasta
Found 150860 sequences in reduced_90_function.fasta
2026-01-16 12:07:59,058 - INFO - Found 150860 sequences in reduced_90_function.fasta
Generating embeddings... (this may take a while)
2026-01-16 12:07:59,059 - INFO - Generating embeddings...


 23%|██▎       | 35207/150860 [1:31:09<4:30:08,  7.14it/s] 

In [40]:
os.chdir("..")

In [41]:
function_all_np = np.load("embeddings/reduced_90_function_embeddings.npy", allow_pickle=True)
function_all_df = pd.read_csv("reduced_90_function.tsv", sep='\t')
assert function_all_np.shape[0] == function_all_df.shape[0]

In [42]:
function_train_df= pd.read_csv("processed_data_90_30/function_train.tsv", sep="\t")
function_val_df.to_csv("processed_data_90_30/function_val.tsv", sep="\t")
function_test_df.to_csv("processed_data_90_30/function_test.tsv", sep="\t")

In [43]:
function_all_dict = dict(zip(function_all_df['Entry'], function_all_np))
function_train_np = np.array([function_all_dict[entry] for entry in function_train_df['Entry']])
function_val_np = np.array([function_all_dict[entry] for entry in function_val_df['Entry']])
function_test_np = np.array([function_all_dict[entry] for entry in function_test_df['Entry']])


assert function_train_np.shape[0] == function_train_df.shape[0]
assert function_val_np.shape[0] == function_val_df.shape[0]
assert function_test_np.shape[0] == function_test_df.shape[0]


In [44]:
np.save("processed_data_90_30/function_train.npy", function_train_np)
np.save("processed_data_90_30/function_val.npy", function_val_np)
np.save("processed_data_90_30/function_test.npy", function_test_np)


with open("expected_shapes.txt", "w") as f:
    print("Function train:", function_train_np.shape, file=f)
    print("Function val:", function_val_np.shape, file=f)
    print("Function test:", function_test_np.shape, file=f)

# Generate IA TSV using the training data 

In [45]:
total_bacteria_function_GO = pd.read_csv("reduced_90_function_propagated.tsv", sep="\t")


In [46]:
with open('function_data_json_dict.json') as json_data:
    function_all_json_dict = json.load(json_data)

function_OBSELETE = [term for term in function_all_json_dict if function_all_json_dict[term]["is_obsolete"]]

print("Function Obselete:", len(function_OBSELETE))

Function Obselete: 22


In [75]:
function_train_df_subset = total_bacteria_function_GO[['Entry', 'Raw propagated GO terms']].copy()
function_train_df_subset.columns = ['EntryID', 'term']

# Adding aspect to each respective dataframe
function_train_df_subset['aspect'] = 'MFO'

# Convert stringified list → real list
function_train_df_subset['term'] = (
    function_train_df_subset['term']
        .str.strip('[]')
        .str.replace("'", "", regex=False)
        .str.split(', ')
)
exploded_df = function_train_df_subset.explode('term')


#REMOVE ALL_OBSELETE_TERMS
exploded_df = exploded_df[~exploded_df['term'].isin(function_OBSELETE)]

In [76]:
exploded_df.head()


Unnamed: 0,EntryID,term,aspect
0,A0A009IHW8,GO:0016798,MFO
0,A0A009IHW8,GO:0061809,MFO
0,A0A009IHW8,GO:0050135,MFO
0,A0A009IHW8,GO:0016799,MFO
0,A0A009IHW8,GO:0016787,MFO


In [77]:
exploded_df.to_csv("terms_for_IA_calculation.tsv", sep="\t", index=False)


In [78]:
%run ia.py -a terms_for_IA_calculation.tsv -o IA_all.tsv -g go-basic.obo

Counting Terms
0 []
0 []
0 []
Computing Information Accretion
Saving to file IA_all.tsv
