In [1]:
from pathlib import Path
import os
import sys
import gzip
import wget

curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))

from src.utils.data import read_fasta

In [None]:
link = 'https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.dat.gz'
filename = 'uniprot_sprot.dat.gz'
unzipped_filename = 'uniprot_sprot.dat'

# Download the file from the web
wget.download(link, filename)

# Unzip the downloaded file
with gzip.open(filename, 'rb') as f_in:
    with open(unzipped_filename, 'wb') as f_out:
        f_out.write(f_in.read())

print(f"File {filename} has been downloaded and unzipped to {unzipped_filename}.")

In [3]:
import pandas as pd
from Bio import SwissProt

# Extract data from SwissProt records
data = []
# See https://biopython.org/docs/1.75/api/Bio.SwissProt.html and https://web.expasy.org/docs/userman.html
with open('../data/swissprot/uniprot_sprot.dat', 'r') as f:
    records = SwissProt.parse(f)
    for record in records:
        # Extract sequence ID
        seq_id = record.accessions[0]
        
        # Extract sequence
        sequence = record.sequence

        # Extract GO ids
        go_ids = [ref[1] for ref in record.cross_references if ref[0] == "GO" and len(ref) > 0]
        
        # Extract free-text description
        description = record.description

        # Extract organism and organism classification
        organism = record.organism
        organism_classification = record.organism_classification

        # Extract organelle
        organelle = record.organelle
        
        # Extract CC line as a dictionary
        cc = {}
        for comment in record.comments:
            key, value = comment.split(": ", 1)
            cc[key] = value
        
        data.append([seq_id, sequence, go_ids, description, organism, organism_classification, organelle, cc])

Bad pipe message: %s [b'l\xa5L\xb7\xef\xae\x18\x07pJ\xf0\x87\xbe\xa0\xb6\r', b' ~q\xe3*\xc8\n\xb3v\xa5&']
Bad pipe message: %s [b'~\xa8\xe3\xc2\x81\x1a\xed>Ne\x91n\xfb\x10F\xc9\x1c\x1b \xf0}\xae\xc2\x98\xe1U\xe1\xf2\xcak\xd9\x1b\xf0a\xc8\x91\x88\xbb\x199\x0e\xf4\xf7\xd4o-\x1cE\xf6\r\xec\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00']
Bad pipe message: %s [b'\x0c\x00\x00\t127.0.0.1']
Bad pipe message: %s [b'3)\xb4\x7f\xe7\x8f\xb3S\xb3\xe9(\xf2;\x9d\xf0\xc0', b"\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x00", b'\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00<\x005\x00/\x00\x9a\x00\x99\xc0\x07\xc0\x11\x00\x96\x00\x05\x00\xff\x01\x00\x00j\x00\x00\x00\x0e\x00\x0c\x00\x00']
Bad pipe message: %s [b'\x9c\xbb\xfc\xae\xb9\x02\xde\x

In [2]:
print(data[0])

['Q6GZX4', 'MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL', ['GO:0046782'], 'RecName: Full=Putative transcription factor 001R;', 'Frog virus 3 (isolate Goorha) (FV-3).', ['Viruses', 'Varidnaviria', 'Bamfordvirae', 'Nucleocytoviricota', 'Megaviricetes', 'Pimascovirales', 'Iridoviridae', 'Alphairidovirinae', 'Ranavirus'], '', {'FUNCTION': 'Transcription activation. {ECO:0000305}.'}]


In [3]:
# Convert data into a pandas DataFrame
df_2023 = pd.DataFrame(data, columns=["seq_id", "sequence", "go_ids", "description", "organism", "organism_classification", "organelle", "cc"])

# Create a new column with the subcellular location
df_2023['subcellular_location'] = df_2023.cc.apply(lambda x: x['SUBCELLULAR LOCATION'] if 'SUBCELLULAR LOCATION' in x else None)

# import sequence embeddings from ../data/embeddings/frozen_proteinfer_sequence_embeddings.pkl
import pickle

# Load the sequence embeddings from the file
with open('../data/embeddings/frozen_proteinfer_sequence_embeddings.pkl', 'rb') as f:
    sequence_embeddings = pickle.load(f)

# Make a set of the sequence strings
sequence_strings_2019 = set(sequence_embeddings.keys())

In [4]:
# Find sequence ids  that are in df but not in sequence_strings
df_2023['in_ProteInfer_dataset'] = df_2023.seq_id.apply(lambda x: x in sequence_strings_2019)

# Print 5 example sequences from df.sequence
print(df_2023.seq_id.head())

# Print 5 example sequences from sequence_strings
print(list(sequence_strings_2019)[:5])

# Count the number of sequences that are in df but not in sequence_strings
print(f"Number of sequences in df_2023 but not in ProteInfer dataset: {df_2023.in_ProteInfer_dataset.value_counts()[False]}")
print(f"Number of sequences in df_2023: {len(df_2023)}")
print(f"Number of sequences in ProteInfer dataset: {len(sequence_strings_2019)}")

0    Q6GZX4
1    Q6GZX3
2    Q197F8
3    Q197F7
4    Q6GZX2
Name: seq_id, dtype: object
['Q8SS29', 'A4QKE2', 'B0BVP3', 'Q55724', 'Q7W2N9']
Number of sequences in df_2023 but not in ProteInfer dataset: 47493
Number of sequences in df_2023: 569793
Number of sequences in ProteInfer dataset: 522607


In [5]:
# Import label embeddings from ../data/embeddings/frozen_proteinfer_label_embeddings.pkl
import pickle

# Load the label embeddings from the file
with open('../data/embeddings/frozen_PubMedBERT_label_embeddings.pkl', 'rb') as f:
    label_embeddings_2019 = pickle.load(f)

# Make a set of the GO labels from the label embeddings
label_ids_2019 = set(label_embeddings_2019.keys())
print(len(label_ids_2019))

# Make a set from all the GO labels that occur in the data
label_ids_2023 = set([item for sublist in df_2023.go_ids for item in sublist])
print(len(label_ids_2023))

47401
29283


In [6]:
# Find GO labels that are in go_label_strings but not in label_strings
print(f"Number of GO labels in go_label_strings but not in label_strings: {len(label_ids_2023 - label_ids_2019)}")

# Print out 10 examples of GO labels that are in go_label_strings but not in label_strings
print(list(label_ids_2023 - label_ids_2019)[:10])

Number of GO labels in go_label_strings but not in label_strings: 666
['GO:0140752', 'GO:0140499', 'GO:0140947', 'GO:0140961', 'GO:0140831', 'GO:0106223', 'GO:0140455', 'GO:0106370', 'GO:0120216', 'GO:0120283']


In [7]:
# Find added labels
new_go_labels = label_ids_2023 - label_ids_2019

# Find protein sequences with added labels
df_2023['new_labels'] = df_2023.go_ids.apply(lambda x: set(x) & new_go_labels)

# Count how many rows have 'in_Proteinfer_dataset' == False
print(f"Number of rows with 'in_ProteInfer_dataset' == False: {len(df_2023[df_2023.in_ProteInfer_dataset == False])}")

# Count how many rows have 'in_Proteinfer_dataset' == False and 'new_labels' != set()
print(f"Number of rows with 'in_ProteInfer_dataset' == False and 'new_labels' != set(): {len(df_2023[(df_2023.in_ProteInfer_dataset == False) & (df_2023.new_labels != set())])}")

# Create a new dataframe out of those that meet that criteria
df_2023_new_sequences_and_labels = df_2023[(df_2023.in_ProteInfer_dataset == False) & (df_2023.new_labels != set())]

Number of rows with 'in_ProteInfer_dataset' == False: 47493
Number of rows with 'in_ProteInfer_dataset' == False and 'new_labels' != set(): 917


In [55]:
# Create a new dataframe containiners seq_id, sequence, and go_ids
filtered_df = df_2023_new_sequences_and_labels[['seq_id', 'sequence', 'new_labels']]

# Set of 20 common amino acids
common_amino_acids = set(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])

