In [1]:
import torch
import os
import pandas as pd
import numpy as np
from peft import PeftConfig, PeftModel
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import LlamaTokenizer, LlamaForSequenceClassification, DataCollatorWithPadding
from tqdm import tqdm
import pickle
from utils.eval_utils import cls_metrics_logits

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gpu_device_num=1
torch.cuda.set_device(gpu_device_num)
torch.cuda.current_device()
device = torch.device(f"cuda:{gpu_device_num}" if torch.cuda.is_available() else "cpu")

In [3]:
checkpoint_id = "experiments/7b-512-4-2e-05-right-April-14-14-52/checkpoint-177144"

In [4]:
config = PeftConfig.from_pretrained(checkpoint_id)

In [5]:
inference_model = LlamaForSequenceClassification.from_pretrained(config.base_model_name_or_path,
                                                       num_labels=738,
                                                       load_in_8bit=True,
                                                       torch_dtype=torch.float16,
                                                       cache_dir="/data/mn27889/.cache/huggingface")

Loading checkpoint shards: 100%|██████████| 33/33 [00:16<00:00,  1.98it/s]
Some weights of the model checkpoint at baffo32/decapoda-research-llama-7b-hf were not used when initializing LlamaForSequenceClassification: ['lm_head.weight']
- This IS expected if you are initializing LlamaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlamaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at baffo32/decapoda-research-llama-7b-hf and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predic

In [6]:
inference_model = PeftModel.from_pretrained(inference_model, checkpoint_id)

In [7]:
tokenizer = LlamaTokenizer.from_pretrained(config.base_model_name_or_path,
                                           model_max_length=512,
                                           cache_dir="/data/mn27889/.cache/huggingface")
tokenizer.pad_token_id = 0

In [8]:
train_data_path = "data/new_train.csv"
test_data_path =  "data/new_test.csv"

In [9]:
train_data = load_dataset("csv", data_files=train_data_path, split=f'train[:{100}%]')
test_data = load_dataset("csv", data_files=test_data_path, split=f'train[:{100}%]')

Found cached dataset csv (/data/mn27889/.cache/huggingface/datasets/csv/default-fb99937b6bf61de5/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
Found cached dataset csv (/data/mn27889/.cache/huggingface/datasets/csv/default-cac9e052575a0588/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


In [10]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [11]:
tokenized_test = test_data.map(preprocess_function, batched=True, remove_columns=['text']).rename_column("label", "labels")

Loading cached processed dataset at /data/mn27889/.cache/huggingface/datasets/csv/default-cac9e052575a0588/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-622d605646289699.arrow


In [12]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [13]:
test_dataloder = DataLoader(tokenized_test, shuffle=True, collate_fn=data_collator, batch_size=32)

In [14]:
inference_model = inference_model.to(device)
inference_model.eval()

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): LlamaForSequenceClassification(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096, padding_idx=31999)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): Linear8bitLt(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear8bitLt(
                in_features=

In [15]:
predictions_all = np.array([], dtype=int)
labels_all = np.array([], dtype=int)

with torch.no_grad():
    for step, batch in enumerate(tqdm(test_dataloder)):
        batch.to(device)
        outputs = inference_model(**batch)
    
        predictions = outputs.logits
        labels = batch['labels']
        predictions_all = np.append(predictions_all, predictions.detach().cpu().numpy())
        labels_all = np.append(labels_all, labels.detach().cpu().numpy())
        del batch

100%|██████████| 821/821 [52:17<00:00,  3.82s/it]


In [16]:
with open("predictions_all_third.pkl", "wb") as f:
    pickle.dump(predictions_all, f)

In [17]:
with open("labels_all_third.pkl", "wb") as f:
    pickle.dump(labels_all, f)

In [18]:
with open("predictions_all_third.pkl", "rb") as f:
    predictions_all = pickle.load(f)

In [19]:
with open("labels_all_third.pkl", "rb") as f:
    labels_all = pickle.load(f)

In [20]:
y_pred = predictions_all.reshape(-1, 738)
y_true = labels_all

In [21]:
cls_metrics_logits(y_pred, y_true, 738)

{'microF1': 0.48754929792139007,
 'macroF1': 0.301450220386843,
 'microAUC': 0.9813528007058248,
 'macroAUC': 0.9582404125407084,
 'labels': 723,
 'count': 26244,
 'acc10': 0.9014631915866483,
 'acc5': 0.8321902149062643,
 'acc': 0.4874256973022405}