In [None]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

from dataloader import WikipediaDataModule
from model import AbstractModel, CoordinateAscentModel
from utils import get_profile_embeddings_by_model_key

import argparse
import collections
import glob
import os
import re

import datasets
import pandas as pd
import torch
import transformers
from tqdm import tqdm


from model_cfg import model_paths_dict

datasets.utils.logging.set_verbosity_error()


num_cpus = len(os.sched_getaffinity(0))


def get_profile_embeddings(model_key: str):
    profile_embeddings = get_profile_embeddings_by_model_key(model_key=model_key)

    print("concatenating train, val, and test profile embeddings")
    all_profile_embeddings = torch.cat(
        (profile_embeddings['test'], profile_embeddings['val'], profile_embeddings['train']), dim=0
    )

    print("all_profile_embeddings:", all_profile_embeddings.shape)
    return all_profile_embeddings


def get_output_folder_by_model_key(model_key: str) -> str:
    adv_csvs_folder = os.path.normpath(
        os.path.join(
            os.path.abspath(__file__), os.pardir, os.pardir, 'adv_csvs_full'
        )
    )
    return os.path.join(adv_csvs_folder, model_key)

def load_adv_csv(dm: WikipediaDataModule) -> pd.DataFrame:
    # Load all the stuff
    adv_df = None
    for model_name in ['model_3_1', 'model_3_2', 'model_3_3__placeholder', 'model_3_4']:
        adv_csvs_folder = os.path.normpath(
            os.path.join(
                os.getcwd(), os.pardir, 'adv_csvs_full'
            )
        )
        print('adv_csvs_folder', adv_csvs_folder)
        csv_filenames = glob.glob(
            os.path.join(
                adv_csvs_folder,
                f'{model_name}/results__b_1__k_1__n_1000.csv'
            )
        )
        print(model_name, csv_filenames)
        for filename in csv_filenames:
            df = pd.read_csv(filename)
            df['model_name'] = re.search(r'adv_csvs_full/(model_\d.+)/.+.csv', filename).group(1)
            df['k'] = re.search(r'adv_csvs_full/.+/.+__k_(\d+)__.+.csv', filename).group(1)
            df['i'] = df.index

            df = df[df['result_type'] == 'Successful']

            mini_df = df[['perturbed_text', 'model_name', 'i', 'k']]
            mini_df = mini_df.iloc[:100]
            
            if adv_df is None:
                adv_df = mini_df
            else:
                adv_df = pd.concat((adv_df, mini_df), axis=0)
    
    # Load baseline redacted data
    mini_val_dataset = dm.test_dataset[:1000]
    ner_df = pd.DataFrame(
        columns=['perturbed_text'],
        data=mini_val_dataset['document_redact_ner_bert']
    )
    ner_df['model_name'] = 'named_entity'
    ner_df['i'] = ner_df.index

    lex_df = pd.DataFrame(
        columns=['perturbed_text'],
        data=mini_val_dataset['document_redact_lexical']
    )
    lex_df['model_name'] = 'lexical'
    lex_df['i'] = lex_df.index

    # Combine both adversarial and baseline redacted data
    baseline_df = pd.concat((lex_df, ner_df), axis=0)
    baseline_df['k'] = 0
    full_df = pd.concat((adv_df, baseline_df), axis=0)

    # Put newlines back
    full_df['perturbed_text'] = full_df['perturbed_text'].apply(lambda s: s.replace('<SPLIT>', '\n'))

    # Standardize mask tokens
    full_df['perturbed_text'] = full_df['perturbed_text'].apply(lambda s: s.replace('[MASK]', dm.mask_token))
    full_df['perturbed_text'] = full_df['perturbed_text'].apply(lambda s: s.replace('<mask>', dm.mask_token))

    return full_df


checkpoint_path = model_paths_dict[model_key]
assert isinstance(checkpoint_path, str), f"invalid checkpoint_path {checkpoint_path} for {model_key}"
print(f"running eval on {model_key} loaded from {checkpoint_path}")
model = CoordinateAscentModel.load_from_checkpoint(
    checkpoint_path
)

