In [None]:
import kagglehub
kagglehub.login()


In [None]:
henryalohfabian_langbridge_wazobia_dataset_path = kagglehub.dataset_download('henryalohfabian/langbridge-wazobia-dataset')

print('Data source import complete.')


In [None]:
!pip install trl
!pip install accelerate
!pip install peft
!pip install bitsandbytes
!pip install datasets
!pip install evaluate
!pip install safetensors
!pip install -q trl peft accelerate bitsandbytes
!pip install -q "transformers>=4.39.3" "datasets>=2.18.0"

# 1. Core Hugging Face Libraries
!pip install -q -U transformers accelerate datasets peft bitsandbytes

# 2. Evaluation & Metrics (CRITICAL for Translation)
!pip install -q evaluate sacrebleu

# 3. Tokenizer Support (CRITICAL for NLLB)
!pip install -q sentencepiece protobuf

# 4. Monitoring (Optional but recommended for Kaggle)
!pip install -q wandb


In [None]:
!pip install -U transformers accelerate bitsandbytes

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import torch
from transformers import Seq2SeqTrainer

print("Seq2SeqTrainer imported successfully!")
print(f"Torch version: {torch.__version__}")

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# 1. Standard Library & Hardware Setup
# ─────────────────────────────────────────────────────────────────────────────
import os
import gc  # Garbage collection for memory management
import torch
import numpy as np
from colorama import Fore, Style

# Set randomness for reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

# Quick Hardware Check
print(f"{Fore.CYAN}--- Environment Status ---{Style.RESET_ALL}")
print(f"PyTorch Version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"GPU Detected: {Fore.GREEN}{torch.cuda.get_device_name(0)}{Style.RESET_ALL}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print(f"{Fore.RED}WARNING: No GPU detected. Training will be extremely slow.{Style.RESET_ALL}")

# ─────────────────────────────────────────────────────────────────────────────
# 2. Hugging Face Datasets (Data Loading)
# ─────────────────────────────────────────────────────────────────────────────
# load_from_disk is critical for your pre-saved Arrow shards
from datasets import load_from_disk, load_dataset, DatasetDict

# ─────────────────────────────────────────────────────────────────────────────
# 3. Transformers (Model & Tokenizer)
# ─────────────────────────────────────────────────────────────────────────────
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,      # <--- STRICTLY for NLLB (Encoder-Decoder)
    BitsAndBytesConfig,         # For 4-bit Quantization (QLoRA)
    DataCollatorForSeq2Seq,     # Dynamic padding for translation batches
    Seq2SeqTrainer,             # The specialized trainer for Translation
    Seq2SeqTrainingArguments,   # Arguments specific to generation tasks
    TrainerCallback,            # For your custom printer
    GenerationConfig            # To control inference parameters during eval
)

# ─────────────────────────────────────────────────────────────────────────────
# 4. PEFT (Parameter-Efficient Fine-Tuning)
# ─────────────────────────────────────────────────────────────────────────────
from peft import (
    LoraConfig,
    TaskType,                   # To specify SEQ_2_SEQ_LM
    get_peft_model,
    prepare_model_for_kbit_training,
    PeftModel                   # Useful for loading adapters later
)

# ─────────────────────────────────────────────────────────────────────────────
# 5. Evaluation Metrics
# ─────────────────────────────────────────────────────────────────────────────
import evaluate

# SacreBLEU is the standard for Machine Translation evaluation
# It requires: pip install sacrebleu
try:
    metric = evaluate.load("sacrebleu")
    print(f"{Fore.GREEN}Metric 'sacrebleu' loaded successfully.{Style.RESET_ALL}")
except ImportError:
    print(f"{Fore.RED}Error: 'sacrebleu' library missing. Run: pip install sacrebleu{Style.RESET_ALL}")

# ─────────────────────────────────────────────────────────────────────────────
# 6. Weights & Biases (Optional but Recommended for Logging)
# ─────────────────────────────────────────────────────────────────────────────
# Ensures you can visualize loss curves even if Kaggle disconnects
import wandb
# wandb.login(key="YOUR_KEY_HERE") # Uncomment if you have a key

print(f"{Fore.CYAN}--- Import Complete ---{Style.RESET_ALL}")

