# Birthdays probing test

In [2]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

In [4]:
from model import DocumentProfileMatchingTransformer

import os

num_cpus = os.cpu_count()

model = DocumentProfileMatchingTransformer(
    document_model_name_or_path='roberta-base',
    profile_model_name_or_path='distilbert-base-uncased',
    num_workers=min(8, num_cpus),
    train_batch_size=64,
    eval_batch_size=64,
    learning_rate=1e-6,
    max_seq_length=256,
    pretrained_profile_encoder=False,
    word_dropout_ratio=0.0,
    word_dropout_perc=0.0,
    lr_scheduler_factor=0.5,
    lr_scheduler_patience=3,
    adversarial_mask_k_tokens=0,
    train_without_names=False,
)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are i

Initialized DocumentProfileMatchingTransformer with learning_rate = 1e-06


In [5]:
from dataloader import WikipediaDataModule
import os

num_cpus = os.cpu_count()

dm = WikipediaDataModule(
    mask_token=model.document_tokenizer.mask_token,
    dataset_name='wiki_bio',
    dataset_train_split='train[:100%]',
    dataset_val_split='val[:20%]',
    dataset_version='1.2.0',
    num_workers=min(8, num_cpus),
    train_batch_size=64,
    eval_batch_size=64,
)
dm.setup("fit")

Initializing WikipediaDataModule with num_workers = 8 and mask token `<mask>`
loading wiki_bio[1.2.0] split train[:100%]


Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


loading wiki_bio[1.2.0] split val[:20%]


Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-58e5e96e220311ed.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-778e9a6d1b0dfab7.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-3c4e94260fbd4dd3.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-9e279afc7bfb46f2.arrow
Loading cached processed dataset at /h

## Get the birthday data

In [6]:
import datetime

d = datetime.datetime.strptime('17 january 1943', "%d %B %Y")
d.day

17

In [7]:
from typing import List, Tuple

from tqdm.notebook import tqdm

import datetime
import re


def process_dataset(_dataset) -> List[Tuple[int, int]]:
    _processed_data = []
    for idx, d in enumerate(tqdm(_dataset, 'processing birthdays')):
        profile = d['profile']
        date_str_matches = re.search(r"birth_date \| ([\d]{1,4} [a-z]+ [\d]{1,4})", profile)
        if date_str_matches:
            date_str = date_str_matches.group(1)
            # print(date_str)
            # parse to datetime.datetime
            try:
                dt = datetime.datetime.strptime(date_str, "%d %B %Y")
            except ValueError as e:
                # print(e)
                continue
            day_class_num = (dt.month - 1) * 31 + (dt.day - 1)
            _processed_data.append((idx, day_class_num))
    return _processed_data

## Create birthday data module

In [8]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

num_cpus = os.cpu_count()

class BirthdayDataModule(LightningDataModule):
    train_dataset: List[Tuple[int, int]]
    val_dataset: List[Tuple[int, int]]
    batch_size: int
    def __init__(self, dm: WikipediaDataModule, batch_size: int = 64):
        super().__init__()
        self.train_dataset = process_dataset(dm.train_dataset)
        self.val_dataset = process_dataset(dm.val_dataset)
        self.batch_size = batch_size
        self.num_workers = min(4, num_cpus)

    def setup(self, stage: str) -> None:
        return

    def train_dataloader(self) -> process_dataset(dm.val_dataset):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False # Only shuffle for train
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )


processing birthdays:   0%|          | 0/14566 [00:00<?, ?it/s]

In [9]:
birthday_dm = BirthdayDataModule(dm)

processing birthdays:   0%|          | 0/582659 [00:00<?, ?it/s]

processing birthdays:   0%|          | 0/14566 [00:00<?, ?it/s]

## Create birthday model

In [None]:
import numpy as np
import torch

