In [55]:
from __future__ import annotations

import argparse
import json
import logging
import os
import random
import re
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from tqdm.auto import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
)
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer

try:  # optional dependency for grammar checking
    from language_tool_python import LanguageTool  # type: ignore
except ImportError:  # pragma: no cover - optional tool
    LanguageTool = None  # type: ignore

import nltk
from nltk.corpus import stopwords

import os

os.environ["JAVA_HOME"] = "C:\\Program Files\\Eclipse Adoptium\\jdk-21.0.8.9-hotspot"
os.environ["PATH"] += os.pathsep + os.path.join(os.environ["JAVA_HOME"], "bin")
print("JAVA_HOME:", os.environ.get("JAVA_HOME"))
print("PATH:", os.environ.get("PATH"))


logger = logging.getLogger("rl_finetune_gpt2")


JAVA_HOME: C:\Program Files\Eclipse Adoptium\jdk-21.0.8.9-hotspot
PATH: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9\libnvvp;;C:\Program Files\Microsoft SDKs\Azure\CLI2\wbin;C:\WINDOWS\system32;C:\WINDOWS;C:\WINDOWS\System32\Wbem;C:\WINDOWS\System32\WindowsPowerShell\v1.0\;C:\WINDOWS\System32\OpenSSH\;C:\Program Files\NVIDIA Corporation\NVIDIA app\NvDLISR;C:\Program Files (x86)\NVIDIA Corporation\PhysX\Common;C:\Program Files\LINQPad8;C:\Program Files\dotnet\;C:\Program Files\Microsoft SQL Server\150\Tools\Binn\;C:\Program Files\Microsoft SQL Server\Client SDK\ODBC\170\Tools\Binn\;C:\Program Files\Microsoft Service Fabric\bin\Fabric\Fabric.Code;C:\Program Files\Microsoft SDKs\Service Fabric\Tools\ServiceFabricLocalClusterManager;C:\Program Files\nodejs\;C:\Program Files\NVIDIA Corporation\Nsight Compute 2025.3.0\;C:\Program Files\Git\cmd;C:\Program Files\CMake\bin;C:\Users\moidhassan\AppData\Local\Microsoft\Window

In [31]:
default_reward_weights = {
        "length": 1.2,
        "politeness": 1.2,
        "sentiment": 0.7,
        "clarity": 0.6,
        "cta": 1.4,
        "personalization": 0.7,
        "grammar": 0.8,
        "value": 1.1,
        "spam": 0.8,
        "structure": 0.8,
        "instructional_tone": 0.8,
        "lexical_coherence": 1.2,
    }

PROMPT_TEMPLATES = [
    "Write a concise, professional sales email introducing this idea: \"{snippet}\"",
    "Compose a friendly B2B sales email based on this concept: \"{snippet}\"",
    "Generate a product outreach email using this information: \"{snippet}\"",
    "Write an enterprise sales email centered on: \"{snippet}\"",
    "Craft a formal product introduction email about: \"{snippet}\"",
]

data_path = "data/seller_emails_v3.json"
output_dir = ""
model_name = "gpt2"
tokenizer_name = "gpt2"
num_epochs = 2
batch_size = 2
max_new_tokens = 768
top_p = 1.0
temperature = 0.85
learning_rate = 1.41e-5
seed = 42
max_length = 1024
save_model = True

In [21]:
def load_seller_emails(file_path: Path) -> pd.DataFrame:
    """Load seller email corpus stored as JSON list of strings."""
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Seller email JSON not found: {file_path}")

    with open(file_path, "r", encoding="utf-8") as fp:
        raw = json.load(fp)

    if not isinstance(raw, list) or not all(isinstance(entry, str) for entry in raw):
        raise ValueError("Expected the JSON file to contain a list of strings (email bodies).")

    df = pd.DataFrame({"email_text": raw})
    # adding email length feature
    df["len_email_text"] = df["email_text"].str.len()
    # clean text by removing extra whitespace
    df["email_text"] = df["email_text"].str.replace(r"\s+", " ", regex=True).str.strip()
    # remove duplicates
    df = df.drop_duplicates(subset=["email_text"]).reset_index(drop=True)
    logger.info("Loaded %d seller emails from %s", len(df), file_path)
    return df



In [24]:
df = load_seller_emails(data_path)
df.head()

Unnamed: 0,email_text,len_email_text
0,"Biodegradable packaging: lower footprint, same...",58
1,"Dear HR Director, I'm Jennifer from HealthFirs...",478
2,Learning platform lifts test scores 18%. Pilot?,47
3,"Hello, I work with boutique hotels to enhance ...",432
4,Cut cloud spend 30–40%. 15‑min chat next week?...,57


