In [2]:
!which python

/st2/myung/anaconda3/envs/pytorch/bin/python


In [3]:
from pathlib import Path
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from tqdm import tqdm
import seaborn as sns
import json
import codecs
from pathlib import Path
from collections import defaultdict
%cd ..  

/v6/myung/iclr/purify_lm


In [4]:
def read_sentiment_results(models_dict, max_gens=None):
    res = {}
    for model in tqdm(models_dict):
        df = pd.read_json(models_dict[model], lines=True)[:max_gens]
        sentiment_labels = df.generations.apply(lambda x: [y['label'] for y in x])
        positive_proportion = sentiment_labels.apply(lambda x: np.sum([1 for y in x if y == 'POSITIVE'])/len(x))
        res[model] = {
            'positive_proportion': positive_proportion.mean()
        }
        # read automatic evaluation
        """
        with open(Path(os.path.dirname(models_dict[model])) / 'eval_results.txt', 'r') as fo:
            for i, line in enumerate(fo):
                if i < 3:
                    print(line)
                    dist_n = float(line.rstrip().replace(f'dist-{i+1} = ', ''))
                    res[model][f'dist-{i+1}'] = dist_n
                elif i == 3:
                    print(line)
                    ppl = float(line.replace('perplexity = ', '').strip('\n').strip())
                    res[model]['perplexity'] = ppl
        """
    return res

In [5]:
def weighted_average(neutral_prompts_res, adversarial_prompts_res, key):
    """
    return weighted average of dist-n or perplexity value across neural prompts (5k) and adversarial prompts (2.5k)
    """
    return np.average([neutral_prompts_res[model][key], adversarial_prompts_res[model][key]], weights=[2,1])

## positive steering

In [63]:
# results corresponding to the top half of Table 3

