This colab demonstrates the main TRICE algorithm in the paper: Training Chain-of-Thought via Latent-Variable Inference ([openreview](https://openreview.net/forum?id=a147pIS2Co)).

![trice](https://drive.google.com/uc?id=1rC6gJV1j8t9PwZpN3vysVN92WUqszjY1)

In [None]:
import jax
import jax.numpy as jnp
import optax
import requests
from tensorflow_probability.substrates import jax as tfp

# 

The code in this colab interfaces with an LLM backend via four functions: `sample`, `log_prob`, `grad`, and `init`. Signatures for these functions are below; to run this colab, you will need to provide callables that implement these functions.

In [None]:
# pylint:disable=unused-argument
def sample(params, context, num_steps=256, temperature=1.0, *, seed):
  """Draws a sample continuation.

  Args:
    params: a PyTree of parameters.
    context: a Python string or a list of Python strings.
    num_steps: The maximum number of tokens to generate.
    temperature: The temperature to use.
    seed: The random seed.

  Returns:
    A Python string or a list of Python strings.
  """
  raise NotImplementedError

def log_prob(params, context, continuation):
  """Computes the log-probability of generating continuation.

  Args:
    params: a PyTree of parameters.
    context: a Python string or a list of Python strings.
    continuation: a target string or a list of target strings to get log
      probability of.

  Returns:
    A scalar or a vector Array.
  """
  raise NotImplementedError

def grad(params, context, continuation):
  """Computes the gradient of log-probability w.r.t. parameters.

  Args:
    params: a PyTree of parameters.
    context: a Python string or a list of Python strings.
    continuation: a target string or a list of target strings to get log
      probability of.

  Returns:
    A PyTree of parameters corresponding to the gradient of log-probability
    w.r.t. parameters. If a list of string inputs are provided, the gradient
    will have a corresponding batch dimension on the left.
  """
  raise NotImplementedError

def init(soft_prompt_init):
  """Returns an initial value for a PyTree of parameters.

  Args:
    soft_prompt_init: a Python string to generate soft-prompt embedding to be
      used as the initial PyTree of parameters.

  Returns:
    A PyTree of parameters.
  """
  raise NotImplementedError
# pylint:enable=unused-argument

def get_dataset(seed=None):
  """Get train and validation datasets.

  Args:
    seed: The random seed to shuffle the training dataset.

  Returns:
    train_questions: A list of strings corresponding to training questions.
    train_answers: A list of strings corresponding to training answers.
    test_questions: A list of strings corresponding to testing questions.
    test_answers: A list of strings corresponding to testing answers.
  """
  task = "logical_deduction_three_objects"
  data_url = f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/bbh/{task}.json"
  examples = requests.get(data_url).json()["examples"]
  if seed is not None:
    permutation = jax.random.permutation(seed, len(examples))
    examples = [examples[int(i)] for i in permutation]
  questions = [ex["input"] for ex in examples]
  answers = [ex["target"] for ex in examples]
  train_questions, train_answers = questions[:150], answers[:150]
  test_questions, test_answers = questions[150:], answers[150:]
  return train_questions, train_answers, test_questions, test_answers

# 

In [None]:
RATIONALE_TEMPLATE = """Question: {question}
Answer: Let's think step by step.
"""
GUIDE_TEMPLATE = """Question: {question}
Answer: The answer is {answer}. Let's think step by step.
"""
SEED = 0
LEARNING_RATE = 1.0
TRAIN_STEPS = 100
BATCH_SIZE = 64
GRADIENT_SUBSAMPLE_SIZE = 64
LOG_EVERY = 1

In [None]:
def is_correct(rationale, answer):
  """Checks whether the rationale is correct."""
  return rationale.endswith("the answer is " + answer + ".")

In [None]:
def make_cot_prompt(questions, answers, seed):
  """Generates the COT soft prompt from the questions and answers."""
  qs, rs = [], []
  for i, (q, a) in enumerate(zip(questions, answers)):
    seed_i = jax.random.fold_in(seed, i)
    r = sample(params=None, context=RATIONALE_TEMPLATE.format(question=q), seed=seed_i)
    if is_correct(r, a):
      qs.append(q)
      rs.append(r)
      if len(qs) == 3:
        break
  assert len(qs) == 3, "Not enough examples to construct cot prompt."
  return "".join([RATIONALE_TEMPLATE.format(question=q) + r + "\n\n" for q, r in zip(qs, rs)])

In [None]:
def get_init_memory(fewshots_prompt, train_questions, train_answers, seed):
  """Obtains the initial memory."""
  examples = fewshots_prompt.strip().split("\n\n")
  guide_prompt = ""
  for ex in examples:
    q, r = ex.split("\nAnswer: Let's think step by step.\n")
    q = q.split("Question: ")[-1]
    a = r.split("the answer is ")[-1][:-1]
    guide_prompt += GUIDE_TEMPLATE.format(question=q, answer=a) + r + "\n\n"
  print("GUIDE PROMPT:", guide_prompt)
  context = [guide_prompt + GUIDE_TEMPLATE.format(question=q, answer=a)
             for q, a in zip(train_questions, train_answers)]
  memory = sample(params=None, context=context, seed=seed)
  return memory

In [None]:
def evaluate(params, test_questions, test_answers):
  """Evaluates the model accuracy given the current params."""
  context = [RATIONALE_TEMPLATE.format(question=q) for q in test_questions]
  rationales = sample(params, context=context, temperature=0.0, seed=jax.random.PRNGKey(0))
  correct_rationales = [r for r, a in zip(rationales, test_answers) if is_correct(r, a)]
  print("Val accuracy:", len(correct_rationales) / len(test_answers), flush=True)

In [None]:
def trice_loss(params, memory, seed, questions, answers):
  """Computes TRICE objective and its gradients."""
  subsample_seed, sample_seed = jax.random.split(seed)
  context = [RATIONALE_TEMPLATE.format(question=q) for q in questions]
  proposal = sample(params, context=context, seed=sample_seed)
  is_proposal_correct = jnp.stack([is_correct(r, a) for r, a in zip(proposal, answers)])
  is_memory_correct = jnp.stack([is_correct(m, a) for m, a in zip(memory, answers)])
  new_memory = [r if accept else m for r, m, accept in zip(proposal, memory, is_proposal_correct)]
  mask = is_proposal_correct | is_memory_correct
  correlation_est = (is_proposal_correct.sum() - is_proposal_correct) / (mask.sum() - 1 + 1e-10)

  # compute weight contributions of rationales from both memory and proposal.
  weights_memory = mask * (1 - correlation_est * is_proposal_correct)
  weights_proposal = mask * correlation_est * (1 - is_proposal_correct)
  flat_weights = jnp.concatenate([weights_memory, weights_proposal])
  flat_rationales = new_memory + proposal
  flat_contexts = context + context
  # Instead of using [jnp.ones_like(mask), -jnp.ones_like(mask)]) here, we use
  # [mask, -mask] to mask out the contributions from weight=0 rationales (which
  # has a very small chance of happening due to the clipping below).
  flat_signs = jnp.concatenate([mask, -1 * mask])
  flat_weights = jnp.clip(flat_weights, a_min=1e-10)
  weights_mean = flat_weights.sum() / (mask.sum() + 1e-10)
  # Note: to compute the loss without subsampling, we can set
  # subsampled_indices = jnp.arange(2 * len(questions))
  # per_item_weights = flat_signs * flat_weights
  subsampled_indices = tfp.experimental.mcmc.resample_systematic(
      jax.nn.log_softmax(flat_weights), GRADIENT_SUBSAMPLE_SIZE, (), subsample_seed)
  subsampled_signs = jnp.stack([flat_signs[i] for i in subsampled_indices])
  per_item_weights = subsampled_signs * weights_mean

  subsampled_rationales = [flat_rationales[i] for i in subsampled_indices]
  subsampled_contexts = [flat_contexts[i] for i in subsampled_indices]
  subsampled_log_probs = log_prob(params, subsampled_contexts, subsampled_rationales)
  subsampled_grads = grad(params, subsampled_contexts, subsampled_rationales)
  loss = -(subsampled_log_probs * per_item_weights).mean()
  params_grad = jax.tree_util.tree_map(
      lambda g: -(jnp.moveaxis(g, 0, -1) * per_item_weights).mean(-1), subsampled_grads)
  return (loss, new_memory, is_proposal_correct.mean(), mask.mean()), params_grad

In [None]:
def fit(optimizer, params, memory, train_questions, train_answers, test_questions, test_answers, seed):
  """Runs the optimization loop."""
  opt_state = optimizer.init(params)
  permute_seed, loss_seed = jax.random.split(seed)
  num_batches_per_epoch = len(train_questions) // BATCH_SIZE
  memory = memory.copy()
  for i in range(TRAIN_STEPS):
    # Get a mini-batch of questions and answers.
    permute_seed_i = jax.random.fold_in(permute_seed, i // num_batches_per_epoch)
    permutation = jax.random.permutation(permute_seed_i, len(train_questions))
    batch_indices = permutation[i % num_batches_per_epoch::num_batches_per_epoch][:BATCH_SIZE]
    batch_questions = [train_questions[int(i)] for i in batch_indices]
    batch_answers = [train_answers[int(i)] for i in batch_indices]
    batch_memory = [memory[int(i)] for i in batch_indices]

    loss_seed_i = jax.random.fold_in(loss_seed, i)
    (loss, batch_memory, proposal_acc, memory_acc), grads = trice_loss(
        params, batch_memory, loss_seed_i, batch_questions, batch_answers)
    for j, idx in enumerate(batch_indices):
      memory[idx] = batch_memory[j]
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    if (i + 1) % LOG_EVERY == 0:
      print(f"Step {i + 1}: loss {loss:.4f} | proposal_acc {proposal_acc:.4f} "
            f"| memory_acc {memory_acc:.4f}", flush=True)
      evaluate(params, test_questions, test_answers)
  return params, memory

In [None]:
data_seed, cot_seed, params_seed, memory_seed, fit_seed = jax.random.split(jax.random.PRNGKey(SEED), 5)
full_train_questions, full_train_answers, full_test_questions, full_test_answers = get_dataset(seed=data_seed)
cot_prompt = make_cot_prompt(full_train_questions, full_train_answers, cot_seed)
init_params = init(cot_prompt)
init_memory = get_init_memory(cot_prompt, full_train_questions, full_train_answers, memory_seed)

GUIDE PROMPT: Question: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. A fruit stand sells three fruits: cantaloupes, oranges, and watermelons. The oranges are the most expensive. The cantaloupes are more expensive than the watermelons.
Options:
(A) The cantaloupes are the cheapest
(B) The oranges are the cheapest
(C) The watermelons are the cheapest
Answer: The answer is (C). Let's think step by step.
The oranges are the most expensive. The cantaloupes are more expensive than the watermelons. So, the fruit that is the cheapest here is the watermelon (as compared to orange and cantaloupe, the watermelon is cheapest).
Thus, the answer is (C).

Question: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. A fruit stand sells three fruits: kiwis, loquats, and cantaloupes. The kiwis are l

In [None]:
optax_optimizer = optax.adam(optax.cosine_decay_schedule(LEARNING_RATE, TRAIN_STEPS, 0.1))
last_params, last_memory = fit(optax_optimizer, init_params, init_memory, full_train_questions,
                               full_train_answers, full_test_questions, full_test_answers, fit_seed)

Step 1: loss 1.2668 | proposal_acc 0.4688 | memory_acc 1.0000
Val accuracy: 0.5
Step 2: loss 6.6766 | proposal_acc 0.4219 | memory_acc 1.0000
Val accuracy: 0.58


Step 3: loss 2.7884 | proposal_acc 0.5625 | memory_acc 1.0000
Val accuracy: 0.64
Step 4: loss 0.6468 | proposal_acc 0.6250 | memory_acc 1.0000
Val accuracy: 0.67


Step 5: loss 4.4672 | proposal_acc 0.5781 | memory_acc 1.0000
Val accuracy: 0.66
Step 6: loss 4.5081 | proposal_acc 0.5625 | memory_acc 1.0000
Val accuracy: 0.65
Step 7: loss 2.0394 | proposal_acc 0.6094 | memory_acc 1.0000
Val accuracy: 0.66


Step 8: loss 1.4886 | proposal_acc 0.5625 | memory_acc 1.0000
Val accuracy: 0.69
Step 9: loss 1.9745 | proposal_acc 0.6094 | memory_acc 1.0000
Val accuracy: 0.71


Step 10: loss 0.5645 | proposal_acc 0.6406 | memory_acc 1.0000
Val accuracy: 0.7
Step 11: loss 3.4187 | proposal_acc 0.5312 | memory_acc 1.0000
Val accuracy: 0.69
Step 12: loss 2.1898 | proposal_acc 0.6250 | memory_acc 1.0000
Val accuracy: 0.71


Step 13: loss 3.2672 | proposal_acc 0.6719 | memory_acc 1.0000
Val accuracy: 0.73
Step 14: loss 3.6095 | proposal_acc 0.5938 | memory_acc 1.0000
Val accuracy: 0.69
Step 15: loss 1.6829 | proposal_acc 0.7031 | memory_acc 1.0000


Val accuracy: 0.69
Step 16: loss 2.8023 | proposal_acc 0.5938 | memory_acc 1.0000
Val accuracy: 0.67
Step 17: loss 2.5703 | proposal_acc 0.6875 | memory_acc 1.0000
Val accuracy: 0.65


Step 18: loss 1.9247 | proposal_acc 0.6406 | memory_acc 1.0000
Val accuracy: 0.64
Step 19: loss 3.4522 | proposal_acc 0.5781 | memory_acc 1.0000
Val accuracy: 0.64
Step 20: loss 5.3423 | proposal_acc 0.6562 | memory_acc 1.0000


Val accuracy: 0.67
Step 21: loss 0.7300 | proposal_acc 0.7500 | memory_acc 1.0000
Val accuracy: 0.69
Step 22: loss 0.6918 | proposal_acc 0.6875 | memory_acc 1.0000
Val accuracy: 0.71


Step 23: loss 1.0869 | proposal_acc 0.6875 | memory_acc 1.0000
Val accuracy: 0.72
Step 24: loss 3.3732 | proposal_acc 0.7344 | memory_acc 1.0000
Val accuracy: 0.71
Step 25: loss 1.8571 | proposal_acc 0.7500 | memory_acc 1.0000


Val accuracy: 0.69
Step 26: loss 1.8002 | proposal_acc 0.7500 | memory_acc 1.0000
Val accuracy: 0.69
Step 27: loss 0.3136 | proposal_acc 0.7656 | memory_acc 1.0000
Val accuracy: 0.7


Step 28: loss -0.4348 | proposal_acc 0.7344 | memory_acc 1.0000
Val accuracy: 0.73
Step 29: loss 0.6007 | proposal_acc 0.7656 | memory_acc 1.0000
Val accuracy: 0.75


Step 30: loss -0.2545 | proposal_acc 0.7500 | memory_acc 1.0000
Val accuracy: 0.75
Step 31: loss -0.3139 | proposal_acc 0.8750 | memory_acc 1.0000
Val accuracy: 0.76
Step 32: loss 0.3814 | proposal_acc 0.8438 | memory_acc 1.0000


Val accuracy: 0.76
Step 33: loss 1.5325 | proposal_acc 0.7188 | memory_acc 1.0000
Val accuracy: 0.75
Step 34: loss 0.3954 | proposal_acc 0.7969 | memory_acc 1.0000
Val accuracy: 0.76


Step 35: loss 0.2625 | proposal_acc 0.7344 | memory_acc 1.0000
Val accuracy: 0.78
Step 36: loss 0.1536 | proposal_acc 0.8438 | memory_acc 1.0000
Val accuracy: 0.77


Step 37: loss -0.0219 | proposal_acc 0.7969 | memory_acc 1.0000
Val accuracy: 0.78
Step 38: loss 0.3998 | proposal_acc 0.8125 | memory_acc 1.0000
Val accuracy: 0.78
Step 39: loss 1.6071 | proposal_acc 0.7969 | memory_acc 1.0000
Val accuracy: 0.76


Step 40: loss 0.6734 | proposal_acc 0.7656 | memory_acc 1.0000
Val accuracy: 0.74
Step 41: loss -0.4060 | proposal_acc 0.8750 | memory_acc 1.0000
Val accuracy: 0.76
Step 42: loss 0.3061 | proposal_acc 0.7969 | memory_acc 1.0000


Val accuracy: 0.76
Step 43: loss 0.2787 | proposal_acc 0.8125 | memory_acc 1.0000
Val accuracy: 0.76
Step 44: loss 0.9522 | proposal_acc 0.8281 | memory_acc 1.0000
Val accuracy: 0.78


Step 45: loss 0.2043 | proposal_acc 0.8594 | memory_acc 1.0000
Val accuracy: 0.8
Step 46: loss 0.3309 | proposal_acc 0.8750 | memory_acc 1.0000
Val accuracy: 0.8


Step 47: loss 0.4043 | proposal_acc 0.8906 | memory_acc 1.0000
Val accuracy: 0.8
Step 48: loss 0.2161 | proposal_acc 0.8594 | memory_acc 1.0000
Val accuracy: 0.8
Step 49: loss 0.6909 | proposal_acc 0.7344 | memory_acc 1.0000
Val accuracy: 0.79


Step 50: loss -0.0147 | proposal_acc 0.7656 | memory_acc 1.0000
Val accuracy: 0.79
Step 51: loss 0.0737 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.78
Step 52: loss 0.4943 | proposal_acc 0.8125 | memory_acc 1.0000


Val accuracy: 0.81
Step 53: loss 0.8976 | proposal_acc 0.8438 | memory_acc 1.0000
Val accuracy: 0.8
Step 54: loss 0.2485 | proposal_acc 0.8906 | memory_acc 1.0000
Val accuracy: 0.81


Step 55: loss 0.2710 | proposal_acc 0.8750 | memory_acc 1.0000
Val accuracy: 0.82
Step 56: loss 0.2302 | proposal_acc 0.8906 | memory_acc 1.0000
Val accuracy: 0.82
Step 57: loss 0.0981 | proposal_acc 0.8906 | memory_acc 1.0000


Val accuracy: 0.84
Step 58: loss 0.8216 | proposal_acc 0.8438 | memory_acc 1.0000
Val accuracy: 0.84
Step 59: loss 0.0988 | proposal_acc 0.8906 | memory_acc 1.0000
Val accuracy: 0.83


Step 60: loss 1.1516 | proposal_acc 0.8438 | memory_acc 1.0000
Val accuracy: 0.83
Step 61: loss -0.1295 | proposal_acc 0.8125 | memory_acc 1.0000
Val accuracy: 0.85
Step 62: loss 0.5059 | proposal_acc 0.8906 | memory_acc 1.0000


Val accuracy: 0.84
Step 63: loss -0.0203 | proposal_acc 0.9375 | memory_acc 1.0000
Val accuracy: 0.83
Step 64: loss 0.3703 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.86
Step 65: loss 0.4292 | proposal_acc 0.8906 | memory_acc 1.0000


Val accuracy: 0.85
Step 66: loss -0.3071 | proposal_acc 0.9062 | memory_acc 1.0000
Val accuracy: 0.87
Step 67: loss 0.0239 | proposal_acc 0.9375 | memory_acc 1.0000
Val accuracy: 0.87


Step 68: loss 0.0374 | proposal_acc 0.8750 | memory_acc 1.0000
Val accuracy: 0.87
Step 69: loss 0.2615 | proposal_acc 0.9375 | memory_acc 1.0000
Val accuracy: 0.87
Step 70: loss 0.0776 | proposal_acc 0.9375 | memory_acc 1.0000


Val accuracy: 0.87
Step 71: loss -0.1293 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.85
Step 72: loss -0.1358 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.87


Step 73: loss -0.0691 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.88
Step 74: loss 0.0730 | proposal_acc 0.9688 | memory_acc 1.0000
Val accuracy: 0.88
Step 75: loss 0.0903 | proposal_acc 0.9062 | memory_acc 1.0000


Val accuracy: 0.88
Step 76: loss -0.0191 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.87
Step 77: loss -0.1886 | proposal_acc 0.8906 | memory_acc 1.0000
Val accuracy: 0.87


Step 78: loss 0.1393 | proposal_acc 0.9062 | memory_acc 1.0000
Val accuracy: 0.88
Step 79: loss 0.1810 | proposal_acc 0.9375 | memory_acc 1.0000
Val accuracy: 0.87
Step 80: loss -0.0811 | proposal_acc 0.9531 | memory_acc 1.0000
Val accuracy: 0.88


Step 81: loss 0.1608 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.88
Step 82: loss -0.3453 | proposal_acc 0.9531 | memory_acc 1.0000
Val accuracy: 0.89
Step 83: loss 0.4232 | proposal_acc 0.8750 | memory_acc 1.0000


Val accuracy: 0.89
Step 84: loss 0.1070 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.88
Step 85: loss 0.1431 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.88


Step 86: loss 0.5932 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.88
Step 87: loss 0.0786 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.88
Step 88: loss -0.0000 | proposal_acc 1.0000 | memory_acc 1.0000
Val accuracy: 0.88


Step 89: loss -0.0293 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.87
Step 90: loss -0.3033 | proposal_acc 0.9531 | memory_acc 1.0000
Val accuracy: 0.88
Step 91: loss -0.1331 | proposal_acc 0.9062 | memory_acc 1.0000


Val accuracy: 0.88
Step 92: loss 0.0357 | proposal_acc 0.9844 | memory_acc 1.0000
Val accuracy: 0.86
Step 93: loss 0.3205 | proposal_acc 0.9219 | memory_acc 1.0000
Val accuracy: 0.87


Step 94: loss 0.2064 | proposal_acc 0.9531 | memory_acc 1.0000
Val accuracy: 0.86
Step 95: loss 0.0000 | proposal_acc 1.0000 | memory_acc 1.0000
Val accuracy: 0.88
Step 96: loss 0.1271 | proposal_acc 0.9375 | memory_acc 1.0000


Val accuracy: 0.87
Step 97: loss 0.2188 | proposal_acc 0.9688 | memory_acc 1.0000
Val accuracy: 0.86
Step 98: loss 0.6120 | proposal_acc 0.9062 | memory_acc 1.0000
Val accuracy: 0.87


Step 99: loss -0.0757 | proposal_acc 0.9062 | memory_acc 1.0000
Val accuracy: 0.87
Step 100: loss 0.1217 | proposal_acc 0.9531 | memory_acc 1.0000
Val accuracy: 0.86


Let's inspect the last memory and the final params.

In [None]:
for k, q_, a_, init_m, last_m in zip(range(5), full_train_questions, full_train_answers, init_memory, last_memory):
  print(f"Question {k}: {q_}\nAnswer: {a_}\nInit memory: {init_m}\nLast memory: {last_m}\n=====")

Question 0: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. In a golf tournament, there were three golfers: Amy, Eli, and Eve. Eve finished above Amy. Eli finished below Amy.
Options:
(A) Amy finished second
(B) Eli finished second
(C) Eve finished second
Answer: (A)
Init memory: Amy finishing above Eve implies Amy finished second.
So, the answer is (A).
Last memory: Eve finished above Amy. Eli finished below Amy. Amy finished second in the golf tournament.
So, the answer is (A).
=====
Question 1: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. A fruit stand sells three fruits: mangoes, watermelons, and kiwis. The watermelons are less expensive than the kiwis. The kiwis are the second-most expensive.
Options:
(A) The mangoes are the most expensive
(B) The watermelons are the most e

In [None]:
example_contexts = [RATIONALE_TEMPLATE.format(question=q) for q in full_test_questions[:5]]
init_rationales = sample(init_params, context=example_contexts, temperature=0.0, seed=jax.random.PRNGKey(0))
last_rationales = sample(last_params, context=example_contexts, temperature=0.0, seed=jax.random.PRNGKey(0))
for k, q_, a_, init_r, last_r in zip(range(5), full_test_questions, full_test_answers, init_rationales, last_rationales):
  print(f"Question {k}: {q_}\nAnswer: {a_}\nRationale (init params): {init_r}\nRationale (last params): {last_r}\n=====")

Question 0: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a shelf, there are three books: a green book, a brown book, and an orange book. The brown book is to the left of the orange book. The green book is to the left of the brown book.
Options:
(A) The green book is the leftmost
(B) The brown book is the leftmost
(C) The orange book is the leftmost
Answer: (A)
Rationale (init params): The green book is to the left of the brown book. The brown book is to the left of the orange book. The green book is the leftmost.
So, the answer is (A).
Rationale (last params): The green book is to the left of the brown book (which is to the left of the orange book). The green book is the leftmost.
So, the answer is (A).
=====
Question 1: The following paragraphs each describe a set of three objects arranged in a fixed order. The statements are logically consistent within each paragraph. On a br