In [1]:
import pandas as pd
import pickle
import os
import string
import re

In [2]:
gen_methods = ['para', 'gen']
datasets = ['ag_news','news_topic', 'trec', 'mnli', 'yahoo', 'tweet_eval_sent', 'yelp']
seeds = ['0','1','2']
icl_strategies = ['synth_dis', 'cos_sim']
methods = ['only', 'uniform']

In [3]:
def pkl_to_dataset(path_to_file, path_to_res_csv):
    file = open(path_to_file,'rb')
    object_file = pickle.load(file)
    file.close()

    pattern = r'\d:'
    second_pattern = r'\d.'
    
    labels = []
    final_sents = []
    for response in object_file:
        loc = response[0]['generated_text'].rfind('[/INST]') + len('[/INST]') + 1
        gen_text = response[0]['generated_text'][loc:]
        sents = gen_text.split('\n')
        sents = [x for x in sents if x != ""]
        tmp_sents = [x.lower() for x in sents]
        tmp_sents = [re.sub(pattern, ' ', x) for x in tmp_sents]
        tmp_sents = [re.sub(second_pattern, ' ', x) for x in tmp_sents]
        tmp_sents = [x for x in tmp_sents if x != ""]
        tmp_sents = [x for x in tmp_sents if not '=>' in x] # label leakage
        tmp_sents = [x for x in tmp_sents if len(x) >= 15] # also potential label leakage
        final_sents.extend(tmp_sents)
        labels.extend([response[1]]*len(tmp_sents))
        
    filt_sents = []
    for sent in final_sents:
        if sent[0] in string.punctuation:
            filt_sents.append(sent[1:].rstrip().lstrip())
        else:
            filt_sents.append(sent.rstrip().lstrip())
            
    dct_data = {'text': filt_sents, 'label': labels}
    df = pd.DataFrame.from_dict(dct_data)
    df['text'] = df['text'].str.lower()
    
    df.sample(frac=1).reset_index(drop=True).to_csv(path_to_res_csv, index=False)

In [4]:
def pkl_to_dataset_mnli(path_to_file, path_to_res_csv):
    file = open(path_to_file,'rb')
    object_file = pickle.load(file)
    file.close()

    pattern = r'\d:'
    second_pattern = r'\d.'
    
    labels = []
    premises = []
    final_sents = []
    for response in object_file:
        #print(response)
        loc = response[0]['generated_text'].rfind('[/INST]') + len('[/INST]') + 1
        gen_text = response[0]['generated_text'][loc:]
        sents = gen_text.split('\n')
        sents = [x for x in sents if x != ""]
        tmp_sents = [x.lower() for x in sents]
        tmp_sents = [x.replace("hypothesis", "") for x in tmp_sents]
        tmp_sents = [re.sub(pattern, ' ', x) for x in tmp_sents]
        tmp_sents = [re.sub(second_pattern, ' ', x) for x in tmp_sents]

        #tmp_sents = [sent[3:] for sent in tmp_sents]
        tmp_sents = [x for x in tmp_sents if x != ""]
        tmp_sents = [x for x in tmp_sents if not 'premise' in x]
        tmp_sents = [x for x in tmp_sents if not '=>' in x] # label leakage
        tmp_sents = [x for x in tmp_sents if len(x) >= 15] # also potential label leakage
        final_sents.extend(tmp_sents)
        labels.extend([response[1]['label']]*len(tmp_sents))
        premises.extend([response[1]['premise']]*len(tmp_sents))
        
    filt_sents = []
    for sent in final_sents:
        if sent[0] in string.punctuation:
            filt_sents.append(sent[1:].rstrip().lstrip())
        else:
            filt_sents.append(sent.rstrip().lstrip())
            
    dct_data = {'premise': premises, 'label': labels, 'hypothesis': filt_sents}
    df = pd.DataFrame.from_dict(dct_data)
    df['hypothesis'] = df['hypothesis'].str.lower()
    
    df.sample(frac=1).reset_index(drop=True).to_csv(path_to_res_csv, index=False)

In [5]:
gen_methods = ['para', 'gen']
#datasets = ['ag_news', 'clinc150', 'news_topic', 'trec', 'snips', 'yahoo', 'tweet_eval_sent', 'yelp']
datasets = ['tweet_eval_sent']
seeds = ['0','1','2']
icl_strategies = ['random', 'baseline', 'synth_dis']
methods = ['only']

In [5]:
for gen_method in gen_methods:
    for dataset in datasets:
        for seed in seeds:
            for icl_strategy in icl_strategies:
                for method in methods:
                    path_to_file = os.path.join(dataset, 'collected_data_100', seed, icl_strategy, method + '_' +gen_method+'.pkl')
                    path_to_res_csv = os.path.join(dataset, 'collected_data_100', seed, icl_strategy, method + '_' +gen_method+'.csv')
                    if dataset == 'mnli':
                        pkl_to_dataset_mnli(path_to_file, path_to_res_csv)
                    else:
                        pkl_to_dataset(path_to_file, path_to_res_csv)
                    # print(path_to_file)
                    # print(path_to_res_csv)