In [3]:
import torch
import evaluate
import pandas as pd
from datasets import Dataset, Audio
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from peft import PeftModel
from torch.utils.data import DataLoader
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [4]:

# ===================================================================================
# --- Step 1: Configuration - UPDATE THESE PATHS ---
# ===================================================================================
# The original model you fine-tuned
MODEL_NAME = "openai/whisper-large-v3"

# Path to the LoRA adapter you trained and saved
# This is the directory created by `trainer.save_model()`
ADAPTER_PATH = "/ocean/projects/cis250085p/shared/A_track/whisper-large-v3-lora-streaming/checkpoint-4000"

# Path to your test data metadata file
TEST_DATA_CSV = "/ocean/projects/cis250085p/shared/A_track/dev_test.json" # <-- IMPORTANT: Change this!

# The base path where your raw audio files are stored, same as in training
RAW_AUDIO_BASE_PATH = "/ocean/projects/cis250085p/shared/A_track/"

# Configuration for the prediction run
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2 # Adjust based on your GPU memory


In [6]:

# ===================================================================================
# --- Step 2: Load Your Fine-Tuned Model and Processor ---
# ===================================================================================
print("Loading model and processor...")

# Load the processor
processor = WhisperProcessor.from_pretrained(MODEL_NAME)

processor.tokenizer.set_prefix_tokens(language="kinyarwanda", task="transcribe")

# Load the base model in float16 for faster inference
base_model = WhisperForConditionalGeneration.from_pretrained(
    MODEL_NAME, torch_dtype=torch.float16,
    cache_dir = "/ocean/projects/cis250085p/shared/A_track"
).to(DEVICE)


# # # model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).to(DEVICE)
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH, is_trainable=True)

# model.enable_adapter() 

model.eval() # Set the model to evaluation mode

base_model.eval()
# --- OPTIONAL BUT RECOMMENDED: Apply optimizations from training ---
# 1. Use greedy search for maximum speed
# model.generation_config.num_beams = 1
# model.generation_config.do_sample = False
# 2. Compile the model if using PyTorch 2.0+
# model = torch.compile(model)

print("✅ Model loaded and configured for inference.")


Loading model and processor...