NEUTRAL_DIR = Path('generations/sentiment/neutral_prompts/')
NEG_DIR = Path('generations/sentiment/negative_prompts/')
'''
    'GPT-2': {
        'neutral_path': NEUTRAL_DIR / 'gpt2/prompted_gens_gpt2.jsonl',
        'neg_path': NEG_DIR / 'gpt2/prompted_gens_gpt2.jsonl',
    },
    'PPLM': {
        'neutral_path': NEUTRAL_DIR / 'pplm/positive/prompted_gens_pplm.jsonl',
        'neg_path': NEG_DIR / 'pplm/prompted_gens_pplm.jsonl'
    },
    'DAPT': {
        'neutral_path': NEUTRAL_DIR / 'dapt/positive/prompted_gens_gpt2.jsonl',
        'neg_path': NEG_DIR / 'dapt/prompted_gens_gpt2.jsonl',
    },
    'GeDi': {
        'neutral_path': NEUTRAL_DIR / 'gedi/positive/prompted_gens_gedi.jsonl',
        'neg_path': NEG_DIR / 'gedi/prompted_gens_gedi.jsonl'
    },
    'CTRL': {
        'neutral_path': NEUTRAL_DIR / 'ctrl/positive/prompted_gens_ctrl.jsonl',
        'neg_path': NEG_DIR / 'ctrl/prompted_gens_ctrl.jsonl'
    },
    'Expert': {
        'neutral_path': NEUTRAL_DIR / 'expert/positive/prompted_gens_gpt2.jsonl',
        'neg_path': NEG_DIR / 'expert/prompted_gens_gpt2.jsonl',
    },
    #'DExperts (anti-only)': {
    #    'neutral_path': NEUTRAL_DIR / 'dexperts_anti-only/a-2.0/prompted_gens_dexperts.jsonl',
    #    'neg_path': NEG_DIR / 'dexperts_anti-only/a-2.0/prompted_gens_dexperts.jsonl',
    #},
    'DExperts (small)': {
        'neutral_path': NEUTRAL_DIR / 'dexperts/small_experts/positive/prompted_gens_dexperts.jsonl',
        'neg_path': NEG_DIR / 'dexperts/small_experts/prompted_gens_dexperts.jsonl'
    },
    #'DExperts (medium)': {
    #    'neutral_path': NEUTRAL_DIR / 'dexperts/medium_experts/positive/prompted_gens_dexperts.jsonl',
    #    'neg_path': NEG_DIR / 'dexperts/small_experts/prompted_gens_dexperts.jsonl'
    #},

    'fuse_style_ep10': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_style_ep10/positive/prompted_gens_style-gpt2-none.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_style_ep10/positive/prompted_gens_style-gpt2-none.jsonl'
    },
    
    'fuse_style_ep20': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_style_ep20/positive/prompted_gens_style-gpt2-none.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_style_ep20/positive/prompted_gens_style-gpt2-none.jsonl'
    },
    'fuse_style_ep30': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_style_ep30/positive/prompted_gens_style-gpt2-none.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_style_ep30/positive/prompted_gens_style-gpt2-none.jsonl'
    },
    'fuse_rev_style_contrast0.25_ep100': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_rev_style_contrast0.25_ep100/positive/prompted_gens_style-gpt2-none.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_rev_style_contrast0.25_ep100/positive/prompted_gens_style-gpt2-none.jsonl'
    },
    'fuse_rev_style_pred_contrast1.0_ep30': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_rev_style_pred_contrast1.0_ep30/positive/prompted_gens_style-gpt2-attr.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_rev_style_pred_contrast1.0_ep30/positive/prompted_gens_style-gpt2-attr.jsonl'
    },
    'fuse_rev_style_pred_contrast1.0_ep50': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_rev_style_pred_contrast1.0_ep50/positive/prompted_gens_style-gpt2-attr.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_rev_style_pred_contrast1.0_ep50/positive/prompted_gens_style-gpt2-attr.jsonl'
    },
    'fuse_rev_style_pred_contrast1.0_ep30_past': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_rev_style_pred_contrast1.0_ep30/positive/prompted_gens_style-gpt2-attr_past.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_rev_style_pred_contrast1.0_ep30/positive/prompted_gens_style-gpt2-attr_past.jsonl'
    },
    'fuse_rev_style_pred_contrast1.0_ep50_past': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_rev_style_pred_contrast1.0_ep50/positive/prompted_gens_style-gpt2-attr_past.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_rev_style_pred_contrast1.0_ep50/positive/prompted_gens_style-gpt2-attr_past.jsonl'
    },
'''
models = {
    'fuse_rev_style_pred_contrast1.0_ep30_with_project': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_rev_style_pred_contrast1.0_ep30_with_project/positive/prompted_gens_style-gpt2-attr.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_rev_style_pred_contrast1.0_ep30_with_project/positive/prompted_gens_style-gpt2-attr.jsonl'
    },
    'fuse_rev_style_pred_contrast1.0_ep30_with_project_past': {
        'neutral_path':'our_generations/prompted_sentiment-10k/neutral_prompts/fuse_rev_style_pred_contrast1.0_ep30_with_project/positive/prompted_gens_style-gpt2-attr_past.jsonl',
        'neg_path': 'our_generations/prompted_sentiment-10k/negative_prompts/fuse_rev_style_pred_contrast1.0_ep30_with_project/positive/prompted_gens_style-gpt2-attr_past.jsonl'
    },

}


In [64]:
# read sentiment control results
neutral_prompts_res = read_sentiment_results({m: p['neutral_path'] for m,p in models.items()})
neg_prompts_res = read_sentiment_results({m: p['neg_path'] for m,p in models.items()})

100%|██████████| 2/2 [00:00<00:00,  2.17it/s]
100%|██████████| 2/2 [00:00<00:00,  5.05it/s]


In [65]:
positive_steering_res = {}
#assert set(neutral_prompts_res.keys()) == set(neg_prompts_res.keys())
#print(neutral_prompts_res)
print(neg_prompts_res)
for model in neg_prompts_res.keys():
    positive_steering_res[model] = {
        'neutral_prompts': neutral_prompts_res[model]['positive_proportion']*100,
        'neg_prompts': neg_prompts_res[model]['positive_proportion']*100,
        #'dist-1': weighted_average(neutral_prompts_res, neg_prompts_res, 'dist-1'),
        #'dist-2': weighted_average(neutral_prompts_res, neg_prompts_res, 'dist-2'),
        #'dist-3': weighted_average(neutral_prompts_res, neg_prompts_res, 'dist-3'),
        #'perplexity': weighted_average(neutral_prompts_res, neg_prompts_res, 'perplexity'),
    }
