In [13]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

from dataloader import WikipediaDataModule
from model import AbstractModel, CoordinateAscentModel
from utils import get_profile_embeddings_by_model_key

import argparse
import collections
import glob
import os
import re

import datasets
import pandas as pd
import torch
import transformers
from tqdm import tqdm


from model_cfg import model_paths_dict

datasets.utils.logging.set_verbosity_error()


num_cpus = len(os.sched_getaffinity(0))


def get_profile_embeddings(model_key: str):
    profile_embeddings = get_profile_embeddings_by_model_key(model_key=model_key)

    print("concatenating train, val, and test profile embeddings")
    all_profile_embeddings = torch.cat(
        (profile_embeddings['test'], profile_embeddings['val'], profile_embeddings['train']), dim=0
    )

    print("all_profile_embeddings:", all_profile_embeddings.shape)
    return all_profile_embeddings


def get_output_folder_by_model_key(model_key: str) -> str:
    adv_csvs_folder = os.path.normpath(
        os.path.join(
            os.path.abspath(__file__), os.pardir, os.pardir, 'adv_csvs_full'
        )
    )
    return os.path.join(adv_csvs_folder, model_key)

def load_adv_csv(dm: WikipediaDataModule) -> pd.DataFrame:
    # Load all the stuff
    adv_df = None
    for model_name in ['model_3_1', 'model_3_2', 'model_3_3__placeholder', 'model_3_4']:
        adv_csvs_folder = os.path.normpath(
            os.path.join(
                os.getcwd(), os.pardir, 'adv_csvs_full'
            )
        )
        print('adv_csvs_folder', adv_csvs_folder)
        csv_filenames = glob.glob(
            os.path.join(
                adv_csvs_folder,
                f'{model_name}/results__b_1__k_1__n_1000.csv'
            )
        )
        print(model_name, csv_filenames)
        for filename in csv_filenames:
            df = pd.read_csv(filename)
            df['model_name'] = re.search(r'adv_csvs_full/(model_\d.+)/.+.csv', filename).group(1)
            df['k'] = re.search(r'adv_csvs_full/.+/.+__k_(\d+)__.+.csv', filename).group(1)
            df['i'] = df.index

            df = df[df['result_type'] == 'Successful']

            mini_df = df[['perturbed_text', 'model_name', 'i', 'k']]
            mini_df = mini_df.iloc[:100]
            
            if adv_df is None:
                adv_df = mini_df
            else:
                adv_df = pd.concat((adv_df, mini_df), axis=0)
    
    # Load baseline redacted data
    mini_val_dataset = dm.test_dataset[:1000]
    ner_df = pd.DataFrame(
        columns=['perturbed_text'],
        data=mini_val_dataset['document_redact_ner_bert']
    )
    ner_df['model_name'] = 'named_entity'
    ner_df['i'] = ner_df.index
        
    lex_df = pd.DataFrame(
        columns=['perturbed_text'],
        data=mini_val_dataset['document_redact_lexical']
    )
    lex_df['model_name'] = 'lexical'
    lex_df['i'] = lex_df.index

    # Combine both adversarial and baseline redacted data
    baseline_df = pd.concat((lex_df, ner_df), axis=0)
    baseline_df['k'] = 0
    full_df = pd.concat((adv_df, baseline_df), axis=0)

    # Put newlines back
    full_df['perturbed_text'] = full_df['perturbed_text'].apply(lambda s: s.replace('<SPLIT>', '\n'))

    # Standardize mask tokens
    full_df['perturbed_text'] = full_df['perturbed_text'].apply(lambda s: s.replace('[MASK]', dm.mask_token))
    full_df['perturbed_text'] = full_df['perturbed_text'].apply(lambda s: s.replace('<mask>', dm.mask_token))

    return full_df