In [25]:
def generate_prompts_from_emails(df: pd.DataFrame, *, text_column: str = "email_text") -> List[str]:
    texts = df[text_column].dropna().tolist()
    prompts: List[str] = []
    for email in texts:
        email = email.strip()
        sentences = nltk.sent_tokenize(email)
        core = [
            s for s in sentences if not re.match(r"^(hi|hello|dear|regards|thank|best)", s.strip().lower())
        ]
        snippet = ""
        for sentence in core:
            if len(sentence.split()) >= 6:
                snippet = sentence.strip()
                break
        if not snippet:
            snippet = email[:120]
        template = random.choice(PROMPT_TEMPLATES)
        prompts.append(template.format(snippet=snippet))
    logger.info("Generated %d prompts from dataset", len(prompts))
    return prompts


In [29]:
prompts = generate_prompts_from_emails(df)
print(len(prompts))
prompts

38


['Compose a friendly B2B sales email based on this concept: "Biodegradable packaging: lower footprint, same cost."',
 'Generate a product outreach email using this information: "Our corporate wellness programs have helped over 200 companies reduce healthcare costs by an average of $450 per employee annually while boosting morale."',
 'Write an enterprise sales email centered on: "Learning platform lifts test scores 18%."',
 'Write an enterprise sales email centered on: "Properties using our platform have seen an average 23% increase in positive reviews and a significant boost in repeat bookings."',
 'Write an enterprise sales email centered on: "Cut cloud spend 30–40%. 15‑min chat next week? – DataSync"',
 'Craft a formal product introduction email about: "Our clients typically see 15-20% increase in conversion rates within 90 days."',
 'Craft a formal product introduction email about: "Logistics costs down 20–30%. Explore fulfillment?"',
 'Compose a friendly B2B sales email based on t

In [40]:
def load_model_and_tokenizer(tokenizer_name, model_name, device):
    logger.info("Loading tokenizer: %s", tokenizer_name)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    tokenizer.padding_side = "left"
    print(f"Tokenizer pad token before adjustment: {tokenizer.pad_token} (id={tokenizer.pad_token_id})")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print(f"Set pad token to eos token: {tokenizer.pad_token} (id={tokenizer.pad_token_id})")

    logger.info("Loading base model for PPO value head: %s", model_name)
    ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
    ppo_model = ppo_model.to(device)
    logger.info("Model loaded on %s", device)
    return ppo_model, tokenizer


In [41]:
model, tokenizer = load_model_and_tokenizer(tokenizer_name, model_name, "cuda:0" if torch.cuda.is_available() else "cpu")



Tokenizer pad token before adjustment: None (id=None)
Set pad token to eos token: <|endoftext|> (id=50256)


  state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs)


In [42]:
def create_ppo_trainer(
    model: AutoModelForCausalLMWithValueHead,
    tokenizer: AutoTokenizer,
    learning_rate: float,
    batch_size: int
) -> PPOTrainer:
    logger.info("Initialising PPO trainer")
    ppo_config = PPOConfig(
        model_name=None,
        learning_rate=learning_rate,
        batch_size=batch_size,
        mini_batch_size=batch_size,
        optimize_cuda_cache=True,
    )
    trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer)
    return trainer


In [44]:
trainer = create_ppo_trainer(model, tokenizer, learning_rate, batch_size)
trainer

<trl.trainer.ppo_trainer.PPOTrainer at 0x2b02bd383d0>

In [45]:
def load_sentiment_analyzer(use_gpu: bool = True):
    device = 0 if use_gpu and torch.cuda.is_available() else -1
    logger.info("Loading sentiment analyzer (%s)...", "GPU" if device == 0 else "CPU")
    return pipeline(
        "sentiment-analysis",
        model="distilbert-base-uncased-finetuned-sst-2-english",
        device=device,
    )

In [51]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
sentiment_analyzer = load_sentiment_analyzer(use_gpu=device == "cuda:0")

Using device: cuda:0


In [52]:
def init_language_tool():
    if LanguageTool is None:
        logger.warning("language_tool_python not installed; grammar rewards disabled.")
        return None
    try:  # pragma: no cover - performs network access on first run
        tool = LanguageTool("en-US")
        logger.info("Initialized LanguageTool for grammar checking.")
        return tool
    except Exception as exc:  # pragma: no cover
        logger.warning("Failed to initialize LanguageTool: %s", exc)
        return None


In [56]:
grammar_tool = init_language_tool()

Downloading LanguageTool latest: 100%|██████████| 254M/254M [00:31<00:00, 7.97MB/s] 
Unzipping C:\Users\MOIDHA~1\AppData\Local\Temp\tmp9qa7lnvx.zip to C:\Users\moidhassan\.cache\language_tool_python.
Downloaded https://internal1.languagetool.org/snapshots/LanguageTool-latest-snapshot.zip to C:\Users\moidhassan\.cache\language_tool_python.
Downloading LanguageTool latest: 100%|██████████| 254M/254M [00:18<00:00, 13.5MB/s] 
Unzipping C:\Users\MOIDHA~1\AppData\Local\Temp\tmp60inuwer.zip to C:\Users\moidhassan\.cache\language_tool_python.
Downloaded https://internal1.languagetool.org/snapshots/LanguageTool-latest-snapshot.zip to C:\Users\moidhassan\.cache\language_tool_python.
Downloading LanguageTool latest: 100%|██████████| 254M/254M [00:19<00:00, 13.3MB/s] 
Unzipping C:\Users\MOIDHA~1\AppData\Local\Temp\tmpwwcfqi6v.zip to C:\Users\moidhassan\.cache\language_tool_python.
Downloaded https://internal1.languagetool.org/snapshots/LanguageTool-latest-snapshot.zip to C:\Users\moidhassan\.cache

