<a href="https://colab.research.google.com/github/mayasrikanth/CS-236G-Project/blob/main/Generate_Story_Prompts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Working on code for in-context prompt generation

In [1]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 5.2 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 10.5 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 18.7 MB/s 
Collecting tokenizers!=0.11.3,>=0.10.1
  Downloading tokenizers-0.11.5-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
[K     |████████████████████████████████| 6.8 MB 38.4 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 5.0 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Foun

Import transformers library and classes for loading model/tokenizer

In [2]:
import transformers # import transformers library 
# Using HuggingFace GPT-Neo (open-source alternative to GPT-3, large GPT-J) to generate story prompts 
from transformers import AutoModelForCausalLM, AutoTokenizer

Load model (currently EleutherAI's gpt-neo with 1.3billion parameters)

In [3]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")  

Downloading:   0%|          | 0.00/1.32k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/200 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Warning: for loading GPT-J (as shown below), you'll need at least ~12GB of space on cpu ram (it's huge!)

In [None]:
# For loading GPT-J 
# from transformers import GPTJForCausalLM
# import torch

# model = GPTJForCausalLM.from_pretrained(
#     "EleutherAI/gpt-j-6B",
#         revision="float16",
#         torch_dtype=torch.float16,
#         low_cpu_mem_usage=True
# )
# tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

In [4]:
# Move model to gpu 
device = 'cuda'
model.to(device)

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 2048)
    (wpe): Embedding(2048, 2048)
    (drop): Dropout(p=0, inplace=False)
    (h): ModuleList(
      (0): GPTNeoBlock(
        (ln_1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0, inplace=False)
            (resid_dropout): Dropout(p=0, inplace=False)
            (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
        )
        (ln_2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=2048, out_features=8192, bias=True)
          (c_proj): Linear(in_fea

PromptTune: function that takes in prompt and outputs a prompt (max length=100 chars) generated by transformer decoder model. We are leveraging the model's zero-shot learning capabilities.

In [5]:
def PromptTune(prompt, model, tokenizer, temp=0.9, max_len=100):
  ''' Takes prompt as input, issues this to the model, and returns the decoded 
      output prompt limited to max_length characters. 
    Inputs: 
        - prompt: tuple containing a string representing the input sequence 
        - model: the specified transformer-based model 
        - max_len: max length of generated output (default=100 characters)
        - temp: temperature (higher temperature gives more diverse outputs, default=0.9)
    Output:
        - output: decoded string output 
  '''
  tokenized_input = tokenizer(prompt, return_tensors="pt").to(device)#.input_ids # tokenize prompt
  gen_tokens = model.generate(
    **tokenized_input,
    do_sample=True,
    temperature=temp,
    max_length=max_len,
    ) 
  gen_text = tokenizer.batch_decode(gen_tokens) # decode output tokens
  output = gen_text[0]
  print('output: ', output)
  # Return unique output (without repeating input)
  return gen_text

Prompt Type 1: Paraphrase with entity-linking

*   The objective of this prompt is to paraphrase a prompt which contains multiple sentences into a sentence that a text2im model can easily digest. Ideally, this step should link pronouns to a single entity in a consistent manner. We can think of this step as summarizing a "single page" in a simple children's storybook. 

In [48]:
# prompt = ["Passage: All was calm under the deep sea. Suddenly, a dolphin appeared. He began swimming towards shallower waters. \n"

#           "Summary: A dolphin swims in shallow sea water. \n"
          
#           # "Passage: The sun rose, casting shadows over the mountain top. Sally the bear climbed the mountain top. She stood at the very tip of the peak. \n "
#           # "Summary: A bear stands on a mountaintop at sunrise. \n "

#           #"Passage: In the light of the moon, a little egg lay on a leaf. The edge hatched, and suddenly a tiny caterpillar appeared. \n ", # inspo from "The very hungry Caterpillar"
#           #"Summary: \n ",
#           # "Passage: A baby turtle with a purple shell was walking on the beach. He was so close to the ocean water. He wanted to plunge into the water for a cold swim. \n" 
#           "Passage: A baby turtle with a purple shell was walking on the beach. He was approaching the ocean."
#           "Summary: A baby turtle with a purple shell walks on the beach towards the ocean. \n"
#           "Passage: "
#           # "Summary: \n"         
#]

In [6]:
prompt = ["Passage: A baby turtle with a purple shell was walking on the beach. He was approaching the ocean. \n"
"Summary: A baby turtle with a purple shell walks on the beach towards the ocean. \n"
"Passage: A lily flower was floating in a pond. A parrot landed on it. \n"
"Summary: A parrot lands on a lily flower in a pond. \n"
"Passage: A rabbit runs across the field. He runs quickly, through many flowers. \n"
"Summary: A rabbit runs across a field, through many flowers. \n"
"Passage: A cat climbs to the top of the mountain. It's sunrise when she reaches the mountain top. \n"
"Summary: A cat is on the top of a mountain at sunrise. \n"
"Passage: A bunny forages for food in the human's garden. To her delight, she finds a carrot and eats it. \n"
"Summary: A bunny eats a carrot in a garden. \n"
"Passage: A parrot sits on a tree. He lets out a loud chirp to let his friends know where is he is. \n"
"Summary: A chirping parrot sits on a tree. \n"
"Passage: A deer runs across the forest. She stops to catch her breath near a deciduous tree. \n"
"Summary: A deer in a forest standing near a tree. \n"
"Passage: An elephant roams the safari. He drinks water at the waterhole. \n"
"Summary: An elephant drinks water from the waterhole in the safari. \n"
"Passage: A dolphin swims under water. He is swimming next to coral.  \n"
"Summary: "]


In [9]:
# prompt = ["Passage: All was calm under the deep sea. Suddenly, a dolphin appeared. He began swimming towards shallower waters. \n"

#           "Summary: A dolphin swims in shallow sea water. \n"
          
prompt =  ["Passage: An elephant roams the safari. He drinks water at the waterhole. \n"
           "Summary: An elephant drinks water from the waterhole in the safari. \n"
           "Passage: A baby turtle with a purple shell was walking on the beach. He was approaching the ocean. \n"
          "Summary: A baby turtle with a purple shell walks on the beach towards the ocean. \n"
          "Passage: A lily flower was floating in a pond. A parrot landed on it. \n"
          "Summary: A parrot lands on a lily flower in a pond. \n"
          "Passage: A rabbit runs across the field. He runs quickly, through many flowers. \n"
          "Summary: A rabbit runs across a field, through many flowers. \n"
          "Passage: A dolphin swims under water. He is swimming next to coral.  \n"
          "Summary: "]


          #"Passage: All was calm under the deep sea. Suddenly, a dolphin appeared. He began swimming towards shallower waters. \n" ]
          # "Summary: \n"

In [65]:
output = PromptTune(prompt, model, tokenizer, temp=0.9, max_len=1000)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


output:  Passage: A baby turtle with a purple shell was walking on the beach. He was approaching the ocean. 
Summary: A baby turtle with a purple shell walks on the beach towards the ocean. 
Passage: A lily flower was floating in a pond. A parrot landed on it. 
Summary: A parrot lands on a lily flower in a pond. 
Passage: A rabbit runs across the field. He runs quickly, through many flowers. 
Summary: A rabbit runs across a field, through many flowers. 
Passage: A cat climbs to the top of the mountain. It's sunrise when she reaches the mountain top. 
Summary: A cat is on the top of a mountain at sunrise. 
Passage: A bunny forages for food in the human's garden. To her delight, she finds a carrot and eats it. 
Summary: A bunny eats a carrot in a garden. 
Passage: A parrot sits on a tree. He lets out a loud chirp to let his friends know where is he is. 
Summary: A chirping parrot sits on a tree. 
Passage: A deer runs across the forest. She stops to catch her breath near a deciduous tree.

In [12]:
# Second try 
output = PromptTune(prompt, model, tokenizer, temp=0.9, max_len=250)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


output:  Passage: An elephant roams the safari. He drinks water at the waterhole. 
Summary: An elephant drinks water from the waterhole in the safari. 
Passage: A baby turtle with a purple shell was walking on the beach. He was approaching the ocean. 
Summary: A baby turtle with a purple shell walks on the beach towards the ocean. 
Passage: A lily flower was floating in a pond. A parrot landed on it. 
Summary: A parrot lands on a lily flower in a pond. 
Passage: A rabbit runs across the field. He runs quickly, through many flowers. 
Summary: A rabbit runs across a field, through many flowers. 
Passage: A dolphin swims under water. He is swimming next to coral.  
Summary:  A dolphin swims under water, next to coral. 
Passage: A lizard came to the beach. He was sitting on a stone. He was looking back at the sea. 
Summary: A lizard comes to the beach, on a stone, looking back at the sea.  
Passage: A peacock came to the beach


Prompt Type 2: Add Visual Descriptors

In [31]:
prompt = ("Sentence: A black cat sits next to a pumpkin \n" 
"Image: A cat with fluffy black fur sits on a stone wall in the backyard next to a carved Halloween pumpkin \n" 
"Sentence: A parrot lands on a tree \n"
"Image: A rainbow parrot lands on a large, deciduous tree \n"
"Sentence: A starfish lies on a rock under water \n"
"Image: A starfish lies on a purple rock under green waters \n"
"Sentence: A dolphin swims under water, next to coral \n")

In [32]:
output = PromptTune(prompt, model, tokenizer, temp=1.0, max_len=150)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


output:  Sentence: A black cat sits next to a pumpkin 
Image: A cat with fluffy black fur sits on a stone wall in the backyard next to a carved Halloween pumpkin 
Sentence: A parrot lands on a tree 
Image: A rainbow parrot lands on a large, deciduous tree 
Sentence: A starfish lies on a rock under water 
Image: A starfish lies on a purple rock under green waters 
Sentence: A dolphin swims under water, next to coral 
Image: A dolphin swims next to white sand beaches 
Sentence: A yellow bird lies on the ground under red clouds 
Image: A yellow bird lies on a red carpet next to


In [33]:
prompt = ["Sentence: A black cat sits next to a pumpkin \n" 
"Image: A cat with fluffy black fur sits on a stone wall in the backyard next to a carved Halloween pumpkin \n" 
"Sentence: A parrot lands on a tree \n"
"Image: A rainbow parrot lands on a large, deciduous tree \n"
"Sentence: A starfish lies on a rock under water \n"
"Image: A starfish lies on a purple rock under turquoise waters \n"
"Sentence: A dog walks up to a fire hydrant next to a school \n"
"Image: A golden retriever trots up to a yellow fire hydrant next to a rainbow school \n"
"Sentence: A dolphin swims under water, next to coral \n"]

In [35]:
output = PromptTune(prompt, model, tokenizer, temp=1.0, max_len=200)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


output:  Sentence: A black cat sits next to a pumpkin 
Image: A cat with fluffy black fur sits on a stone wall in the backyard next to a carved Halloween pumpkin 
Sentence: A parrot lands on a tree 
Image: A rainbow parrot lands on a large, deciduous tree 
Sentence: A starfish lies on a rock under water 
Image: A starfish lies on a purple rock under turquoise waters 
Sentence: A dog walks up to a fire hydrant next to a school 
Image: A golden retriever trots up to a yellow fire hydrant next to a rainbow school 
Sentence: A dolphin swims under water, next to coral 
Image: A fish swims through a coral reef 
Sentence: A dog sleeps on a pink, rainbow bed 
Image: A rainbow dog sleeps on a pink, rainbow bed 
Sentence: Children are excited to get


Prompt Type 2: Open Ended (Plot Generation)

In [41]:
# for a creative transition, where we can utilize in-painting
# taking input=
prompt = ('Write a creative story: \n A playful dolphin swims through turquoise waters next to a clear blue coral')

In [42]:
output = PromptTune(prompt, model, tokenizer, temp=0.95, max_len=100)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


output:  Write a creative story: 
 A playful dolphin swims through turquoise waters next to a clear blue coral reef.
 The dolphins are attracted to a school of fish and swim up on the fish.

This playful dolphin was found swimming between reefs of turquoise water and coral coral.

The dolphin was captured from the Caribbean Sea in 2012. They were about 15 months old.

It is likely the dolphin was being used by fishermen as bait when they caught them for


In [43]:
 prompt = ('A white tiger lays on the ground next to a dark jungle')

In [44]:
output = PromptTune(prompt, model, tokenizer, temp=0.95, max_len=100)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


output:  A white tiger lays on the ground next to a dark jungle. His fur is dappled with pink and purple. He is a large man with a white snout and deep green eyes. His large paw is spread wide to rub his face against his muzzle. He is in the process of grooming himself.

This white tiger has been roaming between the lush foliage of the jungle in an attempt to defile the jungle and the beautiful pink and purple birds. The tiger is trying to get through


In [46]:
prompt = ('A dolphin spins in the ocean next to sea urchins on the ocean floor')

In [47]:
output = PromptTune(prompt, model, tokenizer, temp=0.95, max_len=100)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


output:  A dolphin spins in the ocean next to sea urchins on the ocean floor. / Courtesy of the NOAA Fisheries Service.

“I was fascinated by all the sea urchin things that have gone extinct on the sea floor,” she said. “The problem is, I would have to go in there and bring in a lot of equipment … and I don’t think there’s an easy way to do that.”

So she set


OLD CODE (will remove later)

In [13]:
prompt = (
    "Sentence: A black cat sits next to a pumpkin \n" 
    "Image: A cat with fluffy black fur sits on a stone wall in the backyard next to a carved Halloween pumpkin \n"
    "Sentence: A swan on a lake \n"
) # generate prompt 

In [14]:
tokenized_input = tokenizer(prompt, return_tensors="pt")#.input_ids # tokenize prompt


In [18]:
gen_tokens = model.generate(
    **tokenized_input,
    do_sample=True,
    temperature=0.9,
    max_length=100,
) # generate output tokens 

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [19]:
gen_text = tokenizer.batch_decode(gen_tokens)[0] # decode output tokens

In [22]:
gen_text = tokenizer.batch_decode(gen_tokens)

In [26]:
prompt

'Sentence: A black cat sits next to a pumpkin \nImage: A cat with fluffy black fur sits on a stone wall in the backyard next to a carved Halloween pumpkin \nSentence: A swan on a lake \n'

In [27]:
print(gen_text[0][len(prompt):])

Image: A little girl sits on the ground next to her mother who is wearing a red and white striped apron. It is a scene similar to a fairy tale 
Sentence: A man in tight trousers and a polka-dot shirt sits before
