# Load pretrained models

The previous notebook coded transformer architectures.
Let us load pretrained model onto those architecture.
Most open-source model can be loaded through the HuggingFace Transformers library.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from dataclasses import dataclass

from llmtuto.model.transformer import TransformerConfig, CausalTransformer

model = "gpt2"  # Options are "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"

We start with GPT-2.
Let us load the model specifications.

In [3]:
from llmtuto.model.transformer import TransformerConfig


config_args = {
    'gpt2':   dict(emb_dim=768,  n_head=12, n_layer=12),  # 124M params
    'gpt2-medium':  dict(emb_dim=1024, n_head=16, n_layer=24),  # 350M params
    'gpt2-large':   dict(emb_dim=1280, n_head=20, n_layer=36),  # 774M params
    'gpt2-xl':      dict(emb_dim=1600, n_head=25, n_layer=48),  # 1558M params
}[model]
config_args = config_args | dict(
    vocab_size=50_257,
    pos_emb=True,
    seq_len=1024, 
    attn_bias=True,
    ffn_bias=True,
    norm_bias=True,
    activation="gelu",
    norm="layer",
    pre_norm=True,
    weight_tying=True,
)
config = TransformerConfig(**config_args)
gpt2 = CausalTransformer(config)

The weight can be downloaded from HuggingFace Transformers library.

In [4]:
from transformers import GPT2LMHeadModel

model_hf = GPT2LMHeadModel.from_pretrained(model)
print(model_hf)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


Some work is needed to cast the model into our own class.

In [5]:
gpt_state_dict = model_hf.state_dict()
local_state_dict = gpt2.state_dict()

correspondence = {
    "embeddings.token_emb.weight": "transformer.wte.weight",
    "embeddings.pos_emb.weight": "transformer.wpe.weight",
    "output.weight": "lm_head.weight",
    "output_norm.weight":   "transformer.ln_f.weight",
    "output_norm.bias":     "transformer.ln_f.bias",
}
transposed = []
special = {}
for layer in range(config.n_layer):
    correspondence = correspondence | {
        f"blocks.{layer}.norm_1.weight":       f"transformer.h.{layer}.ln_1.weight",
        f"blocks.{layer}.norm_1.bias":         f"transformer.h.{layer}.ln_1.bias",
        f"blocks.{layer}.attn.output.weight":  f"transformer.h.{layer}.attn.c_proj.weight",
        f"blocks.{layer}.attn.output.bias":    f"transformer.h.{layer}.attn.c_proj.bias",
        f"blocks.{layer}.norm_2.weight":       f"transformer.h.{layer}.ln_2.weight",
        f"blocks.{layer}.norm_2.bias":         f"transformer.h.{layer}.ln_2.bias",
        f"blocks.{layer}.ffn.fc1.weight":      f"transformer.h.{layer}.mlp.c_fc.weight",
        f"blocks.{layer}.ffn.fc1.bias":        f"transformer.h.{layer}.mlp.c_fc.bias",
        f"blocks.{layer}.ffn.fc2.weight":      f"transformer.h.{layer}.mlp.c_proj.weight",
        f"blocks.{layer}.ffn.fc2.bias":        f"transformer.h.{layer}.mlp.c_proj.bias",
    }
    transposed = transposed + [
        f"transformer.h.{layer}.attn.c_attn.weight",
        f"transformer.h.{layer}.attn.c_proj.weight",
        f"transformer.h.{layer}.mlp.c_fc.weight",
        f"transformer.h.{layer}.mlp.c_proj.weight",
    ]
    special = special | {
        f"blocks.{layer}.attn.query.weight": f"transformer.h.{layer}.attn.c_attn.weight",
        f"blocks.{layer}.attn.query.bias":   f"transformer.h.{layer}.attn.c_attn.bias",
        f"blocks.{layer}.attn.key.weight":   f"transformer.h.{layer}.attn.c_attn.weight",
        f"blocks.{layer}.attn.key.bias":     f"transformer.h.{layer}.attn.c_attn.bias",
        f"blocks.{layer}.attn.value.weight": f"transformer.h.{layer}.attn.c_attn.weight",
        f"blocks.{layer}.attn.value.bias":   f"transformer.h.{layer}.attn.c_attn.bias",
    }