In [71]:
CTA_PHRASES = [
    "call",
    "reply",
    "schedule",
    "meet",
    "connect",
    "reach out",
    "get in touch",
    "book a demo",
    "set up a meeting",
    "schedule a call",
    "contact us",
]

POLITE_TERMS = [
    "thank",
    "appreciate",
    "please",
    "hope",
    "kindly",
    "regards",
    "grateful",
    "welcome",
    "would you",
    "could you",
]

UNCLEAR_TERMS = [
    "utilize",
    "leverage",
    "synergy",
    "paradigm",
    "bandwidth",
    "ecosystem",
    "turnkey",
    "disruptive",
]

STOP_WORDS = set(stopwords.words("english"))

VALUE_TERMS_PATTERN = re.compile(
    r"\b(save|reduce|increase|boost|improve|growth|roi|cost|revenue|profit)\b",
    re.IGNORECASE,
)

SPAM_TERMS = [
    "free",
    "winner",
    "click here",
    "urgent",
    "act now",
    "limited time",
    "guarantee",
]

INSTRUCTIONAL_PATTERN = re.compile(
    r"(write (an|a) email|here'?s how|this guide|please follow)",
    re.IGNORECASE,
)


def has_cta_phrase(text: str) -> Tuple[bool, List[str]]:
    lower = text.lower()
    matches = []
    for phrase in CTA_PHRASES:
        if " " in phrase:
            if phrase in lower:
                matches.append(phrase)
        elif re.search(rf"\b{re.escape(phrase)}\b", lower):
            matches.append(phrase)
    return len(matches) > 0, matches


def evaluate_structure_and_intent(text: str) -> Tuple[float, str]:
    lower = text.lower()
    matched: List[str] = []
    greeting_match = re.findall(r"\b(hi|hello|dear)\b", lower)
    closing_match = re.findall(r"\b(regards|sincerely|thanks|best)\b", lower)
    product_match = re.findall(r"\b(product|solution|platform|offer|launch|introduc)\w*", lower)
    if greeting_match:
        matched.extend(greeting_match)
    if closing_match:
        matched.extend(closing_match)
    if product_match:
        matched.extend(product_match)
    has_greeting = bool(greeting_match)
    has_closing = bool(closing_match)
    has_product = bool(product_match)
    if has_greeting and has_closing and has_product:
        return 1.0, f"Proper structure detected: {matched}"
    if has_greeting or has_product:
        return 0.5, f"Partial email structure: {matched}"
    return -0.5, "Missing greeting/closing/product context"


def penalize_instructional_tone(text: str) -> Tuple[float, str]:
    matches = INSTRUCTIONAL_PATTERN.findall(text.lower())
    if matches:
        flat = [m[0] if isinstance(m, tuple) else m for m in matches]
        return -1.0, f"Instructional tone detected: {flat}"
    return 0.0, "No instructional tone detected"


def lexical_coherence(text: str) -> Tuple[float, str]:
    sentences = [s.strip().lower() for s in re.split(r"[.!?]", text) if s.strip()]
    if len(sentences) < 2:
        return 0.5, "Too short to assess coherence"
    overlaps: List[float] = []
    for idx in range(len(sentences) - 1):
        s1 = {w for w in re.findall(r"\b[a-z]+\b", sentences[idx]) if w not in STOP_WORDS}
        s2 = {w for w in re.findall(r"\b[a-z]+\b", sentences[idx + 1]) if w not in STOP_WORDS}
        if not s1 or not s2:
            continue
        overlap = len(s1 & s2) / len(s1 | s2)
        overlaps.append(overlap)
    if not overlaps:
        return 0.3, "Insufficient overlapping vocabulary"
    avg_overlap = float(np.mean(overlaps))
    if avg_overlap > 0.4:
        return 1.0, f"Good coherence (avg overlap={avg_overlap:.2f})"
    if avg_overlap > 0.2:
        return 0.5, f"Partial coherence (avg overlap={avg_overlap:.2f})"
    return -0.5, f"Low coherence (avg overlap={avg_overlap:.2f})"


