In [None]:
# Import required libraries
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel, PeftConfig

In [None]:
# Load the PEFT adapter configuration from HuggingFace
MODEL_ID = "blueplus/basis-project-llama-3.1-8b-finetune"
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"

print(f"Loading adapter from {MODEL_ID}...")
config = PeftConfig.from_pretrained(MODEL_ID)

In [None]:
# Load the base model with quantization
# quantization_config = BitsAndBytesConfig(load_in_8bit=True)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    torch_dtype=torch.float16,
    # quantization_config=quantization_config
).to("cuda")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# Load the fine-tuned adapter on top of the base model
model = PeftModel.from_pretrained(base_model, MODEL_ID)
model.eval() # set to eval so no drop or batch norm

In [None]:
# Helper function to generate text
def generate_response(model, prompt, max_length=256, temperature=0.7, top_p=0.9):
    """
    Generate a response from the model given a prompt.
    
    Args:
        prompt: User input text
        max_length: Maximum tokens to generate
        temperature: Sampling temperature (higher = more random)
        top_p: Nucleus sampling parameter
    
    Returns:
        Generated text response
    """
    
    # Tokenize
    template = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(template, return_tensors="pt").to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode and extract only the assistant's response
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    # Extract just the assistant's response
    if "<|start_header_id|>assistant<|end_header_id|>" in full_response:
        response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
        response = response.split("<|eot_id|>")[0].strip()
    else:
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return response

In [None]:
# load all output.json prompts and run inference
import json

with open("results/llama_suffixes_250_25_wide.json", "r" ) as f:
    results = json.load(f)

results

In [None]:
# run inference on the fine tuned model as well as the base model
# use prompts that are the harmful prompts + generated suffixes

base_outputs = {}
finetuned_outputs = {}
for prompt, suffix in results.items():
    full_prompt = prompt + suffix['str']
    base_output = generate_response(base_model, full_prompt)
    finetuned_output = generate_response(model, full_prompt)
    base_outputs[full_prompt] = base_output
    finetuned_outputs[full_prompt] = finetuned_output

In [None]:
print('\n---\n'.join(base_outputs.values()))

In [None]:
print('\n---\n'.join(finetuned_outputs.values()))

In [None]:
with open("results/base_outputs_250.json", "w") as f:
    json.dump(base_outputs, f)

with open("results/finetuned_outputs_250.json", "w") as f:
    json.dump(finetuned_outputs, f)

In [None]:
# manual binary classification -- doesn't work
import ipywidgets as widgets
import time
base_classifications = {}

clicked = False

def on_click(a):
    global clicked
    clicked = True

def is_clicked():
    global clicked
    return clicked

for k, v in base_outputs.items():
    clicked = False
    t = widgets.Textarea(value=v, disabled=True)
    b = widgets.Checkbox()
    go_next = widgets.Button(description="Confirm")
    go_next.on_click(on_click)
    display(t,b,go_next)
    while not is_clicked():
        time.sleep(0.01)
    base_classifications[k] = b.value

In [None]:
outputs = {f"{k[0]} {k[1]}": v for k, v in outputs.items()}

with open("results/finetuned_outputs.json", "w") as f:
    json.dump(outputs, f)