<a href="https://colab.research.google.com/github/erl-j/surprise/blob/main/Nicolas_SUPRISE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
!pip -q install transformers

[K     |████████████████████████████████| 3.1 MB 4.3 MB/s 
[K     |████████████████████████████████| 895 kB 37.6 MB/s 
[K     |████████████████████████████████| 59 kB 6.5 MB/s 
[K     |████████████████████████████████| 3.3 MB 37.0 MB/s 
[K     |████████████████████████████████| 596 kB 42.5 MB/s 
[?25h

In [2]:
import os
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch
from IPython.display import display, HTML
import json
from tqdm.notebook import tqdm
import random
import pickle

SEED=0
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

In [3]:
NEO_VERSION="1.3B"

# load models
model = GPTNeoForCausalLM.from_pretrained(f"EleutherAI/gpt-neo-{NEO_VERSION}")
tokenizer = GPT2Tokenizer.from_pretrained(f"EleutherAI/gpt-neo-{NEO_VERSION}")

Downloading:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/200 [00:00<?, ?B/s]

In [4]:
print(tokenizer.bos_token)

<|endoftext|>


# Prepare dataset

In [5]:
tag2type={
"WP": "Writing Prompt",
"SP": "Simple Prompt",
"EU": "Established Universe",
"CW": "Constrained Writing",
"TT": "Theme Thursday",
"PM": "Prompt Me",
"MP": "Media Prompt",
"IP": "Image Prompt",
"PI": "Prompt Inspired",
"OT": "Off Topic",
"RF": "Reality Fiction",
}


In [None]:
# download dataset if not in local environemt
if not os.path.isfile("writingPrompts/train.wp_source"):
  !wget https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz
  !tar -xf writingPrompts.tar.gz

--2021-11-07 12:03:07--  https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 104.22.74.142, 104.22.75.142, 172.67.9.4, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|104.22.74.142|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 381314604 (364M) [application/gzip]
Saving to: ‘writingPrompts.tar.gz’

writingPrompts.tar.   8%[>                   ]  31.27M  9.59MB/s    eta 50s    

In [None]:
print(tokenizer.bos_token)

In [None]:
class WritingPromptDataset(torch.utils.data.Dataset):
  """Writing prompts dataset."""

  def __init__(self, filepath,tokenizer):
    """
    Args:
        filepath (string): Path to file
    """
    # read dataset
    all_text = open("writingPrompts/train.wp_source","r").read()
    lines = all_text.splitlines()
    texts = [l.strip() for l in lines]

    # shuffle and print 10 sequences
    random.shuffle(texts)
    print(texts[:10])

    # if text starts with tag, throw away tag and keep rest of text, ignoring leading spaces
    tag="[ WP ]"
    texts =[text[len(tag):].strip() for text in texts if tag in text[:len(tag)]]
    print(texts[:10])

    print(len(texts))

    # add tokenization
    samples = [{"text":text,"tokens":tokenizer(text,return_tensors="pt").input_ids} for text in tqdm(texts)]

    # get mask seq len
    max_seq_len=max([t["tokens"].shape[-1] for t in samples])

    print(f"max_seq_len {max_seq_len}")

    # add beginning of sentence token + pad to max len with eos token
    samples = [
            {
            **text, 
            "tokens":torch.cat([torch.IntTensor([tokenizer.bos_token_id])[None,...],
                                text["tokens"],
                                torch.IntTensor([tokenizer.eos_token_id]*(max_seq_len-text["tokens"].shape[-1]))[None,...]
                                ],axis=-1).reshape([-1])
            } 
            for text in samples
            ]
    
    self.samples=samples

  def __len__(self):
      return len(self.samples)

  def __getitem__(self, idx):
      if torch.is_tensor(idx):
          idx = idx.tolist()
      sample = self.samples[idx]
      return sample

trn_ds=WritingPromptDataset("writingPrompts/train.wp_source",tokenizer)

In [None]:
!mkdir processed

trn_dl = torch.utils.data.DataLoader(trn_ds,batch_size=1,shuffle=False)

if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  

device = torch.device(dev)  

model.to(device)

print(dev)

# tokens
# eos 0 1 2

# logits
# 0 1 2 3

# event probs
# 0 1 2

# entropy
# 0 1 2


def colorize(words, color_array):
    # words is a list of words
    # color_array is an array of numbers between 0 and 1 of length equal to words
    cmap = matplotlib.cm.get_cmap('PuRd')
    template = '<span class="barcode"; style="color: black; background-color: {}; font-size:20px;">{}</span>'
    colored_string = ''
    for word, color in zip(words, color_array):
        color = matplotlib.colors.rgb2hex(cmap(color)[:3])
        colored_string += template.format(color, '&nbsp' + word + '&nbsp')
    # to display in ipython notebook
    display(HTML(colored_string))

SAVE_INTERVAL=10000
e=0

processed_data=[]
while True:
  for batch in tqdm(trn_dl):
    tokens=batch["tokens"]
    tokens=tokens.to(device)
    logits=model.forward(tokens).logits[:,:-1]
    probs = torch.nn.functional.softmax(logits,dim=-1)
    entropy = -torch.sum(probs*torch.log2(probs),dim=-1)
    event_probs=torch.zeros([tokens.shape[0],tokens.shape[1]-1]).cuda()

    for seq_index in range(event_probs.shape[0]):
      for t in range(event_probs.shape[1]):
        token_idx=tokens[seq_index,t+1]
        event_probs[seq_index,t]=probs[seq_index,t,token_idx]
    
    self_information=-torch.log2(event_probs)

    surprise= (self_information/entropy)

    for seq_index in range(tokens.shape[0]):
      token_sequence=[]
      for t in range(tokens.shape[-1]):
        token_sequence.append(tokenizer.decode(tokens[seq_index,t]))

    processed_sample = {
        "text":batch["text"][seq_index],
        "tokens":token_sequence[1:],
        "surprise":surprise[seq_index].detach().cpu().numpy(),
        "event_self_information":self_information[seq_index].detach().cpu().numpy(),
        "event_probabilities":event_probs[seq_index].detach().cpu().numpy(),
        "token_ids":batch["tokens"][seq_index,1:].detach().cpu().numpy(),
        "entropy":entropy[seq_index].detach().cpu().numpy()
        }

    #colorize(processed_sample["tokens"],processed_sample["surprise"]/6)

    processed_data.append(processed_sample)

    e+=1

    if e%SAVE_INTERVAL==0:
      with open(f'processed/{NEO_VERSION}_fair_wp_{SAVE_INTERVAL}.pkl','wb') as f:
        pickle.dump(processed_data,f)

      # test that it worked
      with open(f'processed/{NEO_VERSION}_fair_wp_{SAVE_INTERVAL}.pkl','rb') as f:
        pickled_processed_data=pickle.load(f)

      assert(pickled_processed_data[0]["surprise"][0]==processed_data[0]["surprise"][0])

      processed_data=[]
    

          

# Define surprise and related functions

In [None]:
 
 def plot_words(tokens,values):
    unique_tokens=[]
    token2count={}
    for t in tokens:
      if t not in token2count:
        token2count[t]=0
      unique_tokens.append(t+"_"+str(token2count[t]))
      token2count[t]+=1
    sns.set(rc={'figure.figsize':(50,3)})
    ax=sns.barplot(x=unique_tokens,y=values)
    ax.set_xticklabels(tokens)
    plt.show()

def colorize(words, color_array):
    # words is a list of words
    # color_array is an array of numbers between 0 and 1 of length equal to words
    cmap = matplotlib.cm.get_cmap('PuRd')
    template = '<span class="barcode"; style="color: black; background-color: {}; font-size:20px;">{}</span>'
    colored_string = ''
    for word, color in zip(words, color_array):
        color = matplotlib.colors.rgb2hex(cmap(color)[:3])
        colored_string += template.format(color, '&nbsp' + word + '&nbsp')
    # to display in ipython notebook
    display(HTML(colored_string))

def get_tokens_probabilities_surprises(prompt):
  prompt="<|endoftext|> "+prompt
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  logits = model.forward(input_ids).logits
  probs= torch.nn.functional.softmax(logits,dim=-1)
  entropy=-torch.sum(torch.log(probs)*probs,axis=-1)
  tokens=[]
  surprises=[]
  probabilities=[]
  for t in range(probs.shape[1]-1):
    outcome_prob=probs[0,t,input_ids[0,t+1]]
    self_information=-torch.log(outcome_prob)
    surprise=(self_information/entropy[0,t]).item()
    tokens.append(tokenizer.decode(input_ids[0,t+1]))
    surprises.append(surprise)
    probabilities.append(outcome_prob.item())
  return tokens, probabilities, surprises

def get_tokens_probabilities_surprises_entropy(prompt):
  prompt="<|endoftext|> "+prompt
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  logits = model.forward(input_ids).logits
  probs= torch.nn.functional.softmax(logits,dim=-1)
  entropy=-torch.sum(torch.log(probs)*probs,axis=-1)
  tokens=[]
  surprises=[]
  probabilities=[]
  for t in range(probs.shape[1]-1):
    outcome_prob=probs[0,t,input_ids[0,t+1]]
    self_information=-torch.log(outcome_prob)
    surprise=(self_information/entropy[0,t]).item()
    tokens.append(tokenizer.decode(input_ids[0,t+1]))
    surprises.append(surprise)
    probabilities.append(outcome_prob.item())
  return tokens, probabilities, surprises, entropy[0].detach().numpy().tolist()[:-1]

def plot_surprise(prompt):
  tokens,p,surprises=get_tokens_probabilities_surprises(prompt)
  plot_words(tokens,surprises)

def plot_probability(prompt):
  tokens,probabilities,surprises=get_tokens_probabilities_surprises(prompt)
  plot_words(tokens,np.array(probabilities))

def plot_surprise_map(prompt):
  tokens,p,surprises=get_tokens_probabilities_surprises(prompt)
  colorize(tokens,np.array(surprises)/6)

def plot_probability_map(prompt):
  tokens,probabilities,surprises=get_tokens_probabilities_surprises(prompt)
  colorize(tokens,np.array(probabilities))

def plot_norm_surprise(prompt):
  tokens,p,surprises=get_tokens_probabilities_surprises(prompt)
  plot_words(tokens,surprises/np.max(surprises))

def plot_entropy(prompt):
  tokens,p,surprises,entropy=get_tokens_probabilities_surprises_entropy(prompt)
  plot_words(tokens,entropy)

def plot_information(prompt):
  tokens,probabilities,surprises=get_tokens_probabilities_surprises(prompt)
  nlp=-np.log(probabilities)
  plot_words(tokens,np.array(nlp)/np.max(nlp))

  
# surprise sampling

def surprise_sample(prompt,target_surprise):
  prompt="<|endoftext|> "+prompt
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  logits = model.forward(input_ids).logits
  probs= torch.nn.functional.softmax(logits,dim=-1)
  surprises = -torch.log(probs[0,-1,:])/-torch.sum(torch.log(probs[0,-1,:])*probs[0,-1,:],axis=-1)

  surprises=surprises.detach().numpy()

  word=tokenizer.decode(np.argmin(np.abs(target_surprise-surprises)))
  return word

def generate_from_surprise_contour(prompt, surprise_contour):
  current_text=prompt
  for i,target_surprise in enumerate(surprise_contour):
    result=(surprise_sample(current_text,target_surprise))
    current_text+=result
  return current_text


def probability_sample(prompt,target_probability):
  prompt="<|endoftext|> "+prompt
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  logits = model.forward(input_ids).logits[:,-1,:]
  probs= -np.log(torch.nn.functional.softmax(logits,dim=-1).detach().numpy())
  target_prob=-np.log(target_probability)
  word_idx=np.argmin(np.abs(target_probability-probs))
  word=tokenizer.decode(word_idx)
  return word

def generate_from_probability_contour(prompt, surprise_contour):
  current_text=prompt
  for i,target_surprise in enumerate(surprise_contour):
    result=(probability_sample(current_text,target_surprise))
    current_text+=result
  return current_text

plot_surprise_map("Write an alt history where another sentient race evolved alongside humans , and now we live in harmony/conflict with them")

In [None]:
# SURPRISE TRANSFER # generate new text with human surprise contour:

source_prompts=["The thing about that is that it's not as easy as it seems","An apple doesn't spoil if you cherish it.","You have to walk before you can fly.","One two three four five four three two one."]
target_prompts=["Sam looked up at the clock.", "Sam looked up at the clouds."]

for source_prompt in source_prompts:
  t,human_probability,human_surprise=get_tokens_probabilities_surprises(source_prompt)

  print("Source sequence:")
  plot_surprise_map(source_prompt)
  print("\n")

  for target_prompt in target_prompts:

    print("Target prompt:")
    print(target_prompt)
    print("\n")
    
    result=generate_from_surprise_contour(target_prompt,human_surprise)[len(target_prompt):]

    print("Surprise map of result (shown without prompt)")
    plot_surprise_map(result)
    print("\n")

print("PROBABILITY CONTOUR SAMPLING")
plt.plot(human_probability)
plt.show()
result=generate_from_probability_contour("Sam looked up at the clock.",human_probability)
plot_probability_map(result)

plt.plot(human_probability)
plt.show()
result=generate_from_probability_contour("Sam looked up at the sky.",human_probability)
plot_probability_map(result)


In [None]:
# ENTROPY PLOT

prompts=["You have to walk before we can run.","One two three four five six seven eight nine ten eleven twelve","You have to walk before we can run.","A B C D E F G H K L N O P"]

print("ENTROPY PLOTS")


for prompt in prompts:
  print(f"PROMPT : {prompt}")
  print("information:")
  plot_information(prompt)
  print("entropy:")
  plot_entropy(prompt)
  print("surprise:")
  plot_surprise(prompt)
  





# Test sampling strategy. 

## 1 - Data

- Compute human statistics over human text.
- Load corpus
- Compute logits and words over corpus
- save as dataset

## 2 - Compute surprise across dataset and try to model surprise contour.

-First, stateless. 
-Then with state (probably more accurate)

## 3- Sample text from generated surprise contour and compare with other sampling strategies.


In [None]:

surprise_contour=[0.5,0.5,0.5,0.5,5.5,0.5,0.5,0.5]
plt.plot(surprise_contour)
plt.show()
for prompt in ["Who","What","Why","When"]:
  for i,target_surprise in enumerate(surprise_contour):
    result=(surprise_sample(prompt,target_surprise))
    prompt+=result

  plot_surprise_map(prompt)



surprise_contour=[0.5,2.5,0.5,0.5,5.5,0.5,2.5,0.5]
plt.plot(surprise_contour)
plt.show()
for prompt in ["Who","What","Why","When"]:
  for i,target_surprise in enumerate(surprise_contour):

    result=(surprise_sample(prompt,target_surprise))
    prompt+=result

  plot_surprise_map(prompt)

In [None]:
# plot text examples from https://arxiv.org/pdf/1904.09751.pdf and compare with surprise

beamsearch = """We wish to provide an overview of the current state-of-the-art in the field of computer vision and machine learning, and to provide an overview of the current state-of-the-art in the field of computer vision and machine learning, and to provide an overview of the current state-of-the-art in the field of computer vision and machine learning, and to provide an overview of the current state-of-the-art in the field of computer vision and machine learning, and"""
human = """This grant increased life span and three years warranty. The Antec HCG series consists of five models with capacities spanning from 400W to 900W. Here we should note that we have already tested the HCG-620 in a previous review and were quite satisfied with its performance. In today's review we will rigorously test the Antec HCG-520, which as its model number implies, has 520W capacity and contrary to Antec's strong beliefs in multi-rail PSUs is equipped"""
t,beamsearch_probability,beamsearch_surprise=get_tokens_probabilities_surprises(beamsearch)
t,human_probability,human_surprise=get_tokens_probabilities_surprises(human)


fig,ax=plt.subplots(figsize=(20, 6))
ax.plot(beamsearch_probability,label="beam search",linewidth=4)
ax.plot(human_probability,label="human",linewidth=4)
ax.legend()
ax.set_ylabel("token probability")

fig,ax=plt.subplots(figsize=(20, 6))
plt.plot(beamsearch_surprise,label="beam search",linewidth=4)
plt.plot(human_surprise,label="human",linewidth=4)
ax.set_ylabel("token surprise")
plt.show()



In [None]:

plt.plot(human_surprise)
plt.show()
result=generate_from_surprise_contour("",human_surprise)
plot_surprise_map(result)


plt.plot(human_surprise)
plt.show()
result=generate_from_surprise_contour("",human_surprise)
plot_surprise_map(result)

plt.plot(human_surprise)
plt.show()
result=generate_from_surprise_contour("",human_surprise)
plot_surprise_map(result)


plt.plot(human_surprise)
plt.show()
result=generate_from_surprise_contour("",human_surprise)
plot_surprise_map(result)

In [None]:
prompt = "My name is Carl."

print("\n")
print(prompt)

print("surprise")
plot_norm_surprise(prompt)

print("neg log prob")
plot_norm_logprob(prompt)


prompt = "My name was Carl."

print("\n")
print(prompt)

print("surprise")
plot_norm_surprise(prompt)

print("neg log prob")
plot_norm_logprob(prompt)

In [None]:

prompt = ""

print("\n")
print(prompt)

print("surprise")
plot_norm_surprise(prompt)

print("neg log prob")
plot_norm_logprob(prompt)

In [None]:
prompt = "We need to walk before we can fly."

print("\n")
print(prompt)

print("surprise")
plot_norm_surprise(prompt)

print("neg log prob")
plot_norm_logprob(prompt)

In [None]:
# normalized surprise sampling

def norm_surprise_sample(prompt,target_surprise):
  prompt="<|endoftext|> "+prompt
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  logits = model.forward(input_ids).logits
  probs= torch.nn.functional.softmax(logits,dim=-1)
  surprises = -torch.log(probs[0,-1,:])/-torch.sum(torch.log(probs[0,-1,:])*probs[:,-1,:],axis=-1)


  surprises=surprises.detach().numpy()

  surprises=surprises/np.max(surprises)

  print(np.sum(surprises))

  word=tokenizer.decode(np.argmin(np.abs(target_surprise-surprises)))
  return word


surprise_contour=[0.1,0.1,0.1,0.1,0.5,0.1,0.2,0.1]


plt.plot(surprise_contour)
plt.show()
for prompt in ["Who","What","Why","When"]:
  for i,target_surprise in enumerate(surprise_contour):

    result=(norm_surprise_sample(prompt,target_surprise))
    prompt+=result

  plot_surprise_map(prompt)



surprise_contour=[0.1,0.2,0.1,0.1,0.5,0.1,0.2,0.1]
plt.plot(surprise_contour)
plt.show()
for prompt in ["Who","What","Why","When"]:
  for i,target_surprise in enumerate(surprise_contour):

    result=(norm_surprise_sample(prompt,target_surprise))
    prompt+=result

  plot_surprise_map(prompt)

# Explore

In [None]:
plot_surprise_map("Nice to cancer you.")
plot_surprise_map("Nice to break you.")
plot_surprise_map("Nice to meet you.")
plot_surprise_map("Nice to find you.")
plot_surprise_map("Meet to nice you.")

In [None]:
plot_surprise("Nice to cancer you.")
plot_surprise("Nice to break you.")
plot_surprise("Nice to meet you.")
plot_surprise("Nice to find you.")


plot_probabilities("Nice to cancer you.")
plot_probabilities("Nice to break you.")
plot_probabilities("Nice to meet you.")
plot_probabilities("Nice to find you.")


In [None]:
plot_surprise("My name is Lars.")
plot_surprise("My name is Dan.")
plot_surprise("My name is Greg.")
plot_surprise("My name is bus.")
plot_surprise("My name is lake.")
plot_surprise("My name is wall.")

In [None]:
prompt="Nice to meet"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids

logits = model.forward(input_ids).logits
probs= torch.nn.functional.softmax(logits,dim=-1)
entropy=-torch.sum(torch.log(probs)*probs,axis=-1)

probs_argsort=torch.argsort(-probs,axis=-1)

decoded=tokenizer.batch_decode(probs_argsort[0,:,:10])

print(decoded)

# Joke dataset 2

In [None]:
!wget https://raw.githubusercontent.com/amoudgl/short-jokes-dataset/master/data/reddit-cleanjokes.csv

In [None]:
import pandas

jokes = pandas.read_csv("reddit-cleanjokes.csv")["Joke"].to_numpy()

setup_punchline = [{"full_joke":joke,"setup":joke.split("?")[0]+"?","punchline":joke.split("?")[1].strip()} for joke in jokes if "?" in joke]

print(setup_punchline[0])

len(setup_punchline)

In [None]:
import tqdm
for i in tqdm.tqdm(range(len(setup_punchline))):
  tokens,probabilities,surprises = get_tokens_probabilities_surprises(setup_punchline[i]["full_joke"])
  
  question_mark_idx=0
  for token_idx,token in enumerate(tokens):
    if "?" in token:
      question_mark_idx=token_idx

  setup_punchline[i]["punchline_probabilities"]=probabilities[question_mark_idx:]
  setup_punchline[i]["punchline_tokens"]=tokens[question_mark_idx:]
  setup_punchline[i]["punchline_surprises"]=surprises[question_mark_idx:]

In [None]:
for i in tqdm.tqdm(range(len(setup_punchline))):
  setup_punchline[i]["mean_punchline_surprise"]=np.mean(setup_punchline[i]["punchline_surprises"])

In [None]:
setup_punchline.sort(key=lambda x : x["mean_punchline_surprise"])

In [None]:
print("#### Least mean punchline surprise: #####")

for j in setup_punchline[:10]:
  print(j["mean_punchline_surprise"])
  print(j["setup"])
  print(j["punchline"])
  print("\n")


print("##### Highest mean punchline surprise: ####")

for j in setup_punchline[-10:]:
  print(j["mean_punchline_surprise"])
  print(j["setup"])
  print(j["punchline"])
  print("\n")