def compute_reward(
    text: str,
    *,
    sentiment_analyzer=None,
    tool=None,
    weights: Optional[Dict[str, float]] = None,
    detailed: bool = False,
) -> Tuple[float, Dict[str, float]] | float:
    """Compute composite reward for a generated email."""
    text = (text or "").strip()
    lower = text.lower()
    w = weights or default_reward_weights()

    # Length
    length = len(text)
    if 100 <= length <= 300:
        length_r = 1.0
        length_reason = f"Ideal length ({length} chars)"
    elif 300 < length <= 450:
        length_r = 0.2
        length_reason = f"Slightly long ({length} chars)"
    elif length > 450:
        length_r = -0.7
        length_reason = f"Too long ({length} chars)"
    else:
        length_r = -0.7
        length_reason = f"Too short ({length} chars)"

    # Politeness
    polite_hits = [term for term in POLITE_TERMS if re.search(rf"\b{re.escape(term)}\b", lower)]
    polite_r = 0.5 * len(polite_hits)
    polite_reason = (
        f"Polite terms: {', '.join(polite_hits)}"
        if polite_hits
        else "No polite phrasing detected"
    )

    # Sentiment
    sentiment_r = 0.0
    sentiment_reason = "Sentiment analysis skipped"
    if sentiment_analyzer is not None:
        try:
            output = sentiment_analyzer(text[:512])
            if output and isinstance(output, list):
                label = output[0].get("label", "").upper()
                score = output[0].get("score", 0.0)
                if label.startswith("POS"):
                    sentiment_r = 1.0
                    sentiment_reason = f"Positive tone (score={score:.2f})"
                elif label.startswith("NEG"):
                    sentiment_r = -0.3
                    sentiment_reason = f"Negative tone (score={score:.2f})"
                else:
                    sentiment_reason = f"Neutral tone (score={score:.2f})"
        except Exception as exc:  # pragma: no cover
            sentiment_reason = f"Sentiment analysis failed: {exc}"

    # Clarity
    unclear_hits = [term for term in UNCLEAR_TERMS if re.search(rf"\b{term}\b", lower)]
    clarity_r = -0.3 * len(unclear_hits) if unclear_hits else 0.7
    clarity_reason = (
        f"Unclear buzzwords: {', '.join(unclear_hits)}"
        if unclear_hits
        else "Clear, accessible language"
    )

    # CTA
    has_cta, cta_hits = has_cta_phrase(text)
    cta_r = 1.0 if has_cta else -0.8
    cta_reason = (
        f"CTA detected: {', '.join(cta_hits)}"
        if has_cta
        else "Missing explicit call-to-action"
    )

    # Personalization
    personalization_terms = ["you", "your team", "your company", "dear", "hello"]
    personalized = any(term in lower for term in personalization_terms)
    personalization_r = 0.7 if personalized else -0.2
    personalization_reason = (
        "Personalized tone"
        if personalized
        else "No personalization markers"
    )

    # Grammar (optional)
    grammar_r = 0.0
    grammar_reason = "Grammar check skipped"
    if tool is not None:
        try:
            matches = tool.check(text)
            n_errors = len(matches)
            if n_errors == 0:
                grammar_r = 1.0
                grammar_reason = "No grammar issues"
            elif n_errors < 4:
                grammar_r = 0.5
                grammar_reason = f"Minor grammar issues ({n_errors})"
            else:
                grammar_r = -0.3
                grammar_reason = f"Significant grammar issues ({n_errors})"
        except Exception as exc:  # pragma: no cover
            grammar_reason = f"Grammar check failed: {exc}"

    # Value proposition
    value_hits = VALUE_TERMS_PATTERN.findall(text)
    value_r = min(len(value_hits) * 0.5, 1.0)
    value_reason = (
        f"Value props detected: {', '.join(set(map(str.lower, value_hits)))}"
        if value_hits
        else "No value proposition terms"
    )

    # Spam avoidance
    spam_hits = [term for term in SPAM_TERMS if re.search(rf"\b{term}\b", lower)]
    spam_r = -0.8 * len(spam_hits) if spam_hits else 0.5
    spam_reason = (
        f"Spammy terms present: {', '.join(spam_hits)}"
        if spam_hits
        else "No spam terms detected"
    )

    # Structure
    structure_r, structure_reason = evaluate_structure_and_intent(text)

    # Instructional tone penalty
    instruction_r, instruction_reason = penalize_instructional_tone(text)

    # Lexical coherence
    coherence_r, coherence_reason = lexical_coherence(text)

    total_reward = (
        w["length"] * length_r
        + w["politeness"] * polite_r
        + w["sentiment"] * sentiment_r
        + w["clarity"] * clarity_r
        + w["cta"] * cta_r
        + w["personalization"] * personalization_r
        + w["grammar"] * grammar_r
        + w["value"] * value_r
        + w["spam"] * spam_r
        + w["structure"] * structure_r
        + w["instructional_tone"] * instruction_r
        + w["lexical_coherence"] * coherence_r
    )
    total_reward = float(np.clip(total_reward, -4.0, 6.0))

    if not detailed:
        return total_reward

    breakdown = {
        "length": length_r,
        "politeness": polite_r,
        "sentiment": sentiment_r,
        "clarity": clarity_r,
        "cta": cta_r,
        "personalization": personalization_r,
        "grammar": grammar_r,
        "value": value_r,
        "spam": spam_r,
        "structure": structure_r,
        "instructional_tone": instruction_r,
        "lexical_coherence": coherence_r,
        "reasons": {
            "length": length_reason,
            "politeness": polite_reason,
            "sentiment": sentiment_reason,
            "clarity": clarity_reason,
            "cta": cta_reason,
            "personalization": personalization_reason,
            "grammar": grammar_reason,
            "value": value_reason,
            "spam": spam_reason,
            "structure": structure_reason,
            "instructional_tone": instruction_reason,
            "lexical_coherence": coherence_reason,
        },
    }
    return total_reward, breakdown


