In [None]:
'''
Per token Workflow
  Input_ids
  -> Hidden Dim
  -> q, k, v projs
  -> rotary pe
  -> self.n_repeats
  -> gqa attention
  -> out_proj
  ->ffn
  ->softmax
'''


'\nPer token Workflow\n  Input_ids\n  -> Hidden Dim\n  -> q, k, v projs\n  -> rotary pe\n  -> self.n_repeats\n  -> gqa attention\n  -> out_proj\n  ->ffn\n  ->softmax\n'

In [None]:
!pip install -U -q accelerate transformers[torch] datasets huggingface_hub

In [None]:
import gc
import sys
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from math import sqrt
from transformers import PretrainedConfig
import math
from typing import Tuple, Optional, List
from transformers import logging, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast  #used to compute output when passed previously computed key and value pairs for faster seqential decoding
#subclass of ModelOutput

In [None]:
def flush():
    gc.collect()
    torch.cuda.empty_cache()

def count_parameters(model):
    return f" Model size: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.2f}M parameters"


In [None]:
class BRXConfig(PretrainedConfig):

    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
      self,
      vocab_size=32000,
      hidden_size=1024, # 2048 Tiny LLaMA
      intermediate_size=2048,
      num_hidden_layers=6,
      num_attention_heads=16, # 32 Tiny LLaMA
      num_key_value_heads=2,
      hidden_act="silu",
      max_position_embeddings=2048, #max context length
      initializer_range=0.02,  #std used to initialize the weight
      rms_norm_eps=1e-6,
      use_cache=False,
      pad_token_id=None,
      bos_token_id=1,
      eos_token_id=2,
      pretraining_tp=1, # refers to training a model with a single positive example per batch during pretraining
      tie_word_embeddings=False, #refers to using the same weight matrix for the input embeddings and output softmax layer
      rope_theta=10000.0,
      rope_scaling=None,
      attention_bias=False,
      attention_dropout=0.0,
      use_bias=False,
      lm_head_bias=False,
      residual_dropout=0.0,
      device='cpu',
      **kwargs,
    ):

      self.vocab_size = vocab_size
      self.max_position_embeddings = max_position_embeddings
      self.hidden_size = hidden_size
      self.intermediate_size = intermediate_size
      self.num_hidden_layers = num_hidden_layers
      self.num_attention_heads = num_attention_heads

      self.num_key_value_heads = num_key_value_heads
      self.hidden_act = hidden_act
      self.initializer_range = initializer_range
      self.rms_norm_eps = rms_norm_eps
      self.pretraining_tp = pretraining_tp
      self.use_cache = use_cache
      self.rope_theta = rope_theta
      self.rope_scaling = rope_scaling
      self.attention_bias = attention_bias
      self.attention_dropout = attention_dropout
      self.residual_dropout = residual_dropout
      self.use_bias = use_bias
      self.lm_head_bias = lm_head_bias
      self.device = device

      super().__init__(
        bos_token_id=bos_token_id,
        eos_token_id=eos_token_id,
        **kwargs,
      )

In [None]:
def build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor:
  ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
  return torch.tril(ones).unsqueeze(0).unsqueeze(0) #[1, 1, seq_len, seq_len]

def repeat_kv(hidden_states:torch.Tensor, n_repeats:int):
  batch, n_kv_heads, seq_len, head_dim = hidden_states.shape
  if n_repeats == 1:
      return hidden_states
  hidden_states = hidden_states.unsqueeze(2).expand(batch, n_kv_heads, n_repeats, seq_len, head_dim) # (B, nh, T, hs) -> (B, nh, 1, T, hs) -> # (B, nh, n_repeats, T, hs)
  return hidden_states.reshape(batch, n_kv_heads * n_repeats, seq_len, head_dim) # # (B, nh * n_repeats, T, hs)

