In [62]:
from Bio import SeqIO
def records(path):
    records = list(SeqIO.parse(path, "fasta"))
    return records

In [63]:
def ids_from_gopredsim_annotation(pth):
    with open(pth, 'r') as opened:
        ids = set(x.split()[0] for x in opened.readlines())
        print(len(ids))
    return ids

# Determine the sets of train and test ids

This is the training set originally provided in the CAFA3 challenge, although use could also use other information

In [64]:
orig_cafa3_file = "uniprot_sprot_exp.fasta"
orig_cafa3_ids_set = set(str(x.id) for x in records(orig_cafa3_file))
len(orig_cafa3_ids_set)

66841

This is the training set prepared by goPredSim developers,
closely matching the CAFA3 set, with a corresponding temporal cutoff

In [65]:
gopredsim_file = "goa_annotations_exp_2017.txt"
gopredsim_ids_set = ids_from_gopredsim_annotation(gopredsim_file)

68038


We take their intersection

In [66]:
train_ids = orig_cafa3_ids_set.intersection(gopredsim_ids_set)

In [67]:
len(train_ids)

62626

For the test set, we take the set of proteins which got new annotations in the CAFA3 review period

In [68]:
test_ids = set(str(x.id) for x in records("cafa3_targets.fasta"))

In [69]:
len(test_ids)

3328

# Prepare the sets of embeddings in format required by goPredSim

In [70]:
embeddings_folder = "prepared_embeddings"

In [71]:
import numpy as np
import os
def get_dict_from_separate_files(embeddings_file, ids_file):
    embeddings = np.load(embeddings_file, allow_pickle=True)
    ids = np.load(ids_file, allow_pickle=True)
    embed_dict = {id_: em for id_, em in zip(ids, embeddings)}
    return embed_dict

In [72]:
def get_subset(source, ids, allow_missing=False):
    res = dict()
    for i in ids:
        try:
            res[i]=source[i]
        except KeyError as e:
            print(f"missing embedding for {i}")
            if not allow_missing:
                raise e
    return res

In [73]:
import pickle
def save_dict(di, name):
    with open(f"{name}.pkl", 'wb') as fp:
        pickle.dump(di, fp)

In [74]:
def prepare_separate(model, split):
    assert split in ("train", "test")
    model_data_ids = os.path.join(embeddings_folder, f"{model}_{split}_ids.npy")
    model_data_ems = os.path.join(embeddings_folder, f"{model}_{split}_embeddings.npy")
    model_data_full = get_dict_from_separate_files(model_data_ems, model_data_ids)
    ids = train_ids if split == "train" else test_ids
    model_data = get_subset(model_data_full, ids)
    save_dict(model_data, f"{model}.{split}")

In [75]:
import h5py
import numpy as np
import tqdm

# Code based on that of goPredSim
def read_h5_embeddings(embeddings_in):
    """A2ASS6
    Read embeddings from h5 file generated by bio_embeddings pipeline
    :param embeddings_in: 
    :return: 
    """
    embeddings = dict()
    with h5py.File(embeddings_in, 'r') as f:
        for key, embedding in tqdm.tqdm(f.items()):
            original_id = embedding.attrs['original_id']
            embeddings[original_id] = np.array(embedding)
            
    return embeddings

In [76]:
def prepare_from_h5(model, split, allow_missing=False):
    assert split in ("train", "test")
    if split=="train":
        data_path = os.path.join(embeddings_folder, f"{model}_goa_2017.h5")
    else:
        data_path = os.path.join(embeddings_folder, f"{model}_cafa3_targets.h5")
        
    model_data_full = read_h5_embeddings(data_path)
    ids = train_ids if split == "train" else test_ids
    model_data = get_subset(model_data_full, ids, allow_missing=allow_missing)
    save_dict(model_data, f"{model}.{split}")

proteinbert

In [48]:
prepare_separate("proteinbert", "train")
prepare_separate("proteinbert", "test")

protbert

In [49]:
prepare_separate("protbert", "train")
prepare_separate("protbert", "test")

esm2

In [13]:
import pickle
def esm2_filter(file):
    with open(file, 'rb') as opened:
        di = pickle.load(opened)
    out = {i: o[33] for i, o in di.items()}
    with open(file, 'wb') as fp:
        pickle.dump(out, fp)

In [50]:
prepare_separate("esm2", "train")
prepare_separate("esm2", "test")

