# Marathi GPT-2 Prompt-Response Testing

This notebook loads the latest checkpoint and generates a response to a given Marathi prompt.

In [1]:
import os
import torch
from marathi_gpt2 import GPT2Config, GPT2Model
from tokenizers import ByteLevelBPETokenizer

def get_latest_checkpoint(prefix="marathi_gpt2_step_"):
    ckpts = [f for f in os.listdir(".") if f.startswith(prefix) and f.endswith(".pt")]
    if not ckpts:
        return None
    return sorted(ckpts)[-1]

In [2]:
# Load latest checkpoint and tokenizer
ckpt_path = get_latest_checkpoint() or "marathi_gpt2.pt"
print(f"Loading checkpoint: {ckpt_path}")
config = GPT2Config()
model = GPT2Model(config)
model.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
model.eval()

TOKENIZER_DIR = "data/marathi_bpe_tokenizer"
tokenizer = ByteLevelBPETokenizer(
    os.path.join(TOKENIZER_DIR, "vocab.json"),
    os.path.join(TOKENIZER_DIR, "merges.txt")
)

Loading checkpoint: marathi_gpt2_step_1_300000.pt


In [6]:
# Text generation function
def generate_text(model, tokenizer, prompt, max_new_tokens=30):
    model.eval()
    input_ids = tokenizer.encode(prompt).ids
    input_tensor = torch.tensor([input_ids], dtype=torch.long)
    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(input_tensor)
            next_token_logits = logits[0, -1, :]
            next_token_id = torch.argmax(next_token_logits).item()
        input_ids.append(next_token_id)
        input_tensor = torch.tensor([input_ids], dtype=torch.long)
        # Stop if end-of-text token (if defined) is generated
        if next_token_id == tokenizer.token_to_id("</s>"):
            break
    return tokenizer.decode(input_ids)

# Example: Enter your prompt here
prompt = "माझं नाव आहे."
generated = generate_text(model, tokenizer, prompt, max_new_tokens=30)
print("Prompt:", prompt)
print("Response:", generated)

Prompt: माझं नाव आहे.
Response: माझं नाव आहे.ात�तातातात�
