# Notes

## Dataloader (pretraining)
* [HyenaDNA HG38 dataloader](https://github.com/HazyResearch/hyena-dna/blob/main/src/dataloaders/datasets/hg38_dataset.py)
* HyenaDNA used training/validation intervals from *Effective gene expression prediction from sequence by integrating long-range interactions.* paper.

## Tokenizer?
* Need to check HyenaDNA; I think their Jupyter contained some code of their tokenizer

## Model
* [Original MAMBA repo](https://github.com/state-spaces/mamba)
    * [benchmark_generation_mamba_simple.py](https://github.com/state-spaces/mamba/blob/main/benchmarks/benchmark_generation_mamba_simple.py)
    * Uses [mambaLMHeadModel](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L173) form `mixer_seq_simple.py`
    * Uses [MixerModel](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L83)
    * Uses [create_block](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L21)
    * Uses [Block](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/modules/mamba_simple.py#L298) and [MAMBA](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/modules/mamba_simple.py#L34) classes (from `mamba_simple.py`)
        * Actual MAMBA operation: [mamba_inner_fn](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/ops/selective_scan_interface.py#L155)
* [Mamba small benchmark repo](https://github.com/apapiu/mamba_small_bench)
* [SimplerMambaSSM Jupyter Notebook](./SimplerMambaSSM.ipynb)
    * Use mamba-ssm library
    * See class BigNeuralNetwork
* [MAMBA chat](https://github.com/havenhq/mamba-chat/blob/main/train_mamba.py)

# Required python packages
1. PyTorch (with CUDA)
2. mamba-ssm
3. transformers==4.26.1 *(for tokenizer)*

In [33]:
from zipfile import ZipFile
from io import BytesIO
import requests
import os
from pathlib import Path

import torch
import torch.nn as nn
import numpy as np
from mamba_ssm import Mamba

from pyfaidx import Fasta

# Download genetic data

In [21]:
# datasets
hg38_url = 'https://api.ncbi.nlm.nih.gov/datasets/v2alpha/genome/accession/GCF_000001405.40/download'
t2t_url = 'https://api.ncbi.nlm.nih.gov/datasets/v2alpha/genome/accession/GCF_009914755.1/download'
dataset_url = t2t_url

print("download started...")
response = requests.get(dataset_url, params={'include_annotation_type': 'GENOME_FASTA'})
if response.status_code == 200:
    data_dir_path = 'dataset'
    os.makedirs(data_dir_path, exist_ok=True)
    with BytesIO(response.content) as zip_buffer:
        ZipFile(zip_buffer, 'r').extractall(path=data_dir_path)
    print("dataset ready")

gh38_fasta = 'dataset/ncbi_dataset/data/GCF_000001405.40/GCF_000001405.40_GRCh38.p14_genomic.fna'

print("FASTA files:")
fpaths = list(Path('dataset').rglob('*.fna'))
for fpath in fpaths:
    print(fpath)

fasta_path = fpaths[0]

download started...
dataset ready
FASTA files:
dataset/ncbi_dataset/data/GCF_009914755.1/GCF_009914755.1_T2T-CHM13v2.0_genomic.fna


# Tokenizer

In [193]:
# HyenaDNA tokenizer; code from their jupyter notebook
"""
Just a simple character level tokenizer.

From: https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py

CharacterTokenzier for Hugging Face Transformers.
This is heavily inspired from CanineTokenizer in transformers package.
"""
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Union

from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer


class CharacterTokenizer(PreTrainedTokenizer):
    def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs):
        """Character tokenizer for Hugging Face transformers.
        Args:
            characters (Sequence[str]): List of desired characters. Any character which
                is not included in this list will be replaced by a special token called
                [UNK] with id=6. Following are list of all of the special tokens with
                their corresponding ids:
                    "[CLS]": 0
                    "[SEP]": 1
                    "[BOS]": 2
                    "[MASK]": 3
                    "[PAD]": 4
                    "[RESERVED]": 5
                    "[UNK]": 6
                an id (starting at 7) will be assigned to each character.
            model_max_length (int): Model maximum sequence length.
        """
        self.characters = characters
        self.model_max_length = model_max_length
        bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
        eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
        cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
        pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
        unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)

        mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)

        super().__init__(
            bos_token=bos_token,
            eos_token=sep_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            unk_token=unk_token,
            add_prefix_space=False,
            model_max_length=model_max_length,
            padding_side=padding_side,
            **kwargs,
        )

        self._vocab_str_to_int = {
            "[CLS]": 0,
            "[SEP]": 1,
            "[BOS]": 2,
            "[MASK]": 3,
            "[PAD]": 4,
            "[RESERVED]": 5,
            "[UNK]": 6,
            **{ch: i + 7 for i, ch in enumerate(characters)},
        }
        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}

    @property
    def vocab_size(self) -> int:
        return len(self._vocab_str_to_int)

    def _tokenize(self, text: str) -> List[str]:
        return list(text)

    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])

    def _convert_id_to_token(self, index: int) -> str:
        return self._vocab_int_to_str[index]

    def convert_tokens_to_string(self, tokens):
        return "".join(tokens)

    def convert_token_vector_to_string(self, ivector):
        out_str = ""
        for i in ivector:
            out_str = out_str + self._convert_id_to_token(i.item())
        return out_str

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        result = cls + token_ids_0 + sep
        if token_ids_1 is not None:
            result += token_ids_1 + sep
        return result

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )

        result = [1] + ([0] * len(token_ids_0)) + [1]
        if token_ids_1 is not None:
            result += ([0] * len(token_ids_1)) + [1]
        return result

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        result = len(cls + token_ids_0 + sep) * [0]
        if token_ids_1 is not None:
            result += len(token_ids_1 + sep) * [1]
        return result

    def get_config(self) -> Dict:
        return {
            "char_ords": [ord(ch) for ch in self.characters],
            "model_max_length": self.model_max_length,
        }

    @classmethod
    def from_config(cls, config: Dict) -> "CharacterTokenizer":
        cfg = {}
        cfg["characters"] = [chr(i) for i in config["char_ords"]]
        cfg["model_max_length"] = config["model_max_length"]
        return cls(**cfg)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        cfg = self.get_config()
        with open(cfg_file, "w") as f:
            json.dump(cfg, f, indent=4)

    @classmethod
    def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
        cfg_file = Path(save_directory) / "tokenizer_config.json"
        with open(cfg_file) as f:
            cfg = json.load(f)
        return cls.from_config(cfg)

