# Birthdays probing test - finetuning

what if we *fine-tune* distilbert on the birthday task? what's its performance look like then?

In [1]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

In [2]:
from model import DocumentProfileMatchingTransformer

import os

num_cpus = os.cpu_count()

model = DocumentProfileMatchingTransformer(
    document_model_name_or_path='roberta-base',
    profile_model_name_or_path='distilbert-base-uncased',
    num_workers=min(8, num_cpus),
    train_batch_size=64,
    eval_batch_size=64,
    learning_rate=1e-6,
    max_seq_length=256,
    pretrained_profile_encoder=False,
    word_dropout_ratio=0.0,
    word_dropout_perc=0.0,
    lr_scheduler_factor=0.5,
    lr_scheduler_patience=3,
    adversarial_mask_k_tokens=0,
    train_without_names=False,
)
del model.document_model

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.weight']
- This IS expected if you are i

Initialized DocumentProfileMatchingTransformer with learning_rate = 1e-06


In [3]:
from dataloader import WikipediaDataModule
import os

num_cpus = os.cpu_count()

dm = WikipediaDataModule(
    mask_token=model.document_tokenizer.mask_token,
    dataset_name='wiki_bio',
    dataset_train_split='train[:100%]',
    dataset_val_split='val[:20%]',
    dataset_version='1.2.0',
    num_workers=min(8, num_cpus),
    train_batch_size=64,
    eval_batch_size=64,
)
dm.setup("fit")

Initializing WikipediaDataModule with num_workers = 8 and mask token `<mask>`
loading wiki_bio[1.2.0] split train[:100%]


Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


loading wiki_bio[1.2.0] split val[:20%]


Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-58e5e96e220311ed.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-778e9a6d1b0dfab7.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-3c4e94260fbd4dd3.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-9e279afc7bfb46f2.arrow
Loading cached processed dataset at /h

## Get the birthday data

In [4]:
import datetime

d = datetime.datetime.strptime('17 january 1943', "%d %B %Y")
d.day

17

In [5]:
from typing import List, Tuple

from tqdm.notebook import tqdm

import datetime
import re


def process_dataset(_dataset) -> List[Tuple[int, int]]:
    _processed_data = []
    for idx, d in enumerate(tqdm(_dataset, 'processing birthdays')):
        profile = d['profile']
        date_str_matches = re.search(r"birth_date \| ([\d]{1,4} [a-z]+ [\d]{1,4})", profile)
        if date_str_matches:
            date_str = date_str_matches.group(1)
            # print(date_str)
            # parse to datetime.datetime
            try:
                dt = datetime.datetime.strptime(date_str, "%d %B %Y")
            except ValueError as e:
                # print(e)
                continue
            # day_class_num = (dt.month - 1) * 31 + (dt.day - 1)
            # _processed_data.append((idx, day_class_num))
            _processed_data.append((idx, dt.month-1, dt.day-1))
    return _processed_data

## Create birthday data module

In [7]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

num_cpus = os.cpu_count()

class BirthdayDataModule(LightningDataModule):
    train_dataset: List[Tuple[int, int]]
    val_dataset: List[Tuple[int, int]]
    batch_size: int
    def __init__(self, dm: WikipediaDataModule, batch_size: int = 64):
        super().__init__()
        self.train_dataset = process_dataset(dm.train_dataset)
        self.val_dataset = process_dataset(dm.val_dataset)
        self.batch_size = batch_size
        self.num_workers = min(4, num_cpus)

    def setup(self, stage: str) -> None:
        return

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False # Only shuffle for train
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )


In [8]:
birthday_dm = BirthdayDataModule(dm)

processing birthdays:   0%|          | 0/582659 [00:00<?, ?it/s]

processing birthdays:   0%|          | 0/14566 [00:00<?, ?it/s]

## Create birthday model

In [14]:
from typing import Dict

import numpy as np
import torch
import torchmetrics
import transformers

from pytorch_lightning import LightningModule
from transformers import AdamW

