# Code 7

In [None]:
!pip install transformers datasets einops pytorch_lightning tensorboard

Collecting transformers
  Downloading transformers-4.34.0-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m65.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m49.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pytorch_lightning
  Downloading pytorch_lightning-2.0.9.post0-py3-none-any.whl (727 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m727.7/727.7 kB[0m [31m55.1 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.17.3-py3-none-any.whl (295 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from datasets import load_dataset
import math
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import multiprocessing

import pytorch_lightning as pl
from einops import rearrange # einstein operation

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class InferenceParams(nn.Module):
    def __init__(self):

        self.weight_decay = 0.01
        self.adam_beta1 = 0.9
        self.adam_beta2 = 0.98
        self.adam_epsilon = 1e-7

        self.layer_norm_epsilon = 1e-05

        self.embd_pdrop = 0.0
        self.resid_pdrop = 0.0
        self.attention_pdrop = 0.0

        self.activation_function = "GELU"

        self.n_epochs = 10
        self.initializer_range = 0.02
        self.learning_rate = 5e-4 #3e-4
        self.rotary_dim = 10 #10
        self.n_layer = 4 #4
        self.hidden_size = None
        self.n_head =  8 #8
        self.n_embd =  280  #280
        self.vocab_size = 50257
        self.max_sequence_len = 256 #512
        self.max_batch_size = 32 #32


config = InferenceParams()


In [None]:
sample_train = 10 ** 5
sample_val = int(.1 * sample_train)

dataset = load_dataset("roneneldan/TinyStories")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token

train_subset = dataset['train'][:sample_train]['text']
val_subset = dataset['validation'][:sample_val]['text']

tokenized_trainset = tokenizer(
    train_subset,
    return_tensors='pt',
    padding='max_length',  # Pad sequences to the max_seq_length
    truncation=True,  # Truncate sequences if they exceed max_seq_length
    max_length=config.max_sequence_len  # Set the maximum sequence length
)

tokenized_valset = tokenizer(
    val_subset,
    return_tensors='pt',
    padding='max_length',  # Pad sequences to the max_seq_length
    truncation=True,  # Truncate sequences if they exceed max_seq_length
    max_length=config.max_sequence_len  # Set the maximum sequence length
)

Downloading readme:   0%|          | 0.00/946 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/249M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [None]:
class CustomDataset(Dataset):
    def __init__(self, tokenized_data):
        self.data = tokenized_data

    def __len__(self):
        return len(self.data['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.data['input_ids'][idx],
            'attention_mask': self.data['attention_mask'][idx]
        }

train_data = CustomDataset(tokenized_trainset)
val_data = CustomDataset(tokenized_valset)

cpu_count = multiprocessing.cpu_count()

train_loader = DataLoader(train_data, batch_size=config.max_batch_size, shuffle=True, num_workers=cpu_count)
val_loader = DataLoader(val_data, batch_size=config.max_batch_size, num_workers=cpu_count)

In [None]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)

    def forward(self, input_ids):
        input_shape = input_ids.shape[-1]
        input_ids = input_ids.view(-1, input_shape)

        hidden_states = self.wte(input_ids)
        hidden_states = self.drop(hidden_states)

        return hidden_states

In [None]:
class RotaryPositionEmbedding(nn.Module):
    def __init__(self, config, base = 10000):
        super().__init__()
        self.dim = config.rotary_dim

        inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2) / self.dim))
        self.register_buffer("inv_freq", inv_freq)

        self.cos_cache = None
        self.sin_cache = None

    def forward(self, qkv):
        seqlen = qkv.shape[1]
        # Update cos sin cache
        t = torch.arange(seqlen, device = qkv.device)
        freqs = torch.outer(t, self.inv_freq)

        self.cos_cache = torch.cos(freqs)
        self.sin_cache = torch.sin(freqs)

        # Apply rotary qkv
        rotary_dim = self.cos_cache.shape[1]
        rotary_dim *= 2


        q_rot = qkv[:, :, 0, :, :rotary_dim]
        q_pass = qkv[:, :, 0, :, rotary_dim:]

        k_rot = qkv[:, :, 1, :, :rotary_dim]
        k_pass = qkv[:, :, 1, :, rotary_dim:]

        # Splits the queries and keys in half
        q1, q2 = q_rot.chunk(2, dim=-1)
        k1, k2 = k_rot.chunk(2, dim=-1)
        c = rearrange(self.cos_cache, "t d -> t 1 d")
        s = rearrange(self.sin_cache, "t d -> t 1 d")

        # Computes the new keys and queries
        q_rot = torch.cat([q1 * c - q2 * s, q1 * s - q2 * c], dim=-1)
        k_rot = torch.cat([k1 * c - k2 * s, k1 * s - k2 * c], dim = -1)

        return torch.cat(
            [
                torch.cat([q_rot, q_pass], dim=-1).unsqueeze(2),
                torch.cat([k_rot, k_pass], dim=-1).unsqueeze(2),
                qkv[:, :, 2:3, :, :]
            ],
            dim=2
        )

In [None]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        n_inner = 4 * config.n_embd if config.hidden_size is None \
                                            else config.hidden_size

        self.fc1 = nn.Linear(config.n_embd, n_inner)
        self.fc2 = nn.Linear(n_inner, config.n_embd)
        self.act = getattr(torch.nn, config.activation_function)()

    def forward(self, hidden_states):
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.fc2(hidden_states)

        return hidden_states

In [None]:
class SelfAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.drop = nn.Dropout(config.attention_pdrop)

    def forward(self, qkv, attention_mask = None):
        batch_size, seq_len = qkv.shape[0], qkv.shape[1]
        q, k, v = qkv.unbind(2)

        softmax_scale = 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum("bthd, bshd -> bhts", q, k * softmax_scale)

        if attention_mask is not None:
            padding_mask = torch.full((batch_size, seq_len), -10000.0, device=scores.device)
            padding_mask.masked_fill_(attention_mask, 0.0)

            scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")

        casual_mask = torch.triu(torch.full((seq_len, seq_len), -10000, device=scores.device), 1)
        scores += casual_mask

        attention = torch.softmax(scores, dim=-1)
        attention = self.drop(attention)

        output = torch.einsum("bhts, bshd -> bthd", attention, v)

        return output

In [None]:
class MHA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.rotary_emb = RotaryPositionEmbedding(config)

        self.head_dim = config.n_embd // config.n_head
        opt_size = config.n_head * self.head_dim
        hidden_size = config.n_embd

        self.Wqkv = nn.Linear(hidden_size, 3 * opt_size)
        self.out_proj = nn.Linear(opt_size, hidden_size)

        self.inner_attn = SelfAttention()

    def forward(self, x, attention_mask = None):
        qkv = self.Wqkv(x)
        qkv = rearrange(qkv, 'b t (three h d) -> b t three h d', three=3, d=self.head_dim)

        qkv = self.rotary_emb(qkv)

        if attention_mask is not None:
            attention_mask = attention_mask.bool().to(qkv.device)

        output = self.inner_attn(qkv, attention_mask)

        output = rearrange(output, "... h d -> ... (h d)")
        attn_out = self.out_proj(output)

        return attn_out

In [None]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

        self.attn = MHA(config)
        self.ffwd = MLP(config)

    def forward(self, hidden_states, attention_mask = None):
        residual = hidden_states
        hidden_states = self.ln(hidden_states)

        attn_out = self.attn(hidden_states, attention_mask)
        ffwd_out = self.ffwd(hidden_states)

        attn_out = self.resid_dropout(attn_out)
        ffwd_out = self.resid_dropout(ffwd_out)

        output = attn_out + ffwd_out + residual
        return output

In [None]:
class LMHead(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.linear = nn.Linear(config.n_embd, config.vocab_size)

    def forward(self, output):
        output = self.ln(output)
        logits = self.linear(output)

        return logits

In [None]:
class SequentialForLM(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.initializer_range = config.initializer_range

        modules = [Embedding(config)]
        modules += [Block(config) for _ in range(config.n_layer)]
        modules.append(LMHead(config))

        self.layers = nn.Sequential(*modules)

        self.apply(self._init_weights)

    def forward(self, input_ids, attention_mask = None):
        if attention_mask is not None and self.training:
            print("`attention_mask` is not supported during training. Using it might lead to unexpected results.")

        if attention_mask is None:
            logits = self.layers(input_ids)
        else:
            hidden_layer = self.layers[0](input_ids)
            for module in self.layers[1:-1]:
                hidden_layer = module(hidden_layer, attention_mask=attention_mask)
            logits = self.layers[-1](hidden_layer)

        return logits

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

In [None]:
class LMLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.loss_fct = nn.CrossEntropyLoss()

    def forward(self, logits, labels):

        logits = logits[..., :-1, :].contiguous()
        labels = labels[..., 1:].contiguous()

        loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        return loss

In [None]:
class ModelForVisualization(pl.LightningModule):
    def __init__(self, config, is_load_state_dict, model_path):
        super().__init__()

        self.model_path = model_path

        self.weight_decay = config.weight_decay
        self.betas = (config.adam_beta1, config.adam_beta2)
        self.epsilon = config.adam_epsilon

        self.learning_rate = config.learning_rate

        self.model = SequentialForLM(config)

        if is_load_state_dict:
            self.model.load_state_dict(torch.load(self.model_path,  map_location=torch.device('cpu')))

        self.loss = LMLoss()


    def forward(self, input_ids, attention_mask=None):
        return self.model(input_ids, attention_mask)

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']

        logits = self(input_ids)
        loss = self.loss(logits, input_ids)

        self.log("train loss", loss, prog_bar=True, on_step=True, on_epoch=True)

        if (batch_idx + 1) % 1000 == 0:
            torch.save(self.model.state_dict(), self.model_path)

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        logits = self(input_ids, attention_mask)
        loss = self.loss(logits, input_ids)

        self.log("valid loss", loss, prog_bar=True, on_step=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
            betas=self.betas,
            eps=self.epsilon
        )

        return optimizer

In [None]:
model_path = '/content/drive/MyDrive/microsof_phi15_model.pth'
model = ModelForVisualization(config, True, model_path)
logger = pl.loggers.TensorBoardLogger('/content/drive/MyDrive/')

trainer = pl.Trainer(max_epochs=config.n_epochs, logger=logger, log_every_n_steps=1)

trainer.fit(model, train_loader, val_loader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type            | Params
------------------------------------------
0 | model | SequentialForLM | 32.0 M
1 | loss  | LMLoss          | 0     
------------------------------------------
32.0 M    Trainable params
0         Non-trainable params
32.0 M    Total params
127.881   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
%reload_ext tensorboard
%tensorboard --logdir "/content/drive/MyDrive/lightning_logs" --port 6007