# Dataloader

In [276]:
def complement(in_seq):
    out_seq = ""
    for idx, c in enumerate(in_seq):
        oc = "X"
        if c == 'A':
            oc = 'T'
        elif c == 'T':
            oc = 'A'
        elif c == 'C':
            oc = 'G'
        elif c == 'G':
            oc = 'C'
        elif c == 'N':
            oc = 'N'
        else:
            assert True == False
        out_seq = out_seq + oc
    return out_seq
    

class GenomeDataset(torch.utils.data.Dataset):
    def __init__(self, fasta_path, ds_entries):
        assert Path(fasta_path).exists
        self.fasta = Fasta(fasta_path, one_based_attributes=False)

        dtype = np.dtype([('key', 'U20'), ('start', 'int_'), ('end', 'int_')])
        self.entry_ranges = np.empty(len(ds_entries), dtype=dtype)

        # only append entries of dataset
        count = 0
        for idx, k in enumerate(ds_entries):
            assert k in self.fasta.keys(), \
                "FASTA file does not contain an entry with key {}".format(k)
            seq_len = len(self.fasta[k])
            self.entry_ranges[idx] = np.array([(k, count, count + seq_len)], dtype=dtype)
            count = count + seq_len

        # for e in self.entry_ranges:
        #     print(e)

    def config(self, tokenizer, seq_len):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
             
    def __len__(self):
        # first range forward idices, second range reverse complement
        return self.entry_ranges[-1]['end'] * 2

    def __getitem__(self, idx):
        assert idx >= 0
        assert idx < self.__len__()

        # return reverse complement?
        rev_compl = (idx >= self.entry_ranges[-1]['end'])
        idx = idx % self.entry_ranges[-1]['end']

        # locate FASTA entry of global idx
        key = None
        local_idx = -1
        for e in self.entry_ranges:
            if e['start'] <= idx < e['end']:
                key = e['key']
                local_idx = idx - e['start']

        assert key != None
        assert local_idx != -1
        # print("local_idx: {}, rev_compl: {}".format(local_idx, rev_compl))

        left_bound = 0
        right_bound = len(self.fasta[key])

        # print(self.fasta[key][-20:], len(self.fasta[key][-20:]))

        seq = None
        if not(rev_compl):
            seq = self.fasta[key][:local_idx + 1][-self.seq_len:]
        else:
            seq = self.fasta[key][local_idx:][::-1][-self.seq_len:]
        assert seq != None

        # capitalize all nucleotides
        seq_str = str(seq).upper()
        # print("seq_str: {}".format(seq_str))

        # use complement when reverse
        if rev_compl:
            seq_str = complement(seq_str)
        # print(seq_str, len(seq_str))
        assert len(seq_str) <= self.seq_len

        tokens = self.tokenizer(seq_str, add_special_tokens=False, padding="max_length",
                                max_length=self.seq_len, truncation=True)
        input = torch.LongTensor(tokens["input_ids"]).clone()
        # print("input: {}".format(input))

        # mask
        target = input[-1].clone()
        input[-1] = self.tokenizer._vocab_str_to_int['[MASK]']
        # print(input, target)
        return input, target