for k in correspondence:
    if correspondence[k] in transposed:
        local_state_dict[k] = gpt_state_dict[correspondence[k]].T
    else:
        local_state_dict[k] = gpt_state_dict[correspondence[k]]
for k in special:
    if 'query' in k:
        if 'bias' in k:
            local_state_dict[k] = gpt_state_dict[special[k]][:config.emb_dim]
        else:
            local_state_dict[k] = gpt_state_dict[special[k]][:, :config.emb_dim].T
    elif 'key' in k:
        if 'bias' in k:
            local_state_dict[k] = gpt_state_dict[special[k]][config.emb_dim:2*config.emb_dim]
        else:
            local_state_dict[k] = gpt_state_dict[special[k]][:, config.emb_dim:2*config.emb_dim].T
    elif 'value' in k:
        if 'bias' in k:
            local_state_dict[k] = gpt_state_dict[special[k]][2*config.emb_dim:]
        else:
            local_state_dict[k] = gpt_state_dict[special[k]][:, 2*config.emb_dim:].T
gpt2.load_state_dict(local_state_dict)

<All keys matched successfully>

Let us wrap this code in a function.

In [6]:
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from llmtuto.config import DATA_DIR
from llmtuto.model.pretrained_transformer import GPT2

logging.basicConfig(level=logging.INFO)

gpt2 = GPT2(model='small', save_dir=DATA_DIR / 'pretrained')

for model in ['small', 'medium', 'large', 'xl']:
    gpt2 = GPT2(model=model, save_dir=DATA_DIR / 'pretrained')

INFO:llmtuto.model.pretrained_transformer:Loading GPT2-small tokenizer from Tiktoken
INFO:llmtuto.model.pretrained_transformer:Loading GPT2 model from /private/home/vivc/code/arc-agi/libraries/LLMTutorial/data/pretrained/gpt2-small.pt
  self.model.load_state_dict(torch.load(save_path))
INFO:llmtuto.model.pretrained_transformer:Loading GPT2-small tokenizer from Tiktoken
INFO:llmtuto.model.pretrained_transformer:Loading GPT2 model from /private/home/vivc/code/arc-agi/libraries/LLMTutorial/data/pretrained/gpt2-small.pt
INFO:llmtuto.model.pretrained_transformer:Loading GPT2-medium tokenizer from Tiktoken
INFO:llmtuto.model.pretrained_transformer:Loading GPT2 model from /private/home/vivc/code/arc-agi/libraries/LLMTutorial/data/pretrained/gpt2-medium.pt
INFO:llmtuto.model.pretrained_transformer:Loading GPT2-large tokenizer from Tiktoken
INFO:llmtuto.model.pretrained_transformer:Loading GPT2 model from /private/home/vivc/code/arc-agi/libraries/LLMTutorial/data/pretrained/gpt2-large.pt
INFO:l

## Sentence generation

GPT models are useful to generate sentences, eventually conditioned on a prompt.
To turn the numbers into words, we need to load the tokenizer.
The tokenizer compatible with GPT-2 is available from OpenAI tiktoken library.

In [7]:
# model to generate from
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
transformer_model = gpt2
transformer_model.to(device)
transformer_model.eval();

In [8]:
import tiktoken

prompts = [
    "Machine Learning is ",
    "I am a big fan of large language models.",
    "Women are specially good at ",
]

# tokenize sentence
tokenizer = tiktoken.get_encoding('gpt2')

if isinstance(prompts, str):
    prompts = [prompts]
tokens = [tokenizer.encode(seq) for seq in prompts]

seq_len = 100
random = False
temperature = 1.

if random:
    def choice_token_from_logit(logits):
        logits /= temperature
        probs = F.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1)
else:
    def choice_token_from_logit(logits):
        return torch.argmax(logits, dim=-1, keepdim=True)

# handling different sentences lengths
lenghts = [len(seq) for seq in tokens]
nb_sentences, max_len, min_len = len(prompts), max(lenghts), min(lenghts)

