# Notes

## Dataloader (pretraining)
* [HyenaDNA HG38 dataloader](https://github.com/HazyResearch/hyena-dna/blob/main/src/dataloaders/datasets/hg38_dataset.py)
* HyenaDNA used training/validation intervals from *Effective gene expression prediction from sequence by integrating long-range interactions.* paper.

## Tokenizer?
* Need to check HyenaDNA; I think their Jupyter contained some code of their tokenizer

## Model
* [Original MAMBA repo](https://github.com/state-spaces/mamba)
    * [benchmark_generation_mamba_simple.py](https://github.com/state-spaces/mamba/blob/main/benchmarks/benchmark_generation_mamba_simple.py)
    * Uses [mambaLMHeadModel](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L173) form `mixer_seq_simple.py`
    * Uses [MixerModel](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L83)
    * Uses [create_block](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L21)
    * Uses [Block](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/modules/mamba_simple.py#L298) and [MAMBA](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/modules/mamba_simple.py#L34) classes (from `mamba_simple.py`)
        * Actual MAMBA operation: [mamba_inner_fn](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/ops/selective_scan_interface.py#L155)
* [Mamba small benchmark repo](https://github.com/apapiu/mamba_small_bench)
* [SimplerMambaSSM Jupyter Notebook](./SimplerMambaSSM.ipynb)
    * Use mamba-ssm library
    * See class BigNeuralNetwork
* [MAMBA chat](https://github.com/havenhq/mamba-chat/blob/main/train_mamba.py)

In [2]:
from zipfile import ZipFile
from io import BytesIO
import requests
import os
from pathlib import Path

import torch
import torch.nn as nn
from mamba_ssm import Mamba

# Download genetic data

In [20]:
# datasets
hg38_url = 'https://api.ncbi.nlm.nih.gov/datasets/v2alpha/genome/accession/GCF_000001405.40/download'
t2t_url = 'https://api.ncbi.nlm.nih.gov/datasets/v2alpha/genome/accession/GCF_009914755.1/download'
dataset_url = hg38

print("download started...")
response = requests.get(dataset_url, params={'include_annotation_type': 'GENOME_FASTA'})
if response.status_code == 200:
    data_dir_path = 'dataset'
    os.makedirs(data_dir_path, exist_ok=True)
    with BytesIO(response.content) as zip_buffer:
        ZipFile(zip_buffer, 'r').extractall(path=data_dir_path)
    print("dataset ready")

gh38_fasta = 'dataset/ncbi_dataset/data/GCF_000001405.40/GCF_000001405.40_GRCh38.p14_genomic.fna'

print("FASTA files:")
fpaths = list(Path('dataset').rglob('*.fna'))
for fpath in fpaths:
    print(fpath)

data_path = fpaths[0]

download started...
dataset ready
FASTA files:
dataset/ncbi_dataset/data/GCF_000001405.40/GCF_000001405.40_GRCh38.p14_genomic.fna


# Dataloader

# Model

In [21]:
print("CUDA is available:", torch.cuda.is_available())
print("GPU count:", torch.cuda.device_count())
gpu = 4
torch.cuda.set_device(gpu)
print("Current GPU:", torch.cuda.get_device_name())

CUDA is available: True
GPU count: 8
Current GPU: NVIDIA RTX 6000 Ada Generation


## Need to check model
- Allocate inference cache? E.g., [here](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L142)
- Init weights
- **PARAMETERS**
- Loss function? CrossEntropyLoss?
- Optimizer? Adam?
    - Note: HyenaDNA used AdamW
### MambaTower
- ~~no embedding like [here](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L149)? -> In MambaDNA~~
- ~~put nn.Linear behind? As in [MambaLMHeadModel](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L224) -> In MambaDNA~~
### MambaBlock
- [residual?](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/models/mixer_seq_simple.py#L152)?
- Different order: Original: Add -> LN -> MAMBA ([see](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/modules/mamba_simple.py#L334-L350)); Here: MAMBA -> Add -> LN
    - Does this remove the need for residual? Yes
    - Is this equivalent?
        - Close: MAMBA block is: `x = self.mamba(self.norm(x)) + x`
        - Does it matter if residual connection does not include LN?
    - Fuzed normalization (with add) used for higher performance ([see](https://github.com/state-spaces/mamba/blob/bae8d1a42fec58f4cdd300bf3b987d05eab22ed0/mamba_ssm/modules/mamba_simple.py#L311))
    - Put LN in front (`x = self.mamba(self.norm(x)) + x`); would need to add normalization behind as well
- Normalization: LayerNorm or RMSNorm?
- Fuse normalization with add

In [4]:
# code from https://github.com/apapiu/mamba_small_bench
class MambaBlock(nn.Module):
    def __init__(self, embed_dim, dropout_level=0):
        super().__init__()

        self.mamba = Mamba(d_model=embed_dim, d_state=16, d_conv=4, expand=2)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout_level)

    def forward(self, x):
        x = self.norm(self.mamba(x) + x)
        return self.dropout(x)


class MambaTower(nn.Module):
    def __init__(self, embed_dim, n_layers, seq_len=None, global_pool=False):
        super().__init__()
        self.blocks = nn.Sequential(*[MambaBlock(embed_dim) for _ in range(n_layers)])
        self.global_pool = global_pool #for classification or other supervised learning.

    def forward(self, x):
        #for input (bs, n, d) it returns either (bs, n, d) or (bs, d) is global_pool
        out = self.blocks(x) if not self.global_pool else torch.mean(self.blocks(x),1)
        return out


class MambaDNA(nn.Module):
    def __init__(self, embed_dim, seq_len, n_layers, dropout):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.tower = MambaTower(embed_dim, n_layers, seq_len=seq_len, global_pool=False)
        self.out_proj = nn.Sequential(nn.LayerNorm(embed_dim),
                                      nn.Linear(embed_dim, vocab_size))

    def forward(self, x):
        x = self.tower(self.embed(x))
        return self.out_proj(x)