class RotaryPositionalEmbeddings(nn.Module):
  def __init__(self, dim, max_position_embeddings=2048, base=10000.0, device=None, scaling_factor=1.0):
    super().__init__()
    self.dim = dim
    self.max_position_embeddings = max_position_embeddings
    self.device=device
    self.scaling_factor = scaling_factor
    self.base = base

    #all theta_i's
    inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))  #[dim/2]
    self.register_buffer("inv_freq", inv_freq, persistent=False)


  def _set_sin_cos_cache(self, seq_len, device, dtype):
    self.max_seq_len_cached = seq_len
    m = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) #[seq_len]
    m = m/self.scaling_factor

    freqs = torch.outer(m, self.inv_freq)  #[seq_len, dim/2]

    emb = torch.cat((freqs, freqs), dim=-1) #[seq_len, dim/2 + dim/2]

    self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) # [seq_len, dim]
    self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # [seq_len, dim]

  @torch.no_grad()
  def forward(self, x, seq_len=None):
      # x: [bs, num_attention_heads, seq_len, head_dim]
    self._set_sin_cos_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

    return (
      self.cos_cached[:seq_len].to(dtype=x.dtype),  # till given seq_len, and all dim
      self.sin_cached[:seq_len].to(dtype=x.dtype)
    )

  def apply_rope(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids, unsqueeze_dim=1) -> torch.Tensor:
    #x [batch_size, n_heads, seq_len, embed_size]
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    # print(f"cos {cos.shape}")
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    x1 = x[..., : x.shape[-1] // 2] # (B, nh, T, hs/2) first halt of the last dimension
    x2 = x[..., x.shape[-1] // 2 :] # (B, nh, T, hs/2) second half
    rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
    roped = (x * cos) + (rotated * sin)
    return roped.to(dtype=x.dtype)



In [None]:
# attention class
class BRXAttention(nn.Module):
  def __init__(self, config:BRXConfig):
    super().__init__()
    self.config = config
    self.hidden_dim = hidden_dim = config.hidden_size
    self.n_heads = n_heads = config.num_attention_heads
    self.n_kv_heads = n_kv_heads = config.num_key_value_heads
    self.head_dim = head_dim = config.hidden_size // n_heads
    self.use_bias = config.use_bias

    if (head_dim * n_heads) != self.hidden_dim:
      raise ValueError(
        f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
        f" and `num_heads`: {self.n_heads})."
      )

    self.repeats = self.n_heads // self.n_kv_heads # q_per_kv

    self.q_proj = nn.Linear(hidden_dim, n_heads * head_dim, bias=self.use_bias) # equals to  dim
    self.k_proj = nn.Linear(hidden_dim, n_kv_heads * head_dim, bias=self.use_bias)  # less than dim
    self.v_proj = nn.Linear(hidden_dim, n_kv_heads * head_dim, bias=self.use_bias)
    self.o_proj = nn.Linear(n_heads * head_dim, hidden_dim, bias=self.use_bias)

    self.rotary_emb = RotaryPositionalEmbeddings(
        head_dim,
        max_position_embeddings=config.max_position_embeddings,
        device=config.device,
        base = self.config.rope_theta,
    )

    # self.kv_cache: Optional[KVCache] = None

  def forward(
      self,
      hidden_states :torch.Tensor,
      position_ids: Optional[torch.LongTensor] = None,
      mask: Optional[torch.BoolTensor] = None,
      ):

    B, T, _ = hidden_states.size() # bsz, seq_len, embed_dim
    queries = self.q_proj(hidden_states)
    keys = self.k_proj(hidden_states)
    values = self.v_proj(hidden_states)

    queries = queries.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # bsz, n_heads, seq_len, head_dim
    keys = keys.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)  # bsz, n_kv_heads, seq_len, head_dim
    values = values.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)  # bsz, n_kv_heads, seq_len, head_dim

    kv_seq_len = keys.shape[-2]

    cos, sin = self.rotary_emb(values, seq_len=kv_seq_len)

    queries = self.rotary_emb.apply_rope(queries, cos, sin, position_ids) # bsz, n_heads, seq_len, head_dim
    keys = self.rotary_emb.apply_rope(keys, cos, sin, position_ids) # bsz, n_kv_heads, seq_len, head_dim

    # print(keys.shape)
    # sys.exit()
    keys = repeat_kv(keys, self.repeats) # (B, nh * n_repeats, T, hs)
    values = repeat_kv(values, self.repeats)

    # make the parameters store adjacently in order to avoid buggs and improved performance
    if queries.device.type == "cuda" and mask is not None:
      queries = queries.contiguous()
      keys = keys.contiguous()
      values = values.contiguous()

    y = F.scaled_dot_product_attention(query=queries, key=keys, value=values) # (B, T, n_heads, head_dim)

    y = y.reshape(B, T, self.hidden_dim) # (B, T, hidden_dim)

    return self.o_proj(y) # (B, T, hidden_dim)