seq_idx = torch.zeros((nb_sentences, max_len), dtype=torch.long, device=device)
mask = torch.zeros((nb_sentences, max_len - min_len), dtype=torch.bool, device=device)
for i, seq in enumerate(tokens):
    seq_idx[i, :len(seq)] = torch.tensor(seq, dtype=torch.long, device=device)
    mask[i, :len(seq) - min_len] = 1

for i in range(min_len, max_len):
    logits = transformer_model(seq_idx[:, :i])[:, -1, :]
    next_token = choice_token_from_logit(logits).squeeze()
    torch.where(mask[:, i - min_len], seq_idx[:, i], next_token, out=seq_idx[:, i])

# generation of new tokens
with torch.no_grad():
    for i in range(seq_len):
        logits = transformer_model(seq_idx)[:, -1, :]
        next_token = choice_token_from_logit(logits)
        seq_idx = torch.cat([seq_idx, next_token], dim=-1)

sentences = [tokenizer.decode(list(seq)) for seq in seq_idx]
[print(sen, end='<EOS>\n') for sen in sentences];

Machine Learning is  a branch of computer science that deals with the problem of learning from data.  It is a branch of computer science that deals with the problem of learning from data.  It is a branch of computer science that deals with the problem of learning from data.  It is a branch of computer science that deals with the problem of learning from data.  It is a branch of computer science that deals with the problem of learning from data.  It is a branch of computer science that deals with the problem<EOS>
I am a big fan of large language models. I think they are a great way to model complex systems. I also think that they are a great way to model the human mind.

I have been working on a large language model for a while now. It is a model of the human mind. It is a model of the human brain. It is a model of the human mind-computer interface. It is a model of the human mind-brain interface. It is a model of the human mind-brain interface. It is a<EOS>
Women are specially good at 

Note that the inner loop refills all tokens in the sentence in every loop, which does not utilize efficiently the regressive nature of causal transformer.
We will see how to do better in a following notebook.

Let us wrap this in a function.

In [9]:
from llmtuto.sample import language_generator

prompts = [
    "Hey there, tell me everything about large language models.",
    "I am a big fan of large language models.",
    "I am a big fan of large language models.",
]

transformer_model = transformer_model.to(device).eval()

sentences = language_generator(prompts, tokenizer, transformer_model, seq_len=100, random=True, temperature=1., device=device)

for sentence in sentences:
    print(sentence)
    print()

Hey there, tell me everything about large language models.

(v0.1.0, last commit: May 05, 2012)

Describes the subclass names, their string representations in the Int32 and their Scala.Vase representation. (Thanks Aron Bréchon for contributions.)

(Note: We'll be gradually phasing out those older versions, but that's down the road.)

Represents class array indices and subarrays, not the array objects themselves.

Specifies a special band

I am a big fan of large language models. They give you a machine-legality for the complex semantics that comes to mind during a human introduction. You can then write software that represents that, and departures from that are a lot easier than when your semantic is so fancy. I'm thinking about one that would have "UB" as an derivation of V and "GAU" as a transcription of GC. In software people like working with this kind of space, especially with UB as a future formal logic, it allows complex semantics to

I am a big fan of large language models. The

In [10]:
del transformer_model, sentence, logits, next_token, seq_idx, gpt2, tokens

## Loading Mistral 7B

Let us redo everything from scratch with the Mistral 7B model.
We will load the weight directly from Mistal website.
Note that we could equally retake their code, which is available from their GitHub.

In [15]:
from pathlib import Path

import torch

from llmtuto.model.transformer import TransformerConfig, CausalTransformer
from llmtuto.config import DATA_DIR

# run on bash or with subprocess
# cd DATA_DIR / 'pretrained'
# wget https://files.mistral-7b-v0-1.mistral.ai/mistral-7B-v0.1.tar
# tar -xf mistral-7B-v0.1.tar

model_path = DATA_DIR / 'mistral-7B-v0.1'
model_path = Path('/private/home/vivc/models/mistral-7B-v0.1')

