In [1]:
import torch

from datasets import Dataset
from data_processing import util
from model_utils import evaluate
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
DATA_TYPE = "mbpt_0_top"
MAX_SEQ_LENGTH = 8192
CACHE_DIR = "/nlp/scr/neigbe/.cache"
model_idx = 1
MODEL_NAME = ["meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-70B-Instruct"][model_idx]
model = ["llama3-8b-instruct", "llama3-70b-instruct"][model_idx]
MODEL_PATH = util.get_most_recent_model_path(model, DATA_TYPE)

In [3]:
_, _, test_df = util.get_data_splits(DATA_TYPE, .975, .5)

In [4]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    use_cache=False,
    cache_dir=CACHE_DIR,
    attn_implementation = "flash_attention_2",
    device_map="auto"
)

tkr = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR, model_max_length=MAX_SEQ_LENGTH)
tkr.pad_token_id = tkr.eos_token_id

model.resize_token_embeddings(len(tkr))
model.config.pad_token_id = tkr.pad_token_id

model = PeftModel.from_pretrained(model, MODEL_PATH, device_map="auto")



Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
from tqdm.notebook import tqdm_notebook as tqdm
preds = []
labels = []

import warnings
warnings.filterwarnings('ignore')

for row in tqdm(list(test_df.iloc)):
    input_ids = tkr.apply_chat_template(util.row_to_msg(row), add_generation_prompt=True, return_tensors="pt").to(model.device)
    terminators = [tkr.eos_token_id, tkr.convert_tokens_to_ids("<|eot_id|>")]
    outputs = model.generate(
        input_ids,
        max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
    )
    response = outputs[0][input_ids.shape[-1]:]
    preds.append(tkr.decode(response, skip_special_tokens=True))
    labels.append(row.label)

  0%|          | 0/1892 [00:00<?, ?it/s]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end gene

In [6]:
filt_preds = []
filt_labels = []

for pdt, lb in zip(preds, labels):
    if pdt in ["E", "I"]:
        filt_preds.append(pdt)
        filt_labels.append(lb)

In [10]:
from sklearn.preprocessing import LabelEncoder

label_enc = LabelEncoder()

final_labels = label_enc.fit_transform(filt_labels)
final_preds = label_enc.transform(filt_preds)

## results!

overall metrics

In [11]:
evaluate.get_overall_metrics(final_preds, final_labels)

{'f1': 0.6445306963197337,
 'recall': 0.6497917984218842,
 'precision': 0.6565108237411776}

per class metrics

In [12]:
evaluate.get_class_metrics(final_preds, final_labels, DATA_TYPE).style.hide(axis="index")

label,f1,recall,precision
E,0.611046,0.542797,0.698925
I,0.678016,0.756786,0.614097