In [None]:
config = BRXConfig()
d_model=config.hidden_size
sequence_length = config.max_position_embeddings # max number of tokens ##2048
batch_size = 5
input_data = torch.rand((batch_size, sequence_length, d_model), device=config.device)
position_ids = torch.arange(sequence_length, dtype=torch.long, device=config.device).unsqueeze(0)  #[1, 2048]

attn = BRXAttention(config)
attn(input_data, position_ids).size()

torch.Size([5, 2048, 1024])

In [None]:
count_parameters(attn)

' Model size: 2.36M parameters'

In [None]:
class FeedForwardBlock(nn.Module):
    def __init__(self, hidden_dim, intermediate_dim): # in MLP: intermediate_dim= 4 * hidden_dim
        super().__init__()
        self.linear_1 = nn.Linear(hidden_dim, intermediate_dim)
        self.linear_2 = nn.Linear(hidden_dim, intermediate_dim) # Original: intermediate -> hidden.
        self.activation_fn = nn.SiLU()
        self.out_proj = nn.Linear(intermediate_dim, hidden_dim) # Original: dropout

    def forward(self, hidden_states):
        x_fc_1 = self.linear_1(hidden_states)
        x_fc_2 = self.linear_2(hidden_states)
        x = self.activation_fn(x_fc_1) * x_fc_2
        return self.out_proj(x)

In [None]:
class RMSNorm(nn.Module):
  def __init__(self, hidden_size, eps=1e-6):
    super().__init__()
    #trainable parameter
    self.weight = nn.Parameter(torch.ones(hidden_size))
    self.variance_epsilon = eps

  def forward(self, hidden_states):
    # hidden states [B, T, hs]
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True) # (1/n) * Σ x_i^2
    hidden_state = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    return self.weight * hidden_state.to(input_dtype)

