In [74]:
import torch
from scripts.transformer.modeling import TinyBertForSequenceClassification
from scripts.transformer.modeling import BertConfig as TBertConfig
from scripts.transformer.optimization import BertAdam
from scripts.transformer.file_utils import WEIGHTS_NAME, CONFIG_NAME
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig

from scripts.utils import *
import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

In [3]:
# Teacher model
tokenizer, teacher = create_model('bert-base-cased')
checkpoint = torch.load(os.path.join('artifacts', 'BERT-title-content-benchmark:v0', 'pytorch_model.bin'), map_location=torch.device(device))
teacher.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [4]:
teacher.config.output_attentions = True
teacher.config.output_hidden_states = True

In [79]:
# Student model
student_path = os.path.join('artifacts', '2nd_General_TinyBERT_4L_312D', 'config.json')
student_config = TBertConfig(student_path)
student_config.hidden_size = 768
student = TinyBertForSequenceClassification(student_config, num_labels = 1)

In [80]:
student.config

{
  "attention_probs_dropout_prob": 0.1,
  "cell": {},
  "emb_size": 312,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 1200,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 4,
  "pre_trained": "",
  "structure": [],
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [68]:
student.confi

{
  "attention_probs_dropout_prob": 0.1,
  "cell": {},
  "emb_size": 312,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 312,
  "initializer_range": 0.02,
  "intermediate_size": 1200,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 4,
  "pre_trained": "",
  "structure": [],
  "training": "",
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [6]:
train_data_loader = create_reliable_news_dataloader(
            os.path.join('data', 'nela_gt_2018_site_split', 'train.jsonl'),
            tokenizer,
            max_len = 512,
            batch_size = 8 * max(1, n_gpu),
            shuffle=True,
            sample = 16,
            title_only = False
        )

Max token length: 512 Batch size: 8 Shuffle: True Title only: False


In [7]:
model = teacher.eval()
with torch.no_grad():
    loop = tqdm(train_data_loader)
    for idx, batch in enumerate(loop):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch["labels"].to(device).unsqueeze(1)
        
        outputs = model(input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids, labels = labels)
        break

  0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
outputs.keys()

odict_keys(['loss', 'logits', 'hidden_states', 'attentions'])

In [81]:
model = student.eval()
with torch.no_grad():
    loop = tqdm(train_data_loader)
    for idx, batch in enumerate(loop):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch["labels"].to(device).unsqueeze(1)
        
        logits, atts, reps = model(input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids, labels = labels)
        break

  0%|          | 0/2 [00:00<?, ?it/s]

In [82]:
atts[0].shape

torch.Size([8, 12, 512, 512])

In [83]:
reps[1].shape

torch.Size([8, 512, 768])

In [84]:
teacher_layer_num = len(outputs['attentions'])
student_layer_num = len(atts)
teacher_layer_num % student_layer_num == 0

True

In [85]:
teacher_atts = outputs['attentions']
student_atts = atts
layers_per_block = int(teacher_layer_num / student_layer_num)
new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1]
                    for i in range(student_layer_num)]

In [86]:
att_loss = 0
from torch.nn import CrossEntropyLoss, MSELoss
lossMse = MSELoss()
for student_att, teacher_att in zip(student_atts, new_teacher_atts):
    student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device),
                                student_att)
    teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device),
                                teacher_att)

    tmp_loss = lossMse(student_att, teacher_att)
    att_loss += tmp_loss

In [87]:
teacher_reps = outputs['hidden_states']
student_reps = reps
rep_loss = 0
new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num + 1)]
new_student_reps = student_reps

for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
    tmp_loss = lossMse(student_rep, teacher_rep)
    rep_loss += tmp_loss

In [88]:
att_loss

tensor(0.2547)

In [89]:
rep_loss

tensor(9.8297)