In [None]:
import pickle
import pandas as pd
# load pickled data
with open('./data/t15_copyTask.pkl', 'rb') as f:
    dat = pickle.load(f)
    
if isinstance(dat, dict):
    print("Keys in checkpoint:", dat.keys())
else:
    print("Loaded object type:", type(dat))

# dat["decoded_sentence"] is already a LIST of sentences
preds = dat["decoded_sentence"]

# BLEU requires list[list[str]]
refs = [[r] for r in dat["cue_sentence"]] # wrap each reference sentence in its own list

print(preds)
print(refs)


b2txt_csv_df = pd.read_csv("data/t15_copyTaskData_description.csv")
print(b2txt_csv_df)

Keys in checkpoint: dict_keys(['post_implant_day', 'vocab_size', 'cue_sentence', 'cue_sentence_phonemes', 'decoded_logits', 'decoded_phonemes_raw', 'decoded_sentence', 'decoded_sentence_phonemes', 'speech_duration_s'])
           Date  Post-implant day  Block number  Number of sentences  \
0    2023-08-11                25             2                   20   
1    2023-08-11                25             3                   30   
2    2023-08-11                25             4                   40   
3    2023-08-11                25             5                   50   
4    2023-08-11                25             6                   50   
..          ...               ...           ...                  ...   
260  2025-03-30               622            12                   49   
261  2025-03-30               622            13                   48   
262  2025-04-13               636             2                   20   
263  2025-04-13               636             7              

In [7]:
import pickle
import evaluate
import numpy as np

# WER calculation function
def calculate_aggregate_error_rate(references, hypotheses):
    """
    Calculate Word Error Rate (WER) using edit distance.
    
    Args:
        references: list of reference sentences (ground truth)
        hypotheses: list of hypothesis sentences (predictions)
    
    Returns:
        WER as a float (0.0 to 1.0+)
    """
    def levenshtein_distance(ref, hyp):
        """Calculate edit distance between two sequences."""
        ref_len = len(ref)
        hyp_len = len(hyp)
        
        # Create distance matrix
        d = np.zeros((ref_len + 1, hyp_len + 1), dtype=int)
        
        # Initialize first row and column
        for i in range(ref_len + 1):
            d[i][0] = i
        for j in range(hyp_len + 1):
            d[0][j] = j
        
        # Calculate distances
        for i in range(1, ref_len + 1):
            for j in range(1, hyp_len + 1):
                if ref[i-1] == hyp[j-1]:
                    cost = 0
                else:
                    cost = 1
                
                d[i][j] = min(
                    d[i-1][j] + 1,      # deletion
                    d[i][j-1] + 1,      # insertion
                    d[i-1][j-1] + cost  # substitution
                )
        
        return d[ref_len][hyp_len]
    
    total_distance = 0
    total_ref_words = 0
    
    for ref, hyp in zip(references, hypotheses):
        ref_words = ref.split()
        hyp_words = hyp.split()
        
        distance = levenshtein_distance(ref_words, hyp_words)
        total_distance += distance
        total_ref_words += len(ref_words)
    
    if total_ref_words == 0:
        return 0.0
    
    wer = total_distance / total_ref_words
    return wer


bleu = evaluate.load("bleu")

# # Load pkl file
# with open('./data/t15_copyTask.pkl', 'rb') as f:
#     dat = pickle.load(f)

# dat["decoded_sentence"] is already a LIST of sentences
preds = dat["decoded_sentence"]                 # list[str]

# BLEU requires list[list[str]]
refs = [[r] for r in dat["cue_sentence"]]       # wrap each reference sentence in its own list

# Compute BLEU
bleu_result = bleu.compute(predictions=preds, references=refs)

print(f"BLEU score: {bleu_result['bleu']:.4f}")

# Compute WER (add this after your BLEU calculation)
refs_flat = dat["cue_sentence"]  # WER expects flat list[str]
wer_score = calculate_aggregate_error_rate(refs_flat, preds)
print(f"WER score: {wer_score*100:.2f}%")


BLEU score: 0.9179
WER score: 4.34%