In [12]:
config_args = dict(
    vocab_size= 32_000,
    emb_dim=4096,
    pos_emb=False,
    seq_len=4096,
    n_head=32,
    attn_bias=False,
    rope=True,
    rope_theta=10_000,
    activation="swiglu",
    ffn_dim=14336,
    ffn_bias=False,
    norm="rms",
    norm_bias=False,
    pre_norm=True,
    n_layer=32,
    attn_downsampling=4,
    norm_eps = 1e-3
)

config = TransformerConfig(**config_args)
model = CausalTransformer(config)

In [16]:
correspondence = {
    "embeddings.token_emb.weight": "tok_embeddings.weight",
    "output_norm.weight":          "norm.weight",
    "output.weight":               "output.weight",
}
special = []
special_keys = ["attention.wq.weight", "attention.wk.weight", "attention.wv.weight"]
for layer in range(config.n_layer):
    correspondence = correspondence | {
        f"blocks.{layer}.norm_1.weight":         f"layers.{layer}.attention_norm.weight",
        f"blocks.{layer}.attn.query.weight":     f"layers.{layer}.attention.wq.weight",
        f"blocks.{layer}.attn.key.weight":       f"layers.{layer}.attention.wk.weight",
        f"blocks.{layer}.attn.value.weight":     f"layers.{layer}.attention.wv.weight",
        f"blocks.{layer}.attn.output.weight":    f"layers.{layer}.attention.wo.weight",
        f"blocks.{layer}.norm_2.weight":         f"layers.{layer}.ffn_norm.weight",
        f"blocks.{layer}.ffn.fc1.weight":        f"layers.{layer}.feed_forward.w1.weight",
        f"blocks.{layer}.ffn.fc2.weight":        f"layers.{layer}.feed_forward.w2.weight",
        f"blocks.{layer}.ffn.swiglu_mat.weight": f"layers.{layer}.feed_forward.w3.weight",
    }

local_state_dict = model.state_dict()
mistral_state_dict = torch.load(model_path / 'consolidated.00.pth')
for k in local_state_dict:
    if k in correspondence:
        local_state_dict[k] = mistral_state_dict[correspondence[k]]

model.load_state_dict(local_state_dict)

  mistral_state_dict = torch.load(model_path / 'consolidated.00.pth')


<All keys matched successfully>

Let us load the model tokenizer, which is based on `sentencepiece`.

In [17]:
from sentencepiece import SentencePieceProcessor

tokenizer = SentencePieceProcessor(model_file=str(model_path / 'tokenizer.model'))

In [18]:
device = 'cuda:0'

seq_idx = torch.tensor(tokenizer.encode('Hey there, tell me everything about large language models.'), device=device).unsqueeze(0)
model = model.to(device, torch.float16)

for i in range(50):
    logits = model(seq_idx)[:, -1, :]
    next_token = torch.argmax(logits, dim=-1, keepdim=True)
    seq_idx = torch.cat([seq_idx, next_token], dim=-1)

