# Birthdays probing test

(trying birthday-day probing with newest model.)


In [1]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

In [None]:
from datamodule import WikipediaDataModule
import os

num_cpus = len(os.sched_getaffinity(0))

dm = WikipediaDataModule(
    document_model_name_or_path="roberta-base",
    profile_model_name_or_path="google/tapas-base",
    max_seq_length=128,
    dataset_name='wiki_bio',
    dataset_train_split='train[:10%]', # not used in this notebook
    dataset_val_split='val[:20%]',
    dataset_version='1.2.0',
    word_dropout_ratio=0.0,
    word_dropout_perc=0.0,
    num_workers=1,
    train_batch_size=64,
    eval_batch_size=64
)
dm.setup("fit")

In [None]:
from model import CoordinateAscentModel
from model_cfg import model_paths_dict

checkpoint_path = model_paths_dict["model_8_1day"]
print(checkpoint_path)


model = CoordinateAscentModel.load_from_checkpoint(
    checkpoint_path,
    document_model_name_or_path="roberta-base",
    profile_model_name_or_path="google/tapas-base",
    learning_rate=1e-5,
    pretrained_profile_encoder=False,
    lr_scheduler_factor=0.5,
    lr_scheduler_patience=1,
    train_batch_size=1,
    num_workers=1,
    gradient_clip_val=10.0,
)

## Get the birthday data

In [5]:
import datetime

d = datetime.datetime.strptime('17 january 1943', "%d %B %Y")
d.day

17

In [7]:
from typing import List, Tuple

import datasets
from tqdm.notebook import tqdm

import datetime
import re


def process_dataset(_dataset: datasets.Dataset) -> List[Tuple[int, int, int]]:
    _processed_data = []
    for idx, d in enumerate(tqdm(_dataset, 'processing birthdays')):
        profile = d['profile']
        # print(profile)
        date_str_matches = re.search(r"birth_date \|\| ([\d]{1,4} [a-z]+ [\d]{1,4})", profile)
        if date_str_matches:
            date_str = date_str_matches.group(1)
            # print(date_str)
            # parse to datetime.datetime
            try:
                dt = datetime.datetime.strptime(date_str, "%d %B %Y")
            except ValueError as e:
                # print(e)
                continue
            # day_class_num = (dt.month - 1) * 31 + (dt.day - 1)
            # _processed_data.append((idx, day_class_num))
            _processed_data.append((idx, dt.month-1, dt.day-1))
    return _processed_data

## Create birthday data module

In [8]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
import datasets

num_cpus = os.cpu_count()

class BirthdayDataModule(LightningDataModule):
    train_dataset: List[Tuple[int, int, int]]
    val_dataset: List[Tuple[int, int, int]]
    batch_size: int
    def __init__(self, dm: WikipediaDataModule, batch_size: int = 128):
        super().__init__()
        self.train_dataset = process_dataset(dm.train_dataset)
        self.val_dataset = process_dataset(dm.val_dataset)
        self.batch_size = batch_size
        self.num_workers = min(4, num_cpus)

    def setup(self, stage: str) -> None:
        return

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False # Only shuffle for train
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )


In [9]:
birthday_dm = BirthdayDataModule(dm)

processing birthdays:   0%|          | 0/58266 [00:00<?, ?it/s]

processing birthdays:   0%|          | 0/14566 [00:00<?, ?it/s]

In [10]:
birthday_dm.train_dataset

