# Generate text with zero-shot, one-shot, and few-shot inference

In [None]:
import psutil

notebook_memory = psutil.virtual_memory()
print(notebook_memory)

if notebook_memory.total < 32 * 1000 * 1000 * 1000:
    print('*******************************************')    
    print('YOU ARE NOT USING THE CORRECT INSTANCE TYPE')
    print('PLEASE CHANGE INSTANCE TYPE TO  m5.2xlarge ')
    print('*******************************************')
else:
    correct_instance_type=True

In [None]:
%store -r setup_dependencies_passed

In [None]:
try:
    setup_dependencies_passed
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] YOU HAVE TO RUN THE PREVIOUS NOTEBOOK ")
    print("You did not install the required libraries.   ")
    print("++++++++++++++++++++++++++++++++++++++++++++++")

In [None]:
model_checkpoint = "google/flan-t5-base"
huggingface_dataset_name = "knkarthick/dialogsum"

# Load the Summarization Dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset(huggingface_dataset_name)

In [None]:
example_indices = [40, 70, 80, 160,]
print('Example Input Dialogue:')
print(dataset['test'][example_indices[0]]['dialogue'])
print()
print('Example Output Summary:')
print(dataset['test'][example_indices[0]]['summary'])

# Create prompts for few-shot, one-shot, zero-shot inference on sample data

In [None]:
start_prompt = 'Summarize the following conversation.\n\n'
end_prompt = '\n\nSummary: '
stop_sequence = '---'

In [None]:
def make_prompt(num_shots):
    prompt = ''
    for i in range(num_shots + 1):
        if i == num_shots:
            dialogue = dataset['test'][example_indices[0]]['dialogue']
            summary = dataset['test'][example_indices[0]]['summary']
            prompt = prompt + f'{start_prompt}{dialogue}{end_prompt}'
        else:
            dialogue = dataset['test'][example_indices[i+1]]['dialogue']
            summary = dataset['test'][example_indices[i+1]]['summary']
            prompt = prompt + f'{start_prompt}{dialogue}{end_prompt}{summary}\n{stop_sequence}\n'
    return prompt

In [None]:
zero_shot_prompt = make_prompt(0)
print(zero_shot_prompt)

In [None]:
one_shot_prompt = make_prompt(1)
print(one_shot_prompt)

In [None]:
few_shot_prompt = make_prompt(2)
print(few_shot_prompt)

# Perform zero-shot, one-shot, few-shot inference BEFORE fine-tuning

In [None]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [None]:
from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

# Zero-shot

In [None]:
inputs = tokenizer(zero_shot_prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"], 
        max_new_tokens=50,
        # eos_token_id = [int(tokenizer(stop_sequence, return_tensors='pt').input_ids[0][1])]
        # do_sample=True,
        # top_k=50,
        # top_p=0.9
    )[0], 
    skip_special_tokens=True
)
print(f'ZERO SHOT RESPONSE: {output}')
summary = dataset['test'][example_indices[0]]['summary']
print(f'EXPECTED RESPONSE: {summary}')

# One-shot

In [None]:
inputs = tokenizer(one_shot_prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
        # eos_token_id = [int(tokenizer(stop_sequence, return_tensors='pt').input_ids[0][1])]
    )[0], 
    skip_special_tokens=True
)
print(f'ONE SHOT RESPONSE: {output}')
summary = dataset['test'][example_indices[0]]['summary']
print(f'EXPECTED RESPONSE: {summary}')

# Few-shot

In [None]:
inputs = tokenizer(few_shot_prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
        # eos_token_id = [int(tokenizer(stop_sequence, return_tensors='pt').input_ids[0][1])]
    )[0], 
    skip_special_tokens=True
)
print(f'ONE SHOT RESPONSE: {output}')
summary = dataset['test'][example_indices[0]]['summary']
print(f'EXPECTED RESPONSE: {summary}')

## Store Variables

In [None]:
%store model_checkpoint