### Data Load

In [None]:
import json
# Rejection responses from LLaMA for positive examples
with open('data/llama_predictions_LAT.json', 'r') as file:
    preds_llama = json.load(file)
safe_llama = []
for d in preds_llama:
    if d['safety'] == 'safe': safe_llama.append(d['response'])

In [None]:
# Rejection responses merged with COVID fake news dataset for positive examples
with open('data/merged_COVID_fake.json', 'r') as file:
    safe_llama_merged=json.load(file)
safe_llama_merged = [d['text'] for d in safe_llama_merged]

In [None]:
# Wikipedia summary dataset for negative examples
with open('data/wiki_summary_sampled.json', 'r') as file:
    wiki_summary_sampled = json.load(file)
import random
random.seed(123)
wiki_summary_llama = random.sample(wiki_summary_sampled, len(safe_llama))

In [None]:
# COVID fake news dataset (not merged) for negative examples
with open('data/covid_fakes.json', 'r') as file:
    fakes_515 = json.load(file)
fakes_515 = fakes_515[:len(safe_llama_merged)]

### Classifier Train

In [None]:
from transformers import AutoTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
import torch
from torch.utils.data import Dataset
import random

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_len)
        self.labels = labels
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    
    def __len__(self):
        return len(self.labels)

In [None]:
def set_seed(seed_value=42):
    random.seed(seed_value) 
    np.random.seed(seed_value)  
    torch.manual_seed(seed_value)  
    torch.cuda.manual_seed_all(seed_value)  

set_seed(122)

In [None]:
texts = safe_llama+safe_llama_merged+wiki_summary_llama+fakes_515
labels = [1]*len(safe_llama+safe_llama_merged)+[0]*len(wiki_summary_llama+fakes_515)
len(texts), len(labels)

In [None]:
train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=0.25, random_state=42)  

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
train_dataset = TextDataset(train_texts, train_labels, tokenizer)
val_dataset = TextDataset(val_texts, val_labels, tokenizer)
test_dataset = TextDataset(test_texts, test_labels, tokenizer)

In [None]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

training_args = TrainingArguments(
    output_dir='./results_adv',
    evaluation_strategy="steps",
    eval_steps=200,  
    save_steps=200,
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=512,
    num_train_epochs=5,
    weight_decay=0.01,
    save_strategy="steps",
    logging_dir='./logs',
    logging_steps=10,
    load_best_model_at_end=True, 
    metric_for_best_model="f1",  
    report_to="none"
)

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
    return {"accuracy": accuracy, "f1": f1, "precision": precision, "recall": recall}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

In [None]:
# Train Start
trainer.train()

# Eval with test dataset
test_results = trainer.evaluate(test_dataset)
print("Test Results:", test_results)

In [None]:
# Save trained model
torch.save(model, 'llama_classifier_adv')

In [None]:
# Load saved model
model = torch.load('llama_classifier_adv')

# Trainer 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

### Inspection

In [None]:
# Adjust the follwoing path with inspected data path 
inspected_data_path = "[YOUR DATA PATH]"

with open(inspected_data_path, "r") as file:
    rc_llama = json.load(file)
rc_llama = [d['text'] for d in rc_llama]
rc_dataset = TextDataset(rc_llama, [1]*len(rc_llama), tokenizer)

test_results = trainer.evaluate(rc_dataset)
print(f"Test Results (Recall): {100*test_results['eval_recall']}" )