## Setup

In [None]:
# Imports

import torch
from torch.nn import DataParallel
import torch.nn.functional as F
from dotenv import load_dotenv
import wandb
from accelerate import Accelerator, init_empty_weights, load_checkpoint_and_dispatch
from helpers.memory import check_memory

load_dotenv('secrets.env')
main_device = 'cuda:0'
check_memory()

In [None]:
# Test Multi-GPU Access

# Verify Pytorch can communicate with each GPU
def check_communication():
    for i in range(torch.cuda.device_count()):
        device = torch.device(f"cuda:{i}")
        try:
            x = torch.tensor([1.0, 2.0, 3.0], device = device)
            print(f"GPU {i}: Computation successful.")
        except Exception as e:
            print(f"GPU {i}: Computation failed. Error: {e}")

check_communication()

# Initialize accelerator
accelerator = Accelerator(
    device_placement = True,
    mixed_precision = 'bf16',
    gradient_accumulation_steps = 1, # Temp
    split_batches = False
)

## Test Inference with Base HF Model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False, padding_side = 'right')

model = AutoModelForCausalLM.from_pretrained(
    'allenai/OLMoE-1B-7B-0924',
    device_map = 'auto', 
    max_memory = {i: "20GB" for i in range(torch.cuda.device_count())},
    torch_dtype = torch.bfloat16,
    trust_remote_code = True
)

check_memory()

In [None]:
layer.mlp.experts

In [None]:
# Test with .pipeline()

@torch.no_grad()
def eval_model_v1(model, tokenizer, prompt):
    tokens = tokenizer(prompt, return_tensors = "pt").to(device)
    res = model.generate(
        **tokens,
        max_new_tokens = 32,
        do_sample = False,
        eos_token_id = [tokenizer.eos_token_id]
        )
    print(res)
    return tokenizer.batch_decode(res)[0]

print(eval_model_v1(
    model,
    tokenizer,
    'I am a dog and I like to eat. My favorite food is'
))

In [None]:
# Test with token-by-token generation
 
@torch.no_grad()
def eval_model_v2(model, tokenizer, prompt):
    tokens = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']
    i = 1
    while i <= 1:
        output = model(tokens)
        logits = output['logits']
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        print(output_token)
        tokens = torch.cat((tokens, output_token.view(1, 1)), dim = 1)
                
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")]:
            break

        i = i + 1
    
    print(tokens)
    return tokenizer.batch_decode(tokens)[0]

print(eval_model_v2(
    model,
    tokenizer,
    'I am a dog and I like to eat. My favorite food is'
))

## Reverse Engineer the Class

In [None]:
prompt = 'I am a dog and I like to eat. My favorite food is' # Correct next token output is 'steak'
inputs = tokenizer(prompt, return_tensors = 'pt').to(device)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

In [None]:
# Split off LM head
with torch.no_grad():
    # Everything before the LM head
    decoder_output = model.model(
        input_ids,
        attention_mask
    )['last_hidden_state']
    # The LM head
    output_logits = model.lm_head(decoder_output)


output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))

In [None]:
# Split off LM head
with torch.no_grad():

    # Everything before the LM head
    decoder_output = model.model(
        input_ids,
        attention_mask
    )['last_hidden_state']
    
    # The LM head
    output_logits = model.lm_head(decoder_output)


output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))

In [None]:
# Split out decoder object into component transformer layers
with torch.no_grad():

    # Embedding layer
    embeds_output = model.model.embed_tokens(input_ids)
    B, N, D = embeds_output.shape

    cache_position = torch.arange(0, N, device = embeds_output.device)
    position_ids = cache_position.unsqueeze(0)

    causal_mask = model.model._update_causal_mask(attention_mask, embeds_output, cache_position, None, False)

    position_embeddings = model.model.rotary_emb(embeds_output, position_ids) # Position embeddings to be shared across the decoder layers

    hidden_state = embeds_output

    # Now iterate through the layers
    for i, layer in enumerate(model.model.layers):
        
        # We can ignore all the arguments related to caching/intermediate outputs
        layer_output = layer(
            hidden_state,
            causal_mask,
            position_ids,
            position_embeddings = position_embeddings
        )

        hidden_state = layer_output[0]
        
    # RMS Norm
    hidden_state = model.model.norm(hidden_state)
    
    # The LM head
    output_logits = model.lm_head(hidden_state)