In [25]:
import pickle
def esm2_filter2(file):
    with open(file, 'rb') as opened:
        di = pickle.load(opened)
    out = {i: o.numpy() for i, o in di.items()}
    with open(file, 'wb') as fp:
        pickle.dump(out, fp)

In [17]:
esm2_filter("esm2.train.pkl")
esm2_filter("esm2.test.pkl")

In [None]:
esm2_filter2("esm2.train.pkl")
esm2_filter2("esm2.test.pkl")

prott5

In [76]:
prepare_from_h5("prott5", "train", allow_missing=True)

100%|██████████| 307278/307278 [01:27<00:00, 3501.78it/s]


missing embedding for A2ASS6
missing embedding for G4SLH0
missing embedding for Q9I7U4
missing embedding for Q8WZ42
missing embedding for Q8WXI7
missing embedding for Q09165


In [77]:
prepare_from_h5("prott5", "test")

100%|██████████| 3328/3328 [00:00<00:00, 3546.41it/s]


seqvec

In [79]:
prepare_from_h5("seqvec", "train", allow_missing=True)
prepare_from_h5("seqvec", "test")

100%|██████████| 307278/307278 [01:32<00:00, 3330.10it/s]


missing embedding for A2ASS6
missing embedding for G4SLH0
missing embedding for Q9I7U4
missing embedding for Q8WZ42
missing embedding for Q8WXI7
missing embedding for Q09165


100%|██████████| 3328/3328 [00:01<00:00, 2725.63it/s]


preexisting protbert

In [77]:
prepare_from_h5("theirprotbert", "train", allow_missing=True)
prepare_from_h5("theirprotbert", "test")

100%|██████████| 307278/307278 [01:35<00:00, 3212.55it/s]


missing embedding for Q09165
missing embedding for Q9I7U4
missing embedding for Q8WZ42
missing embedding for G4SLH0
missing embedding for A2ASS6
missing embedding for Q8WXI7


100%|██████████| 3328/3328 [00:01<00:00, 3231.66it/s]


In [89]:
def load(file):
    with open(file, 'rb') as opened:
        di = pickle.load(opened)
        return di

In [92]:
next(iter(load("project/theirprotbert.test.pkl").values())).shape

(1024,)

In [93]:
next(iter(load("project/protbert.test.pkl").values())).shape

(1024,)

In [80]:
# Prepare the set of annotations for our train set

In [83]:
def filter_annotations(annotations_file, id_set, output_file):
    with open(annotations_file, 'r') as in_file:
        with open(output_file, 'w') as out_file:
            for line in in_file.readlines():
                line_id = line.split()[0]
                if line_id in id_set:
                    print(line, file=out_file, end='') #line already has newline

In [84]:
filter_annotations("goa_annotations_exp_2017.txt", train_ids, "project_annotations.txt")

In [83]:
def produce_configs():
    #os.mkdir("configs")
    for model in ("theirprotbert",):#"prott5", "seqvec", "esm2", "protbert", "proteinbert":
        with open(os.path.join("configs", f"{model}_config.txt"), 'w') as config_file:
            print("go: data/GO/go_cafa3.obo", file=config_file)
            print(f"lookup_set: project/{model}.train.pkl", file=config_file)
            print(f"annotations: project_annotations.txt", file= config_file)
            print(f"targets: project/{model}.test.pkl", file=config_file)
            print(f"onto: all", file=config_file)
            print(f"thresh: 1", file=config_file)
            print(f"modus: num", file=config_file)
            print(f"output: results/{model}", file=config_file)

In [84]:
produce_configs()

In [33]:
from yaml import load, dump

In [40]:
import os
import yaml

In [85]:
def produce_assess():
    #os.mkdir("assess_configs")
    for model in ("theirprotbert",):#"prott5", "seqvec", "esm2", "protbert", "proteinbert":
        for onto in "BPO", "MFO", "CCO": 
            d = {
                "file" : f"predictions/{model}-Tch_1_all_go_{onto}.txt",
                "obo": "./precrec/go_cafa3.obo",
                "benchmark": "./precrec/benchmark/CAFA3_benchmarks/",
                "results": "./results"
            }
            conf = yaml.dump({"assess": d, "plot": {}})
            with open(f"assess_configs/{model}_{onto}.yaml", 'w') as opened:
                print(conf, file=opened, end='')

In [86]:
produce_assess()

In [88]:
#this cell has been created with chatgpt
import os
import pandas as pd

