In [1]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m66.8 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.3 tokenizers-0.13.2 transformers-4.27.4


In [122]:
"""Reward function in conversation will be made up of three components:
- r_c = congruence reward: how likely is the agent to have said what the respondent said (negative KL divergence of the next token probabilities)
- r_s = sentiment reward: how positive was the sentiment of the respondent (use a pre-existing sentiment model)
- r_a = affection reward: how much does the agent like the respondent (use discounted sum of previous rewards)

Here, we build the congruence reward.
"""

import os
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from google.colab import drive
from transformers import GPT2Tokenizer, GPT2LMHeadModel


def get_congruence_reward(comment_ids, response_ids_trunc):
  """Iterate through respondent's response to agent's comment, adding each token to the
  prompt each time, and get the KL divergence for what the agent would have said instead.
  Take the mean of all KL divergences at the end to give congruence reward.
  
  Args:
    comment_ids (torch tensor): IDs of agent's comment, including original query.
    response_ids_trunc (torch tensor): IDs of respondent's response, not including
      original comment or query

  Returns:
    float: Congruence reward value
  """

  rewards = list()
  target_ids = comment_ids.clone()

  for id in response_ids_trunc[0][:2]:
    agent_output = agent(target_ids)
    agent_probs = F.log_softmax(agent_output.logits[0][-1], dim=0) # Predicted probs
    reward = agent_probs[id] # Calculate KL divergence (same as cross-entropy in this case)
    rewards.append(reward.item())
    target_ids = torch.cat((target_ids.squeeze(0), id.unsqueeze(0)), dim=0).unsqueeze(0)

  congruence_reward = np.mean(rewards)

  return congruence_reward


drive.mount('/content/drive')
project_path = './drive/MyDrive/Colab Notebooks/GPT_community/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [17]:
# Load in prompts
prompts_file = os.path.join(project_path, 'data/brighton_philosophy_prompts.txt')
with open(prompts_file) as file:
    prompts = [line.rstrip() for line in file]

# Create models
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
pad_token_id = tokenizer.eos_token_id
agent = GPT2LMHeadModel.from_pretrained('Linus4Lyf/Kant_Metaphysics_Of_Morals').to('cuda')
respondent = GPT2LMHeadModel.from_pretrained('Linus4Lyf/Hume_A_Treatise_Of_Human_Nature').to('cuda')

In [131]:
questioner_name = 'Socrates'
agent_name = 'Kant'
respondent_name = 'Hume'

# Get query from questions list
query_text = f"{questioner_name}: " + np.random.choice(prompts)
print(query_text, '\n')
query_text += f"\n{agent_name}: "

# Encode query and get comment from agent
query_ids = tokenizer.encode(query_text, return_tensors='pt').to('cuda')
comment_ids = agent.generate(query_ids, do_sample=True, temperature=0.9, max_new_tokens=200, pad_token_id=pad_token_id, eos_token_id=pad_token_id)
comment_text = tokenizer.batch_decode(comment_ids)[0]
print('--------------------------------------------------------------------------------------')
print(comment_text, '\n')
comment_text += f"\n{respondent_name}: "

# Get response from respondent
comment_ids = tokenizer.encode(comment_text, return_tensors='pt').to('cuda')
response_ids = respondent.generate(comment_ids, do_sample=True, temperature=0.9, max_new_tokens=200, pad_token_id=pad_token_id, eos_token_id=pad_token_id)
response_text = tokenizer.batch_decode(response_ids)[0]
print('--------------------------------------------------------------------------------------')
print(response_text, '\n')

# Remove original query and comment from response text
response_text_trunc = response_text.replace(comment_text, '')[1:]
response_ids_trunc = tokenizer.encode(response_text_trunc, return_tensors='pt').to('cuda')
print('--------------------------------------------------------------------------------------')
print(response_text_trunc, '\n')

# Get congruence reward value for response
reward = get_congruence_reward(comment_ids, response_ids_trunc)
print(f"Congruence reward = {reward}")

Socrates: What is your personal philosophy? 

--------------------------------------------------------------------------------------
Socrates: What is your personal philosophy?
Kant: 

Philosophy has been taught the world for many years, and philosophers have had a considerable sway over the world in this regard. Their main work is that of generalising philosophical ideas and the practical application of them. This is one of the most important duties they are to carry out. This is a duty that is naturally not to be performed on a regular basis throughout society. They do not want to find a sufficient ground for their philosophical ideas, they wish to obtain something which can be carried out upon a special basis. In the end most philosophers are of limited practical ability, not of great utility, or virtue, and their use is to employ them in such a way that it becomes their duty to make use of them as well as to develop them in any possible way. This duty is not so much for practical p