# play around with a trained model from outputs/

In [1]:
# Import necessary libraries
import os
import json
import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import math
from tqdm.notebook import tqdm
from transformers import AutoTokenizer

# Import our local modules
import sys
sys.path.append(".")  # Add the root directory to path
from lmtraining.config import Config, ModelConfig
from lmtraining.models.transformer import TransformerModel
from lmtraining.data.dataset import create_dataloaders

In [2]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [11]:
# Load the saved model
def load_model_from_checkpoint(checkpoint_dir):
    """Load a trained model from checkpoint directory."""
    # Check if directory exists
    if not os.path.exists(checkpoint_dir):
        raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
    
    # Load configuration
    config_path = os.path.join(checkpoint_dir, "config.json")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found: {config_path}")
    
    model_config = ModelConfig.from_json(config_path)
    
    # Load training arguments if available
    training_args_path = os.path.join(checkpoint_dir, "training_args.json")
    if os.path.exists(training_args_path):
        with open(training_args_path, 'r') as f:
            training_args = json.load(f)
        print(f"Loaded training arguments: {training_args}")
    
    # Create model instance
    model = TransformerModel(model_config)
    
    # Load model weights
    weights_path = os.path.join(checkpoint_dir, "pytorch_model.bin")
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"Model weights file not found: {weights_path}")
    
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model = model.to(device)
    model.eval()  # Set model to evaluation mode
    
    return model, model_config

In [4]:
# Load training logs to analyze training progress
def load_and_parse_logs(output_dir):
    """Parse training logs to extract loss values over time."""
    # Find all log files in the output directory
    checkpoint_dirs = glob.glob(os.path.join(output_dir, "checkpoint-*"))
    
    # Sort checkpoints by step number
    checkpoint_dirs.sort(key=lambda x: int(x.split("-")[-1]))
    
    # Collect training data
    training_data = []
    
    # Process each checkpoint
    for checkpoint_dir in checkpoint_dirs:
        step = int(checkpoint_dir.split("-")[-1])
        
        # Look for optimizer state which contains the global step
        optimizer_path = os.path.join(checkpoint_dir, "optimizer.pt")
        if os.path.exists(optimizer_path):
            optimizer_data = torch.load(optimizer_path, map_location="cpu")
            if "best_metric" in optimizer_data:
                best_metric = optimizer_data["best_metric"]
                perplexity = math.exp(best_metric) if best_metric < 20 else float('inf')
                training_data.append({
                    "step": step,
                    "loss": best_metric,
                    "perplexity": perplexity
                })
    
    # Convert to DataFrame
    if training_data:
        return pd.DataFrame(training_data)
    else:
        print("No training logs found.")
        return None

In [5]:
# Function to evaluate the model on text
def evaluate_perplexity(model, tokenizer, text, context_length=512):
    """Evaluate model perplexity on the given text."""
    model.eval()
    tokens = tokenizer.encode(text)
    
    # Process in overlapping chunks if the text is long
    stride = context_length // 2
    nlls = []
    
    # Process the text in chunks 
    for i in range(0, len(tokens) - stride, stride):
        chunk = tokens[i:i + context_length]
        input_ids = torch.tensor([chunk]).to(device)
        
        with torch.no_grad():
            outputs = model(input_ids=input_ids, labels=input_ids)
            
        # Get loss
        if isinstance(outputs, dict):
            neg_log_likelihood = outputs["loss"].item() * len(chunk)
        else:
            neg_log_likelihood = outputs[0].item() * len(chunk)
            
        nlls.append(neg_log_likelihood)
    
    if not nlls:
        return float('inf')
    
    # Calculate perplexity
    avg_nll = sum(nlls) / len(tokens)
    perplexity = math.exp(avg_nll)
    
    return perplexity

In [6]:
# Function to generate text from the model
def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0, top_k=50, top_p=0.95):
    """Generate text using the trained model."""
    model.eval()
    
    # Encode the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    # Initialize generation parameters
    generated = input_ids
    
    # Generate text
    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids=generated)
            
            if isinstance(outputs, dict):
                next_token_logits = outputs["logits"][:, -1, :] / temperature
            else:
                next_token_logits = outputs[0][:, -1, :] / temperature
            
            # Apply top-k filtering
            if top_k > 0:
                indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
                next_token_logits[indices_to_remove] = -float('Inf')
            
            # Apply top-p (nucleus) filtering
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                
                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                # Shift the indices to the right to keep the first token above threshold
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                
                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                next_token_logits[0, indices_to_remove] = -float('Inf')
            
            # Sample from the filtered distribution
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Add the token to the generated sequence
            generated = torch.cat((generated, next_token), dim=1)
            
            # Stop if we generate an EOS token
            if next_token.item() == tokenizer.eos_token_id:
                break
    
    # Decode the generated tokens
    generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
    return generated_text

In [7]:
# Plot training progress
def plot_training_progress(df):
    """Plot training loss and perplexity over time."""
    if df is None or len(df) == 0:
        print("No data to plot.")
        return
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
    
    # Plot loss
    sns.lineplot(x="step", y="loss", data=df, ax=ax1)
    ax1.set_title("Training Loss Over Time")
    ax1.set_xlabel("Training Step")
    ax1.set_ylabel("Loss")
    ax1.grid(True)
    
    # Plot perplexity
    sns.lineplot(x="step", y="perplexity", data=df, ax=ax2)
    ax2.set_title("Perplexity Over Time")
    ax2.set_xlabel("Training Step")
    ax2.set_ylabel("Perplexity")
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

In [12]:
# Example usage
model_path = "/home/marshall/lmodel/outputs/lm-wikitext/best_model"
logs_dir = "outputs/lm-wikitext"  # This should be the parent directory of all checkpoints

# Load tokenizer
tokenizer_name = "gpt2"  # Update this to match what you used during training
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model
print(f"Loading model from {model_path}...")
model, model_config = load_model_from_checkpoint(model_path)
print(f"Model loaded: {model_config.hidden_size} hidden size, {model_config.num_hidden_layers} layers")

Loading model from /home/marshall/lmodel/outputs/lm-wikitext/best_model...
Loaded training arguments: {'output_dir': 'outputs/lm-wikitext', 'seed': 42, 'train_batch_size': 8, 'eval_batch_size': 8, 'gradient_accumulation_steps': 4, 'learning_rate': 5e-05, 'weight_decay': 0.01, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 3, 'max_steps': -1, 'warmup_steps': 0, 'warmup_ratio': 0.1, 'logging_steps': 100, 'save_steps': 1000, 'save_total_limit': 3, 'evaluation_strategy': 'steps', 'eval_steps': 500, 'use_amp': True}
Model loaded: 256 hidden size, 6 layers


  model.load_state_dict(torch.load(weights_path, map_location=device))


In [13]:
print(f"Loading training logs from {logs_dir}...")
training_df = load_and_parse_logs(logs_dir)
if training_df is not None:
    print("Training data loaded. Here are the first few rows:")
    print(training_df.head())

Loading training logs from outputs/lm-wikitext...
No training logs found.
