In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import torch

from promptopt import models
from promptopt import datasets
from promptopt import rlhf
from promptopt import embed
from promptopt import interrogator

from matplotlib import pyplot as plt
import matplotlib as mpl

In [None]:
embedding_model = embed.CLIP()

In [None]:
net_arch_kwargs = {
  'n_layers': 2,
  'layer_size': 256
}
pref_model = models.PrefModel(embedding_model.n_embedding_dims, net_arch_kwargs=net_arch_kwargs)

In [None]:
pref_model_train_config = rlhf.get_default_config()
pref_model_train_config.verbose = True

In [None]:
prompts = [
  'a quick',
  'brown fox',
  'jumped over',
  'the moon'
]

In [None]:
pref_data = [
  (0, 1, 0),
  (1, 2, 1),
  (1, 3, 1),
  (2, 0, 0),
  (3, 0, 0)
]

In [None]:
embeddings = embedding_model.embed_strings(prompts)

In [None]:
embedding_dataset = datasets.EmbeddingDataset(embeddings=list(embeddings))

In [None]:
pref_dataset = datasets.PrefDataset(embedding_dataset, pref_data=pref_data)

In [None]:
optimizer = rlhf.RLHF(pref_model, pref_model_train_config, embedding_dataset, pref_dataset)

In [None]:
candidate_prompts = [
  'quick fox',
  'jumping moon'
]
candidate_prompts += prompts

In [None]:
candidate_embeddings = embedding_model.embed_strings(prompts)

In [None]:
candidate_scores = np.array(pref_model.score(candidate_embeddings))

In [None]:
scored_candidates = list(zip(candidate_prompts, candidate_scores))

In [None]:
sorted_candidates = sorted(scored_candidates, key=lambda x: x[1], reverse=True)
sorted_candidates

In [None]:
score_func = lambda x: torch.tensor(np.array(pref_model.score(x)))
gator = interrogator.Gator(embedding_model=embedding_model, score_func=score_func)

In [None]:
init_prompt, _ = scored_candidates[0]
init_prompt

In [None]:
best_prompt = gator.search(init_prompt)
best_prompt

In [None]:
best_prompt_embedding = embedding_model.embed_string(best_prompt)

In [None]:
optimizer.predict_prefs(best_prompt_embedding)