In [1]:
import torch
import numpy as np
import random
import yaml
from argparse import Namespace
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import (
    AutoTokenizer, 
    AutoModelForMaskedLM
)
from tasks.dataset import FewShotDataset
from tasks.processors import compute_metrics_mapping

In [13]:
seed = 100

In [14]:
s2_prompt_100 = ' totally really really really really'

prompt_100 = 'absolutely seriously Absolutely Simply Simply'
# prompt_100 = 'charact charact utterly absolutely absolutely'
prompt_87 = 'absolutely unequivocallyliterally unequivocally unequivocally'
prompt_string = s2_prompt_100
# prompt_string = tokenizer.convert_tokens_to_string(prompt.split())
prompt_string

' totally really really really really'

In [15]:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [16]:
device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained('roberta-large')
generator = AutoModelForMaskedLM.from_pretrained('roberta-large').to(device)

In [17]:
with open("configs/config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
config = config['translation']['classification']
config = Namespace(**config)

In [50]:
task_name = "SST-2"
config.task_name = task_name.lower()
# config.template = "*cls**sent_0*_It_was*mask*.*sep+*"
config.template = "*cls**sent_0*_It_was*mask*.*sep**sent_1*_It_was*label_0*.*sep**sent_2*_It_was*label_1*.*sep+*"
config.mapping = "{\'0\':\'terrible\',\'1\':\'great\'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/SST-2/16-13', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{'0':'terrible','1':'great'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='sst-2', template='*cls**sent_0*_It_was*mask*.*sep**sent_1*_It_was*label_0*.*sep**sent_2*_It_was*label_1*.*sep+*', template_list=None, use_demo=False)

In [16]:
task_name = "sst-5"
config.task_name = task_name.lower()
config.template = "*cls**sent_0*_It_was*mask*.*sep+*"
config.mapping = "{0:\'terrible\',1:\'bad\',2:\'okay\',3:\'good\',4:\'great\'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/sst-5/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{0:'terrible',1:'bad',2:'okay',3:'good',4:'great'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='sst-5', template='*cls**sent_0*_It_was*mask*.*sep+*', template_list=None, use_demo=False)

In [18]:
task_name = "MRPC"
config.task_name = task_name.lower()
# config.template = "*cls**sent_0**mask*,*+sentl_1**sep+*"
config.template = "*cls**sent_0**mask*,*+sentl_1**sep**sent_2**label_0*,*+sentl_3**sep**sent_4**label_1*,*+sentl_5**sep*"
config.mapping = "{'0':'No','1':'Yes'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/MRPC/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{'0':'No','1':'Yes'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='mrpc', template='*cls**sent_0**mask*,*+sentl_1**sep**sent_2**label_0*,*+sentl_3**sep**sent_4**label_1*,*+sentl_5**sep*', template_list=None, use_demo=False)

In [21]:
task_name = "mr"
config.task_name = task_name.lower()
config.template = "*cls**sent_0*_It_was*mask*.*sep+*"
config.mapping = "{0:'terrible',1:'great'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/mr/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{0:'terrible',1:'great'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='mr', template='*cls**sent_0*_It_was*mask*.*sep+*', template_list=None, use_demo=False)

In [26]:
task_name = "cr"
config.task_name = task_name.lower()
config.template = "*cls**sent_0*_It_was*mask*.*sep+*"
config.mapping = "{0:'terrible',1:'great'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/cr/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{0:'terrible',1:'great'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='cr', template='*cls**sent_0*_It_was*mask*.*sep+*', template_list=None, use_demo=False)

In [32]:
task_name = "mpqa"
config.task_name = task_name.lower()
config.template = "*cls**sent_0*_It_was*mask*.*sep+*"
config.mapping = "{0:'terrible',1:'great'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/mpqa/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{0:'terrible',1:'great'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='mpqa', template='*cls**sent_0*_It_was*mask*.*sep+*', template_list=None, use_demo=False)

In [37]:
task_name = "subj"
config.task_name = task_name.lower()
config.template = "*cls**sent_0*_This_is*mask*.*sep+*"
config.mapping = "{0:'subjective',1:'objective'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/subj/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{0:'subjective',1:'objective'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='subj', template='*cls**sent_0*_This_is*mask*.*sep+*', template_list=None, use_demo=False)

In [42]:
task_name = "trec"
config.task_name = task_name.lower()
config.template = "*cls**mask*:*+sent_0**sep+*"
config.mapping = "{0:'Description',1:'Entity',2:'Expression',3:'Human',4:'Location',5:'Number'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/trec/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{0:'Description',1:'Entity',2:'Expression',3:'Human',4:'Location',5:'Number'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='trec', template='*cls**mask*:*+sent_0**sep+*', template_list=None, use_demo=False)

In [47]:
task_name = "CoLA"
config.task_name = task_name.lower()
config.template = "*cls**sent_0*_This_is*mask*.*sep+*"
config.mapping = "{'0':'incorrect','1':'correct'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/CoLA/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{'0':'incorrect','1':'correct'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='cola', template='*cls**sent_0*_This_is*mask*.*sep+*', template_list=None, use_demo=False)

In [7]:
task_name = "MNLI"
config.task_name = task_name.lower()
config.template = "*cls**sent-_0*?*mask*,*+sentl_1**sep+*"
config.mapping = "{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/MNLI/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='mnli', template='*cls**sent-_0*?*mask*,*+sentl_1**sep+*', template_list=None, use_demo=False)

In [12]:
task_name = "QNLI"
config.task_name = task_name.lower()
config.template = "*cls**sent-_0*?*mask*,*+sentl_1**sep+*"
config.mapping = "{'not_entailment':'No','entailment':'Yes'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/QNLI/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{'not_entailment':'No','entailment':'Yes'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='qnli', template='*cls**sent-_0*?*mask*,*+sentl_1**sep+*', template_list=None, use_demo=False)

In [22]:
task_name = "SNLI"
config.task_name = task_name.lower()
config.template = "*cls**sent-_0*?*mask*,*+sentl_1**sep+*"
config.mapping = "{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/SNLI/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='snli', template='*cls**sent-_0*?*mask*,*+sentl_1**sep+*', template_list=None, use_demo=False)

In [67]:
task_name = "QQP"
config.task_name = task_name.lower()
config.template = "*cls**sent_0**mask*,*+sentl_1**sep+*"
config.mapping = "{'0':'No','1':'Yes'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/QQP/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{'0':'No','1':'Yes'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='qqp', template='*cls**sent_0**mask*,*+sentl_1**sep+*', template_list=None, use_demo=False)

In [72]:
task_name = "RTE"
config.task_name = task_name.lower()
config.template = "*cls**sent-_0*?*mask*,*+sentl_1**sep+*"
config.mapping = "{'not_entailment':'No','entailment':'Yes'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/RTE/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{'not_entailment':'No','entailment':'Yes'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='rte', template='*cls**sent-_0*?*mask*,*+sentl_1**sep+*', template_list=None, use_demo=False)

In [27]:
task_name = "yelp-2"
config.task_name = task_name.lower()
# config.template = "*cls**sent_0*_It_was*mask*.*sep+*"
config.template = "*cls**sent_0*_It_was*mask*.*sep**sent_1*_It_was*label_0*.*sep**sent_2*_It_was*label_1*.*sep+*"
config.mapping = "{0:\'terrible\',1:\'great\'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/yelp-2/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, double_demo=False, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{0:'terrible',1:'great'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='yelp-2', template='*cls**sent_0*_It_was*mask*.*sep**sent_1*_It_was*label_0*.*sep**sent_2*_It_was*label_1*.*sep+*', template_list=None, truncate_head=False, use_demo=True)

In [47]:
task_name = "agnews"
config.task_name = task_name.lower()
config.template = "*cls**mask*_News:*sent_0*.*sep+*"
config.mapping = "{0:\'World\',1:\'Sports\',2:\'Business\',3:\'Tech\'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/agnews/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{0:'World',1:'Sports',2:'Business',3:'Tech'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='agnews', template='*cls**mask*_News:*sent_0*.*sep+*', template_list=None, use_demo=False)

In [7]:
task_name = "dbpedia"
config.task_name = task_name.lower()
config.template = "*cls*[Category: *mask*]*sent_0**sep+*"
config.mapping = "{0:\'Company\',1:\'Education\',2:\'Artist\',3:\'Athlete\',4:\'Office\',5:\'Transportation\',6:\'Building\',7:\'Natural\',8:\'Village\',9:\'Animal\',10:\'Plant\',11:\'Album\',12:\'Film\',13:\'Written\'}"
config.data_dir = f"/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/{task_name}/16-{seed}"
config.skip_space = True
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/dbpedia/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{0:'Company',1:'Education',2:'Artist',3:'Athlete',4:'Office',5:'Transportation',6:'Building',7:'Natural',8:'Village',9:'Animal',10:'Plant',11:'Album',12:'Film',13:'Written'}", max_seq_length=512, num_sample=16, other_sent_limit=None, overwrite_cache=None, prompt=True, skip_space=True, task_name='dbpedia', template='*cls*[Category: *mask*]*sent_0**sep+*', template_list=None, use_demo=False)

In [19]:
config.use_demo = True
config.double_demo = False
config.truncate_head = False
config.num_sample = 1
config

Namespace(data_dir='/home/c2hsieh/soft-Q-learning-for-text-generation/tasks/k-shot/MRPC/16-100', debug_mode=False, demo_filter=False, demo_filter_model=None, demo_filter_rate=0.5, double_demo=False, first_sent_limit=None, gpt3_in_context_head=False, gpt3_in_context_num=32, gpt3_in_context_tail=False, mapping="{'0':'No','1':'Yes'}", max_seq_length=512, num_sample=1, other_sent_limit=None, overwrite_cache=None, prompt=True, task_name='mrpc', template='*cls**sent_0**mask*,*+sentl_1**sep**sent_2**label_0*,*+sentl_3**sep**sent_4**label_1*,*+sentl_5**sep*', template_list=None, truncate_head=False, use_demo=True)

In [20]:
dataset = FewShotDataset(config, tokenizer=tokenizer, mode="test")
# dataset.use_learned_prompt = True
# dataset.set_learned_prompt([prompt_string])
metrics_fn = compute_metrics_mapping[config.task_name]

In [21]:
[tokenizer.convert_ids_to_tokens(i) for i in dataset.get_labels_tok()]

['ĠNo', 'ĠYes']

In [22]:
f = dataset.__getitem__(0)
tokenizer.decode(f.input_ids)

'<s>He said the foodservice pie business doesn \'t fit the company\'s long-term growth strategy.<mask>, " The foodservice pie business does not fit our long-term growth strategy.</s>Last week, Prime Minister Atal Bihari Vajpayee ended an 18-month chill in relations by ordering normalisation of diplomatic ties and restoration of air services with Pakistan. No, the move follows a recent proposal by Mr Vajpayee, whoended an 18-month chill in relations by ordering normalisation of diplomatic links and restoration of air services with Pakistan.</s>The 39-year-old Luster initially gave police a false name, but later revealed his true identity. Yes, barrera said Luster gave police a false name immediately after his arrest Wednesday but later revealed his true identity.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad

In [23]:
dataloader = DataLoader(dataset, collate_fn=dataset.collate_fn, num_workers=4, pin_memory=True, batch_size=16, shuffle=False)
pred_labels, true_labels = [], []
for batch in tqdm(dataloader):
    with torch.no_grad():
        logits = generator(
            input_ids=batch['input_ids'].to(device),
            attention_mask=batch['attention_mask'].to(device),
        ).logits.cpu()

    logits = logits[range(logits.shape[0]), batch['mask_pos'].squeeze()]
    pred_labels += logits[:, dataset.get_labels_tok()].argmax(1).tolist()
    true_labels += batch['labels'].squeeze().tolist()

metrics = metrics_fn(config.task_name, np.array(pred_labels), np.array(true_labels))
metrics

  0%|          | 0/26 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av



{'acc': 0.5122549019607843,
 'f1': 0.6209523809523808,
 'acc_and_f1': 0.5666036414565825}