def parse_result_file(file_path):
    result_data = {}
    ontology_matched = False
    model = file_path.split("_")[1]
    ontology = file_path.split("_")[-1].split(".")[0].upper()

    with open(file_path, 'r') as file:
        lines = file.readlines()

    mode = None
    benchmark_type = None

    for line in lines:
        line = line.strip()
        if line.startswith("ontology:"):
            current_ontology = line.split(":")[1].strip().upper()
            if current_ontology == ontology:
                ontology_matched = True
            else:
                ontology_matched = False
        elif ontology_matched:
            if line.startswith("mode:"):
                mode = line.split(":")[1].strip()
            elif line.startswith("benchmark type:"):
                benchmark_type = line.split(":")[1].strip()
            elif line.startswith("fmax:") and mode == "full":
                fmax = float(line.split(":")[1].strip())
                key = (benchmark_type, ontology, model)
                result_data.setdefault(key, {})['fmax'] = fmax
            elif line.startswith("threshold giving fmax:") and mode == "full":
                threshold = float(line.split(":")[1].strip())
                key = (benchmark_type, ontology, model)
                result_data.setdefault(key, {})['threshold'] = threshold

    return result_data


# Iterate over all files in the output directory
output_dir = "output_files"
fmax_results = []
threshold_results = []

for file_name in os.listdir(output_dir):
    file_path = os.path.join(output_dir, file_name)

    # Parse the result file
    result_data = parse_result_file(file_path)

    # Store the extracted data in a list of dictionaries
    if result_data:
        fmaxes = {key: data['fmax'] for key, data in result_data.items() if 'threshold' in data}
        fmax_results.extend(fmaxes.items())

        # Extract the threshold data
        thresholds = {key: data['threshold'] for key, data in result_data.items() if 'threshold' in data}
        threshold_results.extend(thresholds.items())

# Create a DataFrame for fmax values
fmax_df = pd.DataFrame(fmax_results, columns=['benchmark_type_ontology_model', 'results'])
fmax_df[['benchmark_type', 'ontology', 'model']] = pd.DataFrame(fmax_df['benchmark_type_ontology_model'].tolist(), index=fmax_df.index)
fmax_df = fmax_df.drop('benchmark_type_ontology_model', axis=1)

# Pivot the fmax DataFrame
fmax_pivot = fmax_df.pivot(index='model', columns=['benchmark_type', 'ontology'], values='results')

# Create a DataFrame for thresholds
threshold_df = pd.DataFrame(threshold_results, columns=['benchmark_type_ontology_model', 'threshold'])
threshold_df[['benchmark_type', 'ontology', 'model']] = pd.DataFrame(threshold_df['benchmark_type_ontology_model'].tolist(), index=threshold_df.index)
threshold_df = threshold_df.drop('benchmark_type_ontology_model', axis=1)

# Pivot the threshold DataFrame
threshold_pivot = threshold_df.pivot(index='model', columns=['benchmark_type', 'ontology'], values='threshold')

# Print the fmax DataFrame
print("Fmax DataFrame:")
print(fmax_pivot)

# Print the threshold DataFrame
print("\nThreshold DataFrame:")
print(threshold_pivot)

Fmax DataFrame:
benchmark_type             NK        LK        NK        LK        NK  \
ontology                  BPO       BPO       CCO       CCO       MFO   
model                                                                   
files/esm2           0.320937  0.340213  0.575237  0.570138  0.504384   
files/protbert       0.260052  0.315406  0.538194  0.537135  0.398696   
files/proteinbert    0.311029  0.362275  0.571419  0.543442  0.511744   
files/prott5         0.324303  0.331968  0.588554  0.566115  0.531228   
files/seqvec         0.305571  0.292032  0.561114  0.540299  0.496337   
files/theirprotbert  0.297858  0.330947  0.567013  0.538607  0.472835   

benchmark_type             LK  
ontology                  MFO  
model                          
files/esm2           0.441647  
files/protbert       0.348051  
files/proteinbert    0.451957  
files/prott5         0.462393  
files/seqvec         0.421948  
files/theirprotbert  0.431698  

Threshold DataFrame:
benchmark_type  

In [51]:
results

[]

In [2]:
#this cell has been created with chatgpt
import os
import pandas as pd