print(tokenizer.decode([a.item() for a in seq_idx[0]]), flush=True)

  return t.to(


Hey there, tell me everything about large language models.

## What are large language models?

Large language models are a type of machine learning model that is trained on a large amount of text data. These models are able to generate new text that is similar in style and content to the text


In [19]:
del model, seq_idx, next_token, logits
torch.cuda.empty_cache()

## Multi-GPU fast inference

Even in half precision, Mistral is quite big and requires too much memory for long sentence generation on my single 16GB Quadro GPU.

Yet, I have a second GPU, which is helpful to share the model weights on two GPUs, and be able to generate longer sentences.

#### Naive solution: Cut the model in half at the middle

Let us load a model, and cast the first layers on GPU 0, and the last layers on GPU 1.

In [None]:
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from llmtuto.model.pretrained_transformer import GPT2, Mistral 
from llmtuto.config import DATA_DIR

logging.basicConfig(level=logging.INFO)

# module = GPT2(model='small', save_dir=DATA_DIR / 'pretrained')
module = Mistral(save_dir=DATA_DIR)
model = module.model
tokenizer = module.tokenizer
print(model)
vocab_size = model.embeddings.token_emb.weight.size(0)

INFO:mathllm.model.pretrained_transformer:Loading Mistral model from mistral checkpoint
INFO:mathllm.model.pretrained_transformer:Loading Mistral tokenizer from mistral checkpoint


CausalTransformer(
  (embeddings): Embedding(
    (token_emb): Embedding(32000, 4096)
  )
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (norm_1): RMSNorm()
      (attn): SelfAttention(
        (query): Linear(in_features=4096, out_features=4096, bias=False)
        (key): Linear(in_features=4096, out_features=1024, bias=False)
        (value): Linear(in_features=4096, out_features=1024, bias=False)
        (output): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (norm_2): RMSNorm()
      (ffn): FeedForward(
        (fc1): Linear(in_features=4096, out_features=14336, bias=False)
        (fc2): Linear(in_features=14336, out_features=4096, bias=False)
        (swiglu_mat): Linear(in_features=4096, out_features=14336, bias=False)
      )
    )
  )
  (output_norm): RMSNorm()
  (output): Linear(in_features=4096, out_features=32000, bias=False)
)