print(positive_steering_res)

{'fuse_rev_style_pred_contrast1.0_ep30_with_project': {'positive_proportion': 0.5632}, 'fuse_rev_style_pred_contrast1.0_ep30_with_project_past': {'positive_proportion': 0.3881280000000019}}
{'fuse_rev_style_pred_contrast1.0_ep30_with_project': {'neutral_prompts': 75.78, 'neg_prompts': 56.32}, 'fuse_rev_style_pred_contrast1.0_ep30_with_project_past': {'neutral_prompts': 70.50880000000045, 'neg_prompts': 38.812800000000195}}


In [66]:
pd.DataFrame(positive_steering_res).transpose().sort_values(by='neg_prompts', ascending=True).round(2)

Unnamed: 0,neutral_prompts,neg_prompts
fuse_rev_style_pred_contrast1.0_ep30_with_project_past,70.51,38.81
fuse_rev_style_pred_contrast1.0_ep30_with_project,75.78,56.32


# negative steering 

In [None]:
# results corresponding to the bottom  half of Table 3

NEUTRAL_DIR = Path('generations/sentiment/neutral_prompts/')
POS_DIR = Path('generations/sentiment/positive_prompts/')

models = {
    'GPT-2': {
        'neutral_path': NEUTRAL_DIR / 'gpt2/prompted_gens_gpt2.jsonl',
        'pos_path': POS_DIR / 'gpt2/prompted_gens_gpt2.jsonl',
    },
    'PPLM': {
        'neutral_path': NEUTRAL_DIR / 'pplm/negative/prompted_gens_pplm.jsonl',
        'pos_path': POS_DIR / 'pplm/prompted_gens_pplm.jsonl'
    },
    'DAPT': {
        'neutral_path': NEUTRAL_DIR / 'dapt/negative/prompted_gens_gpt2.jsonl',
        'pos_path': POS_DIR / 'dapt/prompted_gens_gpt2.jsonl',
    },
    'GeDi': {
        'neutral_path': NEUTRAL_DIR / 'gedi/negative/prompted_gens_gedi.jsonl',
        'pos_path': POS_DIR / 'gedi/prompted_gens_gedi.jsonl'
    },
    'CTRL': {
        'neutral_path': NEUTRAL_DIR / 'ctrl/negative/prompted_gens_ctrl.jsonl',
        'pos_path': POS_DIR / 'ctrl/prompted_gens_ctrl.jsonl'
    },
    'Expert': {
        'neutral_path': NEUTRAL_DIR / 'expert/negative/prompted_gens_gpt2.jsonl',
        'pos_path': POS_DIR / 'expert/prompted_gens_gpt2.jsonl',
    },
    'DExperts (anti-only)': {
        'neutral_path': NEUTRAL_DIR / 'dexperts_anti-only/a--2.0/prompted_gens_dexperts.jsonl',
        'pos_path': POS_DIR / 'dexperts_anti-only/a--2.0/prompted_gens_dexperts.jsonl',
    },
    'DExperts (large)': {
        'neutral_path': NEUTRAL_DIR / 'dexperts/large_experts/negative/prompted_gens_dexperts.jsonl',
        'pos_path': POS_DIR / 'dexperts/large_experts/prompted_gens_dexperts.jsonl'
    },
    'DExperts (medium)': {
        'neutral_path': NEUTRAL_DIR / 'dexperts/medium_experts/negative/prompted_gens_dexperts.jsonl',
        'pos_path': POS_DIR / 'dexperts/medium_experts/prompted_gens_dexperts.jsonl'
    },
    'DExperts (small)': {
        'neutral_path': NEUTRAL_DIR / 'dexperts/small_experts/negative/prompted_gens_dexperts.jsonl',
        'pos_path': POS_DIR / 'dexperts/small_experts/prompted_gens_dexperts.jsonl'
    }
}

In [None]:
# read sentiment control results
neutral_prompts_res = read_sentiment_results({m: p['neutral_path'] for m,p in models.items()})
pos_prompts_res = read_sentiment_results({m: p['pos_path'] for m,p in models.items()})

In [None]:
negative_steering_res = {}
assert set(neutral_prompts_res.keys()) == set(pos_prompts_res.keys())
for model in neutral_prompts_res.keys():
    negative_steering_res[model] = {
        'neutral_prompts': neutral_prompts_res[model]['positive_proportion']*100,
        'pos_prompts': pos_prompts_res[model]['positive_proportion']*100,
        'dist-1': weighted_average(neutral_prompts_res, pos_prompts_res, 'dist-1'),
        'dist-2': weighted_average(neutral_prompts_res, pos_prompts_res, 'dist-2'),
        'dist-3': weighted_average(neutral_prompts_res, pos_prompts_res, 'dist-3'),
        'perplexity': weighted_average(neutral_prompts_res, pos_prompts_res, 'perplexity'),
    }

In [None]:
pd.DataFrame(negative_steering_res).transpose().sort_values(by='neutral_prompts', ascending=False).round(2)

## hyperparameter search

In [None]:
GENS_DIR = Path('generations/sentiment/neutral_prompts/')
sizes = ['large', 'medium', 'small']
size_dict = {size: defaultdict(dict) for size in sizes}

for size in sizes:
    dexperts_dir = Path(f'dexperts/{size}_experts')

    for folder in os.listdir(GENS_DIR / dexperts_dir):
        splits = folder.split('-')
        if len(splits) >= 2:
            a = '-' + splits[-1] if len(splits) == 3 else splits[-1]
            a = float(a)
            size_dict[size][a] = GENS_DIR / dexperts_dir / f'{folder}/prompted_gens_dexperts.jsonl'

In [None]:
small_res = read_sentiment_results(size_dict['small'], max_gens=1000)
small_res_df = pd.DataFrame(small_res).transpose()
medium_res = read_sentiment_results(size_dict['medium'], max_gens=1000)
medium_res_df = pd.DataFrame(medium_res).transpose()
large_res = read_sentiment_results(size_dict['large'], max_gens=1000)
large_res_df = pd.DataFrame(large_res).transpose()

In [None]:
def plot(res_df, label, axes):
    ax = axes[0]
    pos_hyperparam_res_df = res_df.loc[res_df.index >= 2.0]
    sns.lineplot(x=pos_hyperparam_res_df['perplexity'], y=pos_hyperparam_res_df['positive_proportion'], ax=ax, marker='o', dashes=False, label=label)
    for i, row in pos_hyperparam_res_df.iterrows():
        x = row['perplexity']
        y = row['positive_proportion']
        ax.text(x=x, y=y, s=i)
    
    ax = axes[1]
    neg_hyperparam_res_df = res_df.loc[res_df.index <= -2.0]
    sns.lineplot(x=neg_hyperparam_res_df['perplexity'], y=neg_hyperparam_res_df['positive_proportion'], ax=ax, marker='o', dashes=False, label=label)
    for i, row in neg_hyperparam_res_df.iterrows():
        x = row['perplexity']
        y = row['positive_proportion']
        ax.text(x=x, y=y, s=i)

In [None]:
plt.figure()
fig, axes = plt.subplots(2, 1, figsize=(8,6))
plt.style.use('seaborn-white')
fig.tight_layout()

plot(small_res_df, label='Small experts', axes=axes)
plot(medium_res_df, label='Medium experts', axes=axes)
plot(large_res_df, label='Large experts', axes=axes)

ax = axes[0]
ax.set_xlabel(' ')
ax.set_ylabel('% Positive', fontsize=13)
ax.set_title('Positive steering')
ax.legend()

ax = axes[1]
ax.set_xlabel('Perplexity', fontsize=13)
ax.set_ylabel('% Positive', fontsize=13)
ax.set_title('Negative steering')

plt.tight_layout()
plt.savefig('figures/sentiment_v_perplexity.png')