In [1]:
from peft import PeftModel

import torch

import os
import sys
sys.path.append(os.getcwd()+"/../..")

from src import paths

from src import paths
from src.utils import load_ms_data, prepare_ms_data, get_DataLoader, load_model_and_tokenizer, perform_inference, check_gpu_memory

from tqdm import tqdm

import argparse

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

PEFT_MODEL_NAMES = ["ms-diag_llama2_4bit_lora_augmented_256", "ms-diag_llama2_4bit_lora_augmented_512"]

# Check GPU Memory
check_gpu_memory()

# Load Model and Tokenizer
model, tokenizer = load_model_and_tokenizer(model_name="llama2", quantization="4bit", num_labels=3)

print("Loaded Model and Tokenizer")

# Inference
for peft_model_name in PEFT_MODEL_NAMES:

    print(f"Starting Inference for PEFT Model: {peft_model_name}")

    # Load PEFT Model
    peft_model = PeftModel.from_pretrained(model, paths.MODEL_PATH/peft_model_name).to(device)

    # Load Data in format matching the PEFT Model configuration
    df = load_ms_data(data="original")

    # Prepare Data
    truncation_size = int(peft_model_name.split("_")[-1])
    peft_type = peft_model_name.split("_")[-3]

    if peft_type in ["prompt", "ptune", "prefix"]:
        is_prompt_tuning = True
        num_virtual_tokens = peft_model.peft_config["default"].num_virtual_tokens
    else:
        is_prompt_tuning = False
        num_virtual_tokens = 0
    
    encoded_dataset = prepare_ms_data(df, tokenizer, is_prompt_tuning = is_prompt_tuning, num_virtual_tokens = num_virtual_tokens, truncation_size=truncation_size)

    # Get DataLoaders
    dataloader = get_DataLoader(encoded_dataset["test"], tokenizer, batch_size=2, shuffle=False)

    # Perform Inference
    inference_results = perform_inference(peft_model, dataloader, device)

    # Save Inference Results
    torch.save(inference_results, paths.RESULTS_PATH/"ms-diag"/peft_model_name)

return