In [None]:
import argilla as rg 
localuser = "argilla"
password = "1234"
apikey = "owner.apikey"
url = "https://kamaljp-argillatest.hf.space/"

rg.init(api_key=apikey,
        api_url=url)

In [None]:
localgilla = "hfgilla"
rg.set_workspace(localgilla)

In [None]:
# creating the dataset as a framework

dataset_fw = rg.FeedbackDataset(
    guidelines="Please read the prompt carefully",
    questions=[
        rg.TextQuestion(
            name="prompt",
            title="Please write a harmless reply",
            required=True,
        )
    ],
    fields=[
        rg.TextField(name="prompt", required=True)
    ]
)

In [None]:
# there are following ways to collect the datasets 

# The steps here can include: 
# (1) finding an open dataset that might contain prompts related to your use 
# case
# (2) performing** exploratory data** analysis and topic extraction** to understand
#  the data
# (3) filtering and selecting prompts based on topic, quality,
# text descriptiveness, etc.
# (4) Asking humans to write prompts for your usecase

In [None]:
# this will be populated from the list of writing topics you create
fields = [
    rg.TextField(name="writing-topic", required=True)
]

# we will ask the labeler to write a possible prompt or instruction
question = rg.TextQuestion(
	name="prompt",
	title="Imagine and write a possible instruction for the given topic:",
	required=True
)

In [None]:
from datasets import load_dataset

prompts = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")

In [None]:
prompts[0]

In [None]:
records = [
    rg.FeedbackRecord(fields={"prompt": rek['prompt'][0]}) for rek in prompts
]
dataset_fw.add_records(records)

In [None]:
# This publishes the dataset with its records to Argilla and returns the dataset in Argilla
remote_dataset = dataset_fw.push_to_argilla(name="rlhf_demo", workspace=localgilla)

In [None]:
# Assume we distribute the workload in one dataset with several labelers
feedback_five = rg.FeedbackDataset.from_argilla(
	name="rlhf_demo",
	workspace=localgilla
)

In [None]:
feedback_five

In [None]:
feedback_five.filter_by(response_status="submitted")

In [None]:
### Create the datasets to rank the responses

questions = [
    rg.RankingQuestion(
        name="response_ranking",
        title="order the responses based on their accuracy & helpfulness",
        required=True,
        values={"res1":"Nice", "res2": "Okay"}
    )
]

In [None]:
question = [
    rg.RatingQuestion(
        name="rate_resp",
        title="Select accurate response between (2) and (3). If same then select (1).",
        required=True,
        values=[1, 2, 3]
    )
]

In [None]:
response_collect_ds = rg.FeedbackDataset(
    guidelines="Please read prompt, its response below and provide feedback",
    questions=question,
    fields=[
        rg.TextField(name="prompt1", required=True),
        rg.TextField(name="response1", required=True),
        rg.TextField(name="response2", required=True),
    ]
)

In [None]:
response_collect_ds.push_to_argilla(name="response_collect", workspace=localgilla)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it",
                                             resume_download=True,
                                            device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")

In [None]:
# Create a pipeline for text generation
gen_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    # device='cuda',
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

In [None]:
for x in prompts:
    print(x['prompt'][0])
    break

In [None]:
records = []
for record in prompts:
    prompt = record["prompt"][0]
    # print(f'This is prompt: {prompt}')
    # Generate two responses in one call
    outputs = gen_pipeline(
        prompt,
        max_length=512,
        do_sample=True,
        top_k=10,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
    )
    # print(outputs)
    responses = [output[0]["generated_text"] for output in outputs]
    # print(f"This is response: {responses}")
    try:
        record = rg.FeedbackRecord(fields={"prompt1": prompt[0],
                                           "response1": responses[0],
                                           "response2": responses[1]})
        response_collect_ds.add_records([record])
    except Exception as e:
        print(f"The prompt {prompt} created error due to : {e}")

# Add records to the dataset

In [None]:
to_rem_ds = response_collect_ds.push_to_argilla(name="response_collect",
                                                workspace=localgilla)

In [None]:
feedback_ds = rg.FeedbackDataset.from_argilla(
        name="response_collect",
        workspace=localgilla
    )

In [None]:
# Define an empty list to store the triplets
triplets = []

# Loop over all records in the dataset
for record in feedback_ds.records:
    print(record.fields)
    # Ensure that the record has responses
    if record.responses is None or len(record.responses) == 0:
        continue

    # Ensure the response has been submitted (not discarded)
    # print(len(record.responses))
    response = record.responses[0]
    print(response)
    if response.status == 'submitted':
        print(response.values['rate_resp'])
        # Get the ranking value from the response for the preferred and least preferred
        # responses, assuming there are no ties
        preferred_rank = response.values["rate_resp"].value
        # least_preferred_rank = response.values["response_ranking"].value[1]["value"]

        # Construct the triplet and append to the list
        triplets.append({
            "prompt": record.fields["prompt1"],
            "preferred_response": preferred_rank #  record.fields[preferred_rank],
            # "least_preferred_response": record.fields[least_preferred_rank],
        })

# Now, "triplets" is a list of dictionaries, each containing a prompt and the associated
# preferred and less preferred responses

In [None]:
triplets