In this notebook, we'll explore few-shot learning with GPT-2.  While GPT-2 is a less expressive model than GPT-3 (and hence not as a good of a few shot learner), it can fit within the memory and processing constraints of laptops while also being openly available.  Can you create a new classification task and design prompts to differentiate between the classes within it?

In [1]:
import torch
from torch.nn import functional as F

In [2]:
from transformers import pipeline

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [4]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

In [5]:
def classify_with_prompt(prompt, labels):
    inputs = tokenizer.encode(prompt, return_tensors='pt')
    completion_layer = model(inputs).logits[:, -1, :]
    probabilities = F.softmax(completion_layer, dim=-1)[0]
    pred_idx=torch.argmax(probabilities)
    pred_token = tokenizer.decode(pred_idx.tolist())

    label_ids=[]
    for label in labels:
        token_ids=tokenizer.encode(label)
        label_ids.append(token_ids[0])
        
    sorted_args=list(torch.argsort(probabilities[label_ids], descending=True))
    for arg in sorted_args: 
            print("%.6f\t%s" % (probabilities[label_ids[arg]], labels[arg]))
    
    print("\nCompletion with highest probability:\n")
    print(prompt + pred_token)
    
    

In [6]:
prompt = """X: I love this movie
Y: positive

X: I hate the movie
Y: negative

X: I kind of like the movie
Y: positive

X: This is one of the best movies I've ever seen
Y:"""

classify_with_prompt(prompt, ["positive", "negative"])

0.000231	positive
0.000102	negative

Completion with highest probability:

X: I love this movie
Y: positive

X: I hate the movie
Y: negative

X: I kind of like the movie
Y: positive

X: This is one of the best movies I've ever seen
Y: positive


In [7]:
prompt = """X: Vampires take over the planet during an eclipse
Y: Horror

X: Two friends switch bodies and live each other's lives
Y: Comedy

X: John turns into a werewolf during a full moon
Y: Horror

X: John is a werewolf who plays basketball
Y: Comedy

X: A court sentences George to be Jerry's butler
Y: Comedy

X: A virus outbreak turns everyone into zombies
Y:"""

classify_with_prompt(prompt, ["Horror", "Comedy"])

0.000030	Horror
0.000005	Comedy

Completion with highest probability:

X: Vampires take over the planet during an eclipse
Y: Horror

X: Two friends switch bodies and live each other's lives
Y: Comedy

X: John turns into a werewolf during a full moon
Y: Horror

X: John is a werewolf who plays basketball
Y: Comedy

X: A court sentences George to be Jerry's butler
Y: Comedy

X: A virus outbreak turns everyone into zombies
Y: Horror


In [8]:
prompt = """Q: This is a text
A: English

Q: Nel mezzo del cammin' di nostra vita
A: Italian

Q: Je ne sais pas
A:"""

classify_with_prompt(prompt, ["English", "Italian", "French", "Spanish", "Japanese"])


0.000223	Spanish
0.000058	French
0.000022	English
0.000020	Italian
0.000008	Japanese

Completion with highest probability:

Q: This is a text
A: English

Q: Nel mezzo del cammin' di nostra vita
A: Italian

Q: Je ne sais pas
A: Spanish


**Q1**.  Your job is to create a new classification task using prompt design (as in the examples above).  You are free to consider binary classification or multiclass classifaction; keep in mind that you have ~1000 tokens to use as a prompt for GPT-2, so be sure to provide enough answered prompts for each class.  (Note it is not a requirement that your model performs *well* (we want to assess what is -- and isn't -- learnable) but give it every opportunity to do so.  Create 5 test examples to assess whether GPT-2 is able to recognize the class given your fixed prompt.  To take the language ID task above, one test example corresponds to one prediction you make for the same set of answered prompts; the following constitutes two test examples for that task:

1.)

```
prompt = """Q: This is a text
A: English

Q: Nel mezzo del cammin' di nostra vita
A: Italian

Q: Je ne sais pas
A:"""
```

2.)

