In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import pickle1
import numpy as np
from llm_utils import cer_with_gpt2_decoder, gpt2_lm_decode
import pandas as pd
from brainaudio.inference.eval_metrics import _cer_and_wer
from brainaudio.inference.eval_metrics import clean_string
import torch

In [23]:
dataset = "b2t_25"

In [24]:
model_name = "facebook/opt-6.7b"

# Load tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained(model_name)

if dataset == "b2t_25":

    print("Loading model in 16-bit with automatic device placement...")
    llm = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    
elif dataset == "b2t_24":
    
    # Load model in 8-bit with automatic device placement
    llm = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_8bit=True,
        device_map="auto"
    )

# Example: Generate from a prompt
inputs = llm_tokenizer("The future of AI is", return_tensors="pt").to(llm.device)
outputs = llm.generate(**inputs, max_new_tokens=50)
print(llm_tokenizer.decode(outputs[0], skip_special_tokens=True))

`torch_dtype` is deprecated! Use `dtype` instead!
Skipping import of cpp extensions due to incompatible torch version 2.10.0+cu128 for torchao version 0.15.0             Please see https://github.com/pytorch/ao/issues/2919 for more info


Loading model in 16-bit with automatic device placement...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The future of AI is in the hands of the people

The future of AI is in the hands of the people

The future of AI is in the hands of the people

The future of AI is in the hands of the people

The future of


In [None]:
nbest_path = "/home/ebrahim/data2/brain2text/b2t_25/wfst_outputs/pretrained_RNN/nbest_wfst_rescore_True_acoustic_scale_0.325.pkl"
model_outputs_path = None

with open(nbest_path, mode = 'rb') as f:
    nbest = pickle.load(f)

if dataset == 'b2t_25':
    acoustic_scale = 0.325
    ground_truth = pd.read_pickle("/home/ebrahim/data2/brain2text/b2t_25/transcripts_val_cleaned.pkl")
    llm_weight = 0.55
else:
    acoustic_scale = 0.5
    llm_weight = 0.5
    

In [19]:
nbest_nejm = pd.read_pickle("/home/ebrahim/data2/brain2text/b2t_25/wfst_outputs/t15_pretrained_rnn_baseline_val_nbest_formatted.pkl")

In [20]:
nbest_nejm[2]

[('not too controversial', -429.5383605957031, -22.741165161132812),
 ('not too crucial', -438.0696105957031, -26.61922264099121),
 ('not to controversial', -429.5383605957031, -30.495746612548828),
 ('not to crucial', -438.1420254516602, -28.84305947303772),
 ('not two controversial', -429.5383605957031, -31.660655975341797),
 ('not two crucial', -438.1420254516602, -29.43133870124817)]

In [15]:
len(nbest_nejm)

1426

In [15]:
lm_preds = [clean_string(nbest[i][0][0]) for i in range(len(nbest))]
metrics = _cer_and_wer(lm_preds, ground_truth)
print(metrics[1])

0.07326627845420858


In [21]:
if model_outputs_path is not None:
    model_outputs = np.load(model_outputs_path, allow_pickle=True) 
    
    for i in range(len(model_outputs['transcriptions'])):
        new_trans = [ord(c) for c in model_outputs['transcriptions'][i]] + [0]
        model_outputs['transcriptions'][i] = np.array(new_trans)
        
    # Rescore nbest outputs with LLM
    llm_out = cer_with_gpt2_decoder(
        llm,
        llm_tokenizer,
        nbest[:],
        acoustic_scale,
        model_outputs,
        outputType="speech_sil",
        returnCI=True,
        lengthPenalty=0,
        alpha=llm_weight,
    )
    
else:
    
    best_hyp_all = []
    
    for nbest_trial in nbest_nejm:
        
        best_hyp = gpt2_lm_decode(
            llm,
            llm_tokenizer,
            nbest_trial,
            acoustic_scale,
            lengthPenlaty=0,
            alpha=llm_weight,
            returnConfidence=False
        )
        
        best_hyp_all.append(best_hyp)

NameError: name 'llm' is not defined

In [13]:
best_hyp_all_cleaned = [clean_string(hyp) for hyp in best_hyp_all]

In [14]:
metrics = _cer_and_wer(best_hyp_all_cleaned, ground_truth)
print(metrics[1])

0.06500794070937004


In [None]:
import pandas as pd                                                                                                                                                                                                 
                
df = pd.DataFrame({                                                                                                                                                                                                 
    'id': range(len(best_hyp_all_cleaned)),
    'text': best_hyp_all_cleaned
})

df.to_csv('/home/ebrahim/brainaudio/results/best_hyp_all_cleaned_output.csv', index=False)

: 