In [None]:
class BRXBLock(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config

    self.hidden_dim = hidden_dim = config.hidden_size
    self.intermediate_dim = intermediate_dim = config.intermediate_size

    self.attn = BRXAttention(config)

    self.ffn = FeedForwardBlock(hidden_dim, intermediate_dim)
    self.input_norm = RMSNorm(hidden_dim, eps=config.rms_norm_eps)
    self.post_attention_norm = RMSNorm(hidden_dim, eps=config.rms_norm_eps)

  def forward(
        self,
        hidden_states,
        position_ids: Optional[torch.LongTensor] = None,
        mask: Optional[torch.Tensor] = None,
    ):
        r = self.attn(self.input_norm(hidden_states), mask, position_ids)
        h = hidden_states + r
        r = self.ffn(self.post_attention_norm(h))
        out = h + r
        return out


In [None]:
seq_len = 2048
block = BRXBLock(config)
position_ids = torch.arange(seq_len, dtype=torch.long, device = config.device)
block(input_data)

tensor([[[ 0.0251,  0.6258, -0.5950,  ...,  1.2009,  0.3768,  0.1205],
         [ 0.8227,  0.6238, -0.1216,  ...,  1.1480,  0.4677,  0.0830],
         [ 0.8946,  0.8118, -0.0104,  ...,  0.1553,  0.4919,  0.4164],
         ...,
         [ 0.7332,  0.3191,  0.5046,  ...,  0.6981,  0.4027,  0.9353],
         [ 0.6341, -0.2303,  0.3370,  ...,  0.5877,  1.1383,  0.4717],
         [ 0.8888,  0.2992,  0.0108,  ...,  0.6900,  0.6446,  1.2076]],

        [[ 0.9524,  0.6950, -0.4568,  ...,  0.5980,  0.6820,  0.2848],
         [ 0.4866,  0.6649, -0.0090,  ...,  0.8755,  0.6309, -0.2794],
         [ 0.4561,  0.7406, -0.1167,  ...,  0.8202,  0.1092, -0.0469],
         ...,
         [ 0.7212, -0.1806,  0.4116,  ...,  0.2203,  1.2673,  1.0953],
         [ 1.0061,  0.5808,  0.2973,  ...,  0.5449,  0.4272,  0.9056],
         [ 0.6391,  0.3757,  0.5734,  ...,  0.2448,  0.9701,  1.0828]],

        [[ 0.2343,  0.5440, -0.3211,  ...,  0.4665,  0.8679,  0.4817],
         [ 0.0765,  0.3270, -0.8029,  ...,  0

In [None]:
class BRXModel(nn.Module):
  def __init__(self, config:BRXConfig):
    super().__init__()
    self.config = config
    self.hidden_dim = hidden_dim = config.hidden_size
    self.vocab_size = vocab_size = config.vocab_size
    assert self.vocab_size > 0
    self.num_hidden_layers = num_hidden_layers = config.num_hidden_layers

    self.embed_tokens = nn.Embedding(vocab_size, hidden_dim)  # This will also get leared during the training
    self.blocks = nn.ModuleList(
      [BRXBLock(config) for _ in range(num_hidden_layers)]
    )
    self.norm = RMSNorm(hidden_dim, eps=config.rms_norm_eps)

  def forward(
        self,
        hidden_states: torch.Tensor,  # [batch_size, seq_len, ids]
        position_ids = None,
        mask: Optional[torch.Tensor] = None,
    ):

      x = self.embed_tokens(hidden_states) # [batch_size, seq_len, hidden_size]

      seq_len = hidden_states.size(1)
      if position_ids is None:
          position_ids = torch.arange(seq_len, dtype=torch.long, device=self.config.device).unsqueeze(0) #[1, seq_len]

      for b in self.blocks:
          x = b(x, position_ids, mask)

      return self.norm(x)  #[batch_size, seq_len, hidden_size]

In [None]:
class BRXPreTrainedModel(PreTrainedModel):
    config_class = BRXConfig
    base_model_prefix = "brx"
    supports_gradient_checkpointing = True
    _skip_keys_device_placement = "past_key_values"  # do not place the past computed key and values on the device such as GPU when the model is moved to gpu.

    def _init_weights(self, module):
        std = self.config.initializer_range #0.2
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)  #initialize the weights according with mean=0 std=0.2
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

class BRXForCausalLM(BRXPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.model = BRXModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=config.lm_head_bias)
        self.post_init()

    def forward(
        self,
        input_ids = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None, # []
    ):
      outputs = self.model(
        hidden_states=input_ids,
        mask=attention_mask,
        position_ids=position_ids,
      ) ## [batch_size, seq_len, hidden_dim]
      logits = self.lm_head(outputs)  #[batch_size, seq_len, vocab_size]
      logits = logits.float()

      loss = None
      if labels is not None:
            # shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous() #words taken [0 -> last-1]
            shift_labels = labels[..., 1:].contiguous() #words taken [1 -> last]
            # Flatten the tokens
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Ensure tensors are on the same device
            shift_labels = shift_labels.to(shift_logits.device)
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits, shift_labels)


      return CausalLMOutputWithPast(
          loss=loss,
          logits=logits,  #untransformed
      ) # [batch, seq_len, hidden_size]


In [None]:
device = "cpu"
config = BRXConfig(device=device)
batch_size = 5
sequence_length = config.max_position_embeddings
brx = BRXForCausalLM(config)
input_ids = torch.randint(1, config.vocab_size,  (batch_size, sequence_length), device = config.device)

In [None]:
outputs = brx(input_ids)
outputs

CausalLMOutputWithPast(loss=None, logits=tensor([[[-0.6471,  1.1560,  0.3832,  ..., -0.5267, -0.8781, -0.1881],
         [ 0.3825,  1.2757,  0.7336,  ...,  0.5424, -0.0634, -0.1192],
         [ 0.7989, -0.1757, -0.1991,  ..., -0.5755,  0.5808, -0.1110],
         ...,
         [-0.1600, -0.3447, -0.3126,  ..., -0.5578,  1.1859,  2.0270],
         [-0.2300,  0.0533, -0.8319,  ..., -0.4022, -0.6401,  0.9655],
         [-0.5259,  0.7973,  0.3103,  ..., -0.8209, -0.3896,  1.3625]],

        [[ 1.0229,  0.3054, -1.7574,  ..., -0.6747,  0.4738,  0.4320],
         [-0.7087,  0.6187,  0.0734,  ...,  0.4521, -0.4249, -1.1668],
         [ 1.3756,  0.2116, -0.0051,  ...,  0.0120, -0.9926,  0.3367],
         ...,
         [-0.7921,  0.7095,  1.0256,  ..., -0.6763, -0.4153,  1.0393],
         [-0.6158, -0.0083,  0.0076,  ...,  0.7838, -0.2367,  0.6458],
         [-0.6262,  0.1442,  0.1058,  ..., -0.2478, -0.3756,  0.5328]],

        [[-0.4676, -0.3840, -0.1927,  ..., -1.3687, -0.3806,  0.9396],
    

In [None]:
outputs[0].shape

torch.Size([5, 2048, 32000])

In [None]:
from datasets import load_dataset

torch.manual_seed(64)

train_dataset = load_dataset("huggingface-course/codeparrot-ds-train", split="train[:1%]")
val_dataset = load_dataset("huggingface-course/codeparrot-ds-valid")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading data:   0%|          | 0.00/8.25G [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Downloading data:   0%|          | 0.00/46.1M [00:00<?, ?B/s]

Generating validation split: 0 examples [00:00, ? examples/s]

In [None]:
from datasets import Dataset, DatasetDict
train_dataset= DatasetDict({'train': train_dataset})
train_dataset

DatasetDict({
    train: Dataset({
        features: ['repo_name', 'path', 'copies', 'size', 'content', 'license'],
        num_rows: 6067
    })
})

In [None]:
from transformers import AutoTokenizer

#using the mistral tokenizer which is byte pair encoded
tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    add_eos_roken=True,
    padding_side='right', #auto left for mistral based models
    )

In [None]:
tokenizer

LlamaTokenizerFast(name_or_path='mistralai/Mistral-7B-v0.1', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [None]:
tokenizer.pad_token = tokenizer.unk_token

In [None]:
text = "Hi I'm Brijesh and rij"
tokenizer(text).tokens()

['<s>', '▁Hi', '▁I', "'", 'm', '▁B', 'rij', 'esh', '▁and', '▁ri', 'j']

In [None]:
text = train_dataset['train'][1]['content']
for _ in range(5):
  text += text

In [None]:
context_length = 1024
tokens = tokenizer(
  train_dataset['train'][1]["content"],
  padding=True,
  truncation=True,
  max_length=context_length,
  return_overflowing_tokens=True,
  return_length=True,
)

# print(f"Input IDs length: {len(tokens['input_ids'])}")
# print(f"Input chunk lengths: {(tokens['length'])}")
# print(f"Chunk mapping: {tokens['overflow_to_sample_mapping']}")

In [None]:
import numpy as np

a = np.array(tokens['input_ids'])
a.shape
a[2][1021:]

array([0, 0, 0])

In [None]:
print(f"Input IDs length: {len(tokens['input_ids'])}")
print(f"Input chunk lengths: {(tokens['length'])}")

Input IDs length: 229
Input chunk lengths: [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]


In [None]:
from transformers import DataCollatorForLanguageModeling  #collates different data together

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="pt") #mlm=True randomly assigns the  [Mask] tokens to the data
sample = data_collator.torch_call([tokens])

In [None]:
sample['attention_mask']

tensor([[[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0]]])

In [None]:
sample['input_ids']

tensor([[[    1,  3504,    13,  ..., 28750,  2974, 28705],
         [    1,   422,   297,  ...,   548,    13, 17422],
         [    1,   277,  2022,  ...,     0,     0,     0]]])

In [None]:
outputs = brx(input_ids=sample['input_ids'][0], labels=sample['labels'][0])

In [None]:
outputs

CausalLMOutputWithPast(loss=tensor(10.5116, grad_fn=<NllLossBackward0>), logits=tensor([[[-0.7529,  0.2396, -0.2563,  ..., -1.2040, -1.3086, -1.0972],
         [-1.1731,  0.3528,  0.1746,  ..., -1.1790, -1.5978, -0.8979],
         [-0.9041, -0.1104, -0.0425,  ..., -1.3452, -1.4894, -0.5970],
         ...,
         [ 0.1771,  0.5808,  0.1633,  ..., -0.7305,  0.4999, -0.3825],
         [ 0.1788,  0.2870,  0.4717,  ..., -0.6304, -0.0235, -0.0608],
         [ 0.2922,  0.2619,  0.2344,  ..., -0.4923,  0.2149, -0.0949]],

        [[-0.7464, -0.1366, -0.6549,  ..., -0.5066, -0.4534, -0.3477],
         [-0.7560, -0.0853, -0.1236,  ..., -0.1410, -0.3382, -0.4144],
         [-0.4240, -0.0217, -0.1214,  ...,  0.0272, -0.7891, -0.3101],
         ...,
         [ 0.0941, -0.8053,  0.5980,  ..., -0.2144,  0.4841, -0.0680],
         [ 0.0788, -1.0478,  0.9114,  ..., -0.2046,  0.8424, -0.4675],
         [ 0.4607, -0.1349,  0.1804,  ...,  0.0939,  0.3921, -0.3445]],

        [[ 0.3538,  0.7373, -0.8192,

In [None]:
outputs.logits.shape  #batch_size, seq_len, vocab_size

torch.Size([3, 1024, 32000])

In [None]:
F.softmax(outputs.logits, dim=-1).shape

torch.Size([3, 1024, 32000])

In [None]:
ids = torch.argmax(F.softmax(outputs.logits, dim=-1), dim=-1)

In [None]:
tokenizer.batch_decode(ids)

In [None]:
tokenized_train_dataset = train_dataset.map(
    tokenize, batched=True, remove_columns=train_dataset.column_names
)

In [None]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    "./brx",
    per_device_train_batch_size=32,
    max_steps=2000,
    num_train_epochs=2,
    logging_steps=10,
    gradient_accumulation_steps=2,
    weight_decay=0.1,
    warmup_steps= 1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=500,
    fp16=True,
    push_to_hub=False,
)

In [None]:
trainer = Trainer(
    model=brx,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_train_dataset
)

In [None]:
trainer.train()

In [None]:
temperature = 1
top_k = None
top_p = None

# Generate the tokens one by one
for _ in range(10):
    # Get the logits from the model
    outputs = brx(input_ids)
    logits = outputs.logits[:, -1, :]

    # Apply temperature scaling
    logits = logits / temperature

    # Apply top-k or top-p sampling if specified
    if top_k is not None:
        logits = logits.topk(top_k, dim=-1)[0]
    elif top_p is not None:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
        sorted_indices_to_remove[:, 0] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits = logits.masked_fill(indices_to_remove, -float('inf'))

    # Sample the next token from the logits
    next_token_id = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)

    # Update the input with the new token
    input_ids = torch.cat([input_ids, next_token_id], dim=-1)

# Decode the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)