Could not load bitsandbytes native library: /lib64/libc.so.6: version `GLIBC_2.34' not found (required by /ocean/projects/cis250085p/shared/envPreproces/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cuda126.so)
Traceback (most recent call last):
  File "/ocean/projects/cis250085p/shared/envPreproces/lib/python3.11/site-packages/bitsandbytes/cextension.py", line 85, in <module>
    lib = get_native_library()
          ^^^^^^^^^^^^^^^^^^^^
  File "/ocean/projects/cis250085p/shared/envPreproces/lib/python3.11/site-packages/bitsandbytes/cextension.py", line 72, in get_native_library
    dll = ct.cdll.LoadLibrary(str(binary_path))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ocean/projects/cis250085p/shared/envPreproces/lib/python3.11/ctypes/__init__.py", line 454, in LoadLibrary
    return self._dlltype(name)
           ^^^^^^^^^^^^^^^^^^^
  File "/ocean/projects/cis250085p/shared/envPreproces/lib/python3.11/ctypes/__init__.py", line 376, in __init__
    self._handl

✅ Model loaded and configured for inference.




In [2]:
from safetensors.torch import load_file 
# Load the saved weights into memory
saved_adapter_weights = torch.load( f"{ADAPTER_PATH}/adapter_model.bin", map_location="GPU")

# Create a new dictionary to hold the corrected keys
new_state_dict = {}
# Loop through the saved keys and rename them to match the current model structure
for key, value in saved_adapter_weights.items():
    # This replaces the erroneous '...model.model...' with the correct '...model...'
    new_key = key.replace("base_model.model.model.", "base_model.model.", 1)
    new_state_dict[new_key] = value

# Load our corrected weights into the model.
# `strict=False` ignores any non-matching keys, which is what we want.
print("Loading corrected weights into the model...")
model.load_state_dict(new_state_dict, strict=False)


# --- 4. Finalize the Model for Training/Inference ---
model.eval() # Or 

NameError: name 'ADAPTER_PATH' is not defined

In [None]:
import torch

# Path to your saved adapter weights file
adapter_weights_path = "./A_track/whisper-large-v3-lora-streaming/checkpoint-4000/adapter_model.safetensors" 

print(f"Loading weights from: {adapter_weights_path}")
adapter_weights = torch.load(adapter_weights_path, weights_only=False)

# Let's inspect a key from one of the deeper layers in the decoder
# This key should have been trained and have non-zero values
key_to_check = "base_model.model.model.decoder.layers.15.self_attn.q_proj.lora_A.default.weight"

if key_to_check in adapter_weights:
    weights = adapter_weights[key_to_check]
    print(f"\nSuccessfully found key: {key_to_check}")
    print("Shape of weights tensor:", weights.shape)
    print("A small sample of the weights:")
    print(weights)
    print(f"\nMean of absolute values: {weights.abs().float().mean()}")
else:
    print(f"\n❌ ERROR: Could not find the key '{key_to_check}' in the adapter file.")
    print("This indicates a serious problem with the saved checkpoint.")

Loading weights from: ./A_track/whisper-large-v3-lora-streaming/checkpoint-4000/adapter_model.safetensors


UnpicklingError: invalid load key, '\xec'.

In [None]:
!pwd

/ocean/projects/cis250085p/shared


In [None]:
TD_df= pd.read_json(TEST_DATA_CSV).T
# TD_df["file_path"] = "processed/"+ TD_df["audio_path"] +".mel.pt"

In [None]:
TD_df["audio_path"] = TD_df.audio_path.str.replace("audio/","/ocean/projects/cis250085p/shared/track_a_audio_files/")  +".wav"

In [None]:
TD_df.audio_path.iloc[0]

'/ocean/projects/cis250085p/shared/track_a_audio_files/1739532284-OogTF7X5UsTPNsR9q4GLZYcJiKB2.wav'

In [None]:
import torchaudio

In [None]:

# ===================================================================================
# --- Step 3: Load and Prepare the Test Dataset ---
# ===================================================================================
print(f"Loading test dataset from: {TEST_DATA_CSV}")

# Load the metadata
test_dataset = Dataset.from_pandas(TD_df)

# Function to prepare a single example for the model
# It loads the audio and converts it into the 'input_features' the model expects
def prepare_dataset(example):
    audio_path = example["audio_path"]
    try:
        # --- THIS IS THE NEW, MORE ROBUST METHOD ---
        # 1. Load audio directly with torchaudio
        waveform, sample_rate = torchaudio.load(audio_path)

        # 2. The Whisper processor will handle resampling to 16kHz automatically
        # We pass the raw waveform and its original sample rate
        input_features = processor(
            waveform.squeeze(), # Remove channel dimension
            sampling_rate=sample_rate
        ).input_features[0]
        # --- END OF NEW METHOD ---

        # 3. Convert to float16 to match the model
        example["input_features"] = torch.from_numpy(input_features).to(torch.float16)

    except Exception as e:
        # This will now catch any errors from torchaudio or the processor
        print(f"Error processing {audio_path}: {e}")
        example["input_features"] = None # Mark as None to filter later
        
    return example

# Apply the preparation function
test_dataset = test_dataset.map(prepare_dataset, num_proc=4)
# Filter out any samples that failed to load
# test_dataset = test_dataset.filter(lambda example: example["input_features"] is not None)

print(f"✅ Test dataset prepared with {len(test_dataset)} samples.")


Loading test dataset from: /ocean/projects/cis250085p/shared/A_track/dev_test.json


Map (num_proc=4): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4632/4632 [06:42<00:00, 11.52 examples/s]
Filter: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4632/4632 [13:30<00:00,  5.72 examples/s]

✅ Test dataset prepared with 4632 samples.





In [None]:

# ===================================================================================
# --- Step 4: Run the Prediction Loop ---
# ===================================================================================
# The data collator just needs to organize the batch
def collate_fn(features):
    input_features = [{"input_features": feature["input_features"]} for feature in features]
    batch = processor.feature_extractor.pad(input_features, return_tensors="pt")
    return batch

# Create a DataLoader for efficient batching
test_dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)

predictions = []
references = test_dataset["transcription"] # Get all ground truth transcriptions

print("Running predictions on the test set...")
# Loop through the test data with a progress bar
for batch in tqdm(test_dataloader):
    # Move batch to the GPU
    inputs = batch["input_features"].to(DEVICE)

    # Run prediction
    with torch.no_grad(): # Disable gradient calculation for inference
        predicted_ids = model.generate(inputs)

    # Decode the predicted IDs to text
    transcriptions = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)
    predictions.extend(transcriptions)

print("✅ Prediction loop complete.")


Running predictions on the test set...


  0%|                                                                                                                                                                          | 0/290 [00:00<?, ?it/s]

  1%|█                                                                                                                                                               | 2/290 [00:45<1:48:36, 22.63s/it]


KeyboardInterrupt: 

In [None]:
references[:32]

["Akajerekani gateretse hasi ku butaka gapfundikiye neza, gapfundiki umuvuniko w'umweru n'akajerekani karasa n'umweru, akajerekani karimo amata.",
 'Uburyo bwo kwishyura hakoreshejwe ikoranabuhanga, kubitsa, kubikuza, kohereza amafaranga, ukoresheje terefone ngendanwa.',
 "Hano hari abantu batatu babiri muri bo ni ab'igitsina gabo, undi umwe usigaye ni uw'igitsina gore, bicaye ku ntebe z'ibara ry'umweru ndetse imbere yabo bari kunywa amata mu birahure.",
 "Abantu benshi bahagaze imbere y'inyubako ifite amarangi y'umutuku n'umweru, yanditseho amagambo ari mu rurimi rw'icyongereza ari mu ibara ry'umutuku, abagabo barimo bambaye imyenda itandukanye, hari bambaye amashati y'umweru ndetse n'amapantaro y'umukara, hari n'abambaye amakote na karavate, hari n'uwambaye ingofero na rinete.",
 "Ahantu heza hafite isuku ndetse n'amacumbi wasohokera ukaruhuka, hari amazi meza akorerwamo imyidagaduro ndetse na siporo ziruhura imiruho.",
 "Abantu batwara amapikipiki by'umwuga bahetse abagore ku mapiki

In [None]:
predictions

['akajerekani ketereze aasi kutaka kepfundi tieneza kepfundi tiju mufundiko u nyeru na akajerekani klasa na u nyeru akajerekani karimo amata',
 'uwurijobu kweishira hakureishishge ikorana wuhanga kubi ita kubi ikuza kuhiriza ama faranga ukureishishge telefonu jendan',
 'hano hari awa nu batatu babiri muri vo na vijitsi nagavo undi ume usigaye nu vijitsi nagori bicha iku newe zivarariji ume eru ndete imbere ya vo barikungwa amata mwira huri',
 'aba anu benshi wa gaze imbere iinyu baku fite ya maranji yumu tuku nu mngeru ya anditseo ama gamba rimu rimu ligu icho njereza harimu ngibara li jumu tuku aba gawa warimo wa mbae imienda ita ndokanya rabamba ya mashati yu mngeru dete na mapano yumu kaina mbae ama kote na klavate hainu wa mbae ingofero na rinete',
 'anoheza haftisukumdece na machombe wasoche rao karuhu kari amazmeza kwerikomu imetafatru ndetze na sporozi ruhura umuliza',
 'aba anu watuara ama pichi pichi pichumunga wa heze abagode kuma pichi pichi yao vose waagaze iruhande wa hind

In [None]:

# ===================================================================================
# --- Step 5: Calculate and Display the Final WER ---
# ===================================================================================
print("Calculating final Word Error Rate (WER)...")
wer_metric = evaluate.load("wer")

final_wer = wer_metric.compute(predictions=predictions, references=references[:32])

print(f"\n🎉 Final Test WER: {final_wer:.4f} 🎉")


Calculating final Word Error Rate (WER)...

🎉 Final Test WER: 1.4067 🎉


In [None]:

# ===================================================================================
# --- Step 5: Calculate and Display the Final WER ---
# ===================================================================================
print("Calculating final Word Error Rate (WER)...")
wer_metric = evaluate.load("wer")

final_wer = wer_metric.compute(predictions=predictions, references=references[:32])

print(f"\n🎉 Final Test WER: {final_wer:.4f} 🎉")


Calculating final Word Error Rate (WER)...

🎉 Final Test WER: 1.4067 🎉


In [None]:

# ===================================================================================
# --- Step 6 (Optional): Save Results to a File ---
# ===================================================================================
print("Saving results to 'test_predictions.csv'...")
results_df = pd.DataFrame({
    "Reference": references[:32],
    "Prediction": predictions
})
results_df["wer"] = results_df.apply(
    lambda row: wer_metric.compute(predictions=[row.Prediction], references=[row.Reference]), axis=1
)

# results_df.to_csv("test_predictions_6k_steps__.csv", index=False)
print("✅ Results saved.")

Saving results to 'test_predictions.csv'...
✅ Results saved.


In [None]:
results_df.sort_values(by="wer", ascending=True, inplace=True)

In [None]:
results_df

Unnamed: 0,Reference,Prediction,wer
0,Akajerekani gateretse hasi ku butaka gapfundik...,akajerekani ketereze aasi kutaka kepfundi tien...,1.125
1,Uburyo bwo kwishyura hakoreshejwe ikoranabuhan...,uwurijobu kweishira hakureishishge ikorana wuh...,1.25
2,Hano hari abantu batatu babiri muri bo ni ab'i...,hano hari awa nu batatu babiri muri vo na viji...,0.793103
3,Abantu benshi bahagaze imbere y'inyubako ifite...,aba anu benshi wa gaze imbere iinyu baku fite ...,1.292683
4,Ahantu heza hafite isuku ndetse n'amacumbi was...,anoheza haftisukumdece na machombe wasoche rao...,0.944444
5,Abantu batwara amapikipiki by'umwuga bahetse a...,aba anu watuara ama pichi pichi pichumunga wa ...,1.421053
6,"Abantu benshi cyane, bafite uruhu rwera, bamba...",awano vwenti chani wafite urvurguwela wambaye ...,0.954545
7,"Abagabo babiri bicaye ahantu hamwe, umwe yamba...",aba gawa wabilibicha ya hanu hamge umge ya amb...,0.964286
8,Inshange ya emutiyeni irimo uburyo bwo kwishyu...,mshange ya mtn ilimu uulijewa kushira mafarang...,1.0
9,Waba ufite amafaranga y'amanyamahanga? Ni byiz...,mfita mafranga ya maja mahanga nibijiza kukana...,1.0