# Check which rows contain amino acids other than the 20 common ones
filtered_df['non_common_amino_acids'] = filtered_df.sequence.apply(lambda x: set(x) - common_amino_acids)

# Filter to only contain rows that contain common amino acids
SwissProt_2023_unseen_sequences_and_labels = filtered_df[filtered_df.non_common_amino_acids == set()]

# Rename "new_ids" to "go_ids"
SwissProt_2023_unseen_sequences_and_labels.rename(columns={'new_labels': 'go_ids'}, inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df['non_common_amino_acids'] = filtered_df.sequence.apply(lambda x: set(x) - common_amino_acids)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  SwissProt_2023_unseen_sequences_and_labels.rename(columns={'new_labels': 'go_ids'}, inplace=True)


In [56]:
# Save the dataframe to a pickle file
SwissProt_2023_unseen_sequences_and_labels.to_pickle('../data/zero_shot/SwissProt_2023_unseen_sequences_and_labels.pkl')

In [57]:
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO

df = SwissProt_2023_unseen_sequences_and_labels

# Convert dataframe to FASTA format and save to a file
records = []
for _, row in df.iterrows():
    seq_record = SeqRecord(Seq(row['sequence']),
                           id=row['seq_id'],
                           description=" ".join(row['go_ids']))
    records.append(seq_record)

# Save to FASTA file
fasta_file = "../data/zero_shot/SwissProt_2023_unseen_sequences_and_labels.fasta"
SeqIO.write(records, fasta_file, "fasta")

855

In [58]:
from src.data.datasets import ProteinDataset
from torch.utils.data import DataLoader

%load_ext autoreload
%autoreload 2

# Paths for the dataset
paths = {
    "data_path": fasta_file,
}

# Create the dataset
protein_dataset = ProteinDataset(paths)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [64]:
from src.data.datasets import collate_variable_sequence_length

# Create the DataLoader
batch_size = 1  # You can adjust this value as needed
protein_dataloader = DataLoader(protein_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_variable_sequence_length)

# Now you can iterate over the DataLoader to get batches of data
for batch in protein_dataloader:
    # Unpack the batch
    sequence_ids, sequence_onehots, label_multihots, sequence_lengths = batch
    print("Original ID 1: ", protein_dataset.int2sequence_id[sequence_ids[0].item()])
    print(f"Sequence IDs: {sequence_ids}")
    print(f"Sequence onehots: {sequence_onehots}")
    print(f"Label multihots: {label_multihots}")
    print(f"Sequence lengths: {sequence_lengths}")
    break

Original ID 1:  A0A443HJY8
Sequence IDs: tensor([89])
Sequence onehots: tensor([[[0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
Label multihots: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
 