output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))


In [None]:
model

In [None]:
# Further split out the transformer layers into self-attention/MLP components
with torch.no_grad():

    # Embedding layer
    embeds_output = model.model.embed_tokens(input_ids)
    B, N, D = embeds_output.shape

    cache_position = torch.arange(0, N, device = embeds_output.device)
    position_ids = cache_position.unsqueeze(0)

    causal_mask = model.model._update_causal_mask(attention_mask, embeds_output, cache_position, None, False)

    position_embeddings = model.model.rotary_emb(embeds_output, position_ids) # Position embeddings to be shared across the decoder layers

    hidden_state = embeds_output

    # Now iterate through the layers
    for i, layer in enumerate(model.model.layers):
        
        residual = hidden_state

        hidden_state = layer.input_layernorm(hidden_state)
        ### SA ###
        hidden_state, _, _ = layer.self_attn(
            hidden_state,
            attention_mask = causal_mask,
            position_ids = position_ids,
            position_embeddings = position_embeddings
        )
        hidden_state = residual + hidden_state.to(device)

        ### MLP ###
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)
        hidden_state, router_logits = layer.mlp(hidden_state)
        hidden_state = residual + hidden_state.to(device)
                
    # RMS Norm
    hidden_state = model.model.norm(hidden_state)
    
    # The LM head
    output_logits = model.lm_head(hidden_state)


output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))

In [None]:
# Further split out the self-attention & MLP components
from helpers.olmoe import apply_rotary_pos_emb
import math

with torch.no_grad():

    # Embedding layer
    embeds_output = model.model.embed_tokens(input_ids)
    B, N, D = embeds_output.shape

    cache_position = torch.arange(0, N, device = embeds_output.device)
    position_ids = cache_position.unsqueeze(0)

    causal_mask = model.model._update_causal_mask(attention_mask, embeds_output, cache_position, None, False)

    position_embeddings = model.model.rotary_emb(embeds_output, position_ids) # Position embeddings to be shared across the decoder layers

    hidden_state = embeds_output

    # Now iterate through the layers
    for i, layer in enumerate(model.model.layers):
        
        residual = hidden_state

        hidden_state = layer.input_layernorm(hidden_state)
        
        ### SA - replicates SPDA vers, not flash attention ###
        H = layer.self_attn.num_heads # Number of attention heads
        Dh = int(D/H) # Dimensions per head
        
        query_state = layer.self_attn.q_norm(layer.self_attn.q_proj(hidden_state)).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048
        key_state = layer.self_attn.k_norm(layer.self_attn.k_proj(hidden_state)).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048
        value_state = layer.self_attn.v_proj(hidden_state).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048

        cos, sin = position_embeddings
        query_state, key_state = apply_rotary_pos_emb(query_state, key_state, cos.to(device), sin.to(device))

        attn_weights = torch.matmul(query_state, key_state.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
        attn_weights = attn_weights + attention_mask # Attemtion mask is upper triangular of negative infinity

        attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(query_state.dtype)

        sa_output = torch.matmul(attn_weights, value_state) # B x H x N x D/H
        sa_output = sa_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
        sa_output = sa_output.reshape(B, N, D) # Concatenate vertically back into B x N x D

        # Finall post-concatenation linear layer
        sa_output = layer.self_attn.o_proj(sa_output)


        hidden_state = residual + sa_output.to(device)

        ### MLP ###
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)
        hidden_state, router_logits = layer.mlp(hidden_state)
        hidden_state = residual + hidden_state.to(device)
                
    # RMS Norm
    hidden_state = model.model.norm(hidden_state)
    
    # The LM head
    output_logits = model.lm_head(hidden_state)


output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))

In [None]:
position_embeddings

In [None]:
embeds_output.device