In [None]:
import numpy as np
import webdataset as wds
import glob
import os
import torch
import pandas as pd
import yaml

embedding_dims = yaml.safe_load(open('../conf/datadims/embedding_dims.yaml', 'r'))
sequence_lengths = yaml.safe_load(open('../conf/datadims/seq_lengths.yaml', 'r'))

DATA_DIR = '../data'

requested_memory_df = []

for task in sequence_lengths.keys():

    n_embeddings = sequence_lengths[task]

    task = task.replace('_length', '')  # remove '_length' suffix if present

    for file in os.listdir(os.path.join(DATA_DIR, task)):
        if not file.startswith(task) or not file.endswith('.bed'):
            continue

        task_annotations = pd.read_csv(os.path.join(DATA_DIR, task, file), sep='\t', engine='python')
        n_samples = len(task_annotations)
        
        for model in embedding_dims.keys():

            embedding_dim = embedding_dims[model]
            
            requested_memory = (np.float32().itemsize * embedding_dim * n_embeddings * n_samples) / (1024 ** 3)
            
            requested_memory_df.append({
                'task': file.split('.')[0],
                'model': model,
                'embedding_dim': embedding_dim,
                'n_embeddings': n_embeddings,
                'n_samples': n_samples,
                'requested_memory_GB': requested_memory
            })

requested_memory_df = pd.DataFrame(requested_memory_df)
print(len(requested_memory_df), 'total rows')

147 total rows


In [135]:
selected_models = ['awdlstm', 'nt_transformer_ms', 'nt_transformer_1000g', 'nt_transformer_human_ref', 'hyenadna-tiny-1k', 'hyenadna-large-1m', 'resnetlm', 'dnabert2']

selected_models_df = requested_memory_df[requested_memory_df['model'].isin(selected_models)]
selected_models_df = selected_models_df.sort_values(by=['task','requested_memory_GB'], ascending=[True, False], ignore_index=True)
selected_models_df

Unnamed: 0,task,model,embedding_dim,n_embeddings,n_samples,requested_memory_GB
0,chromatin_accessibility,nt_transformer_ms,2560,512,2062129,10068.989258
1,chromatin_accessibility,nt_transformer_1000g,2560,512,2062129,10068.989258
2,chromatin_accessibility,nt_transformer_human_ref,1280,512,2062129,5034.494629
3,chromatin_accessibility,dnabert2,768,512,2062129,3020.696777
4,chromatin_accessibility,resnetlm,256,512,2062129,1006.898926
5,chromatin_accessibility,hyenadna-large-1m,256,512,2062129,1006.898926
6,chromatin_accessibility,hyenadna-tiny-1k,128,512,2062129,503.449463
7,chromatin_accessibility,awdlstm,64,512,2062129,251.724731
8,cpg_methylation,nt_transformer_ms,2560,512,959039,4682.807617
9,cpg_methylation,nt_transformer_1000g,2560,512,959039,4682.807617


In [136]:
print('Total memory required for storing embeddings (not compressed): ', requested_memory_df['requested_memory_GB'].sum() / 1024, 'TB')

Total memory required for storing embeddings (not compressed):  100.42015554659883 TB
