In [None]:
%pip install --upgrade pip
%pip install --disable-pip-version-check \
    torch==1.13.1. \
    torchdata==0.5.1 --quiet
%pip install  \
    transformers==4.27.2 \
    datasets==2.11.0 \
    evaluate==0.4.0 \
    rouge_score==0.1.2 \
    # loralib==0.1.1 \
    peft==0.3.0 /
    

In [None]:
%pip install trl==0.4.4 --quiet

In [None]:
from transformers import pipeline,AutoTokenizer,AutoModelForSequenceClassification,AutoModelForSeq2SeqLM,GenerationConfig
from datasets import load_dataset
from peft import PeftModel,PeftConfig,LoraConfig,TaskType

from trl import PPOTrainer,PPOConfig,AutoModelForSeq2SeqLMWithValueHead
from trl import create_reference_model
from trl.core import LengthSampler

import torch
import evaluate
import numpy as np
import pandas as pd

from tqdm import tqdm
tqdm.pandas()

In [None]:
device=0 if torch.cuda.is_available() else "cpu"

In [None]:
huggingface_dataset_name="knkarthick/dialogsum"
dataset_original=load_dataset(huggingface_dataset_name)
dataset_original

In [None]:
from transformers import AutoModelForCausalLM
model_name='google/flan-t5-base'
model=AutoModelForSeq2SeqLM.from_pretrained(model_name,torch_dtype=torch.bfloat16)

In [None]:
tokenizer=AutoTokenizer.from_pretrained(model_name,device_map="auto")

In [None]:
def limiting(dataset,min_length,max_length):
    filtered_examples = dataset.filter(lambda example: min_length <= len(example["dialogue"]) <= max_length)
    return filtered_examples
def build_dataset(model_name,dataset_name,min_length,max_length):
    dataset=load_dataset(dataset_name,split="train")
    (limiting(dataset,min_length,max_length))
    tokenizer=AutoTokenizer.from_pretrained(model_name,device_map="auto")
    def tokenize(sample):
        prompt=f"""
Summarize the following:
{sample["dialogue"]}
Summary:
"""
        sample["input_ids"]=tokenizer.encode(prompt)
        sample["query"]=tokenizer.decode(sample["input_ids"])
        return sample
    dataset=dataset.map(tokenize,batched=False)
    dataset.set_format(type="torch")

    dataset_splits=dataset.train_test_split(test_size=0.2,seed=42,shuffle=False)
    return dataset_splits
dataset=build_dataset(model_name,huggingface_dataset_name,200,1000)

In [None]:
lora_config=LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["q","v"],
    lora_dropout=0.05,
    bias="None",
    task_type=TaskType.SEQ_2_SEQ_LM #FLAN-T5
)
peft_model=PeftModel.from_pretrained(model=model,model_id=model_name,LoraConfig=lora_config,torch_dtype=torch.bfloat16,device_map="auto",is_trainable=True)
# peft_model=get_peft_model(model,lora_config)

In [None]:
ppo_model=AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,torch_dtype=torch.bfloat16,is_trainable=True)
ref_model=create_reference_model(ppo_model)
#jo model RL initialized weights wala hota hai usse aise load krte hai withvaluehead wale se

In [None]:
toxic_model_name="facebook/roberta-hate-speech-dynabench-r4-target"
toxic_tokenizer=AutoTokenizer.from_pretrained(toxic_model_name,device_map="auto")
toxic_model=AutoModelForSequenceClassification.from_pretrained(toxic_model_name,device_map="auto")

In [None]:
test_text="so beautiful so elegant just looking like a wow"
prepared_test_text=toxic_tokenizer(test_text,return_tensors="pt").input_ids
logits=toxic_model(input_ids=prepared_test_text).logits
probabilites=logits.softmax(dim=-1).to_list()
#logits are returned in order not hate and hate logits,we will use logits of nothate as reward
nothate_reward=logits[:2].to_list()

In [None]:
# alternate_method
# instead of tokenizing and input id we can use pipeline which does this
sentiment_pipeline=pipeline("sentiment-analysis",model=toxic_model,auto_map="device")
sentiment_pipeline(test_text)#returns logits

# if you want to return logits with somedifferenct method, then use kwargs
logits_kwargs={
    "top_k":None,#returns all logits
    "function_to_apply":None,
    "batch_size":16
}
probab_kwargs={
    "top_k":None,
    "function_to_apply":"softmax",
    "batch_size":16   #doubt
}
sentiment_pipeline(test_text,**logits_kwargs)#returns logits as earlier case
sentiment_pipeline(test_text,**probab_kwargs)#returns probabilities

In [None]:
toxicity_calculator=evaluate.load(
    "toxicity",toxic_model,
    module_type="measurement",
    toxic_label="hate"
    ) #DOUBT documentation me toh toxic_label jaisa koi parameter nhi hai        
toxicity_score=toxicity_calculator.compute(predictions=[test_text])
print(toxicity_score["toxicity"])

In [None]:
config=PPOConfig(
    model_name=model_name,
    learning_rate=1.41e-5,
    max_ppo_epochs=1,
    batch_size=16,
    mini_batch_size=4
)
def collator(data):
    return dict((key,[d[key] for d in data])for key in data[0])
ppo_trainer=PPOTrainer(
    config=config,
    ref_model=ref_model,
    model=ppo_model,
    tokenizer=tokenizer,
    dataset=dataset["train"],
    data_collator=collator
)

In [None]:
output_min_length=100
output_max_length=400
output_length_sampler=LengthSampler(output_min_length,output_max_length)
generation_kwargs={
    "min_length":5,
    "top_k":0.0,
    "top_p":1.0,
    "do_sample":True
}
reward_kwargs={
    "top_k":None,
    "functions_to_apply":"none",
    "batch_size":16
}
max_ppo_steps=10

for step,batch in tqdm(enumerate(ppo_trainer.dataloader)):
    if step>=max_ppo_steps:
        break
    prompt_tensors=batch["input_ids"]
    summary_tensors=[]
    for prompt_tensors in prompt_tensors:
        max_new_tokens=output_length_sampler()
        generation_kwargs["max_new_tokens"]=max_new_tokens
        summary=ppo_trainer.generate(prompt_tensors,**generation_kwargs)
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    batch["response"]=[tokenizer.decode(r.squeeze())for r in summary_tensors]

    query_response_pairs=[q+r for q,r in zip(batch["query"],batch["response"])]
    rewards=sentiment_pipeline(query_response_pairs,**reward_kwargs)
    reward_tensors=[torch.tensor(reward[0]["score"]) for reward in rewards]#0 is used for not_hate_index
    stats=ppo_trainer.step(prompt_tensors,summary_tensors,reward_tensors)
    ppo_trainer.log_stats(stats,batch,reward_tensors)
    print(f'objective/kl:{stats["objective/kl"]}')
    print(f'ppo/returns/mean:{stats["ppo/returns/mean"]}')
    print(f'ppo/policy/advantages_mean:{stats["ppo/policy/advantages_mean"]}')