def precompute_embeddings(model: DocumentProfileMatchingTransformer, datamodule: WikipediaDataModule):
    model.profile_model.cuda()
    model.profile_model.eval()
    print('Precomputing profile embeddings before first epoch...')
    
    model.train_profile_embeddings = np.zeros((len(datamodule.train_dataset), model.profile_embedding_dim))
    for train_batch in tqdm(datamodule.train_dataloader(), desc="[1/2] Precomputing train embeddings - profile", colour="cyan", leave=False):
        with torch.no_grad():
            profile_embeddings = model.forward_profile_text(text=train_batch["profile"])
        model.train_profile_embeddings[train_batch["text_key_id"]] = profile_embeddings.cpu()
    model.train_profile_embeddings = torch.tensor(model.train_profile_embeddings, dtype=torch.float32)
    
    model.val_profile_embeddings = np.zeros((len(datamodule.val_dataset), model.profile_embedding_dim))
    for val_batch in tqdm(datamodule.val_dataloader(), desc="[2/2] Precomputing val embeddings - profile", colour="green", leave=False):
        with torch.no_grad():
            profile_embeddings = model.forward_profile_text(text=val_batch["profile"])
        model.val_profile_embeddings[val_batch["text_key_id"]] = profile_embeddings.cpu()
    model.val_profile_embeddings = torch.tensor(model.val_profile_embeddings, dtype=torch.float32)
    
    
    model.profile_model.train()

precompute_embeddings(model, dm)

Precomputing profile embeddings before first epoch...


[1/2] Precomputing train embeddings - profile:   0%|          | 0/9105 [00:00<?, ?it/s]

In [None]:
from typing import Dict

import torch
import torchmetrics
import transformers

from pytorch_lightning import LightningModule
from transformers import AdamW

class BirthdayModel(LightningModule):
    """Probes the PROFILE for birthday info."""
    profile_embeddings: torch.Tensor
    classifier: torch.nn.Module
    learning_rate: float
    
    def __init__(self, model: DocumentProfileMatchingTransformer, learning_rate: float):
        super().__init__()
        # We can pre-calculate these embeddings bc
        self.train_profile_embeddings = torch.tensor(model.train_profile_embeddings.cpu())
        self.val_profile_embeddings = torch.tensor(model.val_profile_embeddings.cpu())
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(model.profile_embedding_dim, 64),
            torch.nn.Dropout(p=0.01),
            # 12 * 31 possible outputs
            torch.nn.Linear(64, 12*31),
        )
        self.learning_rate = learning_rate
        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy   = torchmetrics.Accuracy()
        self.loss_criterion = torch.nn.CrossEntropyLoss()

    def training_step(self, batch: Tuple[int, int], batch_idx: int) -> torch.Tensor:
        profile_idxs, birthday_idxs = batch
        assert ((0 <= profile_idxs) & (birthday_idxs < len(self.train_profile_embeddings))).all()
        assert ((0 <= birthday_idxs) & (birthday_idxs < 12*31)).all()
        # print('profile_idxs, birthday_idxs =', profile_idxs, birthday_idxs)
        clf_device = next(self.classifier.parameters()).device
        with torch.no_grad():
            embedding = self.train_profile_embeddings[profile_idxs].to(clf_device)
        birthday_logits = self.classifier(embedding)
        # loss = torch.nn.functional.cross_entropy(
        #     birthday_logits, birthday_idxs
        # )
        # if batch_idx == 0: breakpoint()
        self.log('train_accuracy', self.train_accuracy(birthday_logits, birthday_idxs))
        if batch_idx % 300 == 0: print('train accuracy:', self.train_accuracy(birthday_logits, birthday_idxs))
        return self.loss_criterion(birthday_logits, birthday_idxs)
    
    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int):
        profile_idxs, birthday_idxs = batch
        assert ((0 <= profile_idxs) & (profile_idxs < len(self.val_profile_embeddings))).all()
        assert ((0 <= birthday_idxs) & (birthday_idxs < 12*31)).all()
        # print('profile_idxs, birthday_idxs =', profile_idxs, birthday_idxs)
        clf_device = next(self.classifier.parameters()).device
        with torch.no_grad():
            embedding = self.val_profile_embeddings[profile_idxs].to(clf_device)
        # print('emebdding.shape:', embedding.shape)
        birthday_logits = self.classifier(embedding)
        # print('birthday_logits.shape:', birthday_logits.shape)
        loss = torch.nn.functional.cross_entropy(
            birthday_logits, birthday_idxs
        )
        if batch_idx == 0: self.log('val_accuracy', self.val_accuracy(birthday_logits, birthday_idxs))
        return self.loss_criterion(birthday_logits, birthday_idxs)

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        optimizer = AdamW(
            list(self.classifier.parameters()), lr=self.learning_rate
        )
        return optimizer
            

## Train it

In [None]:
from pytorch_lightning import Trainer, seed_everything

seed_everything(42)

num_validations_per_epoch = 4

In [None]:
birthday_model = BirthdayModel(model, 1e-3)
birthday_dm.batch_size = 512

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
trainer = Trainer(
    default_root_dir=f"saves/jup/birthday_probing",
    val_check_interval=1.0,
    max_epochs=25,
    log_every_n_steps=50,
    gpus=torch.cuda.device_count(),
)

