In [2]:
import os
from Bio import SeqIO

records = list(SeqIO.parse('/work/jiaqi/ProtRepr/data/swissprot/uniprot_sprot_10_1022.fasta', 'fasta'))
pid_seqs = [f'>{record.id.split("|")[1]}\n{str(record.seq)}\n' for record in records]
print(f'Number of sequences: {len(pid_seqs)}')
with open('/work/jiaqi/ProtRepr/data/swissprot/uniprot_sprot_10_1022_cleaned.fasta', 'w') as f:
    f.writelines(pid_seqs)

Number of sequences: 551965


In [3]:
import random
import numpy as np

random.seed(42)
random.shuffle(pid_seqs)
num_batches = 6
batch_size = int(np.ceil(len(pid_seqs) / num_batches))
for i in range(num_batches):
    with open(f'/work/jiaqi/ProtRepr/data/swissprot/uniprot_sprot_10_1022_cleaned_batch{i}.fasta', 'w') as f:
        f.writelines(pid_seqs[i*batch_size:(i+1)*batch_size])
    print(f'Batch {i} has {len(pid_seqs[i*batch_size:(i+1)*batch_size])} sequences')

Batch 0 has 91995 sequences
Batch 1 has 91995 sequences
Batch 2 has 91995 sequences
Batch 3 has 91995 sequences
Batch 4 has 91995 sequences
Batch 5 has 91990 sequences


In [4]:
seqs = [str(record.seq) for record in records]
total_aa = sum([len(seq) for seq in seqs])
print(f'Total number of amino acids: {total_aa}')

Total number of amino acids: 177708347


In [5]:
177708347 * 1280 * 32 / 8 / 1024 / 1024 / 1024

847.3794317245483

In [1]:
import torch

data = torch.load('../data/sprot_esm1b_emb_per_residue/A0A009IHW8.pt')

In [2]:
data

{'label': 'A0A009IHW8',
 'representations': {33: tensor([[ 0.0346, -0.0252,  0.2082,  ..., -0.0231, -0.1055, -0.2171],
          [-0.0496,  0.0191,  0.1307,  ...,  0.0703,  0.1753,  0.1319],
          [ 0.0459,  0.1039, -0.1113,  ...,  0.0921,  0.0433,  0.2043],
          ...,
          [-0.1886,  0.1124, -0.2455,  ..., -0.2014,  0.2529,  0.1187],
          [ 0.1776,  0.0238, -0.1586,  ...,  0.0368,  0.0579,  0.1807],
          [ 0.0412, -0.0378,  0.0217,  ...,  0.0182,  0.0892,  0.0313]])},
 'mean_representations': {33: tensor([-0.0296,  0.0830, -0.1114,  ..., -0.0111, -0.0451,  0.1137])}}

In [3]:
data['representations'][33].shape, data['mean_representations'][33].shape

(torch.Size([269, 1280]), torch.Size([1280]))