``` 
prompt = """Q: This is a text
A: English

Q: Nel mezzo del cammin' di nostra vita
A: Italian

Q: Non lo so
A:"""
```

In [26]:
labels = ["Times", "Journal"]
prompt = """Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

Q: Is Brexit Hurting the U.K. Economy? Trade Data Flashes a Warning
A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Kidnappers in Haiti Demand $17 Million to Free Missionary Group
A:"""

classify_with_prompt(prompt, labels)
# Correct: NYT headline

0.000203	Times
0.000011	Journal

Completion with highest probability:

Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Kidnappers in Haiti Demand $17 Million to Free Missionary Group
A: Times


In [27]:
labels = ["Times", "Journal"]
prompt = """Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

Q: Is Brexit Hurting the U.K. Economy? Trade Data Flashes a Warning
A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Xi Faces Resistance to Property Tax Planned to Combat China’s Real-Estate Bubble
A:"""

classify_with_prompt(prompt, labels)
# Incorrect: WSJ headline

0.000405	Times
0.000007	Journal

Completion with highest probability:

Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Xi Faces Resistance to Property Tax Planned to Combat China’s Real-Estate Bubble
A: Times


In [28]:
labels = ["Times", "Journal"]
prompt = """Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

Q: Is Brexit Hurting the U.K. Economy? Trade Data Flashes a Warning
A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Christian Schools Are Booming as U.S. Faces Covid and Curriculum Fights
A:"""

classify_with_prompt(prompt, labels)
# Correct: NYT headline

0.000275	Times
0.000013	Journal

Completion with highest probability:

Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Christian Schools Are Booming as U.S. Faces Covid and Curriculum Fights
A: Times


In [29]:
labels = ["Times", "Journal"]
prompt = """Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

Q: Is Brexit Hurting the U.K. Economy? Trade Data Flashes a Warning
A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Democrats Try to Salvage IRS Bank-Account Reporting With Scaled-Back Plan
A:"""

classify_with_prompt(prompt, labels)
# Incorrect: WSJ headline

0.000333	Times
0.000011	Journal

Completion with highest probability:

Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Democrats Try to Salvage IRS Bank-Account Reporting With Scaled-Back Plan
A: Times


In [30]:
labels = ["Times", "Journal"]
prompt = """Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

Q: Is Brexit Hurting the U.K. Economy? Trade Data Flashes a Warning
A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Congress Is Losing Patience With Big-Tech Resistance, Klobuchar Says
A:"""

classify_with_prompt(prompt, labels)
# Incorrect: WSJ headline

0.000228	Times
0.000018	Journal

Completion with highest probability:

Q: Popular Investing Apps Might Hold a Tax Surprise
A: Journal

Q: More Companies Mandate Vaccine Ahead of Deadline
A: Journal

A: Journal

Q: What’s Behind the Worker Shortage Slowing the Economic Rebound
A: Times

Q: Democrats, Facing Republican Barrage, Scale Back I.R.S. Enforcement Plan
A: Times

Q: Brazilian Leader’s Pandemic Handling Draws Explosive Allegation: Homicide
A: Times

Q: Congress Is Losing Patience With Big-Tech Resistance, Klobuchar Says
A: Times


Final accuracy: 2/5 correct

My goal with this model was to try to classify headlines from two major newspapers, the New York Times (`Times`) and Wall Street Journal (`Journal`). I assumed that, since the NYT is seen as a more "liberal" newspaper focused on politics and social issues and the WSJ as more "conservative" and focused on economics, that the language would be different enough that GPT-2 would be able to differentiate it. 

To do this, I picked the top headlines from both papers from today (10/19) and used them as the prompts. Interestingly, the model classified all of the headlines as from the Times, so there definitely needs to be adjustments made to the prompts (or possibly different prompts?) to make this a better classification model. 