In [2]:
import torch
import cv2
import os
import argparse
from torch.utils.data import DataLoader, Dataset

import numpy as np
from PIL import Image
from typing import List, Tuple, Optional
from matplotlib import cm
import torch.nn.functional as F


In [3]:
import sys
module_path = "/d1/jinakim/permutation-learning/src/"
sys.path.append(module_path)

from utils import batch_to
from models import hf_bert as bert

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
TASK_TO_VALID = {
    "cola": "validation",
    "mnli": "validation_matched",
    "mrpc": "test",
    "qnli": "validation",
    "qqp": "validation",
    "rte": "validation",
    "sst2": "validation",
    "stsb": "validation",
    "wnli": "validation",
    "bert": "validation",
}

TASK_TO_KEYS = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

In [5]:
from models import hf_bert as berts
import transformers

def get_base_model(dataset, only_tokenizer=False):
    checkpoint = {
        "cola": "textattack/bert-base-uncased-CoLA",
        "mnli": "yoshitomo-matsubara/bert-base-uncased-mnli",
        "mrpc": "textattack/bert-base-uncased-MRPC",
        # "mrpc": "M-FAC/bert-tiny-finetuned-mrpc",
        "qnli": "textattack/bert-base-uncased-QNLI",
        "qqp": "textattack/bert-base-uncased-QQP",
        "rte": "textattack/bert-base-uncased-RTE",
        "sst2": "textattack/bert-base-uncased-SST-2",
        "stsb": "textattack/bert-base-uncased-STS-B",
        "wnli": "textattack/bert-base-uncased-WNLI",
        "bert": "bert-base-uncased",
    }[dataset]

    # NOTE(HJ): this bert models has special hooks
    model = {
        "cola": berts.BertForSequenceClassification,
        "mnli": berts.BertForSequenceClassification,
        "mrpc": berts.BertForSequenceClassification,
        "qnli": berts.BertForSequenceClassification,
        "qqp": berts.BertForSequenceClassification,
        "rte": berts.BertForSequenceClassification,
        "sst2": berts.BertForSequenceClassification,
        "stsb": berts.BertForSequenceClassification,
        "wnli": berts.BertForSequenceClassification,
        "bert": berts.BertForSequenceClassification,
    }[dataset]
    
    tokenizer = transformers.BertTokenizerFast.from_pretrained(checkpoint)
    if only_tokenizer:
        return None, tokenizer
    
    bert = model.from_pretrained(checkpoint, cache_dir='./cache/huggingface/')
    return bert, tokenizer

In [6]:
from datasets import load_dataset, load_metric
import random, torch

def get_dataloader(subset, tokenizer, batch_size, split='train', encode_batch_size=384):
    if subset == 'bert':
        subset = "cola" #return dummy set
    
    dataset = load_dataset('glue', subset, split=split, cache_dir='./cache/datasets')
    
    sentence1_key, sentence2_key = TASK_TO_KEYS[subset]

    def encode(examples):
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*args, padding=True, max_length=256, truncation=True)
        # result = tokenizer(*args, padding="max_length", max_length=512, truncation=True)
        # Map labels to IDs (not necessary for GLUE tasks)
        # if label_to_id is not None and "label" in examples:
        #     result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
        return result
    
    if split.startswith('train'): #shuffle when train set
        dataset = dataset.sort('label')
        dataset = dataset.shuffle(seed=random.randint(0, 10000))
    dataset = dataset.map(lambda examples: {'labels': examples['label']}, batched=True, batch_size=encode_batch_size)
    dataset = dataset.map(encode, batched=True, batch_size=encode_batch_size)
    dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

    dataloader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size, 
        num_workers=0,
    )
    return dataloader

In [7]:
def gather_fixed_batch(dataloader: DataLoader, batch_size: int):
    items = [
        dataloader.dataset.__getitem__(i * (len(dataloader.dataset) // batch_size))
        for i in range(batch_size)
    ]
    max_len = max([it['input_ids'].shape[0] for it in items])
    for it in items:
        it['input_ids'] = F.pad(it['input_ids'], (0, max_len-len(it['input_ids'])))
        it['attention_mask'] = F.pad(it['attention_mask'], (0, max_len-len(it['attention_mask'])))
        it['token_type_ids'] = F.pad(it['token_type_ids'], (0, max_len-len(it['token_type_ids'])))
    # print([[(k, v.shape) for k, v in it.items()] for it in items])
    return dataloader.collate_fn(items)

In [9]:
import tqdm
BF16 = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

N, H, T, T_M =16, 12, 203, 128
subset = "mnli"

base_model, tokenizer = get_base_model(subset)
base_model.to(device=0)

loader = get_dataloader(subset, tokenizer, N, split=TASK_TO_VALID[subset])
batch = gather_fixed_batch(loader, N)
batch = batch_to(batch, device=0)

base_model.eval()

with torch.no_grad():
    base_model(**batch)

for module in base_model.modules():
    if isinstance(module, bert.BertSelfAttention):
        teacher_score = module.perlin_last_attention_score
        teacher_probs = module.perlin_last_attention_prob
        teacher_context_layer = module.perlin_last_context_layer

Found cached dataset glue (/d1/JINA_COLSELkim/permutation-learning/src/poc/cat/cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Loading cached processed dataset at /d1/JINA_COLSELkim/permutation-learning/src/poc/cat/cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e435b23c87ec8670.arrow
Loading cached processed dataset at /d1/JINA_COLSELkim/permutation-learning/src/poc/cat/cache/datasets/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-7d8b16f65baf29cc.arrow


In [None]:
teacher_score

NameError: name 'teacher_score' is not defined