print(f"loading data with {num_cpus} CPUs")
dm = WikipediaDataModule(
    document_model_name_or_path=model.document_model_name_or_path,
    profile_model_name_or_path=model.profile_model_name_or_path,
    dataset_name='wiki_bio',
    dataset_train_split='train[:256]',
    dataset_val_split='val[:256]',
    dataset_test_split='test[:100%]',
    dataset_version='1.2.0',
    num_workers=num_cpus,
    train_batch_size=256,
    eval_batch_size=256,
    max_seq_length=128,
    sample_spans=False,
)
dm.setup("fit")

all_profile_embeddings = get_profile_embeddings(model_key=model_key).cuda()

model.document_model.eval()
model.document_model.cuda()
model.document_embed.eval()
model.document_embed.cuda()

adv_csv = load_adv_csv(dm=dm)

topk_values = []
topk_idxs = []
batch_size = 256
i = 0
while i < len(adv_csv):
    ex = adv_csv.iloc[i:i+batch_size]
    test_batch = dm.document_tokenizer.batch_encode_plus(
        ex['perturbed_text'].tolist(),
        max_length=dm.max_seq_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt',
    )
    test_batch = {
        f'perturbed_text__{k}': v for k,v in test_batch.items()
    }
    with torch.no_grad():
        document_embeddings = model.forward_document(batch=test_batch, document_type='perturbed_text')
        document_to_profile_logits = document_embeddings @ all_profile_embeddings.T
        document_to_profile_probs = document_to_profile_logits.softmax(dim=1)
        topk_100 = document_to_profile_probs.topk(100)
        topk_values.append(topk_100.values)
        topk_idxs.append(topk_100.indices)

    i += batch_size

adv_csv['pred_topk_values'] = torch.cat(topk_values, dim=0).cpu().tolist()
adv_csv['pred_topk_idxs'] = torch.cat(topk_idxs, dim=0).cpu().tolist()


In [5]:
def truncate_text(text: str, max_length=128) -> str:
    input_ids = dm.document_tokenizer(text, truncation=True, max_length=128)['input_ids']
    reconstructed_text = (
        dm.document_tokenizer
            .decode(input_ids)
            .replace('<mask>', ' <mask> ')
            .replace('  <mask>', ' <mask>')
            .replace('<mask>  ', '<mask> ')
            .replace('<s>', '')
            .replace('</s>', '')
            .strip()
    )
    return reconstructed_text

In [13]:
perturbed_text = adv_csv[(adv_csv['model_name'] == 'model_3_2') & (adv_csv['i'] == 9)]['perturbed_text'].iloc[0]
perturbed_text

"<mask> william <mask> <mask> ( 1 <mask> <mask> -- 14 august <mask> ) was <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> the <mask> navy , <mask> <mask> <mask> the <mask> <mask> to <mask> '' , a <mask> that is still <mask> <mask> .\nhe <mask> wrote numerous technical <mask> on naval technology and <mask> and was also noted for his articles concerning racial politics in the <mask> united states .\ndespite <mask> <mask> as a lawyer , <mask> had always preferred <mask> <mask> <mask> , <mask> his <mask> <mask> in 1876 and <mask> <mask> full-time <mask> in 1879 .\nfor the <mask> rendered in his career , clowes was knighted , awarded the gold medal of the united states naval institute and given a civil list pension .\nhe died in sussex in 1905 after years of ill-health .\n"

In [14]:
original_text = dm.test_dataset[9]['target_text']
original_text