class BirthdayModel(LightningModule):
    """Probes the PROFILE for birthday info."""
    model: DocumentProfileMatchingTransformer
    
    train_profiles: np.array
    val_profiles: np.array
    
    classifier: torch.nn.Module
    learning_rate: float
    
    def __init__(self, model: DocumentProfileMatchingTransformer, dm: WikipediaDataModule, learning_rate: float):
        super().__init__()
        # We can pre-calculate these embeddings bc
        self.model = model
        self.train_profiles = np.array(dm.train_dataset['profile'])
        self.val_profiles = np.array(dm.val_dataset['profile'])
        self.month_classifier = torch.nn.Sequential(
            torch.nn.Linear(model.profile_embedding_dim, 64),
            # torch.nn.Dropout(p=0.01),
            torch.nn.Linear(64, 12),
        )
        self.day_classifier = torch.nn.Sequential(
            torch.nn.Linear(model.profile_embedding_dim, 64),
            # torch.nn.Dropout(p=0.01),
            torch.nn.Linear(64, 31),
        )
        self.learning_rate = learning_rate
        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy   = torchmetrics.Accuracy()
        self.loss_criterion = torch.nn.CrossEntropyLoss()

    def training_step(self, batch: Tuple[int, int], batch_idx: int) -> torch.Tensor:
        profile_idxs, months, days = batch
        assert ((0 <= profile_idxs) & (profile_idxs < len(self.train_profiles))).all()
        assert ((0 <= months) & (months < 12)).all()
        assert ((0 <= days) & (days < 31)).all()
        
        profiles = self.train_profiles[profile_idxs.cpu()].tolist()
        embedding = self.model.forward_profile_text(text=profiles)
        
        month_logits = self.month_classifier(embedding)
        day_logits = self.day_classifier(embedding)
        
        
        month_loss = torch.nn.functional.cross_entropy(month_logits, months)
        day_loss = torch.nn.functional.cross_entropy(day_logits, days)
        
        self.log('train_acc_month', self.train_accuracy(month_logits, months))
        self.log('train_acc_day', self.train_accuracy(day_logits, days))
        
        if batch_idx == 0:
            print(
                'train_acc_month', self.train_accuracy(month_logits, months), 
                '; train_acc_day', self.train_accuracy(day_logits, days)
            )
        
        return (month_loss + day_loss)
    
    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int):
        profile_idxs, months, days = batch
        assert ((0 <= profile_idxs) & (profile_idxs < len(self.val_profiles))).all()
        assert ((0 <= months) & (months < 12)).all()
        assert ((0 <= days) & (days < 31)).all()
        
        profiles = self.val_profiles[profile_idxs.cpu()].tolist()
        embedding = self.model.forward_profile_text(text=profiles)
        
        month_logits = self.month_classifier(embedding)
        day_logits = self.day_classifier(embedding)
        
        
        month_loss = torch.nn.functional.cross_entropy(month_logits, months)
        day_loss = torch.nn.functional.cross_entropy(day_logits, days)
        
        self.log('val_acc_month', self.val_accuracy(month_logits, months))
        self.log('val_acc_day', self.val_accuracy(day_logits, days))
        
        if (batch_idx == 0) or (batch_idx == 1855):
            print(
                batch_idx,
                'val_acc_month', self.val_accuracy(month_logits, months),
                '; val_acc_day', self.val_accuracy(day_logits, days)
            )

        return (month_loss + day_loss)

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        params = (
            list(self.day_classifier.parameters()) + 
            list(self.month_classifier.parameters()) +
            list(self.model.profile_model.parameters())
        )
        optimizer = AdamW(
            params, lr=self.learning_rate
        )
        return optimizer
            

## Train it

In [15]:
from pytorch_lightning import Trainer, seed_everything

seed_everything(42)

num_validations_per_epoch = 4

Global seed set to 42


In [16]:
birthday_model = BirthdayModel(model=model, dm=dm, learning_rate=5e-5)
birthday_dm.batch_size = 128

val_per_epoch = 4

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
trainer = Trainer(
    default_root_dir=f"saves/jup/birthday_probing",
    val_check_interval=1/val_per_epoch,
    max_epochs=5,
    log_every_n_steps=50,
    gpus=torch.cuda.device_count(),
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(birthday_model, birthday_dm)

  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name             | Type                               | Params
------------------------------------------------------------------------
0 | model            | DocumentProfileMatchingTransformer | 67.0 M
1 | month_classifier | Sequential                         | 50.0 K
2 | day_classifier   | Sequential                         | 51.2 K
3 | train_accuracy   | Accuracy                           | 0     
4 | val_accuracy     | Accuracy                           | 0     
5 | loss_criterion   | CrossEntropyLoss                   | 0     
------------------------------------------------------------------------
67.1 M    Trainable params
0         Non-trainable params
67.1 M    Total params
268.219   Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

0 val_acc_month tensor(0.0859, device='cuda:0') ; val_acc_day tensor(0.0391, device='cuda:0')


Global seed set to 42


Training: 0it [00:00, ?it/s]

train_acc_month tensor(0.0938, device='cuda:0') ; train_acc_day tensor(0.0547, device='cuda:0')


Validating: 0it [00:00, ?it/s]

0 val_acc_month tensor(0.9922, device='cuda:0') ; val_acc_day tensor(1., device='cuda:0')


Validating: 0it [00:00, ?it/s]

0 val_acc_month tensor(0.9922, device='cuda:0') ; val_acc_day tensor(0.9922, device='cuda:0')
