In [None]:
from src import utils

utils.set_seed(42)
utils.clear_memory()

In [None]:
import pandas as pd

df = pd.read_json("hf://datasets/allenai/real-toxicity-prompts/prompts.jsonl", lines=True)

input_dicts = df["prompt"].values.flatten().tolist()

inputs = [d["text"] for d in input_dicts]

In [None]:
inputs[:5]

In [None]:
import matplotlib.pyplot as plt
from nltk.tokenize import word_tokenize

word_counts = [len(word_tokenize(i)) for i in inputs]

plt.title("Number of words")
plt.hist(word_counts, bins=range(min(word_counts), max(word_counts) + 2), edgecolor="black")
plt.show()

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

token_counts = [len(tokenizer.tokenize(i)) for i in inputs]

plt.hist(token_counts, bins=range(min(token_counts), max(token_counts) + 2), edgecolor="black")
plt.title("Number of ALL Tokens")
plt.show()

In [None]:
from src.rating.detoxify import DetoxifyRater

rater_backend = DetoxifyRater(model_type="original", amp=True)

In [None]:
from src.rating.base import *
import matplotlib.pyplot as plt

rating_runner = RatingRunner(rater_backend)
rating_results = rating_runner.rate_stram(inputs[-200:], batch_size=100)
rating_results = list(rating_results)

toxicity_scores = [r.scores["toxicity"] for r in rating_results]
_ = plt.hist(toxicity_scores, bins=20)
plt.show()

print("Rating results Example:")
print(rating_results[:5])

# Print examples of toxicity > 0.5 and < 0.5
toxic_examples = [r for r in rating_results if r.scores["toxicity"] > 0.5]
non_toxic_examples = [r for r in rating_results if r.scores["toxicity"] <= 0.5]

print("Toxic examples:")
for r in toxic_examples[:5]:
    print(r.text)

print("Non-toxic examples:")
for r in non_toxic_examples[:5]:
    print(r.text)

In [None]:
import torch
import torch_tensorrt

torch._dynamo.list_backends()

In [None]:
from nltk import sent_tokenize
text = "What is the capital of France? The capital of France is Paris. What is the capital of Italy? The capital of Italy is Rome."
split_text = sent_tokenize(text)
print(split_text)

In [None]:
from src.generation.chat_model import ChatGenerator
from src.generation.vanilla_model import VanillaGenerator

hf_key = utils.api_key_from_file("HF_KEY.txt")
generator_backend = ChatGenerator(
    model_name="meta-llama/Llama-3.2-1B-Instruct",
    hub_token=hf_key,
    max_input_tokens=40,
    max_output_tokens=30,
    torch_dtype=torch.bfloat16,
)

In [None]:
from transformers import CompileConfig
from src.generation.base import GenerationRunner

generation_runner = GenerationRunner(generator_backend)
generated_texts = generation_runner.generate_stream(
    inputs[:5000],
    batch_size=300,
    # padding="max_length",
    # cache_implementation="static",
    # compile_config=CompileConfig(
    #     fullgraph=True,
    #     dynamic=True,
    #     backend="inductor",
    #     mode="reduce-overhead",
    # ),
)
generated_list = list(generated_texts)

In [None]:
# sample 10 random examples

import random

random.seed(42)
random_examples = random.sample(generated_list, 10)

for example in random_examples:
    print(example)

In [None]:
from functools import partial
from src.survival_runner import (
    SurvivalRunner,
    SurvivalResult,
    default_toxicity_func,
    default_text_prep_func,
)

survival_runner = SurvivalRunner(
    generator=generator_backend,
    rater=rater_backend,
    max_attempts=40,
    toxicity_func=partial(default_toxicity_func, thresh=0.1),
    text_prep_func=lambda gen: gen.prompt + " " + gen.output,
    conserve_memory=True,
)

survival_results = survival_runner.generate(
    prompts=inputs[-400:],
    batch_size=500,
)

In [None]:
survival_list: list[SurvivalResult] = []
for res in survival_results:
    survival_list.append(res)
    print(res)

In [None]:
import matplotlib.pyplot as plt

survival_time = [r.num_attempts for r in survival_list]
_ = plt.hist(survival_time, bins=20)
plt.show()