<a href="https://github.com/gretelai/public_research/blob/main/gretel-gpt-sentiment-swap/gretel-gpt-sentiment-swap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üé≠ Gretel-GPT Sentiment Swap

In this notebook, we will demonstrate how to use [Gretel GPT](https://docs.gretel.ai/reference/synthetics/models/gretel-gpt) to fine tune and prompt a large language model (LLM) to swap the sentiments of product reviews. Given a product review with a particular sentiment (positive or negative), our fine-tuned model will generate a new review with the opposite sentiment.

### üíæ Install the Gretel SDK

In [None]:
!pip install gretel-client

In [None]:
import pandas as pd
import gretel_client as gretel

### ‚öôÔ∏è Main Project Settings

- The data subset can be `Video_Games_v1_00` or `Apparel_v1_00` from the [Amazon Customer Reviews](https://huggingface.co/datasets/amazon_us_reviews) dataset.

- The pair selection metric can be `helpful_votes` or `cos_sim`.

In [None]:
DATA_SUBSET = "Video_Games_v1_00"
PROJECT_NAME = "gretel-gpt-sentiment-swap"
PAIR_SELECTION_METRIC = "helpful_votes"
DATA_BASE_PATH = "https://github.com/gretelai/public_research/raw/main/gretel-gpt-sentiment-swap/data/"

### üõú Configure Gretel Session

- You will need a Gretel API key for this step. 

- If you haven't already, get your API key by signing up for free at [gretel.ai](https://console.gretel.ai/login/).

In [None]:
gretel.configure_session(
    api_key="prompt",
    endpoint="https://api.gretel.cloud",
    validate=True,
    cache="yes",
)

### üìä Fetch data from GitHub

- The datasets were created by this [create_dataset.py](https://github.com/gretelai/public_research/blob/main/gretel-gpt-sentiment-swap/create_dataset.py) script.

- The training dataset consists of product review pairs.

- Each record in the conditional dataset contains the first review in a review pair. The model's job is to generate the second review.

In [None]:
dataset_label = f"{DATA_SUBSET}_{PAIR_SELECTION_METRIC}"

df_train = pd.read_csv(DATA_BASE_PATH + f"training_review_pairs-{dataset_label}.csv.gz")
df_prompts = pd.read_csv(DATA_BASE_PATH + f"conditional_prompts-{DATA_SUBSET}.csv.gz")

### Print random examples from the fine-tuning dataset

- Run the cell below multiple times to see different product review pairs

In [None]:
print(df_train.sample(1).iloc[0]["text"])

### üéõÔ∏è Fine-tune an LLM with Gretel GPT

- This will take a few hours to complete. Feel free to grab some coffee ‚òïÔ∏è

- You can also monitor the progress of the training job in the [Gretel Console](https://console.gretel.cloud)

In [None]:
print(f"Creating or fetching Gretel project with name {PROJECT_NAME}")
project = gretel.projects.get_project(name=PROJECT_NAME, display_name=PROJECT_NAME, create=True)

config = {
    "schema_version": 1,
    "models": [
        {
            "gpt_x": {
                "data_source": "__",
                "pretrained_model": "gretelai/mpt-7b",
                "batch_size": 16,
                "epochs": 4,
                "weight_decay": 0.01,
                "warmup_steps": 100,
                "lr_scheduler": "linear",
                "learning_rate": 0.0005,
                "validation": None,
                "generate": {"num_records": 100, "maximum_text_length": 500},
            }
        }
    ],
}

print("Creating model object")
model = project.create_model_obj(model_config=config)
model.data_source = df_train
model.name = f"{PROJECT_NAME}_{dataset_label}"

print(f"Submitting fine-tuning job to Gretel Cloud with data subset {dataset_label}")
model.submit_cloud()

gretel.helpers.poll(model, verbose=False)

### ü§ñ Generate sentiment-swapped reviews

In [None]:
# Fetch the latest model from the project.
model = [m for m in project.search_models(model_name=DATA_SUBSET) if m.status == "completed"][-1]

# Create a record handler with the conditional prompts as seed data.
record_handler = model.create_record_handler_obj(
    params={"maximum_text_length": 200, "temperature": 1.2}, 
    data_source=df_prompts
)

# Submit the record handler to the Gretel Cloud for generation.
record_handler.submit_cloud()
gretel.helpers.poll(record_handler, verbose=False)

# Fetch the generated data from the Gretel Cloud.
df_generations = pd.read_csv(record_handler.get_artifact_link("data"), compression='gzip')

### üñ®Ô∏è Print example sentiment-swapped review pairs 

- The first review in each pair is sampled from the conditional prompt dataset.

- The second review in each pair is the sentiment-swapped review from our model. 

In [None]:
num_samples = 5

samples = df_prompts.sample(num_samples)
for idx, prompt in samples.itertuples():
    generation = df_generations.loc[idx, "text"]
    print(f"{prompt} \033[1;30;46m{generation}\033[0;0m\n-----\n")