In [None]:
trainer.fit(birthday_model, birthday_dm)

In [None]:
import torchmetrics

In [None]:
val_batch = next(iter(birthday_dm.val_dataloader()))

def do_validation_batch(batch, batch_idx):
    profile_idxs, birthday_idxs = batch
    clf_device = next(birthday_model.classifier.parameters()).device
    embedding = birthday_model.val_profile_embeddings[profile_idxs].to(clf_device)
    print('emebdding.shape:', embedding.shape)
    birthday_logits = birthday_model.classifier(embedding)
    print('birthday_logits.shape:', birthday_logits.shape)
    loss = torch.nn.functional.cross_entropy(
        birthday_logits, birthday_idxs
    )
    # self.log('val_accuracy', self.val_accuracy(birthday_logits, birthday_idxs))
    print('loss:', loss)

do_validation_batch(val_batch, 0)

In [None]:
val_batch # last element: idx 85, birthday 55

In [None]:
55 % 31 # february 24th

dm.val_dataset[85]

In [137]:
birthday_model.val_profile_embeddings[85][:5]

tensor([ 0.0936,  0.3824,  0.6874,  0.7990, -0.5991])

In [138]:
model.eval()
model.forward_profile_text(text=[dm.val_dataset[85]['profile']])[0, :5]

tensor([ 0.0936,  0.3824,  0.6874,  0.7990, -0.5991], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [129]:
train_batch = next(iter(birthday_dm.train_dataloader()))
train_batch # 682, 78
78 % 31 # 16 -> this is march 17th

16

In [131]:
dm.train_dataset[682]['profile']

'nationalgoals | 12\nfullname | jes√∫s candelas rodrigo\nmanagerclubs | netherlands assistant -rrb- iran netherlands malta thailand -lrb- assistant -rrb- hong kong malaysia netherlands -lrb-\nname | victor hermans\narticle_title | victor hermans\nnationalyears | 1977 -- 1989\nposition | manager -lrb- association football -rrb-\ncurrentclub | thailand national futsal team -lrb- head coach -rrb-\nclubs | mvv maastricht k.s.k. tongeren\nnationalteam | netherlands -lrb- futsal -rrb-\nbirth_place | maastricht , netherlands\nbirth_date | 17 march 1953\nnationalcaps | 50\nmanageryears | 1990 2000 2001 2001-2007 2009 -- 2011 2012 -- -- 1992 1992 -- 1996 1996 1997 --\nheight | 1.72'

In [132]:
birthday_model.train_profile_embeddings[682][:5]

tensor([-0.3974,  0.4090,  0.3919,  1.2626, -0.1960])

In [136]:
model.eval()
model.forward_profile_text(text=[dm.train_dataset[682]['profile']])[0, :5]

tensor([-0.3974,  0.4090,  0.3919,  1.2626, -0.1960], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [140]:
list(birthday_model.named_parameters())

[('0.weight',
  Parameter containing:
  tensor([[-0.0123,  0.0198, -0.0286,  ...,  0.0052, -0.0202,  0.0360],
          [-0.0107,  0.0232,  0.0180,  ...,  0.0116, -0.0154, -0.0274],
          [-0.0222, -0.0221,  0.0122,  ...,  0.0234,  0.0198,  0.0023],
          ...,
          [ 0.0127,  0.0177, -0.0266,  ..., -0.0159, -0.0071,  0.0111],
          [-0.0245,  0.0075,  0.0298,  ..., -0.0179, -0.0173,  0.0030],
          [ 0.0115,  0.0255,  0.0330,  ..., -0.0075, -0.0049, -0.0297]],
         device='cuda:0', requires_grad=True)),
 ('0.bias',
  Parameter containing:
  tensor([-0.0013,  0.0079,  0.0005, -0.0231,  0.0133, -0.0023,  0.0213, -0.0355,
          -0.0328, -0.0144, -0.0042,  0.0066, -0.0263, -0.0157,  0.0100,  0.0275,
           0.0136, -0.0305, -0.0026, -0.0168,  0.0358, -0.0242,  0.0104,  0.0301,
           0.0180,  0.0171,  0.0291,  0.0126,  0.0347,  0.0225,  0.0016, -0.0308,
           0.0349, -0.0179, -0.0320,  0.0195, -0.0254,  0.0104,  0.0150, -0.0162,
           0.0283, -