def get_T2T_datasets():
    T2T_path = "dataset/ncbi_dataset/data/GCF_009914755.1/GCF_009914755.1_T2T-CHM13v2.0_genomic.fna"
    training_entries = ['NC_060925.1', 'NC_060926.1', 'NC_060927.1', 'NC_060928.1', 'NC_060929.1',
                        'NC_060931.1', 'NC_060932.1', 'NC_060933.1', 'NC_060934.1', 'NC_060935.1',
                        'NC_060936.1', 'NC_060937.1', 'NC_060938.1', 'NC_060939.1', 'NC_060941.1',
                        'NC_060942.1', 'NC_060943.1', 'NC_060944.1', 'NC_060945.1', 'NC_060946.1',
                        'NC_060947.1', 'NC_060948.1']
    test_entries = ['NC_060930.1', 'NC_060940.1']

    # check training and test dataset do not contain the same entries
    assert set(training_entries).isdisjoint(set(test_entries)) == True
    
    train_dataset = GenomeDataset(T2T_path, training_entries)
    test_dataset = GenomeDataset(T2T_path, test_entries)
    return train_dataset, test_dataset

# train_ds, test_ds = get_T2T_datasets()

In [293]:
# validate results of T2T dataset;
# check if correct tokens are returned for predefined indices
def validate_T2T_ds():
    train_ds, _ = get_T2T_datasets()
    
    token_len = 30
    tokenizer = CharacterTokenizer(
        characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters, N is uncertain
        model_max_length=token_len,  # to account for special tokens, like EOS
        add_special_tokens=False,  # we handle special tokens elsewhere
        padding_side='left', # since HyenaDNA is causal, we pad on the left
    )

    train_ds.config(tokenizer, token_len)
    print(train_ds.entry_ranges)
    
    def check(idx, expct_inpt, expct_trgt):
        inpt, trgt = train_ds.__getitem__(idx)
        actual_trgt = tokenizer._convert_id_to_token(trgt.item())
        actual_inpt = tokenizer.convert_token_vector_to_string(inpt)

        # Does not recognize "PAD" and "MASK" token
        # expct_inpt_len = len(tokenizer(expct_inpt, add_special_tokens=False, max_length=token_len,
        #                            truncation=False)["input_ids"])
        # assert expct_inpt_len == token_len, \
        #     "Unexpected token length of expct_inpt ({})".format(expct_inpt_len)
        
        assert expct_trgt == actual_trgt, \
            "Target tokens do not match; expected: {}, actual: {}".format(expct_trgt, actual_trgt)
        assert expct_inpt == actual_inpt, \
            "Input tokens do not match; expected: {} (len {}), actual: {} (len {})".format(expct_inpt, len(expct_inpt), actual_inpt, len(actual_inpt))
        print("tokens match!")

    # forward
    check(5, "[PAD]"*24 + "CACCC" + "[MASK]", "T")
    check(248387322, "GGGTTAGGGTTAGGGTTAGGGTTAGGGTT" + "[MASK]", "A")
    check(1067810436, "[PAD]"*5 + "CCTAACCCTAACCCTAACCCCTAA" + "[MASK]", "C")
    check(1228377838, "GGGTTAGGGTTAGGGGTTAGGGTTAGGGT" + "[MASK]", "T")
    check(2786358510, "CTAACCCTAACCCTAACCCTAACCCTAAC" + "[MASK]", "C")
    check(2848818478, "AGGGTTAGGGTTAGGGTTAGGGTTAGGGT" + "[MASK]", "T")
    # reverse complement
    offset = 2848818499
    check((5 + offset), "GTTAGGGTTAGGGTTAGGGGTTAGGGTTT" + "[MASK]", "A")
    check(248387322 + offset, "[PAD]"*24 + "AACCC" + "[MASK]", "T")
    check(1067810436 + offset, "GTTAGGAGGGTTAGGGGATTAGGGTTAGG" + "[MASK]", "G")
    check(1228377838 + offset, "[PAD]"*28 + "T" + "[MASK]", "A")
    check(2786358510 + offset, "GTTAGGGTTAGGGTTAGGGTTAGGGTTAG" + "[MASK]", "G")
    print(len("GATTGGGATTGGGATTGGGA")
    check(2848818478 + offset, "[PAD]"*10 + "CTAACCCTAACCCTAACCCT" + "[MASK]", "A")

    print("tests completed successfully!")
validate_T2T_ds()

[('NC_060925.1',          0,  248387328)
 ('NC_060926.1',  248387328,  491084080)
 ('NC_060927.1',  491084080,  692190028)
 ('NC_060928.1',  692190028,  885764973)
 ('NC_060929.1',  885764973, 1067810412)
 ('NC_060931.1', 1067810412, 1228377840)
 ('NC_060932.1', 1228377840, 1374637171)
 ('NC_060933.1', 1374637171, 1525254418)
 ('NC_060934.1', 1525254418, 1660012552)
 ('NC_060935.1', 1660012552, 1795140321)
 ('NC_060936.1', 1795140321, 1928464869)
 ('NC_060937.1', 1928464869, 2042031555)
 ('NC_060938.1', 2042031555, 2143193047)
 ('NC_060939.1', 2143193047, 2242946242)
 ('NC_060941.1', 2242946242, 2327223139)
 ('NC_060942.1', 2327223139, 2407765677)
 ('NC_060943.1', 2407765677, 2469473041)
 ('NC_060944.1', 2469473041, 2535683296)
 ('NC_060945.1', 2535683296, 2580773978)
 ('NC_060946.1', 2580773978, 2632098904)
 ('NC_060947.1', 2632098904, 2786358470)
 ('NC_060948.1', 2786358470, 2848818499)]
tokens match!
tokens match!
tokens match!
tokens match!
tokens match!
tokens match!
tokens match!

AssertionError: Input tokens do not match; expected: [PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]CTAACCCTAACCCTAACCCT[MASK] (len 76), actual: [PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]CTAACCCTAACCCTAACCCT[MASK] (len 71)

In [178]:
tokenizer = CharacterTokenizer(
    characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters, N is uncertain
    model_max_length=max_length + 2,  # to account for special tokens, like EOS
    add_special_tokens=False,  # we handle special tokens elsewhere
    padding_side='left', # since HyenaDNA is causal, we pad on the left
)

dataset = GenomeDataset(fasta_path, tokenizer, 10, split="train")
dataset.__getitem__(5)
dataset.__getitem__(dataset.__len__()//2 + 5)
dataset.__getitem__(248387317)
dataset.__getitem__(dataset.__len__()//2 + 248387317)
# dataset.__getitem__(248387338)
# dataset.__getitem__(2786358460)
# dataset.__getitem__(1928464869)

# tokenizer("ATGAAGGAAGG", 
#                 add_special_tokens=False, 
#                 padding="max_length",
#                 max_length=10,
#                 truncation=True,
#             ) 

local_idx: 5
ttagggttagggttagggtt 20
CACCCT 6
tensor([4, 4, 4, 4, 8, 7, 8, 8, 8, 3]) tensor(10)
local_idx: 5
ttagggttagggttagggtt 20
TTAGGGTTTA 10
tensor([10, 10,  7,  9,  9,  9, 10, 10, 10,  3]) tensor(7)
local_idx: 248387323
ttagggttagggttagggtt 20
TTAGGGTTAG 10
tensor([10, 10,  7,  9,  9,  9, 10, 10,  7,  3]) tensor(9)
local_idx: 248387323
ttagggttagggttagggtt 20
AACCC 5
tensor([4, 4, 4, 4, 4, 7, 7, 8, 8, 3]) tensor(8)


(tensor([4, 4, 4, 4, 4, 7, 7, 8, 8, 3]), tensor(8))

# Model

In [21]:
print("CUDA is available:", torch.cuda.is_available())
print("GPU count:", torch.cuda.device_count())
gpu = 4
torch.cuda.set_device(gpu)
print("Current GPU:", torch.cuda.get_device_name())

CUDA is available: True
GPU count: 8
Current GPU: NVIDIA RTX 6000 Ada Generation


## Investigations
- Allocate inference cache? E.g., [here](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L142)
- Init weights
- **PARAMETERS**
- Loss function? CrossEntropyLoss?
- Optimizer? Adam?
    - Note: HyenaDNA used AdamW
### MambaTower
- ~~no embedding like [here](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L149)? -> In MambaDNA~~
- Do I need an embedding, for my small set of tokens?
- ~~put nn.Linear behind? As in [MambaLMHeadModel](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L224) -> In MambaDNA~~
### MambaBlock
- [residual?](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L152)?
- Different order: Original: Add -> LN -> MAMBA ([see](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/modules/mamba_simple.py#L334-L350)); Here: MAMBA -> Add -> LN
    - Does this remove the need for residual? Yes
    - Is this equivalent?
        - Close: MAMBA block is: `x = self.mamba(self.norm(x)) + x`
        - Does it matter if residual connection does not include LN?
    - Fuzed normalization (with add) used for higher performance ([see](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/modules/mamba_simple.py#L311))
    - Put LN in front (`x = self.mamba(self.norm(x)) + x`); would need to add normalization behind as well
- Normalization: LayerNorm or RMSNorm?
- Fuse normalization with add

In [4]:
# code from https://github.com/apapiu/mamba_small_bench
class MambaBlock(nn.Module):
    def __init__(self, embed_dim, dropout_level=0):
        super().__init__()

        self.mamba = Mamba(d_model=embed_dim, d_state=16, d_conv=4, expand=2)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout_level)

    def forward(self, x):
        x = self.norm(self.mamba(x) + x)
        return self.dropout(x)


class MambaTower(nn.Module):
    def __init__(self, embed_dim, n_layers, seq_len=None, global_pool=False):
        super().__init__()
        self.blocks = nn.Sequential(*[MambaBlock(embed_dim) for _ in range(n_layers)])
        self.global_pool = global_pool #for classification or other supervised learning.

    def forward(self, x):
        #for input (bs, n, d) it returns either (bs, n, d) or (bs, d) is global_pool
        out = self.blocks(x) if not self.global_pool else torch.mean(self.blocks(x),1)
        return out


class MambaDNA(nn.Module):
    def __init__(self, embed_dim, seq_len, n_layers, dropout):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.tower = MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=False)
        self.out_proj = nn.Sequential(nn.LayerNorm(embed_dim),
                                      nn.Linear(embed_dim, vocab_size))

    def forward(self, x):
        x = self.tower(self.embed(x))
        return self.out_proj(x)

# Pretraining

In [4]:
max_length = 1000


# create tokenizer
tokenizer = CharacterTokenizer(
    characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters, N is uncertain
    model_max_length=max_length + 2,  # to account for special tokens, like EOS
    add_special_tokens=False,  # we handle special tokens elsewhere
    padding_side='left', # since HyenaDNA is causal, we pad on the left
)


train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
