In [7]:
import os
import sys
sys.path.append("/home/gridsan/dbeneto/TFG/BCI")
from functools import partial
from tqdm import tqdm
from g2p_en import G2p

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM

from models.phoneme_llm import PhonemeLLM
from utils.config_utils import DictConfig, update_config
from utils.data_utils import PhonemesFinetuneDataset, ft_pad_collate_fn, prepare_phonemes_data
from utils.eval_utils import word_error_count

In [8]:
prompt = "phonemes: %% sentence:"
checkpoint_dir = "/n/home07/djimenezbeneto/lab/BCI/checkpoints"
checkpoint_dir = "/home/gridsan/dbeneto/TFG/BCI/checkpoints"

savestring = "new_ft/eval_64-accum_4-rank_1-lr_1.e-4-gauss_0.0-spikes_0.6_2_0.8-norm_identity_1"
STEP = 36249

checkpoint_dir = os.path.join(checkpoint_dir, savestring)
load_dir = os.path.join(checkpoint_dir,f"STEP{STEP}")
config = DictConfig(torch.load(os.path.join(load_dir, "config.pth")))

# config = update_config(config, "iaifi_dirs.yaml")
# config = update_config(config, "configs/sc_dirs.yaml")
llama_dir = config.dirs.llm_dir
tokenizer_dir = config.dirs.tokenizer_dir

In [9]:
llm = AutoModelForCausalLM.from_pretrained(config.dirs.llm_dir)
model = PhonemeLLM(llm, load_dir)
adapter_file = os.path.join(load_dir, "adapter_config.json")
if os.path.isfile(adapter_file):
    model.load_adapter(load_dir, is_trainable=False)
model.to("cuda")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


FileNotFoundError: [Errno 2] No such file or directory: 'configs/phoneme_coupler.yaml'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.dirs.tokenizer_dir, padding_side='left', add_bos_token=False, add_eos_token=False)
pad_id = tokenizer.eos_token_id
g2p = G2p()

In [None]:
config["trainer"]["test_len"] = 20
data = torch.load(os.path.join(config.dirs.data_dir, config.data_file))
train_data = {k: v[:config.trainer.train_len] if config.trainer.train_len != -1 else v for k,v in data["train"].items()}
train_data = prepare_phonemes_data(train_data, tokenizer, g2p, config.prompt)
test_data = {k: v[:config.trainer.test_len] if config.trainer.test_len != -1 else v for k,v in data["test"].items()}
test_data = prepare_phonemes_data(test_data, tokenizer, g2p, config.prompt)
train_dataset = PhonemesFinetuneDataset(train_data)
test_dataset = PhonemesFinetuneDataset(test_data)

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=partial(ft_pad_collate_fn,config.noise,config.mask,pad_id,"test"), batch_size=1, pin_memory=True,
)
test_dataloader = DataLoader(
    test_dataset, collate_fn=partial(ft_pad_collate_fn,config.noise,config.mask,pad_id,"test"), batch_size=1, pin_memory=True,
)

train_iter = iter(train_dataloader)
test_iter = iter(test_dataloader)


In [None]:
 model

In [None]:
beams = 1
gen_config = {
    "max_new_tokens": 20, 
    "do_sample": False, "temperature": 1.0,  #"top_p": 1.0, 
    "num_beams": beams, 
    "num_beam_groups": beams, "diversity_penalty": 1.2,
    "repetition_penalty": 0.0, "length_penalty": 0.0, "no_repeat_ngram_size": None, 
    "renormalize_logits": True, 
    "low_memory": True,
    "num_return_sequences": beams, "output_scores": True, "return_dict_in_generate": True,
    "pad_token_id": pad_id
}

In [None]:
from time import perf_counter

all_pairs = []
all_sentences = []
all_errors = []
all_words = []
all_scores = []
time_b = 0.
time_c = 0.
for i, (model_inputs, prompt_inputs, sentence, true_ph, pred_ph) in tqdm(enumerate(test_dataloader)):
    a = perf_counter()
    prompt_inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else [sub_v.to("cuda") for sub_v in v] for k,v in prompt_inputs.items()}
    preds = model.predict(**prompt_inputs, **gen_config, synced_gpus=None)
    b = perf_counter()
    time_b += b-a 
    dec_preds = [tokenizer.decode(p.detach().cpu().squeeze(), skip_special_tokens=True) for i, p in enumerate(preds.sequences)]
    print(dec_preds)
    scores = preds.sequences_scores
    scores = (scores-scores.mean())/scores.std()
    pairs = sorted([(a.item(), b) for a,b in zip(scores, dec_preds)], key=lambda x: -x[0])
    # print(sentence)
    # for pair in pairs:
    #     print(pair[0], pair[1])
    errors, words = word_error_count(pairs[0][1], sentence)
    all_errors.append(errors)
    all_words.append(words)
    all_pairs.append(pairs)
    all_sentences.append(sentence)
    all_scores.append(scores.tolist())
    c = perf_counter()
    time_c += c-b


print(time_b, time_c)
torch.save({"errors": all_errors, "words": all_words, "pairs": all_pairs, "sentences": [s for [s] in all_sentences], "scores": all_scores}, "data.pth")