In [1]:
import json
import os
import subprocess
import torch
# import transformers
from transformers import PreTrainedModel
import re
from standalone_hyenadna import HyenaDNAModel
from standalone_hyenadna import CharacterTokenizer


In [2]:
# helper 1
def inject_substring(orig_str):
    """Hack to handle matching keys between models trained with and without
    gradient checkpointing."""

    # modify for mixer keys
    pattern = r"\.mixer"
    injection = ".mixer.layer"

    modified_string = re.sub(pattern, injection, orig_str)

    # modify for mlp keys
    pattern = r"\.mlp"
    injection = ".mlp.layer"

    modified_string = re.sub(pattern, injection, modified_string)

    return modified_string

# helper 2
def load_weights(scratch_dict, pretrained_dict, checkpointing=False):
    """Loads pretrained (backbone only) weights into the scratch state dict."""

    # loop thru state dict of scratch
    # find the corresponding weights in the loaded model, and set it

    # need to do some state dict "surgery"
    for key, value in scratch_dict.items():
        if 'backbone' in key:
            # the state dicts differ by one prefix, '.model', so we add that
            key_loaded = 'model.' + key
            # breakpoint()
            # need to add an extra ".layer" in key
            if checkpointing:
                key_loaded = inject_substring(key_loaded)
            try:
                scratch_dict[key] = pretrained_dict[key_loaded]
            except:
                raise Exception('key mismatch in the state dicts!')

    # scratch_dict has been updated
    return scratch_dict


In [3]:
class HyenaDNAPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    base_model_prefix = "hyenadna"

    def __init__(self, config):
        pass

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    @classmethod
    def from_pretrained(cls,
                        path,
                        model_name,
                        download=False,
                        config=None,
                        device='cpu',
                        use_head=False,
                        n_classes=2,
                      ):
        # first check if it is a local path
        pretrained_model_name_or_path = os.path.join(path, model_name)
        if os.path.isdir(pretrained_model_name_or_path) and download == False:
            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))
        else:
            hf_url = f'https://huggingface.co/LongSafari/{model_name}'

            subprocess.run(f'rm -rf {pretrained_model_name_or_path}', shell=True)
            command = f'mkdir -p {path} && cd {path} && git lfs install && git clone {hf_url}'
            subprocess.run(command, shell=True)

            if config is None:
                config = json.load(open(os.path.join(pretrained_model_name_or_path, 'config.json')))

        scratch_model = HyenaDNAModel(**config, use_head=use_head, n_classes=n_classes)  # the new model format
        loaded_ckpt = torch.load(
            os.path.join(pretrained_model_name_or_path, 'weights.ckpt'),
            map_location=torch.device(device)
        )

        # need to load weights slightly different if using gradient checkpointing
        if config.get("checkpoint_mixer", False):
            checkpointing = config["checkpoint_mixer"] == True or config["checkpoint_mixer"] == True
        else:
            checkpointing = False

        # grab state dict from both and load weights
        state_dict = load_weights(scratch_model.state_dict(), loaded_ckpt['state_dict'], checkpointing=checkpointing)

        # scratch model has now been updated
        scratch_model.load_state_dict(state_dict)
        print("Loaded pretrained weights ok!")
        return scratch_model

# Inference (450k to 1M tokens)!

If all you're interested in is getting embeddings on long DNA sequences
(inference), then we can do that right here in Colab!

