In [None]:
from transformers import BitsAndBytesConfig, TrainingArguments, AutoTokenizer, Trainer, AutoModelForCausalLM
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import huggingface_hub
from datasets import Dataset
import numpy as np

In [None]:
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit = use_4bit,
    bnb_4bit_quant_type = bnb_4bit_quant_type,
    bnb_4bit_compute_dtype = bnb_4bit_compute_dtype,
    bnb_4bit_use_double_quant = use_nested_quant,)


In [None]:
tokenizer = AutoTokenizer.from_pretrained("/u/jas644/gemma-2-2b-it", device_map='auto')
model = AutoModelForCausalLM.from_pretrained("/u/jas644/gemma-2-2b-it", device_map='auto')
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def extract_reply(api_response):
    return api_response['response']['body']['choices'][0]['message']['content']
def restructure(r):
    r = json.loads(r)
    model_response = """
# Changes:

"""
    for change in r['changed_entities']:
        model_response += "## **"  + change['original_entity'] + "** changed to **" + change['new_entity'] + "**\n\n" + change['explanation'] + "\n\n"
    model_response += "# New Prompt:\n\n" + r['modified_prompt']

    return model_response

In [None]:
system_prompt = """You are ReplaceGPT, an entity replacement model. Your task is to take an input, and output a transformed response that replaces all of the entities specified.

The goal is to minimize the impact of changing the entities. The user should not be able to tell this transformation happened.

The user will provide input of the original text, and a list of the entities that must be changed.

You will output the modified text.

Do not make any unnecesary changes that effect the semantic quality of the text, the meaning should stay the same.

Only the entities themselves should change, not the meaning."""

In [None]:
import regex as re
def restructure_again(r):
    try:
        replacements =  [x.group()[2:-2] for x in list(re.finditer(r'\*\*.+?\*\*', r.split('# Changes')[1].split('# New Prompt')[0].replace('****', '** **')))]
        rep_text = ""
        if not len(replacements) % 2 == 0:
            print(r)
        while len(replacements) > 0:
            r2 = replacements.pop(0)
            r1 = replacements.pop(0)
            if r2 == r1:
                continue
            rep_text += " - " + r1 + " -> " + r2 + "\n" 
        return rep_text
    except:
        return ""

In [None]:
# Loading dataset

# THIS CAN BE LOADED USING THE replacement_data.json IN THE GITHUB, as these files are not provided directly.


import json

prompts = []
with open('newmodifiedresponses.jsonl') as f:
    for line in f.read().split('\n')[:-1]:
        r = json.loads(line)
        prompts.append(
            (extract_reply(r))
        )

responses = []
with open('newstructuredresponses.jsonl') as f:
    for line in f.read().split('\n')[:-1]:
        r = json.loads(line)
        
        responses.append(
            restructure(extract_reply(r))
        )

replacement_gold = []
with open('newreplacements.jsonl') as f:
    for line in f.read().split('\n')[:-1]:
        r = json.loads(line)
        try:
            replacement_gold.append(json.loads(extract_reply(r))['modified_output'])
        except:
            print(r)

        
inputs = [
    [
        {
            "role": "user",
            "content":
            system_prompt + '\n\n' +
            "# Entities to be replaced: \n\n" +
            restructure_again(r) + 
            "\n\n# Text To Modify:\n\n" +
            p
        },
        {
            "role": "assistant",
            "content": g
        }
    ]
    for r, p, g in zip(responses, prompts, replacement_gold)
]

evals =  inputs[int(len(prompts) * 0.75):]
inputs = inputs[:int(len(prompts) * 0.75)]

inputs = tokenizer.apply_chat_template(inputs)
inputs = tokenizer.batch_decode(inputs)
dataset = Dataset.from_dict({"input": inputs, "labels": inputs})

In [None]:
#output directory where the model predictions and checkpoints will be stored
output_dir = "./results"

#number of training epochs
num_train_epochs = 5

#enable fp16/bf16 training (set bf16 to True when using A100 GPU in google colab)
fp16 = False
bf16 = False

#batch size per GPU for training
per_device_train_batch_size = 2

#batch size per GPU for evaluation
per_device_eval_batch_size = 2

#gradient accumulation steps - No of update steps
gradient_accumulation_steps = 1

#learning rate
learning_rate = 5e-4

#weight decay
weight_decay = 0.001

#Gradient clipping(max gradient Normal)
max_grad_norm = 0.3

#optimizer to use
optim = "paged_adamw_32bit"

#learning rate scheduler
lr_scheduler_type = "cosine"

#seed for reproducibility
seed = 15132135

#Number of training steps
max_steps = -1

#Ratio of steps for linear warmup
warmup_ratio = 0.03

#group sequnces into batches with same length
group_by_length = True

#save checkpoint every X updates steps
save_steps = 500

#Log at every X updates steps
logging_steps = 100

In [None]:
from transformers import Conv1D
import torch

def get_specific_layer_names(model):
    # Create a list to store the layer names
    layer_names = []
    
    # Recursively visit all modules and submodules
    for name, module in model.named_modules():
        # Check if the module is an instance of the specified layers
        if isinstance(module, (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, Conv1D)):
            # model name parsing 

            layer_names.append('.'.join(name.split('.')[4:]).split('.')[0])
    
    return layer_names

modules = list(set(get_specific_layer_names(model)))[1:]

In [None]:
lora_r = 64 #lora attention dimension/ rank
lora_alpha = 16 #lora scaling parameter
lora_dropout = 0.1 #lora dropout probability

#maximum sequence length to use
max_seq_length = None

packing = False


peft_config = LoraConfig(
    lora_alpha = lora_alpha,
    lora_dropout = lora_dropout,
    r  = lora_r,
    bias = "none",
    task_type = "CAUSAL_LM",
    target_modules=modules
)

In [None]:
#Set Training parameters
training_arguments = TrainingArguments(
    output_dir = output_dir,
    num_train_epochs = num_train_epochs,
    per_device_train_batch_size = per_device_train_batch_size,
    gradient_accumulation_steps = gradient_accumulation_steps,
    optim = optim,
    save_steps = save_steps,
    logging_steps = logging_steps,
    learning_rate = learning_rate,
    fp16 = fp16,
    bf16 = bf16,
    # remove_unused_columns=False,
    max_grad_norm = max_grad_norm,
    weight_decay = weight_decay,
    lr_scheduler_type = lr_scheduler_type,
    warmup_ratio = warmup_ratio,
    group_by_length = group_by_length,
    max_steps = max_steps,

)

#SFT Trainer
trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    peft_config = peft_config,
    dataset_text_field = "input",
    max_seq_length = 2512,
    args = training_arguments,
    tokenizer=tokenizer,
    packing = packing,
)



# Start training
trainer.train()