In [151]:
def prepare_queries(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> List[torch.Tensor]:
    prompt_lengths = attention_mask.sum(dim=1)
    queries: List[torch.Tensor] = []
    for row, length in zip(input_ids, prompt_lengths):
        length_int = int(length.item())
        start_idx = row.size(0) - length_int  # account for left padding
        queries.append(row[start_idx:].clone())
    return queries

def slice_responses(
    generated_ids: torch.Tensor,
    prompt_token_count: int,
    prompt_lengths: Sequence[int],
) -> List[torch.Tensor]:
    responses: List[torch.Tensor] = []
    for idx in range(generated_ids.size(0)):
        start = prompt_token_count
        resp_ids = generated_ids[idx, start:]
        if resp_ids.numel() == 0:
            resp_ids = generated_ids.new_tensor([generated_ids[idx, -1].item()])
        responses.append(resp_ids)
    return responses


def run_ppo_training(
    device: str,
    top_p: float,
    temperature: float,
    max_new_tokens: int,
    num_epochs: int,
    batch_size: int,
    max_length: int,
    reward_weights: Dict[str, float],
    save_model: bool,
    output_dir: Path,
    model: AutoModelForCausalLMWithValueHead,
    tokenizer: AutoTokenizer,
    trainer: PPOTrainer,
    prompts: Sequence[str],
    sentiment_analyzer,
    grammar_tool,
) -> None:
    device = torch.device(device)
    pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    gen_kwargs = dict(
        max_new_tokens=max_new_tokens,
        pad_token_id=pad_token_id,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
    )

    logger.info("Starting PPO fine-tuning for %d epochs", num_epochs)
    logging.info("Using batch size: %d", batch_size)
    for epoch in range(num_epochs):
        logger.info("Epoch %d/%d", epoch + 1, num_epochs)
        progress = tqdm(range(0, len(prompts), batch_size), desc=f"epoch {epoch+1}")
        for idx in progress:
            batch_prompts = list(prompts[idx : idx + batch_size])
            print(f"number of prompts in a batch - {len(batch_prompts)}")
            enc = tokenizer(
                batch_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length,
            )
            input_ids = enc.input_ids.to(device)
            print(f"shape of input_ids - {input_ids.shape}")
            attention_mask = enc.attention_mask.to(device)
            print(f"shape of attention_mask - {attention_mask.shape}")

            queries = prepare_queries(input_ids, attention_mask)
            print(f"queries shape: {[query.shape for query in queries]}")

            with torch.no_grad():
                generated = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    **gen_kwargs,
                )
            print(f"shape of generated: {generated.shape}")
            prompt_token_count = input_ids.size(1)
            print(f"prompt_token_count: {prompt_token_count}")
            prompt_lengths = attention_mask.sum(dim=1).tolist()
            print(f"prompt_lengths: {prompt_lengths}")
            responses = slice_responses(generated, prompt_token_count, prompt_lengths)
            #print(f"responses: {responses}")
            print(f"shape of responses: {[resp.shape for resp in responses]}")
            response_texts = [
                tokenizer.decode(resp, skip_special_tokens=True) for resp in responses
            ]
            print(f"response_texts: {response_texts}")

            rewards: List[torch.Tensor] = []
            for text, resp_ids in zip(response_texts, responses):
                total_reward = compute_reward(
                    text,
                    sentiment_analyzer=sentiment_analyzer,
                    tool=grammar_tool,
                    weights=reward_weights,
                )
                reward_tensor = torch.tensor(
                    float(total_reward),
                    device=device,
                    dtype=torch.float32,
                )
                rewards.append(reward_tensor)

            print({"query": batch_prompts, "response": response_texts})
            print(f"queries shape: {[query.shape for query in queries]}")
            print(f"responses shape: {[resp.shape for resp in responses]}")
            print(f"rewards shape: {[reward.shape for reward in rewards]}")
            stats = trainer.step(queries, responses, rewards)

            rewards_for_log = torch.stack(rewards).to(device)
            batch_log = {"query": batch_prompts, "response": response_texts}
            print({"query": batch_prompts, "response": response_texts})
            trainer.log_stats(stats, batch_log, rewards_for_log)

        logger.info("Completed epoch %d", epoch + 1)

    if save_model:
        save_dir = output_dir / "ppo_gpt2_epoch2_batch8"
        save_dir.mkdir(parents=True, exist_ok=True)
        logger.info("Saving fine-tuned model to %s", save_dir)
        trainer.model.save_pretrained(save_dir)
        tokenizer.save_pretrained(save_dir)