[(1, 7, 15),
 (2, 3, 13),
 (3, 0, 18),
 (4, 4, 15),
 (5, 7, 15),
 (6, 2, 21),
 (7, 11, 28),
 (12, 7, 20),
 (14, 6, 2),
 (15, 5, 15),
 (17, 3, 3),
 (19, 1, 26),
 (20, 2, 27),
 (21, 6, 0),
 (23, 3, 16),
 (25, 4, 13),
 (26, 3, 26),
 (27, 5, 25),
 (28, 5, 2),
 (29, 8, 2),
 (30, 5, 19),
 (31, 8, 1),
 (33, 4, 12),
 (36, 0, 29),
 (37, 1, 17),
 (38, 7, 11),
 (39, 4, 14),
 (40, 7, 8),
 (41, 8, 18),
 (43, 7, 28),
 (44, 8, 16),
 (45, 9, 0),
 (46, 4, 22),
 (47, 3, 8),
 (48, 1, 9),
 (50, 11, 5),
 (51, 11, 6),
 (53, 8, 3),
 (54, 2, 1),
 (55, 4, 1),
 (56, 1, 22),
 (57, 7, 0),
 (58, 4, 20),
 (59, 10, 21),
 (60, 5, 25),
 (61, 0, 23),
 (62, 8, 11),
 (63, 1, 8),
 (64, 10, 3),
 (65, 5, 25),
 (67, 9, 9),
 (68, 9, 3),
 (69, 9, 1),
 (70, 2, 1),
 (72, 5, 21),
 (73, 6, 8),
 (74, 4, 4),
 (75, 1, 21),
 (77, 8, 6),
 (78, 9, 5),
 (79, 10, 18),
 (80, 1, 13),
 (81, 1, 18),
 (82, 4, 28),
 (83, 8, 17),
 (85, 4, 5),
 (86, 8, 24),
 (87, 0, 21),
 (88, 11, 25),
 (89, 2, 29),
 (90, 10, 10),
 (92, 8, 11),
 (94, 5, 3),
 (95,

In [11]:
print(len(birthday_dm.train_dataset), len(birthday_dm.val_dataset))

43032 10761


In [None]:
print(next(iter(birthday_dm.train_dataloader())))

## Create birthday model

In [None]:
import numpy as np
import torch

def precompute_profile_embeddings():
    model.profile_model.cuda()
    model.profile_model.eval()
    model.profile_embed.cuda()
    model.profile_embed.eval()


    model.train_profile_embeddings = np.zeros((len(dm.train_dataset), model.shared_embedding_dim))
    for train_batch in tqdm(dm.train_dataloader(), desc="Precomputing train embeddings", colour="teal", leave=False):
        with torch.no_grad():
            profile_embeddings = model.forward_profile(batch=train_batch)
        model.train_profile_embeddings[train_batch["text_key_id"]] = profile_embeddings.cpu()
    model.train_profile_embeddings = torch.tensor(model.train_profile_embeddings, dtype=torch.float32)
    
    model.val_profile_embeddings = np.zeros((len(dm.val_dataset), model.shared_embedding_dim))
    for val_batch in tqdm(dm.val_dataloader()[0], desc="Precomputing val embeddings", colour="green", leave=False):
        with torch.no_grad():
            profile_embeddings = model.forward_profile(batch=val_batch)
        model.val_profile_embeddings[val_batch["text_key_id"]] = profile_embeddings.cpu()
    model.val_profile_embeddings = torch.tensor(model.val_profile_embeddings, dtype=torch.float32)

precompute_profile_embeddings()

In [None]:
len(model.val_profile_embeddings)

In [None]:
from typing import Dict

import torch
import torchmetrics
import transformers

from pytorch_lightning import LightningModule
from transformers import AdamW

class BirthdayModel(LightningModule):
    """Probes the PROFILE for birthday info."""
    profile_embeddings: torch.Tensor
    classifier: torch.nn.Module
    learning_rate: float
    
    def __init__(self, model: CoordinateAscentModel, learning_rate: float):
        super().__init__()
        # We can pre-calculate these embeddings bc
        self.train_profile_embeddings = torch.tensor(model.train_profile_embeddings.cpu())
        self.val_profile_embeddings = torch.tensor(model.val_profile_embeddings.cpu())
        self.month_classifier = torch.nn.Sequential(
            torch.nn.Linear(model.shared_embedding_dim, 64),
            # torch.nn.Dropout(p=0.01),
            torch.nn.Linear(64, 12),
        )
        self.day_classifier = torch.nn.Sequential(
            torch.nn.Linear(model.shared_embedding_dim, 64),
            # torch.nn.Dropout(p=0.01),
            torch.nn.Linear(64, 31),
        )
        self.learning_rate = learning_rate
        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy   = torchmetrics.Accuracy()
        self.loss_criterion = torch.nn.CrossEntropyLoss()

    def training_step(self, batch: Tuple[int, int], batch_idx: int) -> torch.Tensor:
        profile_idxs, months, days = batch
        assert ((0 <= profile_idxs) & (profile_idxs < len(self.train_profile_embeddings))).all()
        assert ((0 <= months) & (months < 12)).all()
        assert ((0 <= days) & (days < 31)).all()
        
        clf_device = next(self.month_classifier.parameters()).device
        with torch.no_grad():
            embedding = self.train_profile_embeddings[profile_idxs].to(clf_device)
        
        
        month_logits = self.month_classifier(embedding)
        day_logits = self.day_classifier(embedding)
        
        
        month_loss = torch.nn.functional.cross_entropy(month_logits, months)
        day_loss = torch.nn.functional.cross_entropy(day_logits, days)
        
        self.log('train_acc_month', self.train_accuracy(month_logits, months))
        self.log('train_acc_day', self.train_accuracy(day_logits, days))
        
        if batch_idx == 0:
            print('train_acc_month', self.train_accuracy(month_logits, months))
            print('train_acc_day', self.train_accuracy(day_logits, days))
        
        return (month_loss + day_loss)
    
    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int):
        profile_idxs, months, days = batch
        assert ((0 <= profile_idxs) & (profile_idxs < len(self.val_profile_embeddings))).all()
        assert ((0 <= months) & (months < 12)).all()
        assert ((0 <= days) & (days < 31)).all()
        
        clf_device = next(self.month_classifier.parameters()).device
        with torch.no_grad():
            embedding = self.val_profile_embeddings[profile_idxs].to(clf_device)
        
        
        month_logits = self.month_classifier(embedding)
        day_logits = self.day_classifier(embedding)
        
        
        month_loss = torch.nn.functional.cross_entropy(month_logits, months)
        day_loss = torch.nn.functional.cross_entropy(day_logits, days)
        
        self.log('val_acc_month', self.val_accuracy(month_logits, months))
        self.log('val_acc_day', self.val_accuracy(day_logits, days))
        
        if batch_idx == 0:
            print('val_acc_month', self.val_accuracy(month_logits, months))
            print('val_acc_day', self.val_accuracy(day_logits, days))

        return (month_loss + day_loss)

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        optimizer = AdamW(
            list(self.parameters()), lr=self.learning_rate
        )
        return optimizer
            

## Train it

In [None]:
from pytorch_lightning import Trainer, seed_everything

seed_everything(42)

num_validations_per_epoch = 4

In [None]:
birthday_model = BirthdayModel(model, 1e-4)
birthday_dm.batch_size = 2048

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
trainer = Trainer(
    default_root_dir=f"saves/jup/birthday_probing",
    val_check_interval=1.0,
    max_epochs=100,
    log_every_n_steps=500,
    gpus=torch.cuda.device_count(),
)

In [None]:
trainer.fit(birthday_model, birthday_dm)

In [4]:
trainer.logged_metrics

NameError: name 'trainer' is not defined

In [None]:
model.train_profile_embeddings.shape