In [15]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, BertForSequenceClassification, BertTokenizer, BertModel

from datasets import load_dataset
import evaluate
from transformers import TrainingArguments, Trainer

os.environ["WANDB_PROJECT"] = "<my-amazing-project>"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints

# Init Generator and Detector

In [2]:
#GEN_PATH = "microsoft/phi-2"
#GEN_PATH = "openai-community/gpt2"
GEN_PATH = "Qwen/Qwen1.5-0.5B-Chat"
BERT_PATH = "bert-base-uncased"
device = "cuda" if torch.cuda.is_available() else "cpu"


class GPTGenerator(nn.Module):
  def __init__(self, gpt_model, tokenizer):
    super().__init__()

    # gpt should already be trained
    self.gpt = gpt_model
    self.tokenizer = tokenizer

  def forward(self, text, max_length=512, temperature=1, top_k=50, top_p=0.9, repetition_penalty=1):
    # tokenize text using the tokenizer
    input_ids = self.tokenizer.encode(text, return_tensors="pt")

    # generate text using the gpt model
    output_ids = self.gpt.generate(input_ids, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p)

    # optional, remove input_ids from output_ids
    #output_ids = [output_id[len(input_ids):] for input_id, output_id in zip(input_ids, output_ids)]

    # decode the generated text
    decoded_output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    return decoded_output

class BertClassifier(nn.Module):
  def __init__(self, bert_model, tokenizer, num_classes):
    super().__init__()

    self.tokenizer = tokenizer

    # bert should already be trained
    self.bert = bert_model

    # set num_classes
    self.num_classes = num_classes

  def forward(self, text):

    # tokenize text using the tokenizer
    output = self.tokenizer(text, return_tensors="pt")
    input_ids = output["input_ids"]
    logits = self.bert(input_ids)["logits"]

    # apply sigmoid to get probabilities of each class
    output = torch.sigmoid(logits)
    return output
        

In [3]:
torch.set_default_device("cuda")

gen_model = AutoModelForCausalLM.from_pretrained(GEN_PATH, torch_dtype="auto", trust_remote_code=True).to(device)
gen_tokenizer = AutoTokenizer.from_pretrained(GEN_PATH, trust_remote_code=True)
generator = GPTGenerator(gen_model, gen_tokenizer)

text_input = '''def print_prime(n):
   """
   Print all primes between 1 and n
   """'''

output = generator(text_input)
print(output)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


def print_prime(n):
   """
   Print all primes between 1 and n
   """    
   for i in range(2, int(n**0.5)+1):
       if n % i == 0:
           break    
   return [i] + print_prime(i+1)

print_prime(2)
print_prime(3)


In [4]:
detector_model = BertForSequenceClassification.from_pretrained(BERT_PATH).to(device)
bert_tokenizer = BertTokenizer.from_pretrained(BERT_PATH)
detector = BertClassifier(detector_model, bert_tokenizer, 2)

text = "def print_prime(n):\n   \"\"\"\n   Print all primes between 1 and n\n   \"\"\""

logits = detector(text)
fake = logits.argmax().item()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Load dataset of instructions and output with gen

In [5]:
dataset_path = "databricks/databricks-dolly-15k"
dataset = load_dataset(dataset_path)

dataset

DatasetDict({
    train: Dataset({
        features: ['instruction', 'context', 'response', 'category'],
        num_rows: 15011
    })
})

In [6]:
# test output with first instruction
text = dataset["train"][0]
text


{'instruction': 'When did Virgin Australia start operating?',
 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.',
 'category': 'closed_qa'}

In [7]:
text_instruction = f"Context: {text["context"]} \n Question: {text["instruction"]}"
output = generator(text_instruction)
print("Question: ", text_instruction)
#print()
print("Generated answer: ", output)
#print()
print("Real human answer: ", text["response"])


Question:  Context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. 
 Question: When did Virgin Australia start operating?
Generated answer:  Context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in Septem

In [8]:
output = generator("What is the capital of France?")
output

'What is the capital of France?'

In [9]:
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "What is the capital of France?"}
]
text = gen_tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

In [10]:
output = generator(text)
print(output)

system
You are a helpful assistant.
user
What is the capital of France?
assistant
Paris


In [11]:
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": f"{text_instruction}"},
]
text = gen_tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

output = generator(text)
print(output)

system
You are a helpful assistant.
user
Context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney. 
 Question: When did Virgin Australia start operating?
assistant
On 31 August 2000, Virgin Australia was started as Virgin Blue.


# Training Detector

In [19]:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_ratio=0.1,
    weight_decay=0.01,
    learning_rate=1e-3,
    logging_steps=5,
    logging_dir="./logs",
    report_to="wandb"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=human_fake_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

NameError: name 'model' is not defined

In [None]:
trainer.train()
wandb.finish()

# DPO training BERT