[source](https://github.com/HazyResearch/hyena-dna/blob/main/huggingface.py)

In [4]:
import json
import os
import subprocess
# import transformers
from transformers import PreTrainedModel

In [5]:
'''
this selects which backbone to use, and grabs weights/ config from HF
4 options:
  'hyenadna-tiny-1k-seqlen'   # fine-tune on colab ok
  'hyenadna-small-32k-seqlen'
  'hyenadna-medium-160k-seqlen'  # inference only on colab
  'hyenadna-medium-450k-seqlen'  # inference only on colab
  'hyenadna-large-1m-seqlen'  # inference only on colab
'''

# you only need to select which model to use here, we'll do the rest!
pretrained_model_name = 'hyenadna-small-32k-seqlen'

max_lengths = {
    'hyenadna-tiny-1k-seqlen': 1024,
    'hyenadna-small-32k-seqlen': 32768,
    'hyenadna-medium-160k-seqlen': 160000,
    'hyenadna-medium-450k-seqlen': 450000,  # T4 up to here
    'hyenadna-large-1m-seqlen': 1_000_000,  # only A100 (paid tier)
}

max_length = max_lengths[pretrained_model_name]  # auto selects

In [6]:
## Configs

# data settings:
use_padding = True
rc_aug = False  # reverse complement augmentation
add_eos = False  # add end of sentence token

# we need these for the decoder head, if using
use_head = False
n_classes = 2  # not used for embeddings only

# you can override with your own backbone config here if you want,
# otherwise we'll load the HF one in None
backbone_cfg = None

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cpu


### Load Model

In [7]:
model = HyenaDNAPreTrainedModel.from_pretrained(
    './checkpoints',
    pretrained_model_name,
    download=True,
    config=backbone_cfg,
    device=device,
    use_head=use_head,
    n_classes=n_classes,
)


Updated Git hooks.
Git LFS initialized.


Cloning into 'hyenadna-small-32k-seqlen'...


Loaded pretrained weights ok!


### Load Tokenizer

In [8]:
# create tokenizer
tokenizer = CharacterTokenizer(
    characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters
    model_max_length=max_length,
)

### Load Data

In [41]:
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, TensorDataset

In [38]:
data_path = '/Users/djemec/data/mammal_ref/ref_genome/'
example_ref = data_path + 'GCF_017591425.1_NIBS_Ocur_1.0_genomic.fna.gz'

out_path = '/Users/djemec/data/gptomics/mam_predict/'
raw_examples = out_path + 'raw_examples/'
bpe_path = out_path + 'bpe_mammal/'

In [11]:
def str_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:  # Open the file
        data = file.read()  # Read all the contents into a string
    return data.strip().replace('\n','')

def load_csvs_from_folder(folder_path):
    all_files = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.csv')]
    df_list = [pd.read_csv(file) for file in all_files]
    combined_df = pd.concat(df_list, ignore_index=True)
    return combined_df

In [20]:
mam_df = load_csvs_from_folder(raw_examples)

print(f'number of examples {len(mam_df)}')
mam_df.head()

number of examples 170446


Unnamed: 0,filename,section,example,n_per
0,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",TGGAAATGACAATTTTCTGGGGTTTAATAATTTAAAGTAAGTTTTA...,0.0
1,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",GTGTGCACCAGCATAACTGGCTAAAAGTCTTCTGGGATGCAGATGC...,0.0
2,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",ttgctttttttctttctttctactttattttgagaaaaagcCCAGT...,0.0
3,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",CACACTAAAGATGCTCAAGCTAGCCACACTCTATAACATACTCTGT...,0.0375
4,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",GAAAAGTTCCAGTTCTTAAGCAGATTTTTGATCAATCAATCTAAAC...,0.0


In [21]:
def extract_species(string):
    spec = string
    # remove GCF prefix
    spec = re.sub(r'GCF_\d+\.\d+_', '', spec)
    # remove suffix
    spec = re.sub(r'.fna.gz', '', spec)
    spec = re.sub(r'_genomic', '', spec)
    spec = re.sub(r'_genome_assembly', '', spec)

    #readability
    spec = re.sub(r'_', ' ', spec)
    spec = re.sub(r'-', ' ', spec)

    return spec

mam_df['species'] = mam_df.filename.apply(extract_species)

def extract_contig(string):
    #by string 
    match = re.search(r'(chromosome|chr)\s([a-zA-Z0-9]+),', string, re.IGNORECASE)
    if match:
        return match.group(0)[:-1]
    
    base = string.split(',')[1].strip()
    if base == 'complete genome':
        return 'complete genome'
    base = string.split(',')[1].strip().split(' ')
    if len(base) >= 2:
        return str(base[1])
    else:
        return ''

mam_df['contig'] = mam_df.section.apply(lambda i: extract_contig(i))


mam_df.head()

Unnamed: 0,filename,section,example,n_per,species,contig
0,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",TGGAAATGACAATTTTCTGGGGTTTAATAATTTAAAGTAAGTTTTA...,0.0,CAROLI EIJ v1.1,chromosome 1
1,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",GTGTGCACCAGCATAACTGGCTAAAAGTCTTCTGGGATGCAGATGC...,0.0,CAROLI EIJ v1.1,chromosome 1
2,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",ttgctttttttctttctttctactttattttgagaaaaagcCCAGT...,0.0,CAROLI EIJ v1.1,chromosome 1
3,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",CACACTAAAGATGCTCAAGCTAGCCACACTCTATAACATACTCTGT...,0.0375,CAROLI EIJ v1.1,chromosome 1
4,GCF_900094665.2_CAROLI_EIJ_v1.1_genomic.fna.gz,">NC_034570.1 Mus caroli chromosome 1, CAROLI_E...",GAAAAGTTCCAGTTCTTAAGCAGATTTTTGATCAATCAATCTAAAC...,0.0,CAROLI EIJ v1.1,chromosome 1


In [22]:
n_per_threshold = 0.050
print(f'Dropping {len(mam_df[mam_df.n_per >= n_per_threshold])} rows for having N content over {n_per_threshold}.')

# drop rows and only copy two over
ex_df = mam_df[mam_df.n_per < n_per_threshold][['species','example']].copy()

# drop mam_df from memory
del mam_df

# cleanup the example column
ex_df['example'] = ex_df.example.str.strip()

# rename columns
ex_df.columns = ['labels_raw','sequence']

ex_df.head()

Dropping 4516 rows for having N content over 0.05.


Unnamed: 0,labels_raw,sequence
0,CAROLI EIJ v1.1,TGGAAATGACAATTTTCTGGGGTTTAATAATTTAAAGTAAGTTTTA...
1,CAROLI EIJ v1.1,GTGTGCACCAGCATAACTGGCTAAAAGTCTTCTGGGATGCAGATGC...
2,CAROLI EIJ v1.1,ttgctttttttctttctttctactttattttgagaaaaagcCCAGT...
3,CAROLI EIJ v1.1,CACACTAAAGATGCTCAAGCTAGCCACACTCTATAACATACTCTGT...
4,CAROLI EIJ v1.1,GAAAAGTTCCAGTTCTTAAGCAGATTTTTGATCAATCAATCTAAAC...


In [23]:
uniq_lbl = sorted(set(ex_df.labels_raw))
label_to_index = {value: idx for idx, value in enumerate(uniq_lbl)}
print(label_to_index)
ex_df['labels'] = ex_df.labels_raw.apply(lambda l: label_to_index[l])

ex_df.head()

{'Asia NLE v1': 0, 'CAROLI EIJ v1.1': 1, 'GRCh38.p14': 2, 'GRCm39': 3, 'HMol V2': 4, 'Kamilah GGO v0': 5, 'Mhudiblu PPA v0': 6, 'NHGRI mPanTro3 v1.1 hic.freeze pri': 7, 'NHGRI mPonAbe1 v1.1 hic.freeze pri': 8, 'NHGRI mPonPyg2 v1.1 hic.freeze pri': 9, 'NIBS Ocur 1.0': 10, 'PAHARI EIJ v1.1': 11, 'Rrattus CSIRO v1': 12, 'mRatBN7.2': 13}


Unnamed: 0,labels_raw,sequence,labels
0,CAROLI EIJ v1.1,TGGAAATGACAATTTTCTGGGGTTTAATAATTTAAAGTAAGTTTTA...,1
1,CAROLI EIJ v1.1,GTGTGCACCAGCATAACTGGCTAAAAGTCTTCTGGGATGCAGATGC...,1
2,CAROLI EIJ v1.1,ttgctttttttctttctttctactttattttgagaaaaagcCCAGT...,1
3,CAROLI EIJ v1.1,CACACTAAAGATGCTCAAGCTAGCCACACTCTATAACATACTCTGT...,1
4,CAROLI EIJ v1.1,GAAAAGTTCCAGTTCTTAAGCAGATTTTTGATCAATCAATCTAAAC...,1


### Test a sequence

In [31]:
# pull out the first example
dna_seq = ex_df.sequence[0]
dna_seq

'TGGAAATGACAATTTTCTGGGGTTTAATAATTTAAAGTAAGTTTTAGACCCACTCTAGGAAAGAAAGCCTGATGCATCATATAAATGGACAGCAGTGTGTAAgccaaaaatagaaaagaaatttaattttttcttttaagaaacagAGCCTTCTATTTAACAGGATAAAGAGGCTTACTAGTGAACTTTGTGAATGTTAAAGTTGACAAAAATGATACCCATatgtttgttctggttttgtctttATGAAGTTAAGATTTATCGAACTAAAATTCTGGTTGGAAAGAATGTTAGTTACTTGCTTTGTCCCTgaatgacaaataaaataaaataaaataaaggaggaCATGACCGCTCCCAAAGTACAGAGAGTTTTATTGCAGATACAAGGGAGAGACCAGGCTGACTATGGCCCCAGACTAAACTGAACTGTGAgaggagtgagagggagggagagcaagagaggagccaGGGAACAAGAGGGGGCAGCCTGGGCTAAGAGCCAGCCTGGGTAAAGTAGCCAAAATGGCTGGGTTCTATAGGGatcaggctgggggagggaaggggaagcctAGCCCCTGGGGGTGGAGGCTGGAGTGAGGGGGGCTGTGCTGGGAGTAGGTTCTGAGGCATGCTGGGAGAGGCCCCAGGTACTGCGTGAGACCAGCCTCAGGGATGAGACCAAAAAGTCCCGACTCTCTATGTCCTACCTGGCActctgcacagacacacagacacacacacacagagaaatgaacaACTCAGAAACACCACTCTTAAGCTTATAGGCTTAATTTTCAGGAACCCACTCTGAATTAACTCTGAGAGATAATTATATTTCActttattaaaaaagaattcGCATTGGTGGTTCTTACACAGCTGACTAAAAAACCTAAAGACAGGCCTCCGGAAGAAACTGGAGGCAACTTCCCACTGACAAAGTTCTAAAACAGgctgagggaaggggagggagcagagggagaggggagggggagg

In [52]:
def hyena_embed(sequence):
    # tokenize
    tok_seq = tokenizer(sequence)
    tok_seq = tok_seq["input_ids"]  # grab ids

    # place on device, convert to tensor
    tok_seq = torch.LongTensor(tok_seq).unsqueeze(0)  # unsqueeze for batch dim
    tok_seq = tok_seq.to(device)
    
    with torch.inference_mode():
        embeddings = model(tok_seq)

    return embeddings

In [53]:
# prep model and forward
model.to(device)
model.eval()
hyena_embed(dna_seq)

tensor([[[-0.7127,  0.6789,  0.2612,  ..., -2.1606,  0.5249,  1.0987],
         [-0.2605,  0.6501, -0.9796,  ..., -1.1980,  1.0538,  2.1562],
         [ 0.0924,  0.8900, -0.5685,  ..., -1.3519,  1.1949,  2.5591],
         ...,
         [-0.3191,  0.4715, -0.6948,  ..., -1.4375,  1.6862,  2.3110],
         [-0.9599,  0.6181, -0.3071,  ..., -1.3796,  1.7013,  1.7353],
         [-0.5913,  1.0730, -1.3255,  ..., -1.2581,  2.0096,  1.5729]]])

In [54]:
sampled_df = ex_df.sample(n=1000, random_state=42).copy()

In [68]:
# prep model and forward
model.to(device)
model.eval()

tqdm.pandas()

sampled_df['hyena'] = sampled_df.sequence.progress_apply(lambda i: hyena_embed(i).numpy())

sampled_df.head()

  0%|          | 0/1000 [00:00<?, ?it/s]

Unnamed: 0,labels_raw,sequence,labels,hyena
91595,NHGRI mPonAbe1 v1.1 hic.freeze pri,AAATTGTCATGATCAAGTATAAACTATTAGAAGGAAACAAATGAAA...,8,"[[[-0.7126597, 0.6789453, 0.2611758, -0.923249..."
100542,NHGRI mPonAbe1 v1.1 hic.freeze pri,cacatcgtaaaggagtttctgacaatgcttctgttttgtttttatg...,8,"[[[-0.7126602, 0.6789465, 0.26118058, -0.92324..."
114581,PAHARI EIJ v1.1,ACAGGCCTTCTGTTTGAAGCAGTTTTATCCACCTGGTTCTGGAGGG...,11,"[[[-0.7126611, 0.67894554, 0.26117802, -0.9232..."
87817,Mhudiblu PPA v0,aggtgtagatgtttccatcttcatacagatagcggtgtagatatta...,6,"[[[-0.7126601, 0.6789436, 0.26117998, -0.92324..."
143956,NIBS Ocur 1.0,ctgcttccttcctgttcctCCCACCGACTAGGGGAGGAAAGCAGAC...,10,"[[[-0.71266127, 0.67894554, 0.2611779, -0.9232..."


### Classify Mammals

In [72]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix, accuracy_score
from sklearn.metrics import precision_score, recall_score
from sklearn.model_selection import train_test_split

import seaborn as sns
import numpy as np

In [73]:
def get_scores_mc(y,yhat):
    print('accuracy: ', round(accuracy_score(y,yhat),4))
    print('precision: ', round(precision_score(y,yhat,average='macro'),4))
    print('recall: ', round(recall_score(y,yhat,average='macro'),4))
    print('f1: ', round(f1_score(y,yhat,average='macro'),4))
    print('confusion matrix:\n', confusion_matrix(y,yhat))

In [74]:
y = sampled_df.labels
X_bow = np.vstack(sampled_df.hyena)

print(X_bow.shape,y.shape)

ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 2002 and the array at index 7 has size 1472

In [None]:
X_train,X_test,y_train,y_test = train_test_split(X_bow, y, test_size=0.2)

In [None]:
model = RandomForestClassifier(n_jobs=-1)
model.fit(X_train,y_train)

yhat = model.predict(X_test)

get_scores_mc(y_test,yhat)

In [None]:
cm = confusion_matrix(y_test,yhat)

In [None]:
class_names = [label for label, pos in sorted(label_to_index.items(), key=lambda item: item[1])]

# Normalize the rows of the confusion matrix
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Create a heatmap
plt.figure(figsize=(20, 10))
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names)

# Add labels and title
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')

# Rotate x-axis labels
plt.xticks(rotation=90)

# Show the plot
plt.show()