# Marathi GPT-2 Checkpoint Test

This notebook demonstrates how to load a saved checkpoint of the Marathi GPT-2 model and test it on sample input.

In [3]:
# Import required libraries
import torch
from marathi_gpt2 import GPT2Config, GPT2Model

In [8]:
# Load the checkpoint
ckpt_path = "/home/lamrin/Projects/BolBot/marathi_gpt2_step_1_200000.pt"  # Change to your checkpoint file if needed
config = GPT2Config()
model = GPT2Model(config)
state = torch.load(ckpt_path, map_location="cpu")
missing, unexpected = model.load_state_dict(state, strict=False)
model.eval()
print("Checkpoint loaded.")
if missing:
    print("Missing keys:", missing)
if unexpected:
    print("Unexpected keys:", unexpected)

Checkpoint loaded.
Missing keys: ['h.0.norm1.weight', 'h.0.norm1.bias', 'h.0.self_attn.in_proj_weight', 'h.0.self_attn.in_proj_bias', 'h.0.self_attn.out_proj.weight', 'h.0.self_attn.out_proj.bias', 'h.0.norm2.weight', 'h.0.norm2.bias', 'h.0.linear1.weight', 'h.0.linear1.bias', 'h.0.linear2.weight', 'h.0.linear2.bias', 'h.1.norm1.weight', 'h.1.norm1.bias', 'h.1.self_attn.in_proj_weight', 'h.1.self_attn.in_proj_bias', 'h.1.self_attn.out_proj.weight', 'h.1.self_attn.out_proj.bias', 'h.1.norm2.weight', 'h.1.norm2.bias', 'h.1.linear1.weight', 'h.1.linear1.bias', 'h.1.linear2.weight', 'h.1.linear2.bias', 'h.2.norm1.weight', 'h.2.norm1.bias', 'h.2.self_attn.in_proj_weight', 'h.2.self_attn.in_proj_bias', 'h.2.self_attn.out_proj.weight', 'h.2.self_attn.out_proj.bias', 'h.2.norm2.weight', 'h.2.norm2.bias', 'h.2.linear1.weight', 'h.2.linear1.bias', 'h.2.linear2.weight', 'h.2.linear2.bias', 'h.3.norm1.weight', 'h.3.norm1.bias', 'h.3.self_attn.in_proj_weight', 'h.3.self_attn.in_proj_bias', 'h.3.sel

In [7]:
# Test the loaded checkpoint on a sample Marathi prompt
from tokenizers import ByteLevelBPETokenizer
import os
import torch
TOKENIZER_DIR = "data/marathi_bpe_tokenizer"
tokenizer = ByteLevelBPETokenizer(
    os.path.join(TOKENIZER_DIR, "vocab.json"),
    os.path.join(TOKENIZER_DIR, "merges.txt")
)
marathi_prompt = "माझं नाव गीता आहे."
input_ids = torch.tensor([tokenizer.encode(marathi_prompt).ids])
with torch.no_grad():
    logits = model(input_ids)
print("Prompt:", marathi_prompt)
print("Input IDs:", input_ids.tolist())
print("Logits shape:", logits.shape)
# Get the most probable next token
next_token_id = torch.argmax(logits[0, -1]).item()
print("Next token ID:", next_token_id)
print("Next token (decoded):", tokenizer.id_to_token(next_token_id) if hasattr(tokenizer, 'id_to_token') else '<no id_to_token>')

Prompt: माझं नाव गीता आहे.
Input IDs: [[161, 102, 111, 161, 102, 127, 161, 102, 256, 161, 102, 229, 225, 161, 102, 106, 161, 102, 127, 161, 102, 118, 225, 161, 102, 250, 161, 103, 227, 161, 102, 102, 161, 102, 127, 225, 161, 102, 233, 161, 102, 122, 161, 103, 234, 18]]
Logits shape: torch.Size([1, 46, 32000])
Next token ID: 17862
Next token (decoded): None


In [None]:
# Test the model on a random input tensor
import torch
seq_len = 10
batch_size = 1
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
with torch.no_grad():
    logits = model(input_ids)
print("Input shape:", input_ids.shape)
print("Logits shape:", logits.shape)

Input shape: torch.Size([1, 10])
Logits shape: torch.Size([1, 10, 32000])


In [None]:
# Example: Encode Marathi text, run through model, and decode
from tokenizers import ByteLevelBPETokenizer

# Load tokenizer
TOKENIZER_DIR = "data/marathi_bpe_tokenizer"
tokenizer = ByteLevelBPETokenizer(
    os.path.join(TOKENIZER_DIR, "vocab.json"),
    os.path.join(TOKENIZER_DIR, "merges.txt")
)

marathi_prompt = "माझं नाव गीता आहे."
input_ids = torch.tensor([tokenizer.encode(marathi_prompt).ids])
with torch.no_grad():
    logits = model(input_ids)
print("Prompt:", marathi_prompt)
print("Input IDs:", input_ids.tolist())
print("Logits shape:", logits.shape)

Prompt: माझं नाव गीता आहे.
Input IDs: [[161, 102, 111, 161, 102, 127, 161, 102, 256, 161, 102, 229, 225, 161, 102, 106, 161, 102, 127, 161, 102, 118, 225, 161, 102, 250, 161, 103, 227, 161, 102, 102, 161, 102, 127, 225, 161, 102, 233, 161, 102, 122, 161, 103, 234, 18]]
Logits shape: torch.Size([1, 46, 32000])


In [None]:
# List all available checkpoint files in the current directory
import os
[fn for fn in os.listdir(".") if fn.startswith("marathi_gpt2") and fn.endswith(".pt")]

['marathi_gpt2_step_1_100000.pt']

In [None]:
# Simple autoregressive text generation example
import torch
import os
from tokenizers import ByteLevelBPETokenizer

# Load tokenizer
TOKENIZER_DIR = "data/marathi_bpe_tokenizer"
tokenizer = ByteLevelBPETokenizer(
    os.path.join(TOKENIZER_DIR, "vocab.json"),
    os.path.join(TOKENIZER_DIR, "merges.txt")
)

def generate_text(model, tokenizer, prompt, max_new_tokens=30):
    model.eval()
    input_ids = tokenizer.encode(prompt).ids
    input_tensor = torch.tensor([input_ids], dtype=torch.long)
    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(input_tensor)
            next_token_logits = logits[0, -1, :]
            next_token_id = torch.argmax(next_token_logits).item()
        input_ids.append(next_token_id)
        input_tensor = torch.tensor([input_ids], dtype=torch.long)
        # Stop if end-of-text token (if defined) is generated
        if next_token_id == tokenizer.token_to_id("</s>"):
            break
    return tokenizer.decode(input_ids)

# Example usage
prompt = "माझं नाव आहे."
generated = generate_text(model, tokenizer, prompt, max_new_tokens=30)
print("Prompt:", prompt)
print("Generated:", generated)

Prompt: माझं नाव गीता आहे.
Generated: माझं नाव गीता आहे.������������������������������


In [None]:
prompt = """कारखान्यात काम करणाऱ्या कामगारांच्या कामाच्या वेळेची मर्यादा आता दिवसाला 9 तासांवरुन 12 तास करण्याच्या तरतुदीला """
generated = generate_text(model, tokenizer, prompt, max_new_tokens=30)
print("Prompt:", prompt)
print("Generated:", generated)

Prompt: कारखान्यात काम करणाऱ्या कामगारांच्या कामाच्या वेळेची मर्यादा आता दिवसाला 9 तासांवरुन 12 तास करण्याच्या तरतुदीला 
Generated: कारखान्यात काम करणाऱ्या कामगारांच्या कामाच्या वेळेची मर्यादा आता दिवसाला 9 तासांवरुन 12 तास करण्याच्या तरतुदीला ������������������������������
