<a href="https://colab.research.google.com/github/kimestelle/llm-chatbot/blob/main/aristotle_llama.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- Last sequence in dataset was randomly taken from earlier to make even batches
- vocab size: 1000
- number of lines: 48780
- sequence length: 383
- batch size: 36

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import math
import sentencepiece as spm
from collections import OrderedDict
import gc

from google.colab import drive
drive.mount('/content/drive')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
cpu


# Model Params

In [None]:
MASTER_CONFIG = {
    "dim": 4096,
    "n_layers": 16,
    "n_heads": 16,
    "n_kv_heads": 4,
    "vocab_size": 1000,
    "multiple_of": 256,
    "ffn_dim_multiplier": None,
    "norm_eps": 1e-5,
    # "max_batch_size": 36,
    # "batch_size": 36
    # "max_batch_size": 18,
    # "batch_size": 18
    "max_batch_size": 4,
    "batch_size": 4,
    "max_seq_len": 412,
    "seq_len": 412,
    "device": device,
}

print(MASTER_CONFIG["device"])

cpu


# Import data from SentencePiece

In [None]:
spm_model = spm.SentencePieceProcessor()
spm_model.load('/content/drive/MyDrive/philosophy_data/a.model')

pad_id = spm_model.pad_id()

def decode_without_padding(sequence, spm_model):
    filtered_sequence = [token_id for token_id in sequence if token_id != pad_id]
    return spm_model.decode_ids(filtered_sequence)

vocab_size = spm_model.get_piece_size()
print(spm_model.decode_ids(spm_model.encode_as_ids('hello world')))

hello world


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# with open('/content/drive/MyDrive/philosophy_data/padded_processed_aristotle.txt', 'r') as file:
#     lines = file.readlines()

# data = [list(map(int, line.split())) for line in lines]

# dataset = torch.zeros((len(data), MASTER_CONFIG["seq_len"]), dtype=torch.int64)

# for i in range(len(data)):
#     seq = data[i]
#     dataset[i, :len(seq)] = torch.tensor(seq, dtype=torch.int64)

# print(dataset.shape)

#Generate Training Data

In [None]:
def get_batches(data, split, batch_size, config=MASTER_CONFIG):
    print(config["device"])

    train = data[:int(.8 * len(data))]
    val = data[int(.8 * len(data)): int(.9 * len(data))]
    test = data[int(.9 * len(data)):]

    if split == 'train':
        batch_data = train
    elif split == 'val':
        batch_data = val
    elif split == 'test':
        batch_data = test
    else:
        raise ValueError("Invalid split name. Choose from 'train', 'val', or 'test'.")

    # Sample batch indices
    ix = torch.randint(0, batch_data.size(0) - 1, (batch_size,))

    # Generate x and y batches
    x = batch_data[ix].long().to(config["device"])
    y = batch_data[ix, 1:].long().to(config["device"])
    return x, y

#Full Llama