"sir william laird clowes -lrb- 1 february 1856 -- 14 august 1905 -rrb- was a british journalist and historian whose principal work was `` the royal navy , a history from the earliest times to 1900 '' , a text that is still in print .\nhe also wrote numerous technical pieces on naval technology and strategy and was also noted for his articles concerning racial politics in the southern united states .\ndespite having trained as a lawyer , clowes had always preferred literature and writing , publishing his first work in 1876 and becoming a full-time journalist in 1879 .\nfor the services rendered in his career , clowes was knighted , awarded the gold medal of the united states naval institute and given a civil list pension .\nhe died in sussex in 1905 after years of ill-health .\n"

In [15]:
truncate_text(original_text) # this is what the first model sees during search

"sir william laird clowes -lrb- 1 february 1856 -- 14 august 1905 -rrb- was a british journalist and historian whose principal work was `` the royal navy, a history from the earliest times to 1900 '', a text that is still in print.\nhe also wrote numerous technical pieces on naval technology and strategy and was also noted for his articles concerning racial politics in the southern united states.\ndespite having trained as a lawyer, clowes had always preferred literature and writing, publishing his first work in 1876 and becoming a full-time journalist in 1879"

In [16]:
truncate_text(perturbed_text) # this is what the second model sees later

"<mask> william <mask> <mask> ( 1 <mask> <mask> -- 14 august <mask> ) was <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> the <mask> navy, <mask> <mask> <mask> the <mask> <mask> to <mask> '', a <mask> that is still <mask> <mask> .\nhe <mask> wrote numerous technical <mask> on naval technology and <mask> and was also noted for his articles concerning racial politics in the <mask> united states.\ndespite <mask> <mask> as a lawyer, <mask> had always preferred <mask> <mask> <mask> , <mask> his <mask> <mask> in 1876 and <mask> <mask> full-time <mask> in 1879.\nfor the <mask> rendered in his career, clowes was knighted,"

In [19]:
dm.document_tokenizer.tokenize(original_text)

['s',
 'ir',
 'Ġwill',
 'iam',
 'Ġl',
 'aird',
 'Ġcl',
 'ow',
 'es',
 'Ġ-',
 'lr',
 'b',
 '-',
 'Ġ1',
 'Ġfe',
 'b',
 'ruary',
 'Ġ18',
 '56',
 'Ġ--',
 'Ġ14',
 'Ġaug',
 'ust',
 'Ġ1905',
 'Ġ-',
 'rr',
 'b',
 '-',
 'Ġwas',
 'Ġa',
 'Ġb',
 'rit',
 'ish',
 'Ġjournalist',
 'Ġand',
 'Ġhistorian',
 'Ġwhose',
 'Ġprincipal',
 'Ġwork',
 'Ġwas',
 'Ġ``',
 'Ġthe',
 'Ġroyal',
 'Ġnavy',
 'Ġ,',
 'Ġa',
 'Ġhistory',
 'Ġfrom',
 'Ġthe',
 'Ġearliest',
 'Ġtimes',
 'Ġto',
 'Ġ1900',
 "Ġ''",
 'Ġ,',
 'Ġa',
 'Ġtext',
 'Ġthat',
 'Ġis',
 'Ġstill',
 'Ġin',
 'Ġprint',
 'Ġ.',
 'Ċ',
 'he',
 'Ġalso',
 'Ġwrote',
 'Ġnumerous',
 'Ġtechnical',
 'Ġpieces',
 'Ġon',
 'Ġnaval',
 'Ġtechnology',
 'Ġand',
 'Ġstrategy',
 'Ġand',
 'Ġwas',
 'Ġalso',
 'Ġnoted',
 'Ġfor',
 'Ġhis',
 'Ġarticles',
 'Ġconcerning',
 'Ġracial',
 'Ġpolitics',
 'Ġin',
 'Ġthe',
 'Ġsouthern',
 'Ġunited',
 'Ġstates',
 'Ġ.',
 'Ċ',
 'despite',
 'Ġhaving',
 'Ġtrained',
 'Ġas',
 'Ġa',
 'Ġlawyer',
 'Ġ,',
 'Ġcl',
 'ow',
 'es',
 'Ġhad',
 'Ġalways',
 'Ġpreferred',
 'Ġlitera