In [None]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

from dataloader import WikipediaDataModule

import numpy as np


num_cpus = 8

dm = WikipediaDataModule(
    document_model_name_or_path='roberta-base',
    profile_model_name_or_path='google/tapas-base',
    dataset_name='wiki_bio',
    dataset_train_split='train[:100%]',
    dataset_val_split='val[:100%]',
    dataset_test_split='test[:100%]',
    dataset_version='1.2.0',
    num_workers=num_cpus,
    train_batch_size=256,
    eval_batch_size=256,
    max_seq_length=128,
    sample_spans=False,
)
dm.setup("fit")

In [None]:
import torch
from utils import get_profile_embeddings_by_model_key

def get_profile_embeddings(model_key: str):
    profile_embeddings = get_profile_embeddings_by_model_key(model_key=model_key)

    print("concatenating train, val, and test profile embeddings")
    all_profile_embeddings = torch.cat(
        (profile_embeddings['test'], profile_embeddings['val'], profile_embeddings['train']), dim=0
    )

    print("all_profile_embeddings:", all_profile_embeddings.shape)
    return all_profile_embeddings

rt_profile_embeddings = get_profile_embeddings('model_3_2')
rr_profile_embeddings = get_profile_embeddings('model_3_3__placeholder')

In [None]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

from model import AbstractModel, CoordinateAscentModel

from model_cfg import model_paths_dict

rt = CoordinateAscentModel.load_from_checkpoint(model_paths_dict['model_3_2'])
rr = CoordinateAscentModel.load_from_checkpoint(model_paths_dict['model_3_3__placeholder'])

In [None]:
import torch
import transformers

from model import AbstractModel

class ModelWrapper:
    document_tokenizer: transformers.AutoTokenizer
    profile_embeddings: torch.Tensor
    max_seq_length: int
    
    def __init__(self,
            model: AbstractModel,
            document_tokenizer: transformers.AutoTokenizer,
            profile_embeddings: torch.Tensor,
            max_seq_length: int = 128
        ):
        self.model = model
        self.model.eval()
        self.document_tokenizer = document_tokenizer
        self.profile_embeddings = profile_embeddings.clone().detach()
        self.max_seq_length = max_seq_length
                 
    def to(self, device):
        self.model.to(device)
        self.profile_embeddings = self.profile_embeddings.to(device)
        return self # so semantics `model = MyModelWrapper().to('cuda')` works properly

    def __call__(self, text_input_list):
        model_device = next(self.model.parameters()).device

        tokenized_documents = self.document_tokenizer.batch_encode_plus(
            text_input_list,
            max_length=self.max_seq_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        batch = {f"document__{k}": v for k,v in tokenized_documents.items()}

        with torch.no_grad():
            document_embeddings = self.model.forward_document(batch=batch, document_type='document')
            document_to_profile_logits = document_embeddings @ (self.profile_embeddings.T)
        # breakpoint()
        assert document_to_profile_logits.shape == (len(text_input_list), len(self.profile_embeddings))
        return document_to_profile_logits

    
rt_mw = ModelWrapper(rt, dm.document_tokenizer, rt_profile_embeddings).to('cuda')
rr_mw = ModelWrapper(rr, dm.document_tokenizer, rr_profile_embeddings).to('cuda')


In [2]:
from typing import List, Tuple

from IPython.display import HTML, display
import html

wrap_th = lambda s: f'<th>{s}</th>'
wrap_td = lambda s: f'<td>{s}</td>'

def get_person(i: int):
    if i < len(dm.test_dataset):
        return dm.test_dataset[i]
    elif i < (len(dm.test_dataset) + len(dm.val_dataset)):
        return dm.val_dataset[i - len(dm.test_dataset)]
    else:
        return dm.train_dataset[i - len(dm.test_dataset) - len(dm.val_dataset)]

def table_from_table_rows(rows_str: str) -> List[Tuple[str, str]]:
    return [[el.strip() for el in r.split('||')] for r in rows_str.split('\n')]

def make_prof_html(profile: str) -> str:
    table = table_from_table_rows(profile)
    s = '<table style="border: 1px solid black"><tbody>'
    # print('table:', table)
    for rkey, rval in table:
        s += '<tr>'
        s += f'<th><b>{rkey}</b></th>'
        s += f'<td>{rval}</td>'
        s += '</tr>'
    s += '</tbody></table>'
    return s

def display_profile_by_index(idx: int):
    display(HTML(make_prof_html(get_person(idx)['profile'])))


In [None]:
label_names = np.array(list(dm.test_dataset['name']) + list(dm.val_dataset['name']) + list(dm.train_dataset['name']))

In [None]:
def call_rr(s: str):
    probs = rr_mw([s]).softmax(1)
    return (label_names[probs.argmax()], probs.argmax().item(), probs.max().item())

def call_rt(s: str):
    probs = rt_mw([s]).softmax(1)
    return (label_names[probs.argmax()], probs.argmax().item(), probs.max().item())