Colab Owner: Darshil Modi [linkedin: https://www.linkedin.com/in/darshil3011] [email: darshil3011@gmail.com]

Built upon: https://github.com/fkodom/yet-another-retnet/

Paper link: https://arxiv.org/pdf/2307.08621.pdf

Check out my other projects on https://www.github.com/darshil3011

## Install dependencies and Clone github repo

In [1]:
!pip install yet-another-retnet[train]



In [2]:
!git clone https://github.com/fkodom/yet-another-retnet.git

fatal: destination path 'yet-another-retnet' already exists and is not an empty directory.


In [3]:
!pip install lightning==2.0.1
!pip install tiktoken



## Custom Data Generator for reading data from list

I overwrote data generator for training the model with my custom data which can be simply put into the list

1. Make sure you enter your own data in MY_TEXT_DATA list.
2. The data in MY_TEXT_DATA will be used for training

In [4]:
cd yet-another-retnet

/content/yet-another-retnet


In [5]:
import random
from pathlib import Path
from typing import Generator, List, Literal, Optional

import requests
from torch.hub import get_dir
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


# Your own text data
MY_TEXT_DATA = [
    "AI’s influence on technology is due in part because of how it impacts computing. Through AI, computers have the ability to harness massive amounts of data and use their learned intelligence to make optimal decisions and discoveries in fractions of the time that it would take humans. AI has come a long way since 1951, when the first documented success of an AI computer program was written by Christopher Strachey, whose checkers program completed a whole game on the Ferranti Mark I computer at the University of Manchester. Since then, AI has been used to help sequence RNA for vaccines and model human speech, technologies that rely on model- and algorithm-based machine learning and increasingly focus on perception, reasoning and generalization. With innovations like these, AI has re-taken center stage like never before — and it won’t cede the spotlight anytime soon. ",
    "There’s virtually no major industry that modern AI — more specifically, “narrow AI,” which performs objective functions using data-trained models and often falls into the categories of deep learning or machine learning — hasn’t already affected. That’s especially true in the past few years, as data collection and analysis has ramped up considerably thanks to robust IoT connectivity, the proliferation of connected devices and ever-speedier computer processing.",
    "With companies spending billions of dollars on AI products and services annually, tech giants like Google, Apple, Microsoft and Amazon spending billions to create those products and services, universities making AI a more prominent part of their curricula and the U.S. Department of Defense upping its AI game, big things are bound to happen. ",
]

def get_split_indices(
    num_samples: int,
    split: Literal["train", "val", "test"],
    seed: int = 42,
    val_split: float = 0.1,
    test_split: float = 0.1,
) -> List[int]:
    indices = list(range(num_samples))
    random.seed(seed)
    random.shuffle(indices)

    num_val = int(num_samples * val_split)
    num_test = int(num_samples * test_split)

    if split == "train":
        out = indices[num_val + num_test :]
    elif split == "val":
        out = indices[:num_val]
    elif split == "test":
        out = indices[num_val : num_val + num_test]
    else:
        raise ValueError(f"Invalid split: {split}")

    return sorted(out)


class TextChunker(IterDataPipe[str]):
    def __init__(
        self,
        dp: IterDataPipe[str],
        chunk_size: int = 4096,
        step_size: Optional[int] = None,
        drop_last: bool = False,
    ):
        self.dp = dp
        self.chunk_size = chunk_size
        self.step_size = step_size or chunk_size
        self.drop_last = drop_last

    def __iter__(self) -> Generator[str, None, None]:
        for text in self.dp:
            for i in range(0, len(text), self.step_size):
                chunk = text[i : i + self.chunk_size]
                if self.drop_last and len(chunk) < self.chunk_size:
                    continue

                chunk = chunk.split(" ", maxsplit=1)[-1]  # leading partial words
                chunk = chunk.rsplit(" ", maxsplit=1)[0]  # trailing partial words
                chunk = chunk.strip()  # leading/trailing whitespace
                yield chunk


def my_own_text_datapipe(
    split: Literal["train", "val", "test"],
    chunk_size: int = 4096,
    step_size: Optional[int] = None,
    shuffle: bool = False,
    shuffle_buffer_size: int = 8192,
    drop_last: bool = True,
) -> IterDataPipe[str]:
    # Use your own text data instead of URLs
    # You can generate your own text or load it from files, databases, etc.
    indices = get_split_indices(len(MY_TEXT_DATA), split=split)
    if shuffle:
        random.shuffle(indices)

    # Iterable datapipe of your own text data
    pipe: IterDataPipe = IterableWrapper([MY_TEXT_DATA[i] for i in indices])
    pipe = TextChunker(
        pipe, chunk_size=chunk_size, step_size=step_size, drop_last=drop_last
    )
    if shuffle:
        pipe = pipe.shuffle(buffer_size=shuffle_buffer_size)

    return pipe

## Import dependencies

In [6]:
!cd yet-another-retnet/

/bin/bash: line 1: cd: yet-another-retnet/: No such file or directory


In [7]:
import os
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import tiktoken
import torch
from lightning import Fabric, seed_everything
from lightning.fabric.loggers import TensorBoardLogger
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

from yet_another_retnet.retnet import RetNet

## Define tokenizer
We have used opensource gpt2, you can also use gpt3 embeddings if you have openai balance

In [16]:
torch.set_float32_matmul_precision("medium")
TOKENIZER = tiktoken.get_encoding("gpt2")

In [17]:
def collate_fn(
    batch: List[str],
    max_length: int = 1024,
    device: Optional[Union[torch.device, str]] = None,
) -> Tuple[Tensor, Tensor]:
    x = torch.zeros(len(batch), max_length, device=device, dtype=torch.long)
    y = torch.zeros(len(batch), max_length, device=device, dtype=torch.long)

    for i, text in enumerate(batch):
        encoding = torch.as_tensor(
            TOKENIZER.encode(text), device=device, dtype=torch.long
        )
        seq_length = min(len(encoding) - 1, max_length)
        x[i, :seq_length] = encoding[:seq_length]
        y[i, :seq_length] = encoding[1 : seq_length + 1]

    return x, y

In [18]:
@dataclass
class TrainingState:
    fabric: Fabric
    model: RetNet
    optimizer: torch.optim.Optimizer
    callbacks: Sequence[Callable[["TrainingState", float], None]] = ()

    current_step: int = 0
    current_epoch: int = 0
    accumulate_grad_batches: int = 1
    monitor: str = "val_loss"
    monitor_mode: Literal["min", "max"] = "min"

In [19]:
@dataclass
class ModelCheckpoint:
    state_dict: Dict[str, Tensor]
    optimizer_state: Dict[str, Tensor]
    current_step: int
    current_epoch: int

    @classmethod
    def from_training_state(cls, state: TrainingState) -> "ModelCheckpoint":
        return cls(
            state_dict=state.model.state_dict(),
            optimizer_state=state.optimizer.state_dict(),
            current_step=state.current_step,
            current_epoch=state.current_epoch,
        )

    def to_dict(self) -> Dict[str, Any]:
        return {
            "state_dict": self.state_dict,
            "optimizer_state": self.optimizer_state,
            "current_step": self.current_step,
            "current_epoch": self.current_epoch,
        }

    def save(self, path: str) -> None:
        torch.save(self.to_dict(), path)

    @classmethod
    def load(cls, path: str) -> "ModelCheckpoint":
        checkpoint_dict = torch.load(path)
        return cls(**checkpoint_dict)

In [20]:
class CheckpointCallback:
    def __init__(
        self, save_dir: str, name: str = "checkpoint_epoch-{epoch:03d}.pt"
    ) -> None:
        self.save_dir = save_dir
        self.name = name
        self.best_path: Optional[str] = None
        self.best_loss: Optional[float] = None

    def __call__(self, state: TrainingState, loss: float) -> None:
        if self.best_loss is None:
            self.best_loss = loss

        fabric = state.fabric
        # 'local_rank == 0' means this only happens for the main process
        if fabric.local_rank == 0 and loss <= self.best_loss:
            checkpoint = ModelCheckpoint.from_training_state(state)
            self.best_loss = loss
            if self.best_path is not None:
                os.remove(self.best_path)
            self.best_path = os.path.join(
                self.save_dir, self.name.format(epoch=state.current_epoch)
            )
            torch.save(checkpoint, self.best_path)

        # All processes wait for main to finish saving the checkpoint.
        fabric.barrier()

In [21]:
def train_one_epoch(
    state: TrainingState,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    log_frequency: int = 25,
) -> None:
    state.current_epoch += 1
    fabric, model, optimizer = state.fabric, state.model, state.optimizer
    is_main_process = fabric.local_rank == 0
    is_training = model.training
    model.train()

    with tqdm(
        desc=f"Ep: {state.current_epoch}", disable=(not is_main_process)
    ) as progbar:
        train_loss, val_loss = 0.0, 0.0
        for x, y in train_dataloader:
            state.current_step += 1
            accumulating = state.current_step % state.accumulate_grad_batches != 0
            with fabric.no_backward_sync(model, enabled=accumulating):  # type: ignore
                loss = model.forward(inputs=x, labels=y)
                fabric.backward(loss)

            if not accumulating:
                optimizer.step()
                optimizer.zero_grad()

            if state.current_step % log_frequency == 0:
                fabric.log("loss", loss, step=state.current_step)
                train_loss = loss.item()
                progbar.set_postfix_str(f"loss={train_loss:.4f}", refresh=False)
            progbar.update(1)

        model.eval()
        val_progbar = tqdm(desc="val", position=1, leave=False)
        for i, (x, y) in enumerate(val_dataloader):
            with torch.inference_mode():
                loss = model.forward(inputs=x, labels=y)
            val_loss = (val_loss * i + loss.item()) / (i + 1)

            if i % log_frequency == 0:
                val_progbar.set_postfix_str(f"val_loss={val_loss:.4f}", refresh=False)
            val_progbar.update(1)
            progbar.update(1)

        fabric.log("val_loss", val_loss, step=state.current_step)
        val_progbar.close()
        progbar.set_postfix_str(
            f"loss={train_loss:.4f}, val_loss={val_loss:.4f}", refresh=False
        )

        for callback in state.callbacks:
            callback(state, val_loss)

        # Return model to its original training state
        model.train(mode=is_training)



In [22]:
def train(
    retnet: RetNet,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    accelerator: str = "auto",
    strategy: str = "auto",
    precision: Optional[str] = None,
    epochs: int = 1,
    lr: float = 3e-2,
    log_frequency: int = 25,
):
    if precision is None:
        if torch.cuda.is_available():
            # use bfloat16 if supported
            version, _ = torch.cuda.get_device_capability()
            precision = "bf16-mixed" if version >= 8 else "16-mixed"
        else:
            precision = "32-true"

    logger = TensorBoardLogger(root_dir="./")
    fabric = Fabric(
        accelerator=accelerator,
        strategy=strategy,
        precision=precision,  # type: ignore
        loggers=[logger],
    )
    fabric.launch()
    print(f"Experiment version: {logger.version}")
    print("-" * 40)

    # Setup with fabric.
    optimizer = torch.optim.AdamW(retnet.parameters(), lr=lr)
    retnet, optimizer = fabric.setup(retnet, optimizer)
    train_dataloader, val_dataloader = fabric.setup_dataloaders(
        train_dataloader, val_dataloader
    )
    # Construct a training state and run the training loop.
    state = TrainingState(
        fabric=fabric,
        model=retnet,
        optimizer=optimizer,
        callbacks=[CheckpointCallback(save_dir=logger.log_dir)],
    )
    for _ in range(epochs):
        train_one_epoch(
            state=state,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            log_frequency=log_frequency,
        )


In [23]:
def generate(
    retnet: RetNet,
    prompt: str,
    prompt_chunk_size: Optional[int] = None,
    max_new_tokens: int = 1024,
    stop_tokens: Sequence[str] = (),
    top_k: int = 5,
    temperature: float = 1.0,
    seed: int = 42,
) -> Iterator[str]:
    seed_everything(seed)
    device = next(iter(retnet.parameters())).device
    is_training = retnet.training
    retnet.eval()

    # Tokenize the prompt and convert to a tensor.
    tokenized = TOKENIZER.encode(prompt)
    x = torch.as_tensor(tokenized, dtype=torch.long, device=device).unsqueeze_(0)

    if not prompt_chunk_size:
        prompt_chunk_size = x.size(1)

    prev_states: List[Optional[Tensor]] = [None] * retnet.num_layers
    start_idx: int = 0
    for start_idx in range(0, x.size(1), prompt_chunk_size):
        y, prev_states = retnet.forward_chunkwise(  # type: ignore
            x, start_idx=start_idx, prev_states=prev_states
        )
        y = y[:, -1]

    # Generate tokens until we reach the maximum number of tokens or a stop token.
    for i in range(max_new_tokens):
        probs: Tensor = torch.softmax(y.squeeze() / max(temperature, 1e-8), dim=-1)
        # Get top-k tokens, renormalize their probabilities, and weighted sample.
        tokens: Tensor  # for mypy
        probs, tokens = probs.topk(k=top_k, dim=-1)
        probs /= probs.sum()

        # Take weighted random sample from the top-k tokens.
        sampled_idx: int = torch.multinomial(probs, num_samples=1).item()  # type: ignore
        token: int = tokens[sampled_idx].item()  # type: ignore
        tokenized.append(token)
        yield TOKENIZER.decode(tokenized)

        token_str: str = TOKENIZER.decode([token])
        if token_str in stop_tokens:
            break
        elif i < (max_new_tokens - 1):
            start_idx += 1
            x = torch.as_tensor([token], dtype=torch.long, device=device)
            y, prev_states = retnet.forward_recurrent(  # type: ignore
                x, start_idx, prev_states=prev_states
            )

    # Restore the model's original training state.
    retnet.train(mode=is_training)


## Analyse custom data using data loader

1. We will use our custom data pipeline and print one chunk of it to see how the data is being ingested to our model.

You can uncomment collate fn if you want to see tokens instead of words.

In [26]:
train_dataloader = DataLoader(
            my_own_text_datapipe(
                split="train",
                chunk_size=64,
                step_size=2,
                shuffle=True,
                drop_last=True,
            ),
            #collate_fn=collate_fn,
            batch_size=1,
            drop_last=True,
        )

In [27]:
for i, text_chunk in enumerate(train_dataloader):
    print(f"Chunk {i + 1}: {text_chunk}")
    break

Chunk 1: ['machine learning and increasingly focus on perception,']


## Main Function to train our model

In [29]:
def main(
    model_checkpoint: Optional[str] = None,
    accelerator: str = "auto",
    strategy: str = "auto",
    precision: Optional[str] = None,
    epochs: int = 1,
    batch_size: int = 4,
    lr: float = 3e-4,
    log_frequency: int = 25,
    seed: int = 42,
    eval_only: bool = False,
    eval_prompt: str = 'AI is a branch of computer science ',
    eval_max_tokens: int = 128,
):
    seed_everything(seed)
    # Create a (relatively small) model and dataloaders
    retnet = RetNet(
        num_tokens=TOKENIZER.n_vocab,
        d_model=768,
        nhead=8,
        num_layers=12,
    )
    if model_checkpoint is not None:
        retnet.load_state_dict(ModelCheckpoint.load(model_checkpoint).state_dict)

    if not eval_only:
        num_devices = torch.cuda.device_count()
        if num_devices > 0:
            # Lightning Fabric does not scale the batch size for distributed training.
            # In order to keep batch size the same, divide by the number of devices.
            if batch_size % num_devices != 0:
                raise ValueError(f"{batch_size=} must be divisible by {num_devices=}.")
            batch_size = batch_size // num_devices

        train_dataloader = DataLoader(
            my_own_text_datapipe(
                split="train",
                chunk_size=28,
                step_size=32,
                shuffle=True,
                drop_last=True,
            ),
            batch_size=batch_size,
            collate_fn=collate_fn,
            drop_last=True,
        )
        val_dataloader = DataLoader(
            my_own_text_datapipe(
                split="val", chunk_size=28, step_size=32
            ),
            batch_size=batch_size,
            collate_fn=collate_fn,
        )

        train(
            retnet=retnet,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            accelerator=accelerator,
            strategy=strategy,
            precision=precision,
            epochs=epochs,
            lr=lr,
            log_frequency=log_frequency,
        )

    # Generate some text
    prev_output: str = ""
    for output in generate(retnet, eval_prompt, max_new_tokens=eval_max_tokens):
        # Return to the start of the line and print the output (no newline)
        print(output[len(prev_output) :], end="", flush=True)
        prev_output = output
    print()


## Run Training / Inference
If you just want to evaluate model without training use main(eval_only = True)

In [30]:
# Model will continue writing on the below prompt
EVAL_PROMPT = "Artificiall intelligence is "

In [None]:
#Make sure to adjust args as per your requirements

main(eval_prompt=EVAL_PROMPT)
#main(eval_prompt=EVAL_PROMPT, eval_only=True)

INFO: Global seed set to 42
INFO:lightning.fabric.utilities.seed:Global seed set to 42


Experiment version: 0
----------------------------------------


Ep: 1: 1it [01:18, 78.22s/it]