# Marathi GPT-2 Checkpoint Test

This notebook demonstrates how to load a saved checkpoint of the Marathi GPT-2 model and test it on sample input.

In [1]:
# Import required libraries
import torch
from marathi_gpt2 import GPT2Config, GPT2Model

In [4]:
# Load the checkpoint
ckpt_path = "marathi_gpt2_step_1_100000.pt"  # Change to your checkpoint file if needed
config = GPT2Config()
model = GPT2Model(config)
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
model.eval()
print("Checkpoint loaded.")

Checkpoint loaded.


In [5]:
# Test the model on a random input tensor
import torch
seq_len = 10
batch_size = 1
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
with torch.no_grad():
    logits = model(input_ids)
print("Input shape:", input_ids.shape)
print("Logits shape:", logits.shape)

Input shape: torch.Size([1, 10])
Logits shape: torch.Size([1, 10, 32000])


In [6]:
# Example: Encode Marathi text, run through model, and decode
from tokenizers import ByteLevelBPETokenizer

# Load tokenizer
TOKENIZER_DIR = "data/marathi_bpe_tokenizer"
tokenizer = ByteLevelBPETokenizer(
    os.path.join(TOKENIZER_DIR, "vocab.json"),
    os.path.join(TOKENIZER_DIR, "merges.txt")
)

marathi_prompt = "माझं नाव गीता आहे."
input_ids = torch.tensor([tokenizer.encode(marathi_prompt).ids])
with torch.no_grad():
    logits = model(input_ids)
print("Prompt:", marathi_prompt)
print("Input IDs:", input_ids.tolist())
print("Logits shape:", logits.shape)

Prompt: माझं नाव गीता आहे.
Input IDs: [[161, 102, 111, 161, 102, 127, 161, 102, 256, 161, 102, 229, 225, 161, 102, 106, 161, 102, 127, 161, 102, 118, 225, 161, 102, 250, 161, 103, 227, 161, 102, 102, 161, 102, 127, 225, 161, 102, 233, 161, 102, 122, 161, 103, 234, 18]]
Logits shape: torch.Size([1, 46, 32000])


In [3]:
# List all available checkpoint files in the current directory
import os
[fn for fn in os.listdir(".") if fn.startswith("marathi_gpt2") and fn.endswith(".pt")]

['marathi_gpt2_step_1_100000.pt']

In [None]:
# Simple autoregressive text generation example
import torch
import os
from tokenizers import ByteLevelBPETokenizer

# Load tokenizer
TOKENIZER_DIR = "data/marathi_bpe_tokenizer"
tokenizer = ByteLevelBPETokenizer(
    os.path.join(TOKENIZER_DIR, "vocab.json"),
    os.path.join(TOKENIZER_DIR, "merges.txt")
)

def generate_text(model, tokenizer, prompt, max_new_tokens=30):
    model.eval()
    input_ids = tokenizer.encode(prompt).ids
    input_tensor = torch.tensor([input_ids], dtype=torch.long)
    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(input_tensor)
            next_token_logits = logits[0, -1, :]
            next_token_id = torch.argmax(next_token_logits).item()
        input_ids.append(next_token_id)
        input_tensor = torch.tensor([input_ids], dtype=torch.long)
        # Stop if end-of-text token (if defined) is generated
        if next_token_id == tokenizer.token_to_id("</s>"):
            break
    return tokenizer.decode(input_ids)

# Example usage
prompt = "माझं नाव आहे."
generated = generate_text(model, tokenizer, prompt, max_new_tokens=30)
print("Prompt:", prompt)
print("Generated:", generated)

Prompt: माझं नाव गीता आहे.
Generated: माझं नाव गीता आहे.������������������������������


In [9]:
prompt = "माझं नाव आहे."
generated = generate_text(model, tokenizer, prompt, max_new_tokens=30)
print("Prompt:", prompt)
print("Generated:", generated)

Prompt: माझं नाव आहे.
Generated: माझं नाव आहे.������������������������������
