# ByT5 Zero-Shot Morphological Inflection
This notebook demonstrates zero-shot morphological inflection using the ByT5 model. It loads the model once, reads test data, performs batched predictions, and saves results.

## 1. Load Required Libraries and ByT5 Model
This cell loads PyTorch, Hugging Face Transformers, tqdm, and the ByT5 model/tokenizer. The model is loaded once and moved to GPU if available.

In [1]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm import tqdm

# Use GPU if available
if not torch.cuda.is_available():
    print("CUDA is not available. Exiting.")
    raise SystemExit

device = torch.device("cuda")
model_name = "google/byt5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = model.to(device)
print(f"Model loaded on device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Model loaded on device: cuda


## 2. Read and Preview Test Data
This cell reads the test set (e.g., eng.tst) and previews a few examples.

In [None]:
# Path to test set (English)
test_path = "../baseline/data/eng.tst"

# Read test set
test_lines = []
with open(test_path, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if line:
            test_lines.append(line)

# Preview first 5 examples
for line in test_lines[:5]:
    print(line)

## 3. Batched Zero-Shot Prediction
This cell runs ByT5 in batches, using a task-specific prompt, and displays predictions with a progress bar.

In [None]:
batch_size = 16
results = []
start_time = time.time()
num_examples = len(test_lines)
for i in tqdm(range(0, num_examples, batch_size), desc="Predicting", unit="batch"):
    batch_lines = test_lines[i:i+batch_size]
    batch_inputs = []
    for line in batch_lines:
        parts = line.split("\t")
        if len(parts) < 2:
            continue
        lemma, features = parts[0], parts[1]
        input_str = f"Inflect the following verb: {lemma} {features}"
        batch_inputs.append(input_str)
    if not batch_inputs:
        continue
    inputs = tokenizer(batch_inputs, padding=True, return_tensors="pt").to(device)
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_length=32)
    predictions = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    for inp, pred in zip(batch_inputs, predictions):
        results.append((inp, pred))
elapsed = time.time() - start_time
print(f"Total time: {elapsed:.2f} seconds")
# Show first 5 predictions
for inp, pred in results[:5]:
    print(f"Input: {inp}\tPrediction: {pred}")

## 4. Save Predictions to File
This cell saves all predictions to a text file for further analysis.

In [None]:
output_path = "output/predictions_eng.txt"
with open(output_path, "w", encoding="utf-8") as f:
    for inp, pred in results:
        f.write(f"{inp}\t{pred}\n")
print(f"Predictions saved to {output_path}")