def get_adv_predictions(model_key: str):
    checkpoint_path = model_paths_dict[model_key]
    assert isinstance(checkpoint_path, str), f"invalid checkpoint_path {checkpoint_path} for {model_key}"
    print(f"running eval on {model_key} loaded from {checkpoint_path}")
    model = CoordinateAscentModel.load_from_checkpoint(
        checkpoint_path
    )

    print(f"loading data with {num_cpus} CPUs")
    dm = WikipediaDataModule(
        document_model_name_or_path=model.document_model_name_or_path,
        profile_model_name_or_path=model.profile_model_name_or_path,
        dataset_name='wiki_bio',
        dataset_train_split='train[:256]',
        dataset_val_split='val[:256]',
        dataset_test_split='test[:100%]',
        dataset_version='1.2.0',
        num_workers=num_cpus,
        train_batch_size=256,
        eval_batch_size=256,
        max_seq_length=128,
        sample_spans=False,
    )
    dm.setup("fit")

    all_profile_embeddings = get_profile_embeddings(model_key=model_key).cuda()

    model.document_model.eval()
    model.document_model.cuda()
    model.document_embed.eval()
    model.document_embed.cuda()

    adv_csv = load_adv_csv(dm=dm)

    topk_values = []
    topk_idxs = []
    batch_size = 256
    i = 0
    while i < len(adv_csv):
        ex = adv_csv.iloc[i:i+batch_size]
        test_batch = dm.document_tokenizer.batch_encode_plus(
            ex['perturbed_text'].tolist(),
            max_length=dm.max_seq_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        test_batch = {
            f'perturbed_text__{k}': v for k,v in test_batch.items()
        }
        with torch.no_grad():
            document_embeddings = model.forward_document(batch=test_batch, document_type='perturbed_text')
            document_to_profile_logits = document_embeddings @ all_profile_embeddings.T
            document_to_profile_probs = document_to_profile_logits.softmax(dim=1)
            topk_100 = document_to_profile_probs.topk(100)
            topk_values.append(topk_100.values)
            topk_idxs.append(topk_100.indices)

        i += batch_size
    
    adv_csv['pred_topk_values'] = torch.cat(topk_values, dim=0).cpu().tolist()
    adv_csv['pred_topk_idxs'] = torch.cat(topk_idxs, dim=0).cpu().tolist()
    return adv_csv



In [14]:
roberta_roberta_predictions = get_adv_predictions(model_key='model_3_3__placeholder')

running attack on model_3_3__placeholder loaded from /home/jxm3/research/deidentification/unsupervised-deidentification/saves/ca__roberta__dropout_0.5_1.0_0.0__e3072__ls0.1/deid-wikibio-4_default/1c9464tp_750/checkpoints/epoch=58-step=134342-idf_total.ckpt


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaMod

Initialized model with learning_rate = 0.0001 and patience 6
loading data with 8 CPUs
Initializing WikipediaDataModule with num_workers = 8 and mask token `<mask>`
loading wiki_bio[1.2.0] split train[:256]
loading wiki_bio[1.2.0] split val[:256]
loading wiki_bio[1.2.0] split test[:100%]
                        >> loaded 582659 train embeddings from /home/jxm3/research/deidentification/unsupervised-deidentification/embeddings/profile/model_3_3__placeholder/train.pkl
>> loaded 72831 val embeddings from /home/jxm3/research/deidentification/unsupervised-deidentification/embeddings/profile/model_3_3__placeholder/val.pkl
>> loaded 72831 test embeddings from /home/jxm3/research/deidentification/unsupervised-deidentification/embeddings/profile/model_3_3__placeholder/test.pkl
concatenating train, val, and test profile embeddings
all_profile_embeddings: torch.Size([728321, 3072])
adv_csvs_folder /home/jxm3/research/deidentification/unsupervised-deidentification/adv_csvs_full
model_3_1 ['/home/jx

In [15]:
roberta_tapas_predictions = get_adv_predictions(model_key='model_3_2')

running attack on model_3_2 loaded from /home/jxm3/research/deidentification/unsupervised-deidentification/saves/ca__roberta__tapas__dropout_-1.0_1.0_0.0__e3072__ls0.1/deid-wikibio-4_lightning_logs/ojgxa1tf_6/checkpoints/epoch=65-step=150282-idf_total.ckpt


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Initialized model with learning_rate = 0.0001 and patience 6
loading data with 8 CPUs
Initializing WikipediaDataModule with num_workers = 8 and mask token `<mask>`
loading wiki_bio[1.2.0] split train[:256]
loading wiki_bio[1.2.0] split val[:256]
loading wiki_bio[1.2.0] split test[:100%]
                        >> loaded 582659 train embeddings from /home/jxm3/research/deidentification/unsupervised-deidentification/embeddings/profile/model_3_2/train.pkl
>> loaded 72831 val embeddings from /home/jxm3/research/deidentification/unsupervised-deidentification/embeddings/profile/model_3_2/val.pkl
>> loaded 72831 test embeddings from /home/jxm3/research/deidentification/unsupervised-deidentification/embeddings/profile/model_3_2/test.pkl
concatenating train, val, and test profile embeddings
all_profile_embeddings: torch.Size([728321, 3072])
adv_csvs_folder /home/jxm3/research/deidentification/unsupervised-deidentification/adv_csvs_full
model_3_1 ['/home/jxm3/research/deidentification/unsupervis

In [18]:
roberta_roberta_predictions[roberta_roberta_predictions['model_name'] == 'model_3_3__placeholder']

Unnamed: 0,perturbed_text,model_name,i,k,pred_topk_values,pred_topk_idxs
0,"<mask> shenoff <mask> ( born february <mask> ,...",model_3_3__placeholder,0,1,"[0.11052840203046799, 0.05406518280506134, 0.0...","[467718, 158460, 530731, 401411, 685027, 68917..."
1,<mask> <mask> ( born <mask> <mask> <mask> in r...,model_3_3__placeholder,1,1,"[0.5141785740852356, 0.09889905899763107, 0.04...","[39633, 1, 627677, 467415, 168807, 305245, 404..."
2,<mask> <mask> ( born <mask> <mask> <mask> <mas...,model_3_3__placeholder,2,1,"[0.04162168875336647, 0.017952967435121536, 0....","[411731, 469648, 333439, 265113, 178437, 21793..."
3,john <mask> jack '' <mask> ( 21 february 1869 ...,model_3_3__placeholder,3,1,"[0.18395081162452698, 0.15015146136283875, 0.1...","[382939, 455178, 360496, 613292, 3, 296495, 20..."
4,"<mask> <mask> <mask> , ( born 7th july 1979 ) ...",model_3_3__placeholder,4,1,"[0.2873842716217041, 0.20317216217517853, 0.14...","[248452, 4, 445867, 331079, 646469, 428111, 13..."
...,...,...,...,...,...,...
97,<mask> <mask> hildebert <mask> ( ; born 6 marc...,model_3_3__placeholder,97,1,"[0.1861652284860611, 0.17227914929389954, 0.03...","[46645, 97, 649379, 524477, 200807, 658420, 27..."
98,<mask> <mask> ( born <mask> <mask> <mask> ) is...,model_3_3__placeholder,98,1,"[0.07223951816558838, 0.06788275390863419, 0.0...","[292621, 452271, 643379, 637442, 278902, 98, 4..."
99,<mask> bosisio ( born <mask> <mask> <mask> ) i...,model_3_3__placeholder,99,1,"[0.47312384843826294, 0.4663839340209961, 0.01...","[478090, 99, 131467, 307770, 492295, 345943, 4..."
100,"<mask> <mask> ( born <mask> <mask> , <mask> ) ...",model_3_3__placeholder,100,1,"[0.20568141341209412, 0.11684911698102951, 0.0...","[428736, 610597, 429695, 292739, 512155, 100, ..."


In [19]:
roberta_roberta_predictions.groupby('model_name').count()

Unnamed: 0_level_0,perturbed_text,i,k,pred_topk_values,pred_topk_idxs
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
lexical,1000,1000,1000,1000,1000
model_3_1,100,100,100,100,100
model_3_2,100,100,100,100,100
model_3_3__placeholder,100,100,100,100,100
model_3_4,100,100,100,100,100
named_entity,1000,1000,1000,1000,1000


In [27]:
new_model_name = {
    'model_3_1': 'roberta_tapas__no_masking',
    'model_3_2': 'roberta_tapas',
    'model_3_3__placeholder': 'roberta_roberta',
    'model_3_4': 'pmlm_tapas'
}
roberta_roberta_predictions['model_name'] = roberta_roberta_predictions['model_name'].apply(lambda s: new_model_name.get(s,s))

In [28]:
out_df = roberta_roberta_predictions.rename(columns={'pred_topk_values': 'roberta_roberta__pred_topk_values', 'pred_topk_idxs': 'roberta_roberta__pred_topk_idxs'})
out_df['roberta_tapas__pred_topk_values'] = roberta_tapas_predictions['pred_topk_values']
out_df['roberta_tapas__pred_topk_idxs'] = roberta_tapas_predictions['pred_topk_idxs']
out_df.head()

Unnamed: 0,perturbed_text,model_name,i,k,roberta_roberta__pred_topk_values,roberta_roberta__pred_topk_idxs,roberta_tapas__pred_topk_values,roberta_tapas__pred_topk_idxs
0,"leonard shenoff <mask> ( born february 12 , <m...",roberta_tapas__no_masking,0,1,"[0.9138374924659729, 0.019291497766971588, 0.0...","[0, 578457, 120349, 267136, 532648, 719788, 49...","[0.8143362402915955, 0.022000230848789215, 0.0...","[0, 5473, 685461, 144197, 267136, 724119, 1798..."
1,<mask> <mask> ( born 25 august <mask> in rhège...,roberta_tapas__no_masking,1,1,"[0.9647166728973389, 0.0008567320764996111, 0....","[1, 100573, 39633, 569950, 176364, 255263, 395...","[0.18824410438537598, 0.14044655859470367, 0.1...","[627677, 467415, 1, 39633, 193135, 646267, 728..."
2,<mask> <mask> ( born 14 june <mask> in dvůr kr...,roberta_tapas__no_masking,2,1,"[0.9932199120521545, 0.0012878905981779099, 0....","[2, 333439, 654853, 507896, 122539, 224290, 50...","[0.5528936982154846, 0.2653157114982605, 0.033...","[2, 456329, 333439, 507309, 534720, 4831, 4452..."
3,john `` <mask> '' <mask> ( 21 february <mask> ...,roberta_tapas__no_masking,3,1,"[0.7518554925918579, 0.04164363816380501, 0.00...","[3, 455178, 382939, 202957, 360496, 296495, 45...","[0.8678763508796692, 0.027283597737550735, 0.0...","[3, 202957, 382939, 296495, 516160, 34822, 345..."
4,"william <mask> <mask> , ( born 7th july 1979 )...",roberta_tapas__no_masking,4,1,"[0.41385799646377563, 0.09564604610204697, 0.0...","[4, 248452, 254701, 391864, 331079, 428111, 44...","[0.8090968728065491, 0.08909827470779419, 0.00...","[4, 520062, 254701, 128242, 445867, 61930, 365..."


In [32]:
out_df['roberta_roberta__was_correct'] = out_df.apply(lambda row: row['i'] == row['roberta_roberta__pred_topk_idxs'][0], axis=1)

In [33]:
out_df['roberta_tapas__was_correct'] = out_df.apply(lambda row: row['i'] == row['roberta_tapas__pred_topk_idxs'][0], axis=1)

In [34]:
out_df.groupby('model_name').mean()

Unnamed: 0_level_0,i,roberta_roberta__was_correct,roberta_tapas__was_correct
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
lexical,499.5,0.164,0.224
named_entity,499.5,0.674,0.557
pmlm_tapas,50.37,0.81,0.5
roberta_roberta,50.56,0.0,0.51
roberta_tapas,51.09,0.77,0.0
roberta_tapas__no_masking,50.17,0.89,0.78


In [35]:
out_df.to_csv('model_comparison_predictions.csv')

In [36]:
!realpath model_comparison_predictions.csv

/home/jxm3/research/deidentification/unsupervised-deidentification/notebooks/model_comparison_predictions.csv


In [45]:
import numpy as np
roberta_tapas__predictions_for__pmlm_tapas = np.vstack(out_df[out_df['model_name'] == 'pmlm_tapas']['roberta_tapas__pred_topk_idxs'])
for i in range(len(roberta_tapas__predictions_for__pmlm_tapas)):
    print(i,'\t', roberta_tapas__predictions_for__pmlm_tapas[i,:5])

0 	 [     0 157382 202669 453258 647283]
1 	 [627677 467415      1  39633 193135]
2 	 [456329      2 333439 507309 445282]
3 	 [     3 382939 345455 176246 202957]
4 	 [     4 520062 254701 128242 445867]
5 	 [     5 399211 357441 406264 553956]
6 	 [396740 447046 170143 210733 284738]
7 	 [     7 395200 191185 320731 367417]
8 	 [601929 461665 176324 354936  30302]
9 	 [     9 317732 235567 286404  18112]
10 	 [    10 711653 708314 148504 408799]
11 	 [530409  41750 380123 512320 668322]
12 	 [569603 338498     12  19054 285513]
13 	 [451541     13 348456 285891 123096]
14 	 [    14 229972  19055 398523 323485]
15 	 [    15 308063 604486 480493 408304]
16 	 [    16 448666 269221 379587 252302]
17 	 [383229 571143 374161     17   8379]
18 	 [225774 256563 698969 670183     18]
19 	 [    19 436791  58874  87961 426805]
20 	 [    20 484930 160605 483058 524858]
21 	 [441863     21 567151 614879 547715]
22 	 [492087 139891 105389 126565 612756]
23 	 [    23 384675  18320 653767 111297]
24

In [46]:
import datasets

d = datasets.load_dataset("wiki_bio")

  0%|          | 0/3 [00:00<?, ?it/s]