In [None]:
class Block1(nn.Module):
    """
    Cast first half of a model on a device
    
    Parameters
    ----------
    model : nn.Module
        neural network object
    device : torch.device
        device for computation

    See Also
    --------
    CausalTranformer
    """
    def __init__(self, model):
        super().__init__()

        # embeddings
        self.embeddings = model.embeddings

        # first blocks
        n_layer = len(model.blocks)
        self.blocks = nn.ModuleList([
            model.blocks[i] for i in range(n_layer // 2)
        ])
        self.dropout = model.dropout

    def forward(self, x):
        out = x
        out = self.embeddings(out)
        for block in self.blocks:
            out = block(out)
        return out

class Block2(nn.Module):
    """
    Cast second half of a model on a device
    
    Parameters
    ----------
    model : nn.Module
        neural network object
    device : torch.device
        device for computation

    See Also
    --------
    CausalTranformer
    """
    def __init__(self, model):
        super().__init__()

        # last blocks
        n_layer = len(model.blocks)
        self.blocks = nn.ModuleList([
            model.blocks[i] for i in range(n_layer // 2, n_layer)
        ])

        # output layer
        self.output_norm = model.output_norm
        self.output = model.output

        self.dropout = model.dropout

    def forward(self, x):
        out = x
        for block in self.blocks:
            out = block(out)
        out = self.output_norm(out)
        out = F.dropout(x, p=self.dropout, training=self.training)
        out = self.output(out)
        return out

In [None]:
device1 = "cuda:0"
device2 = "cuda:1"
dtype = torch.float16
model_1 = Block1(model).to(device=device1, dtype=dtype)
model_2 = Block2(model).to(device=device2, dtype=dtype)
print(model_1)
print(model_2)

  return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)


Block1(
  (embeddings): Embedding(
    (token_emb): Embedding(32000, 4096)
  )
  (blocks): ModuleList(
    (0-15): 16 x TransformerBlock(
      (norm_1): RMSNorm()
      (attn): SelfAttention(
        (query): Linear(in_features=4096, out_features=4096, bias=False)
        (key): Linear(in_features=4096, out_features=1024, bias=False)
        (value): Linear(in_features=4096, out_features=1024, bias=False)
        (output): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (norm_2): RMSNorm()
      (ffn): FeedForward(
        (fc1): Linear(in_features=4096, out_features=14336, bias=False)
        (fc2): Linear(in_features=14336, out_features=4096, bias=False)
        (swiglu_mat): Linear(in_features=4096, out_features=14336, bias=False)
      )
    )
  )
)
Block2(
  (blocks): ModuleList(
    (0-15): 16 x TransformerBlock(
      (norm_1): RMSNorm()
      (attn): SelfAttention(
        (query): Linear(in_features=4096, out_features=4096, bias=False)
        (key): Lin

In [None]:
x = torch.randint(0, vocab_size, (2, 10)).to(device=device1)
with torch.no_grad():
    h = model_1(x)
    y = model_2(h.to(device2))

In [None]:
del model_1, model_2, x, h, y
torch.cuda.empty_cache()

#### Decreasing GPU idleness with model weight copy

In the previous setup, the second GPU is idle while the first one is doing the forward pass on the first part of the network.
However, we could copy the first layer of the model (now saved on the first GPU) to this second GPU, perform of the first layer forward pass on some other prompt in parallel. Then discard the first layer, and copy the second layer to the second GPU, and so on.

This quite generic idea is the backbone idea of [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed) original idea called ZeRO, for [zero redundancy optimizer](https://arxiv.org/abs/1910.02054).

Let us implement this for inference with the `asyncio` library Python.
First, let us write a basic example to get use to the library.

In [None]:
import asyncio

async def f():
    # Simulate some asynchronous task
    await asyncio.sleep(2)
    print("Function f is done")
    return "That was function f"

async def g():
    # Simulate another asynchronous task
    await asyncio.sleep(1)
    print("Function g is done")
    return "That was function g"

async def my_print(string, t):
    await asyncio.sleep(t)
    print(string)

# Create tasks for functions f and g
task_f = asyncio.create_task(f())
task_g = asyncio.create_task(g())

# Wait for both tasks to complete
result_f = await task_f
result_g = await task_g

task_f = asyncio.create_task(my_print(result_f, 1))
task_g = asyncio.create_task(my_print(result_g, .5))

await task_f
await task_g

Function g is done
Function f is done
That was function g
That was function f


For simplicity, let us distribute weights to GPU from CPU (rather than from GPU, which might be beneficial in terms of communication speed in practice).

In [None]:
model_1_cpu = Block1(model).to(device='cpu', dtype=dtype)
model_2_cpu = Block2(model).to(device='cpu', dtype=dtype)

In [None]:
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

time_sleep = 0
async def fsdp_inference(x, device):
    await asyncio.sleep(time_sleep)
    out = x.to(device)
    logger.info(f"first forward on {device}")
    with torch.no_grad():
        model = model_1_cpu.to(device)
        out = model(out)
        del model
        torch.cuda.empty_cache()
    await asyncio.sleep(time_sleep)
    logger.info(f"second forward on {device}")
    with torch.no_grad():
        model = model_2_cpu.to(device)
        out = model(out)
        del model
        torch.cuda.empty_cache()
    return out

vocab_size = 32_000
x1 = torch.randint(0, vocab_size, (2, 20)).to(device=device1)
x2 = torch.randint(0, vocab_size, (2, 20)).to(device=device2)

inf1 = asyncio.create_task(fsdp_inference(x1, device1))
inf2 = asyncio.create_task(fsdp_inference(x2, device2))

out1, out2 = await asyncio.gather(inf1, inf2)

del fsdp_inference
del inf1, inf2
del out1, out2
torch.cuda.empty_cache()

INFO:__main__:first forward on cuda:0
INFO:__main__:first forward on cuda:1
INFO:__main__:second forward on cuda:0
INFO:__main__:second forward on cuda:1


#### Native tools

The simple ideas presented above for the forward pass are the basis of multi-GPU processing.
However, coding it for foward, backward and optimization state is a serious engineering project, which hopefully, is been tackled by several open-source solution.

Native PyTorch tools implement these ideas in efficient way, and generalize them for the backward and the optimizer steps.
Those tools are part of the `torch.distributed` module.
At the time of writing, it has two main mode of operation: `ddp` (for distributed data parallel, which consists in processing data in parallel on different machines which all have a full copy of the model) and `fsdp` (fully sharded data parallel, which assimilates to what we have just seen).

- TODO: implement FSDP for mistral 7B inference (restart the notebook there)

#### Faster generation with key-value caching

The usefulness of causal transformer is their auto-regressive properties, which allow to generate sequences one token at a time.

However, we are currently not leveraging this auto-regressive property at inference time.
In particular, when a new token has such been generated, the generation of the next one only requires to compute the attention of the new token with the previous ones: there is no need to recompute the attention key and value for the previous tokens.

- TODO: do it