**Imports**

In [1]:
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

**Model: GPT-2**

In [7]:
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)

**Utils**

In [None]:
def count_req_rows(type):
    file = '..\\data\\orig\\processed\\train\\news-data.csv'
    if(type=='sentiment'):
        file = '..\\data\\orig\\processed\\train\\sentiment-data-mini.csv'
    elif(type=='spam'):
        file = '..\\data\\orig\\processed\\train\\spam-data-mini.csv'
    df = pd.read_csv(file)
    return int(df.shape[0]/2)

prompts = {
    'spam': 'Text message with advertisement or offer',
    'non_spam':'Text message from a friend or family says',
    'real_news':'Recently published political news title',
    'fake_news':'Fake political news title',
    'happy_tweet':'Tweet as a happy person',
    'sad_tweet': 'Tweet as a sad person'
}

zsl_tasks = [
    {
        'name':'news',
        'rows_per_category': count_req_rows('news'),
        'query': prompts['fake_news'],
        'non_query': prompts['real_news'],
        'print_count':10
    },
    {
        'name':'spam',
        'rows_per_category': count_req_rows('spam'),
        'query': prompts['spam'],
        'non_query': prompts['non_spam'],
        'print_count':100
    },    
    {
        'name':'sentiment',
        'rows_per_category': count_req_rows('sentiment'),
        'query': prompts['happy_tweet'],
        'non_query': prompts['sad_tweet'],
        'print_count':250
    }
]

def generate_text_from_prompt(prompt):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=50, num_return_sequences=1, do_sample=True, top_k=0)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text


In [None]:
# example
# generate_text_from_prompt(prompt=zsl_tasks[0]['non_query'])

'Text message from a friend or family says that one of the missing items found near the Vancouver wildfires is alcohol.\n\nThe bride quickly removed her spectacles and agreed to send it across the border with a question.\n\nBut firefighters found the bag'

**ZSP - Data Generation**

In [None]:
df_main = pd.DataFrame()
for task in zsl_tasks:
    df_1 = pd.DataFrame(columns=['text'])
    df_2 = pd.DataFrame(columns=['text'])

    print("1. Starting task: ", task['name'])
    print("NQ-Progress:", end=' ')
    for _ in range(task['rows_per_category']):
        gen_text = generate_text_from_prompt(prompt = task['non_query'])
        df_2.loc[len(df_2)] = gen_text
        if _%task['print_count'] == 0:
            print(_, end=' ')
            df_2.to_csv("../data/syn/mid/auto-" + task['name'] + "-nquery-data.csv", index=False)
    

    print("\n2. Generated non query df")
    print("Q-Progress:", end=' ')
    for _ in range(task['rows_per_category']):
        gen_text = generate_text_from_prompt(prompt = task['query'])
        df_1.loc[len(df_1)] = gen_text
        if _%task['print_count'] == 0:
            print(_, end=' ')
            df_1.to_csv("../data/syn/mid/auto-" + task['name'] + "-query-data.csv", index=False)
    
    print("\n3. Generated query df")

    df_1['y']=1
    df_2['y']=0

    print("4. Saving df")

    df = pd.concat([df_1, df_2], ignore_index=True)
    df = df.sample(frac=1).reset_index(drop=True)
    df_main = df
    df.to_csv("../data/syn/zsl/auto-" + task['name'] + "-data.csv", index=False)

    print("5. Ending task: ", task['name'])
    print()

1. Starting task:  spam
NQ-Progress: 0 20 40 60 80 100 120 140 160 180 200 220 240 260 280 300 320 340 360 380 400 420 440 460 480 500 520 540 560 580 600 
2. Generated non query df
Q-Progress: 0 20 40 60 80 100 120 140 160 180 200 220 240 260 280 300 320 340 360 380 400 420 440 460 480 500 520 540 560 580 600 
3. Generated query df
4. Saving df
5. Ending task:  spam
1. Starting task:  sentiment
NQ-Progress: 0 250 500 750 1000 1250 1500 1750 2000 2250 
2. Generated non query df
Q-Progress: 0 250 500 750 1000 1250 1500 1750 2000 2250 
3. Generated query df
4. Saving df
5. Ending task:  sentiment