In [None]:
quant_config = BitsAndBytesConfig(
    load_in_4bit = True, # selection of quantization level
    bnb_4bit_quant_type="nf4", # this parameter retains the quality of data even upon reduction of the it bits
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant = True,
)

In [None]:
from peft import LoraConfig, TaskType

peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=[
        "q_proj",
        "v_proj",
        "k_proj",
        "out_proj"
    ],
    bias="none"
)

In [None]:
model_name = "facebook/nllb-200-distilled-600M"
device_map = {"": torch.cuda.current_device()}
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    device_map= device_map,
    quantization_config = quant_config
)

In [None]:

model = get_peft_model(model, peft_config)

In [None]:
from transformers import DataCollatorForSeq2Seq
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    label_pad_token_id=-100
)

In [None]:
from datasets import load_from_disk
import os


dataset_path = "/kaggle/input/langbridge-wazobia-dataset/nllb-training-data-merged/content/drive/MyDrive/data/NLLB_FINAL_TRAIN_DATA"

if os.path.exists(dataset_path):
    print("Files in this folder:", os.listdir(dataset_path))
    train_dataset = load_from_disk(dataset_path)
    print("Success! Dataset loaded.")
else:
    print(f"Directory not found: {dataset_path}")

# 5. METRICS & CALLBACKS

In [None]:

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple): preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}

class PrinterCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, model, **kwargs):
        print(f"\n--- SAMPLE TRANSLATION (Step {state.global_step}) ---")
        device = model.device
        # Grab first 2 examples from val set
        inputs = [val_dataset[i] for i in range(2)]

        # Manually collate
        from torch.nn.utils.rnn import pad_sequence
        input_ids = [torch.tensor(x['input_ids']) for x in inputs]
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)

        with torch.no_grad():
            # Force max length to prevent infinite loops
            gen_ids = model.generate(input_ids=input_ids, max_new_tokens=60)

        in_text = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        out_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

        for i in range(len(in_text)):
            print(f"In : {in_text[i]}")
            print(f"Out: {out_text[i]}")
            print("---")

In [None]:

val_path = "/kaggle/input/langbridge-wazobia-dataset/igbo_test_val-20251130T173420Z-1-001/igbo_test_val"

if os.path.exists(val_path):
    val_dataset = load_from_disk(val_path)
    print(f"Validation set loaded from: {val_path}")
else:
  
    print("Specific val set not found")

print(f"Training on {len(train_dataset)} examples.")
print(f"Validating on {len(val_dataset)} examples (Igbo Only).")

In [None]:
OUTPUT_DIR = "/kaggle/working/langbridge_AI"

training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    push_to_hub=True,
    hub_model_id="coded-by-49/Lang_bridge_AI",
    hub_strategy="checkpoint",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    num_train_epochs=1,
    weight_decay=0.01,
    eval_strategy="steps", 
    eval_steps=500,
    save_strategy="steps",       
    save_steps=500,
    logging_steps=100,
    fp16=True,
    remove_unused_columns=False,
    report_to="none",            
    predict_with_generate=True,  
    generation_max_length=60,
    save_total_limit=2
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,    
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics, 
    callbacks=[PrinterCallback()]     
)



# --- 7. EXECUTION (CRASH PROOF WRAPPER) ---
import traceback

print("Starting Training...")

try:
    trainer.train()

except Exception as e:
    print(f"\n\nCRITICAL ERROR DURING TRAINING: {e}")
    print(traceback.format_exc())
    print("\nAttempting to save latest checkpoint before exiting...")

    # Attempt to save whatever we have
    trainer.save_model(os.path.join(OUTPUT_DIR, "crash_salvage_adapter"))
    tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, "crash_salvage_adapter"))
    print("Salvage successful. Exiting gracefully so Kaggle saves the output.")

# --- 8. NORMAL SAVING ---
# This runs regardless of success or failure
print("Saving final artifacts...")
final_save_path = os.path.join(OUTPUT_DIR, "final_adapter")
trainer.model.save_pretrained(final_save_path)
tokenizer.save_pretrained(final_save_path)
print(f"DONE! Model saved to: {final_save_path}")