In [None]:
class Llama(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = nn.Embedding(config['vocab_size'], config['dim'])
        self.llama_blocks = nn.Sequential(
            OrderedDict([(f"llama_{i}", AttentionBlock(config)) for i in range(config['n_layers'])])
        )

        self.output_layer = nn.Linear(config['dim'], config['vocab_size'])
        self.freqs_complex = precompute_theta_pos_frequencies(
            head_dim=config['dim'] // config['n_heads'],
            seq_len=config['max_seq_len'],
            device=config['device']
        )


    def forward(self, idx, targets=None, start_pos=0):
        # print(f"Input idx shape before embedding: {idx.shape}")
        x = self.embeddings(idx)
        # print(f"Shape after embedding: {x.shape}")

        for llama_block in self.llama_blocks:
            x = llama_block(x, start_pos, self.freqs_complex)
            x = x.detach()  # Detach here if needed

        logits = self.output_layer(x)

        if targets is not None:
            # Ensure logits and targets are aligned properly
            logits = logits[:, :targets.size(1), :].contiguous()  # Adjust logits to match the target's seq_len
            # print(f"Adjusted logits shape: {logits.shape}")
            # print(f"Adjusted targets shape: {targets.shape}")

            loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
            return logits, loss
        else:
            return logits

# Pull data and turn into embeddings

In [None]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0, "Dimension must be divisible by 2"

    theta_numerator = torch.arange(0, head_dim, 2).float()
    # print(f"theta_numerator shape: {theta_numerator.shape}")

    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)
    # print(f"theta shape: {theta.shape}")

    m = torch.arange(seq_len, device=device)
    # print(f"m shape (sequence length): {m.shape}")

    freqs = torch.outer(m, theta).float()
    # print(f"freqs shape: {freqs.shape}")

    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    # print(f"freqs_complex shape: {freqs_complex.shape}")

    return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    # print(f"x shape: {x.shape}")
    # print(f"freqs_complex shape: {freqs_complex.shape}")

    # Reshape and ensure correct shape for complex tensor
    try:
        batch_size, seq_len, n_heads, head_dim = x.shape
        x_complex = torch.view_as_complex(x.float().reshape(batch_size, seq_len, n_heads, -1, 2))
    except ValueError as e:
        print(f"Error unpacking x shape: {x.shape}")
        raise e

    original_seq_len = freqs_complex.shape[0]

    if seq_len != original_seq_len:
        if seq_len < original_seq_len:
            freqs_complex = freqs_complex[:seq_len]
        else:
            repeat_factor = (seq_len + original_seq_len - 1) // original_seq_len
            freqs_complex = freqs_complex.repeat(repeat_factor, 1)[:seq_len]

    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)

    x_rotated = x_complex * freqs_complex
    x_out = torch.view_as_real(x_rotated).reshape(batch_size, seq_len, n_heads, head_dim)

    # print(f"x_complex shape: {x_complex.shape}")
    # print(f"freqs_complex shape: {freqs_complex.shape}")

    return x_out.type_as(x).to(device)

#Attention Layer

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config["n_heads"]
        self.dim = config["dim"]
        self.head_dim = self.dim // self.n_heads
        self.attention = SelfAttention(config)
        self.feed_forward = FeedForward(config)
        self.attention_norm = RMSNorm(config["dim"], eps=config["norm_eps"])
        self.ffn_norm = RMSNorm(config["dim"], eps=config["norm_eps"])

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        # print(f"Input x shape: {x.shape}")

        # If x has 4 dimensions, ensure they represent the expected structure
        if len(x.shape) == 4:
            batch_size, num_heads, seq_len, head_dim = x.shape
            if num_heads == self.n_heads and head_dim == self.head_dim:
                # Reshape to merge heads and sequence if necessary, keeping structure intact
                x = x.view(batch_size, seq_len, self.dim)
            else:
                print("Shape before attention:", x.shape)
                raise ValueError("Unexpected shape in attention block.")


        normed_x = self.attention_norm(x)
        # print(f"Normed x shape (after attention norm): {normed_x.shape}")

        attention_out = self.attention(normed_x, start_pos, freqs_complex)
        # print(f"Attention output shape: {attention_out.shape}")

        h = x + attention_out
        # print(f"Residual connection output h shape: {h.shape}")

        normed_h = self.ffn_norm(h)
        # print(f"Normed h shape (after feed-forward norm): {normed_h.shape}")

        ff_out = self.feed_forward(normed_h)
        # print(f"Feed-forward output shape: {ff_out.shape}")

        out = h + ff_out
        # print(f"Final output shape: {out.shape}")

        return out

## RMS Norm

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        # print(f"RMSNorm initialized with dim: {dim}, eps: {eps}")

    def _norm(self, x: torch.Tensor):
        # print(f"Input x shape in _norm: {x.shape}")
        norm = x.pow(2).mean(-1, keepdim=True)
        # print(f"Shape after mean in _norm: {norm.shape}")
        norm = torch.rsqrt(norm + self.eps)
        # print(f"Shape after rsqrt in _norm: {norm.shape}")
        return x * norm

    def forward(self, x: torch.Tensor):
        # print(f"Input x shape in forward: {x.shape}")
        normed_x = self._norm(x.float()).type_as(x)
        # print(f"Normed x shape in forward: {normed_x.shape}")
        weighted_x = self.weight * normed_x
        # print(f"Final output shape in forward: {weighted_x.shape}")
        return weighted_x

## Self-Attention

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Define the number of query heads (n_heads_q) and key-value heads (n_kv_heads)
        self.n_kv_heads = config["n_kv_heads"] if config["n_kv_heads"] is not None else config["n_heads"]
        self.n_heads_q = config["n_heads"]
        self.n_rep = self.n_heads_q // self.n_kv_heads  # How many times to repeat K and V heads
        self.head_dim = config["dim"] // self.n_heads_q

        # Linear transformations for Q, K, V, and output projection
        self.wq = nn.Linear(config["dim"], self.n_heads_q * self.head_dim, bias=False)
        self.wk = nn.Linear(config["dim"], self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(config["dim"], self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(self.n_heads_q * self.head_dim, config["dim"], bias=False)

        # Initialize key and value caches
        self.cache_k = torch.zeros((config["max_batch_size"], config["max_seq_len"], self.n_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((config["max_batch_size"], config["max_seq_len"], self.n_kv_heads, self.head_dim))

    @staticmethod
    def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
        batch_size, seq_len, n_kv_heads, head_dim = x.shape
        if n_rep == 1:
            return x
        return (
            x[:, :, :, None, :]  # (B, Seq_Len, N_KV_Heads, 1, Head_Dim)
            .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)  # (B, Seq_Len, N_KV_Heads, N_Rep, Head_Dim)
            .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)  # (B, Seq_Len, N_KV_Heads * N_Rep, Head_Dim)
        )

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        batch_size, seq_len, _ = x.shape
        # print(f"Input x shape: {x.shape}")

        # Compute Q, K, V
        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)
        # print(f"xq shape after wq: {xq.shape}")
        # print(f"xk shape after wk: {xk.shape}")
        # print(f"xv shape after wv: {xv.shape}")

        # Reshape to add head dimensions
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        # print(f"xq shape after view: {xq.shape}")
        # print(f"xk shape after view: {xk.shape}")
        # print(f"xv shape after view: {xv.shape}")

        # Apply rotary embeddings
        xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device)
        xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device)
        # print(f"xq shape after rotary embeddings: {xq.shape}")
        # print(f"xk shape after rotary embeddings: {xk.shape}")

        # Update and use cache
        self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
        self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv
        # print(f"cache_k shape after update: {self.cache_k.shape}")
        # print(f"cache_v shape after update: {self.cache_v.shape}")

        # Move KV cache to the device (GPU)
        self.cache_k = self.cache_k.to(x.device)
        self.cache_v = self.cache_v.to(x.device)

        keys = self.cache_k[:batch_size, :start_pos + seq_len]
        values = self.cache_v[:batch_size, :start_pos + seq_len]
        # print(f"keys shape after cache slice: {keys.shape}")
        # print(f"values shape after cache slice: {values.shape}")

        # Repeat K and V heads to match Q heads
        keys = self.repeat_kv(keys, self.n_rep)
        values = self.repeat_kv(values, self.n_rep)
        # print(f"keys shape after repeat_kv: {keys.shape}")
        # print(f"values shape after repeat_kv: {values.shape}")

        # Transpose to match expected shapes for matmul
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        # print(f"xq shape after transpose: {xq.shape}")
        # print(f"keys shape after transpose: {keys.shape}")
        # print(f"values shape after transpose: {values.shape}")

        # Attention mechanism
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        # print(f"scores shape after matmul: {scores.shape}")

        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        # print(f"scores shape after softmax: {scores.shape}")

        output = torch.matmul(scores, values)
        # print(f"output shape after matmul: {output.shape}")

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        # print(f"output shape after transpose and view: {output.shape}")

        final_output = self.wo(output)
        # print(f"Final output shape: {final_output.shape}")

        return final_output


## Feed Forward Neural Network

In [None]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        hidden_dim = 4 * config["dim"]
        hidden_dim = int(2 * hidden_dim / 3)
        if config["ffn_dim_multiplier"] is not None:
            hidden_dim = int(config["ffn_dim_multiplier"] * hidden_dim)
            hidden_dim = config["multiple_of"] * ((hidden_dim + config["multiple_of"] - 1) // config["multiple_of"])

        self.w1 = nn.Linear(config["dim"], hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, config["dim"], bias=False)
        self.w3 = nn.Linear(config["dim"], hidden_dim, bias=False)

        # print(f"FeedForward initialized with input dim: {config['dim']} and hidden dim: {hidden_dim}")

    def forward(self, x: torch.Tensor):
        # print(f"Input x shape: {x.shape}")

        swish = F.silu(self.w1(x))
        # print(f"Shape after w1 and SiLU activation: {swish.shape}")

        x_V = self.w3(x)
        # print(f"Shape after w3: {x_V.shape}")

        x = swish * x_V
        # print(f"Shape after element-wise multiplication: {x.shape}")

        x = self.w2(x)
        # print(f"Shape after w2 (final output): {x.shape}")

        return x

#Training Func

In [None]:
def train(model, optimizer, config=MASTER_CONFIG, epochs=10, print_logs=True):
    total_samples = len(dataset)
    batch_size = config['batch_size']
    steps_per_epoch = total_samples // batch_size  # 1355

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        # loop over all batches
        for step in range(steps_per_epoch):
            print(f"Step {step}/{steps_per_epoch} - Epoch {epoch+1}/{epochs}")
            xs, ys = get_batches(dataset, 'train', batch_size)

            logits, loss = model(xs, ys)
            print(f"Logits shape: {logits.shape}, Loss: {loss.item()}")

            optimizer.zero_grad()
            print("Running backward pass...")
            try:
                loss.backward()
                print(f"After backward pass, step {step}")
            except RuntimeError as e:
                print(f"Error during backward pass: {e}")
                raise

            optimizer.step()
            print("Optimizer step completed.")

            total_loss += loss.item()
            del xs, ys, logits, loss

        # print avg loss per epoch
        if print_logs:
            avg_loss = total_loss / steps_per_epoch
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

#Master Function

In [None]:
# aristotle_llama_model = Llama(MASTER_CONFIG).to(device)
# optimizer = torch.optim.AdamW(aristotle_llama_model.parameters(), lr=3e-4)
# train(aristotle_llama_model, optimizer, config=MASTER_CONFIG)
# torch.save(aristotle_llama_model.state_dict(), 'aristotle_llama_model.pth')

# Inference

In [None]:
def generate_response(model, tokenizer, prompt, max_length=512, device='cpu'):
    model.eval()
    model.to(device)

    # Tokenize and truncate to max_seq_len
    input_tokens = tokenizer.encode_as_ids(prompt)
    input_tokens = input_tokens[:model.config['max_seq_len']]  # Truncate to max_seq_len
    input_tensor = torch.tensor(input_tokens, dtype=torch.long, device=device).unsqueeze(0)

    generated = input_tensor

    with torch.no_grad():
        for _ in range(max_length):
            logits = model(generated)
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
            generated = torch.cat((generated, next_token), dim=1)

            # Early stopping if model predicts the end-of-sequence token
            if next_token.item() == tokenizer.pad_id() or len(generated[0]) >= model.config['max_seq_len']:
                break

    # Decode the generated tokens back to text
    output_tokens = generated[0].tolist()
    output_text = tokenizer.decode(output_tokens)
    return output_text


# Load the model and state dictionary
model = Llama(MASTER_CONFIG).to(device)
checkpoint = torch.load('/content/drive/MyDrive/philosophy_data/checkpoint_step_340.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Prompt for generation
prompt = "Hi Aristotle"
response = generate_response(model, spm_model, prompt, max_length=512, device=device)
print(response)

  checkpoint = torch.load('/content/drive/MyDrive/philosophy_data/checkpoint_step_340.pth')


Hi Aristotleasant objection of injustice present advantageate soundstection strength cannot be indivisible line political friendshipstection strength cannot be indivisible line political friendshipstection strength into two footed him treatmentstexious temperance with regard to be indivisible line political friendshipally significal question whether So much for instance women feeling him treatment ashenderical question whether So much for instance women feeling him treatment ashenderical question whether So much for instance women feeling him treatment ashenderious temperance with regard to be indivisible line political friendshipally significal question whether So much for instance women feeling him treatment occurroundarent exercise Oneaminewoundarent exercise Oneaminewoundarent exercise Oneaminewoundarent exercise Oneaminewoundarent exercise Oneaminewoundarent exercise Oneaminewoundarent exercise Oneaminewoundarent exercise Oneaminewoundarent exercise Oneaminewound about next questi

# Precompute Rotary Embedding Freqs

In [None]:
freqs_complex = precompute_theta_pos_frequencies(
    head_dim=MASTER_CONFIG['dim'] // MASTER_CONFIG['n_heads'],
    seq_len=MASTER_CONFIG['max_seq_len'],
    device=MASTER_CONFIG['device']
)