In [152]:
run_ppo_training(
    device=device,
    top_p=top_p,
    temperature=temperature,
    max_new_tokens=max_new_tokens,
    num_epochs=2,
    batch_size=8,
    max_length=max_length,
    reward_weights=default_reward_weights,
    save_model=save_model,
    output_dir=Path(output_dir or "."),
    model=model,
    tokenizer=tokenizer,
    trainer=trainer,
    prompts=prompts,
    sentiment_analyzer=sentiment_analyzer,
    grammar_tool=grammar_tool,
)

epoch 1:   0%|          | 0/5 [00:00<?, ?it/s]

number of prompts in a batch - 8
shape of input_ids - torch.Size([8, 36])
shape of attention_mask - torch.Size([8, 36])
queries shape: [torch.Size([27]), torch.Size([36]), torch.Size([17]), torch.Size([33]), torch.Size([27]), torch.Size([25]), torch.Size([20]), torch.Size([33])]
shape of generated: torch.Size([8, 804])
prompt_token_count: 36
prompt_lengths: [27, 36, 17, 33, 27, 25, 20, 33]
shape of responses: [torch.Size([768]), torch.Size([768]), torch.Size([768]), torch.Size([768]), torch.Size([768]), torch.Size([768]), torch.Size([768]), torch.Size([768])]
response_texts: ['\n\n\nSellout: 1% of 1% of the 1% by 1.\n\n\nSaleoff: 1% of 1% by 1.\n\n\nSale: 1 of 1% by 1.\n\n\nSale: 10% 10%\n\n\nSale: 1% of 1% of 1% by 10%\n\n\nSale: 1% of 1% by 10%\n\n\nSale: 1% by 1% by 10%\n\n\nPayscale, by market share of the\n\nSale as used by the seller, by the selling buyer,\n\nSale, buyer of\n\n\nSale, seller of\n\n\nA sales, a,\n\n\nSale: 1% of 1% of 1% by 1.\n\n\nSale: 1% of 1% by 1.\n\n\nSale: 

Downloading LanguageTool latest: 100%|██████████| 254M/254M [00:15<00:00, 16.5MB/s]
Unzipping C:\Users\MOIDHA~1\AppData\Local\Temp\tmp97jof6mz.zip to C:\Users\moidhassan\.cache\language_tool_python.
Downloaded https://internal1.languagetool.org/snapshots/LanguageTool-latest-snapshot.zip to C:\Users\moidhassan\.cache\language_tool_python.


{'query': ['Compose a friendly B2B sales email based on this concept: "Biodegradable packaging: lower footprint, same cost."', 'Generate a product outreach email using this information: "Our corporate wellness programs have helped over 200 companies reduce healthcare costs by an average of $450 per employee annually while boosting morale."', 'Write an enterprise sales email centered on: "Learning platform lifts test scores 18%."', 'Write an enterprise sales email centered on: "Properties using our platform have seen an average 23% increase in positive reviews and a significant boost in repeat bookings."', 'Write an enterprise sales email centered on: "Cut cloud spend 30–40%. 15‑min chat next week? – DataSync"', 'Craft a formal product introduction email about: "Our clients typically see 15-20% increase in conversion rates within 90 days."', 'Craft a formal product introduction email about: "Logistics costs down 20–30%. Explore fulfillment?"', 'Compose a friendly B2B sales email based o

ValueError: Batch size (2) does not match number of examples - but got 8 for: queries

In [126]:
def prepare_queries(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> List[torch.Tensor]:
    prompt_lengths = attention_mask.sum(dim=1)
    print(f"prompt_lengths inside prepare_queries: {prompt_lengths}")
    queries: List[torch.Tensor] = []
    for row, length in zip(input_ids, prompt_lengths):
        length_int = int(length.item())
        print(f"length_int inside prepare_queries loop: {length_int}")
        start_idx = row.size(0) - length_int  # account for left padding
        print(f"start_idx inside prepare_queries loop: {start_idx}")
        queries.append(row[start_idx:].clone())
    return queries

def slice_responses(
    generated_ids: torch.Tensor,
    prompt_token_count: int,
    prompt_lengths: Sequence[int],
) -> List[torch.Tensor]:
    responses: List[torch.Tensor] = []
    print(f"shape of generated: {generated_ids.shape}")
    print(f"prompt_token_count: {prompt_token_count}")
    print(f"prompt_lengths: {prompt_lengths}")
    for idx in range(generated_ids.size(0)):
        start = prompt_token_count
        print(f"start index for response slicing: {start}")
        resp_ids = generated_ids[idx, start:]
        print(f"shape of resp_ids: {resp_ids.shape}")
        if resp_ids.numel() == 0:
            print("resp_ids is empty, using last token as response")
            resp_ids = generated_ids.new_tensor([generated_ids[idx, -1].item()])
            print(f"shape of resp_ids after using last token: {resp_ids.shape}")
        responses.append(resp_ids)
    return responses

In [138]:
idx = 4
batch_size = 2
batch_prompts = list(prompts[idx : idx + batch_size])
print(batch_prompts)
print(f"number of prompts in a batch - {len(batch_prompts)}\n")

print(f"max_length: {max_length}\n")
enc = tokenizer(
    batch_prompts,
    return_tensors="pt",
    padding=True,
    truncation=False,
    #max_length=max_length,
)
input_ids = enc.input_ids.to(device)
print(f"input_ids: {input_ids}\n")
print(f"shape of input_ids - {input_ids.shape}\n")
attention_mask = enc.attention_mask.to(device)
print(f"shape of attention_mask - {attention_mask.shape}\n")

queries = prepare_queries(input_ids, attention_mask)
print(f"queries shape: {[query.shape for query in queries]}")

#decode queries to get actual prompt text
decoded_queries = [tokenizer.decode(query) for query in queries]
for i in range(len(decoded_queries)):
    if decoded_queries[i] != batch_prompts[i]:
        print(f"Decoded query {i} does not match original prompt:")
        print(f"Decoded: {decoded_queries[i]}")
        print(f"Original: {batch_prompts[i]}\n")

pad_token_id = tokenizer.pad_token_id
print(f"max_new_tokens: {max_new_tokens}, pad_token_id: {pad_token_id}, top_p: {top_p}, temperature: {temperature}\n")
gen_kwargs = dict(
        max_new_tokens=max_new_tokens,
        pad_token_id=pad_token_id,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
    )


#print(**gen_kwargs)
with torch.no_grad():
    generated = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    **gen_kwargs,
                            )
print(f"shape of generated: {generated.shape}")


['Write an enterprise sales email centered on: "Cut cloud spend 30–40%. 15‑min chat next week? – DataSync"', 'Craft a formal product introduction email about: "Our clients typically see 15-20% increase in conversion rates within 90 days."']
number of prompts in a batch - 2

max_length: 1024

input_ids: tensor([[16594,   281, 13953,  4200,  3053, 19254,   319,    25,   366, 26254,
          6279,  4341,  1542,  1906,  1821,  7225,  1315, 20977,  1084,  8537,
          1306,  1285,    30,   784,  6060, 28985,     1],
        [50256, 50256, 14467,   257,  8766,  1720,  9793,  3053,   546,    25,
           366,  5122,  7534,  6032,   766,  1315,    12,  1238,     4,  2620,
           287, 11315,  3965,  1626,  4101,  1528,   526]], device='cuda:0')

shape of input_ids - torch.Size([2, 27])

shape of attention_mask - torch.Size([2, 27])

prompt_lengths inside prepare_queries: tensor([27, 25], device='cuda:0')
length_int inside prepare_queries loop: 27
start_idx inside prepare_queries loop:

In [131]:
print(f"shape of generated: {generated.shape}")
prompt_token_count = input_ids.shape[1]
print(f"prompt_token_count: {prompt_token_count}")
prompt_lengths = attention_mask.sum(dim=1).tolist()
print(f"prompt_lengths: {prompt_lengths}")

responses = slice_responses(generated, prompt_token_count, prompt_lengths)
print(f"shape of responses: {[resp.shape for resp in responses]}")

response_texts = [tokenizer.decode(resp, skip_special_tokens=True) for resp in responses]
#for i, text in enumerate(response_texts):
#    print(f"Response {i} text: {text}\n")
            

shape of generated: torch.Size([2, 795])
prompt_token_count: 27
prompt_lengths: [27, 25]
shape of generated: torch.Size([2, 795])
prompt_token_count: 27
prompt_lengths: [27, 25]
start index for response slicing: 27
shape of resp_ids: torch.Size([768])
start index for response slicing: 27
shape of resp_ids: torch.Size([768])
shape of responses: [torch.Size([768]), torch.Size([768])]


In [133]:
response_texts[1]

'\n\n1x\n\n...\n\n\n2...\n\n\n.\n\n\n,\n\n,\n\n,\n\n,\n,\n\n,\n\n\n,\n,\n\n\n,\n\n.\n\n"I think many people assume that people on a website have something they would like to see from all the time in that can come some things in their free time.But many people don\'t realise that people are most likely to see the new products in the free time.I think many people think most people will see the new products in the free time.I think many people in the industry want to see the new products in free time. They think most people of one sort will see the new products in free time,but many people don\'t realise that many people will see the new products in free time.\n\n"I think, I see the new products that, many new things. I think people who are most likely to see the new products in free time,I think that the free of, I think people have, to see the new products in free time, in a. I think most people in the industry won\'t see the products in free time.Some, some are for not new, some one fo

In [135]:
print(default_reward_weights)

{'length': 1.2, 'politeness': 1.2, 'sentiment': 0.7, 'clarity': 0.6, 'cta': 1.4, 'personalization': 0.7, 'grammar': 0.8, 'value': 1.1, 'spam': 0.8, 'structure': 0.8, 'instructional_tone': 0.8, 'lexical_coherence': 1.2}


In [140]:
rewards: List[torch.Tensor] = []
for text in response_texts:
    total_reward = compute_reward(text,
                    sentiment_analyzer=sentiment_analyzer,
                    tool=grammar_tool,
                    weights=default_reward_weights
                    )
    print(f"Total reward: {total_reward}")
    reward_tensor = torch.tensor(float(total_reward),
                    device=device,
                    dtype=torch.float32,
                )
    print(f"Reward tensor: {reward_tensor}")
    rewards.append(reward_tensor)

Total reward: -1.94
Reward tensor: -1.940000057220459
Total reward: -0.5400000000000003
Reward tensor: -0.5400000214576721


In [141]:
rewards

[tensor(-1.9400, device='cuda:0'), tensor(-0.5400, device='cuda:0')]

In [139]:
print(f"queries shape: {[query.shape for query in queries]}")
print(f"responses shape: {[resp.shape for resp in responses]}")
print(f"rewards shape: {[reward.shape for reward in rewards]}")
            

queries shape: [torch.Size([27]), torch.Size([25])]
responses shape: [torch.Size([768]), torch.Size([768])]
rewards shape: [torch.Size([]), torch.Size([])]


In [145]:
# writing a inference function to generate email with the fine-tuned model
def generate_email(
    model,
    tokenizer,
    prompt: str,
    device: str = "cuda:0",
    max_length: int = 1024,
    max_new_tokens: int = 768,
    top_p: float = 1.0,
    temperature: float = 0.85,
) -> str:
    model = model.to(device)
    enc = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        #max_length=max_length,
    )
    input_ids = enc.input_ids.to(device)
    attention_mask = enc.attention_mask.to(device)

    gen_kwargs = dict(
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
    )

    with torch.no_grad():
        generated = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **gen_kwargs,
        )

    prompt_token_count = input_ids.size(1)
    prompt_lengths = attention_mask.sum(dim=1).tolist()
    responses = slice_responses(generated, prompt_token_count, prompt_lengths)
    response_text = tokenizer.decode(responses[0], skip_special_tokens=True)
    return response_text

In [149]:
finetuned_model_dir = "ppo_gpt2"
finetuned_model = AutoModelForCausalLMWithValueHead.from_pretrained(finetuned_model_dir)
finetuned_tokenizer = AutoTokenizer.from_pretrained(finetuned_model_dir)
test_prompt1 = "Write a concise, professional sales email introducing this idea: \"Our new AI-powered analytics platform can help businesses unlock insights from their data faster and more accurately.\""
test_prompt2 = "Compose a friendly B2B sales email based on this concept: \"We are launching a cloud-based collaboration tool that enhances team productivity and communication.\""
test_prompts = [test_prompt1, test_prompt2]
generated_emails = []
for test_prompt in test_prompts:
    generated_email = generate_email(
        model=finetuned_model,
        tokenizer=finetuned_tokenizer,
        prompt=test_prompt,
        max_length=256,
        max_new_tokens=256,
        top_p=0.9,
        temperature=0.7,
    )
    print(generated_email)
    generated_emails.append(generated_email)


Some weights of the model checkpoint at ppo_gpt2 were not used when initializing GPT2LMHeadModel: ['v_head.summary.bias', 'v_head.summary.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


shape of generated: torch.Size([1, 288])
prompt_token_count: 32
prompt_lengths: [32]
start index for response slicing: 32
shape of resp_ids: torch.Size([256])


To test this idea, we created a new, fast-paced app, called "Tune" (tune it to your preferences). Tune is an app that lets you send your data to Google Analytics for the most accurate predictions on your data, and then use it to generate, analyze, and analyze your data to optimize your business.

We analyzed the data that we wanted to analyze, and then used it to create a new app, called "Tune", which was optimized to generate, analyze, and analyze your data.

Now, let's analyze the data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, and analyze your data, 

In [148]:
generated_email

'\n\nWe\'ll let you know if you want to send us the news and features you like in the comments below.\n\nIn the meantime, we\'ll keep you posted on the latest developments in the market and the latest news from the market.\n\nIf you enjoy the website, we recommend you to check out our blog.\n\nRead also:\n\nHow to build a custom app with Android, iOS and Windows 8\n\nHow to build a custom app with iOS, Android and Windows 8\n\nWhat you need to know about the new API\n\nThe new API is called "SockApp", and it contains:\n\nA simple way to interact with the data\n\nA way to track the amount of data\n\nA way to track the number of requests\n\nA way to track the number of requests\n\nHow to implement the API\n\nIn the first part of the post, we will introduce the API for the API that we have set up in the past, so you can easily learn more about it.\n\nThe API will be written using a simple, low level API.\n\nThe first thing you need to do is to create a database in the database, which will