In [None]:
import numpy as np
import pytorch_lightning as pl
import torch
from lightning.pytorch import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.data import DataLoader, Dataset, random_split

from model import mamba_4chan
from model_config import mamba_4chan_370m_config

In [None]:
# load base mamba model
model = mamba_4chan(mamba_4chan_370m_config())
state_dict = torch.load("path_to.bin")
model.load_state_dict(state_dict)

# load mamba 4chan model
# model = mamba_4chan.load_from_checkpoint("path_to.ckpt")

model.learning_rate = 1e-4

In [None]:
class pol_dataset(Dataset):
    def __init__(
        self,
        memmap_path: str,
        context_size: int = 2048,
        eos_token: int = 0,
        stride: int = 2047,
    ):
        self.memmap = np.memmap(memmap_path, dtype="uint16")
        self.context_size = context_size
        self.eos_token = eos_token
        self.stride = stride

    def __len__(self):
        return ((len(self.memmap) - self.context_size) // self.stride) + 1

    def __getitem__(self, idx):
        start = idx * self.stride
        end = start + self.context_size

        if end > len(self.memmap):
            padding_size = end - len(self.memmap)
            data = np.concatenate(
                (
                    self.memmap[start:],
                    np.full(padding_size, self.eos_token, dtype="uint16"),
                )
            )
        else:
            data = self.memmap[start:end]

        return torch.tensor(data, dtype=torch.long)

In [None]:
class data_module(pl.LightningDataModule):
    def __init__(
        self,
        memmap_path: str,
        batch_size: int,
        context_size: int = 2048,
        eos_token: int = 0,
        stride: int = 2047,
        train_val_ratio: float = 0.95,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.memmap_path = memmap_path
        self.batch_size = batch_size
        self.context_size = context_size
        self.eos_token = eos_token
        self.stride = stride
        self.train_val_ratio = train_val_ratio

    def setup(self, stage: str = None):
        dataset = pol_dataset(
            self.memmap_path, self.context_size, self.eos_token, self.stride
        )

        train_size = int(len(dataset) * self.train_val_ratio)
        val_size = len(dataset) - train_size

        self.train_set, self.val_set = random_split(dataset, [train_size, val_size])

    def train_dataloader(self):
        return DataLoader(
            self.train_set,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_set,
            batch_size=self.batch_size,
            num_workers=8,
            pin_memory=True,
        )


data = data_module(memmap_path="dataset.dat", batch_size=4)

In [None]:
trainer = pl.Trainer(
    callbacks=[
        ModelCheckpoint(
            dirpath="models/",
            save_top_k=-1,
        ),
        LearningRateMonitor(logging_interval="step"),
    ],
    logger=pl_loggers.WandbLogger(project="Mamba 4chan 370m", name="Fine-tuning"),
    precision="bf16-mixed",
    max_epochs=1,
    accumulate_grad_batches=250,
)

trainer.fit(model, data)
# trainer.fit(model, data, ckpt_path="ckpt to resume training")