def parse_result_file(file_path):
    result_data = {}
    ontology_matched = False
    model = file_path.split("_")[1]
    ontology = file_path.split("_")[-1].split(".")[0].upper()

    with open(file_path, 'r') as file:
        lines = file.readlines()

    mode = None
    benchmark_type = None

    for line in lines:
        line = line.strip()
        if line.startswith("ontology:"):
            current_ontology = line.split(":")[1].strip().upper()
            if current_ontology == ontology:
                ontology_matched = True
            else:
                ontology_matched = False
        elif ontology_matched:
            if line.startswith("mode:"):
                mode = line.split(":")[1].strip()
            elif line.startswith("benchmark type:"):
                benchmark_type = line.split(":")[1].strip()
            elif line.startswith("fmax:") and mode == "full":
                fmax = float(line.split(":")[1].strip())
                key = (benchmark_type, ontology, model)
                result_data.setdefault(key, {})['fmax'] = fmax
            elif line.startswith("threshold giving fmax:") and mode == "full":
                threshold = float(line.split(":")[1].strip())
                key = (benchmark_type, ontology, model)
                result_data.setdefault(key, {})['threshold'] = threshold

    return result_data


# Iterate over all files in the output directory
output_dir = "assessment-output-cosine"
fmax_results = []
threshold_results = []

for file_name in os.listdir(output_dir):
    file_path = os.path.join(output_dir, file_name)

    # Parse the result file
    result_data = parse_result_file(file_path)

    # Store the extracted data in a list of dictionaries
    if result_data:
        fmaxes = {key: data['fmax'] for key, data in result_data.items() if 'threshold' in data}
        fmax_results.extend(fmaxes.items())

        # Extract the threshold data
        thresholds = {key: data['threshold'] for key, data in result_data.items() if 'threshold' in data}
        threshold_results.extend(thresholds.items())

fmax_results
# # Create a DataFrame for fmax values
# fmax_df = pd.DataFrame(fmax_results, columns=['benchmark_type_ontology_model', 'results'])
# fmax_df[['benchmark_type', 'ontology', 'model']] = pd.DataFrame(fmax_df['benchmark_type_ontology_model'].tolist(), index=fmax_df.index)
# fmax_df = fmax_df.drop('benchmark_type_ontology_model', axis=1)

# # Pivot the fmax DataFrame
# fmax_pivot = fmax_df.pivot(index='model', columns=['benchmark_type', 'ontology'], values='results')

# # Create a DataFrame for thresholds
# threshold_df = pd.DataFrame(threshold_results, columns=['benchmark_type_ontology_model', 'threshold'])
# threshold_df[['benchmark_type', 'ontology', 'model']] = pd.DataFrame(threshold_df['benchmark_type_ontology_model'].tolist(), index=threshold_df.index)
# threshold_df = threshold_df.drop('benchmark_type_ontology_model', axis=1)

# # Pivot the threshold DataFrame
# threshold_pivot = threshold_df.pivot(index='model', columns=['benchmark_type', 'ontology'], values='threshold')

# # Print the fmax DataFrame
# print("Fmax DataFrame:")
# print(fmax_pivot)

# # Print the threshold DataFrame
# print("\nThreshold DataFrame:")
# print(threshold_pivot)

[(('NK', 'BPO', 'output'), 0.3191653498509426),
 (('LK', 'BPO', 'output'), 0.3431828174196418),
 (('NK', 'CCO', 'output'), 0.5839474233111744),
 (('LK', 'CCO', 'output'), 0.56327549723111),
 (('NK', 'MFO', 'output'), 0.5108162370889814),
 (('LK', 'MFO', 'output'), 0.4546069655838999),
 (('NK', 'BPO', 'output'), 0.26406124274401904),
 (('LK', 'BPO', 'output'), 0.3145956336279498),
 (('NK', 'CCO', 'output'), 0.5372396511878191),
 (('LK', 'CCO', 'output'), 0.5343917945928709),
 (('NK', 'MFO', 'output'), 0.3979736574652955),
 (('LK', 'MFO', 'output'), 0.3507041730403241),
 (('NK', 'BPO', 'output'), 0.3089987882973226),
 (('LK', 'BPO', 'output'), 0.36730118795227124),
 (('NK', 'CCO', 'output'), 0.5729976347154231),
 (('LK', 'CCO', 'output'), 0.5362689622488322),
 (('NK', 'MFO', 'output'), 0.52246873663165),
 (('LK', 'MFO', 'output'), 0.4449159105223092),
 (('NK', 'BPO', 'output'), 0.3219275394528631),
 (('LK', 'BPO', 'output'), 0.3328045729998566),
 (('NK', 'CCO', 'output'), 0.5861962463960