In [1]:
## Load Library
import os, sys
import json
import pandas as pd
from tqdm import tqdm
import random
from typing import List
import numpy as np
from collections import defaultdict, Counter
import pickle
import hashlib

# Weights & Biases
import wandb

# Pytorch modules
import torch
from torch.nn import functional as F
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, IterableDataset, TensorDataset, DataLoader, random_split

from transformers import AutoModel, AutoModelForQuestionAnswering, AutoTokenizer
from transformers import WEIGHTS_NAME, CONFIG_NAME

torch.manual_seed(1234)

<torch._C.Generator at 0x7f3ba5130e70>

In [2]:
class InputFeature():
    """A single set of features of data"""

    def __init__(self,
                 id,
                 input_ids,
                 token_type_ids,
                 attention_mask,
                 start_positions,
                 end_positions,
                 offset_mapping,
                 source=None
                 ):
        self.id = id
        self.input_ids = input_ids
        self.token_type_ids = token_type_ids
        self.attention_mask = attention_mask
        self.start_positions = start_positions
        self.end_positions = end_positions
        self.offset_mapping = offset_mapping
        self.source = source


def read_file(file_path, has_src=False):
    with open(file_path) as f:
        data = json.load(f)

    qa_data = defaultdict(list)
    for doc in data['data']:
        if has_src:
            source = doc['source']
        else:
            source = "all"
        for paragraph in doc['paragraphs']:
            context = paragraph['context'].replace('\u200b', '')
            for question_and_answers in paragraph['qas']:
                is_impossible = question_and_answers['is_impossible'] if 'is_impossible' in question_and_answers else None
                if not is_impossible:
                    question = question_and_answers['question']
                    answers = question_and_answers['answers']
                    for answer in answers:
                        id = question + context
                        id = hashlib.shake_256(id.encode()).hexdigest(5)
                        qa_data['id'].append(id)
                        qa_data['context'].append(context)
                        qa_data['question'].append(question)
                        qa_data['answers'].append(answer)
                        if has_src:
                            qa_data['source'].append(source)
    return qa_data


def convert_to_features(qa_data, tokenizer, max_len, has_src=False):
    encodings = []
    for idx, (id, context, question, answers) in tqdm(enumerate(zip(qa_data['id'],
                                                                    qa_data['context'],
                                                                    qa_data['question'],
                                                                    qa_data['answers'])), total=len(qa_data['context'])):
        encoding = tokenizer(
            question,
            context,
            truncation=True,
            max_length=max_len,
            return_offsets_mapping=True,
            padding="max_length"
        )
        encoding['id'] = id
        if has_src:
            encoding['source'] = qa_data['source'][idx]
        else:
            encoding['source'] = None
        offset_mapping = encoding.pop("offset_mapping")
        encoding['offset_mapping'] = offset_mapping

        input_ids = encoding['input_ids']
        sequence_ids = encoding.sequence_ids(0)

        start_char = answers['answer_start']
        end_char = start_char + len(answers['text'])

        token_start_index = 0
        while sequence_ids[token_start_index] != 1:
            token_start_index += 1

        token_end_index = len(input_ids) - 1
        while sequence_ids[token_end_index] != 1:
            token_end_index -= 1

        if offset_mapping[token_start_index][0] <= start_char and offset_mapping[token_end_index][1] >= end_char:
            while token_start_index < len(offset_mapping) and offset_mapping[token_start_index][0] <= start_char:
                token_start_index += 1
            encoding["start_positions"] = token_start_index - 1
            while offset_mapping[token_end_index][1] >= end_char:
                token_end_index -= 1
            encoding["end_positions"] = token_end_index + 1
            if encoding['start_positions'] < encoding['end_positions']:
                encodings.append(encoding)

    return [InputFeature(enc['id'],
                         enc['input_ids'],
                         enc['token_type_ids'],
                         enc['attention_mask'],
                         enc['start_positions'],
                         enc['end_positions'],
                         enc['offset_mapping'],
                         enc['source']) for enc in encodings]


In [3]:
tokenizer = AutoTokenizer.from_pretrained("klue/bert-base")

In [4]:
data_path = '/home/ubuntu/workspace/kaist.ir/qa/data'
squad_train_data = read_file(os.path.join(data_path, 'korquad/KorQuAD_v1.0_train.json'))
squad_valid_data = read_file(os.path.join(data_path, 'korquad/KorQuAD_v1.0_dev.json'))
train_hub_data = read_file(os.path.join(data_path, 'newsqa/news_train_all_10.json'), has_src=True)
valid_hub_data = read_file(os.path.join(data_path, 'newsqa/news_test_all_10.json'), has_src=True)

In [25]:
print('korquad train:', len(squad_train_data['context']))
print('korquad test:', len(squad_valid_data['context']))
print('news_hub train:', len(train_hub_data['context']))
print('news_hub test:', len(valid_hub_data['context']))

korquad train: 60407
korquad test: 5774
news_hub train: 220256
news_hub test: 23169


In [5]:
train_squad_features = convert_to_features(squad_train_data, tokenizer, max_len=512)
valid_squad_features = convert_to_features(squad_valid_data, tokenizer, max_len=512)

100%|██████████| 60407/60407 [01:03<00:00, 953.25it/s]


In [27]:
train_hub_features = convert_to_features(train_hub_data, tokenizer, max_len=512, has_src=True)
valid_hub_features = convert_to_features(valid_hub_data, tokenizer, max_len=512, has_src=True)

100%|██████████| 220256/220256 [06:54<00:00, 531.96it/s]
100%|██████████| 23169/23169 [00:56<00:00, 409.78it/s]


In [8]:
print('korquad train features:', len(train_squad_features))
print('korquad test features:', len(valid_squad_features))
print('news_hub train features:', len(train_hub_features))
print('news_hub test features:', len(valid_hub_features))

korquad train features: 60141
korquad test features: 5717
news_hub train features: 206859
news_hub test features: 21761


In [28]:
train_sources = [f.source for f in train_hub_features]
valid_sources = [f.source for f in valid_hub_features]
print(Counter(train_sources))
print(Counter(valid_sources))

Counter({1: 32762, 4: 25466, 3: 25444, 7: 22456, 5: 22356, 8: 21861, 6: 19996, 9: 19663, 2: 16855})
Counter({1: 3360, 4: 2713, 3: 2695, 7: 2403, 5: 2371, 8: 2245, 9: 2137, 6: 2090, 2: 1747})


In [6]:
# with open('../data/pkl/train_squad_features.pkl', 'wb') as f:
#     pickle.dump(train_squad_features, f)

# with open('../data/pkl/valid_squad_features.pkl', 'wb') as f:
#     pickle.dump(valid_squad_features, f)

# with open('../data/pkl/train_hub_features.pkl', 'wb') as f:
#     pickle.dump(train_hub_features, f)

# with open('../data/pkl/valid_hub_features.pkl', 'wb') as f:
#     pickle.dump(valid_hub_features, f)

In [4]:
with open('../data/pkl/train_squad_features.pkl', 'rb') as f:
    train_squad_features = pickle.load(f)

with open('../data/pkl/valid_squad_features.pkl', 'rb') as f:
    valid_squad_features = pickle.load(f)

# with open('../data/pkl/train_hub_features.pkl', 'rb') as f:
#     train_hub_features = pickle.load(f)

# with open('../data/pkl/valid_hub_features.pkl', 'rb') as f:
#     valid_hub_features = pickle.load(f)

# Random Select Data

In [5]:
random.seed(1234)
ft_train_hub_features = random.sample(train_hub_features, k=len(train_squad_features))
ft_valid_hub_features = random.sample(valid_hub_features, k=len(valid_squad_features))

In [13]:
# with open('../data/pkl/ft_train_hub_features.pkl', 'wb') as f:
#     pickle.dump(ft_train_hub_features, f)

# with open('../data/pkl/ft_valid_hub_features.pkl', 'wb') as f:
#     pickle.dump(ft_valid_hub_features, f)

In [5]:
# with open('../data/pkl/ft_train_hub_features.pkl', 'rb') as f:
#     ft_train_hub_features = pickle.load(f)

# with open('../data/pkl/ft_valid_hub_features.pkl', 'rb') as f:
#     ft_valid_hub_features = pickle.load(f)

- Counter({1: 9471, 3: 7412, 4: 7350, 5: 6571, 7: 6497, 8: 6278, 6: 5898, 9: 5804, 2: 4860})
- Counter({1: 908, 4: 722, 3: 690, 5: 618, 9: 611, 7: 608, 8: 578, 6: 538, 2: 444})

In [5]:
source = 9
with open(f'../data/pkl/train_hub_features{source}.pkl', 'rb') as f:
    ft_train_hub_features = pickle.load(f)

with open(f'../data/pkl/valid_hub_features{source}.pkl', 'rb') as f:
    ft_valid_hub_features = pickle.load(f)

# ft_valid_hub_features = [f for f in valid_hub_features if f.source == source]

In [6]:
print(len(ft_train_hub_features))
print(len(ft_valid_hub_features))

19663
2137


In [6]:
train_sources = [f.source for f in ft_train_hub_features]
valid_sources = [f.source for f in ft_valid_hub_features]
print(Counter(train_sources))
print(Counter(valid_sources))

Counter({1: 9471, 3: 7412, 4: 7350, 5: 6571, 7: 6497, 8: 6278, 6: 5898, 9: 5804, 2: 4860})
Counter({1: 908, 4: 722, 3: 690, 5: 618, 9: 611, 7: 608, 8: 578, 6: 538, 2: 444})


In [5]:
with open('../data/pkl/ft_train_book_features.pkl', 'rb') as f:
    ft_train_book_features = pickle.load(f)

with open('../data/pkl/ft_valid_book_features.pkl', 'rb') as f:
    ft_valid_book_features = pickle.load(f)

In [7]:
train_features = train_squad_features + ft_train_hub_features
# train_features = train_squad_features + ft_train_book_features
# Make torch
torch_input_ids = [torch.tensor(f.input_ids, dtype=torch.long) for f in train_features]
all_input_ids = torch.cat([ii.unsqueeze(0) for ii in torch_input_ids], dim=0)
torch_token_type_ids = [torch.tensor(f.token_type_ids, dtype=torch.long) for f in train_features]
all_token_type_ids = torch.cat([tti.unsqueeze(0) for tti in torch_token_type_ids], dim=0)
torch_attention_mask = [torch.tensor(f.attention_mask, dtype=torch.long) for f in train_features]
all_attention_mask = torch.cat([am.unsqueeze(0) for am in torch_attention_mask], dim=0)
torch_start_positions = [torch.tensor(f.start_positions, dtype=torch.long) for f in train_features]
all_start_positions = torch.cat([sp.unsqueeze(0) for sp in torch_start_positions], dim=0)
torch_end_positions = [torch.tensor(f.end_positions, dtype=torch.long) for f in train_features]
all_end_positions = torch.cat([ep.unsqueeze(0) for ep in torch_end_positions], dim=0)
# train_dataset = TensorDataset(all_input_ids,
#                                 all_token_type_ids,
#                                 all_attention_mask,
#                                 all_start_positions,
#                                 all_end_positions)
input_type = torch.tensor([0] * len(train_squad_features) + [1] * len(ft_train_hub_features), dtype=torch.long)
train_dataset = TensorDataset(all_input_ids,
                                all_token_type_ids,
                                all_attention_mask,
                                all_start_positions,
                                all_end_positions,
                                input_type)

In [8]:
valid_features = valid_squad_features + ft_valid_hub_features
# valid_features = valid_squad_features + ft_valid_book_features
#Make Torch
torch_input_ids = [torch.tensor(f.input_ids, dtype=torch.long) for f in valid_features]
all_input_ids = torch.cat([ii.unsqueeze(0) for ii in torch_input_ids], dim=0)
torch_token_type_ids = [torch.tensor(f.token_type_ids, dtype=torch.long) for f in valid_features]
all_token_type_ids = torch.cat([tti.unsqueeze(0) for tti in torch_token_type_ids], dim=0)
torch_attention_mask = [torch.tensor(f.attention_mask, dtype=torch.long) for f in valid_features]
all_attention_mask = torch.cat([am.unsqueeze(0) for am in torch_attention_mask], dim=0)
torch_start_positions = [torch.tensor(f.start_positions, dtype=torch.long) for f in valid_features]
all_start_positions = torch.cat([sp.unsqueeze(0) for sp in torch_start_positions], dim=0)
torch_end_positions = [torch.tensor(f.end_positions, dtype=torch.long) for f in valid_features]
all_end_positions = torch.cat([ep.unsqueeze(0) for ep in torch_end_positions], dim=0)
valid_dataset = TensorDataset(all_input_ids,
                                all_token_type_ids,
                                all_attention_mask,
                                all_start_positions,
                                all_end_positions)

In [9]:
class WarmupLinearSchedule(LambdaLR):
    """ Linear warmup and then linear decay.
        Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
        Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
    """
    def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.t_total = t_total
        super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1, self.warmup_steps))
        return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))

In [10]:
def train_log(loss, example_ct, epoch, lr):
    # Where the magic happens
    wandb.log({"epoch": epoch, "loss": loss, "lr": lr}, step=example_ct)
    print(f"Loss after " + str(example_ct).zfill(5) +
          f" examples: {loss:.3f}" + f" with lr: {lr}")


def train_batch(step, batch, model, optimizer, scheduler, beta, sigma, device, config):
    for param_group in optimizer.param_groups:
        lr = param_group['lr']

    input_ids = batch[0].to(device)
    token_type_ids = batch[1].to(device)
    attention_mask = batch[2].to(device)
    start_positions = batch[3].to(device)
    end_positions = batch[4].to(device)
    if config.method == 'contrastive':
        input_type = batch[5].to(device)
        # Forward pass ➡
        outputs = model(input_ids=input_ids,
                        token_type_ids=token_type_ids,
                        attention_mask=attention_mask,
                        start_positions=start_positions,
                        end_positions=end_positions,
                        input_type=input_type,
                        beta=beta,
                        sigma=sigma)
    else:
        outputs = model(input_ids=input_ids,
                        token_type_ids=token_type_ids,
                        attention_mask=attention_mask,
                        start_positions=start_positions,
                        end_positions=end_positions)
    loss = outputs[0]

    if config.gradient_accumulation_steps > 1:
        loss = loss / config.gradient_accumulation_steps

    # Backward pass ⬅
    loss.backward()

    if (step + 1) % config.gradient_accumulation_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
        # Step with optimizer
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()
    return loss, lr


def validation(valid_loader, model, device):
    model.eval()
    with torch.no_grad():
        val_losses, acc = [], []
        valid_example_ct = 0
        for _, batch in tqdm(enumerate(valid_loader), total=len(valid_loader)):
            valid_example_ct += batch[0].shape[0]
            input_ids = batch[0].to(device)
            attention_mask = batch[2].to(device)
            start_positions = batch[3].to(device)
            end_positions = batch[4].to(device)

            outputs = model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            start_positions=start_positions,
                            end_positions=end_positions)

            loss = outputs[0].item()
            val_losses.append(loss)
            start_pred = torch.argmax(outputs[1], dim=1)
            end_pred = torch.argmax(outputs[2], dim=1)
            acc.append(
                ((start_pred == start_positions).sum()/len(start_pred)).item())
            acc.append(((end_pred == end_positions).sum()/len(end_pred)).item())
        avg_loss = sum(val_losses)/len(val_losses)
        avg_acc = sum(acc)/len(acc)

        print(f"Accuracy of the model on the {valid_example_ct} " +
              f"test samples: {100 * avg_acc}% with valid loss: {avg_loss}")

        wandb.log({"valid_loss": avg_loss, "accuracy": avg_acc})
    return avg_loss, avg_acc


def save_model(model, tokenizer, model_path, metric):
    print(f'saving model to {model_path}')
    os.makedirs(model_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(model_path, metric))
    # model.config.to_json_file(os.path.join(model_path, CONFIG_NAME))
    # tokenizer.save_pretrained(model_path)


In [11]:
def make(config, train_dataset, valid_dataset, model):
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True,)
    valid_loader = DataLoader(valid_dataset,
                            batch_size=config.batch_size,
                            shuffle=True,)
    # Make the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    t_total = len(
        train_loader) // config.gradient_accumulation_steps * config.epochs
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': config.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=config.learning_rate, eps=config.adam_epsilon)
    scheduler = WarmupLinearSchedule(
        optimizer, warmup_steps=t_total*config.warmup_proportion, t_total=t_total)

    return train_loader, valid_loader, criterion, optimizer, scheduler


In [12]:
def model_pipeline(config, train_dataset, valid_dataset, tokenizer, model, device):

    # config = wandb.config

    model = model.to(device)
    train_loader, valid_loader, criterion, optimizer, scheduler= make(config, train_dataset, valid_dataset, model)
    # wandb.watch(model, criterion, log="all", log_freq=10)

    example_ct = 0  # number of examples seen
    batch_ct = 0
    best_acc, best_loss = 0, 100.0
    
    for epoch in tqdm(range(config.epochs)):
        for step, batch in enumerate(train_loader):
            loss, lr = train_batch(step, batch, model, optimizer, scheduler, config.beta, config.sigma, device, config)
            example_ct +=  batch[0].shape[0]
            batch_ct += 1

            if ((batch_ct + 1) % config.log_interval) == 0:
                train_log(loss, example_ct, epoch, lr)
        # scheduler.step()
        avg_loss, avg_acc = validation(valid_loader, model, device)
        
        if config.metric == 'loss' or config.metric == "all":
            if best_loss > avg_loss:
                best_loss = avg_loss
                print(f'best loss changed to {best_loss}')
                save_model(model, tokenizer, config.model_path, "pytorch_model.loss.bin")
        if config.metric == 'accuracy' or config.metric == "all":
            if best_acc < avg_acc:
                best_acc = avg_acc
                print(f'best accuracy changed to {best_acc}')
                save_model(model, tokenizer, config.model_path, "pytorch_model.acc.bin")
            
        model.train()

    # return model

In [13]:
import logging
import math
import os
import warnings

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

from transformers import PreTrainedModel, BertPreTrainedModel
from transformers import BertModel
#from transformers.modeling_bert import BertEncoder, BertPooler

In [14]:
class CustomizedBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, sigma=0.0):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        inputs_embeds = inputs_embeds + sigma * torch.randn_like(inputs_embeds, device=inputs_embeds.device)
        
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class CustomizedBertModel(BertModel):
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.embeddings = CustomizedBertEmbeddings(config)
        self.init_weights()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        sigma=0.01,
    ):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, sigma=sigma
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

        outputs = (sequence_output, pooled_output,) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)


class CustomizedBertForQuestionAnswering(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = CustomizedBertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0])+int(target.size()[0])
        total = torch.cat([source, target], dim=0)
    
        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        l2_distance = ((total0-total1)**2).sum(2)
    
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.nansum(l2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]

        kernel_val = [torch.exp(-l2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)
    
    def mmd(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        batch_size = int(source.size()[0])
        kernels = self.guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        loss = torch.mean(XX) + torch.mean(YY) - torch.mean(XY) - torch.mean(YX)
        return loss

    # @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
    # @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
    def forward(
        self,
        input_type=None,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        beta=0.01,
        sigma=0.01,
    ):

        if input_type is not None:
            outputs = self.bert(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                sigma=sigma,
            )

            sequence_output = outputs[0]

            logits = self.qa_outputs(sequence_output)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1)
            end_logits = end_logits.squeeze(-1)

            outputs = (start_logits, end_logits,) + outputs[2:]
            if start_positions is not None and end_positions is not None:
                # If we are on multi-GPU, split add a dimension
                if len(start_positions.size()) > 1:
                    start_positions = start_positions.squeeze(-1)
                if len(end_positions.size()) > 1:
                    end_positions = end_positions.squeeze(-1)
                # sometimes the start/end positions are outside our model inputs, we ignore these terms
                ignored_index = start_logits.size(1)
                start_positions.clamp_(0, ignored_index)
                end_positions.clamp_(0, ignored_index)

                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                start_loss = loss_fct(start_logits, start_positions)
                end_loss = loss_fct(end_logits, end_positions)

                # find answer, context and question position mask
                a_mask_1 = torch.zeros(token_type_ids.shape[0], token_type_ids.shape[1]+1).to(token_type_ids.device)
                a_mask_1[torch.arange(a_mask_1.shape[0]), start_positions] = 1
                a_mask_1 = a_mask_1.cumsum(dim=1)[:, :-1]
                a_mask_2 = torch.zeros(token_type_ids.shape[0], token_type_ids.shape[1]+1).to(token_type_ids.device)
                a_mask_2[torch.arange(a_mask_2.shape[0]), end_positions+1] = 1
                a_mask_2 = a_mask_2.cumsum(dim=1)[:, :-1]
                a_mask = a_mask_1 * (1 - a_mask_2)
                    
                splits = (input_ids == 102) * torch.arange(input_ids.shape[1], 0, -1).to(input_ids.device)
                _, splits = torch.sort(splits, -1, descending=True)
                splits = splits[:, :2]
                # splits = (input_ids == 102).nonzero()[:, 1].reshape(input_ids.size(0),-1)
                c_mask = (token_type_ids == 1) * attention_mask
                c_mask[torch.arange(c_mask.size(0)), splits[:, 0]] = 0
                c_mask[torch.arange(c_mask.size(0)), splits[:, 1]] = 0
                c_mask = c_mask * (1 - a_mask)

                q_mask = (token_type_ids == 0) * attention_mask
                q_mask[torch.arange(q_mask.size(0)), splits[:, 0]] = 0
                q_mask[:, 0] = 0
                
                a_rep = (sequence_output * a_mask.unsqueeze(-1)).sum(1) / a_mask.sum(-1).unsqueeze(-1)
                cq_mask = ((c_mask + q_mask) > 0) * 1.0
                cq_rep = (sequence_output * cq_mask.unsqueeze(-1)).sum(1) / cq_mask.sum(-1).unsqueeze(-1)

                can_loss = -self.mmd(cq_rep, a_rep)
                
                if len((input_type==0).nonzero()[:, 0]) != 0 and len((input_type==1).nonzero()[:, 0]) != 0:
                    a_rep_source = a_rep[(input_type==0).nonzero()[:, 0]].view(-1, a_rep.size(1))
                    a_rep_target = a_rep[(input_type==1).nonzero()[:, 0]].view(-1, a_rep.size(1))
                    cq_rep_source = cq_rep[(input_type==0).nonzero()[:, 0]].view(-1, cq_rep.size(1))
                    cq_rep_target = cq_rep[(input_type==1).nonzero()[:, 0]].view(-1, cq_rep.size(1))

                    can_loss += self.mmd(a_rep_source, a_rep_target) + self.mmd(cq_rep_source, cq_rep_target)
                
                total_loss = (start_loss + end_loss) / 2 + beta * can_loss
                if torch.isnan(total_loss).any():
                    print(start_loss, end_loss, beta, can_loss)
                outputs = (total_loss,) + outputs

            return outputs
        else:
            outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            sigma=0.0
        )

            sequence_output = outputs[0]

            logits = self.qa_outputs(sequence_output)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1).contiguous()
            end_logits = end_logits.squeeze(-1).contiguous()

            outputs = (start_logits, end_logits,) + outputs[2:]

            total_loss = None
            if start_positions is not None and end_positions is not None:
                # If we are on multi-GPU, split add a dimension
                if len(start_positions.size()) > 1:
                    start_positions = start_positions.squeeze(-1)
                if len(end_positions.size()) > 1:
                    end_positions = end_positions.squeeze(-1)
                # sometimes the start/end positions are outside our model inputs, we ignore these terms
                ignored_index = start_logits.size(1)
                start_positions = start_positions.clamp(0, ignored_index)
                end_positions = end_positions.clamp(0, ignored_index)

                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                start_loss = loss_fct(start_logits, start_positions)
                end_loss = loss_fct(end_logits, end_positions)
                total_loss = (start_loss + end_loss) / 2
        return ((total_loss,) + outputs) if total_loss is not None else outputs

In [15]:
class DotDict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [16]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

config = dict(
    epochs=3,
    batch_size=8,
    learning_rate=5e-5,
    model_path='/home/ubuntu/workspace/kaist.ir/qa/model/squad_news9.contra', #change
    method='contrastive', #change
    weight_decay=0.01,
    adam_epsilon=1e-6,
    warmup_proportion=0.1,
    gradient_accumulation_steps=1,
    max_grad_norm=1.0,
    log_interval=100,
    beta=0.001,
    sigma=0.001,
    metric='all')

In [17]:
model = CustomizedBertForQuestionAnswering.from_pretrained('klue/bert-base')
# model = AutoModelForQuestionAnswering.from_pretrained('klue/bert-base')

Some weights of the model checkpoint at klue/bert-base were not used when initializing CustomizedBertForQuestionAnswering: ['bert.embeddings.position_ids', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing CustomizedBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CustomizedBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Cus

In [18]:
wandb.init(project="qa_contrastive", config=config, name='kor+news9_contra_v1.0')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhannabros[0m (use `wandb login --relogin` to force relogin)


In [19]:
config = DotDict(config)

In [20]:
model_pipeline(config, train_dataset, valid_dataset, tokenizer, model, device)

  0%|          | 0/3 [00:00<?, ?it/s]

Loss after 00792 examples: 6.079 with lr: 1.637262763966854e-06
Loss after 01592 examples: 4.571 with lr: 3.307939053728949e-06
Loss after 02392 examples: 2.207 with lr: 4.978615343491045e-06
Loss after 03192 examples: 1.987 with lr: 6.6492916332531405e-06
Loss after 03992 examples: 1.159 with lr: 8.319967923015236e-06
Loss after 04792 examples: 2.277 with lr: 9.990644212777332e-06
Loss after 05592 examples: 1.511 with lr: 1.1661320502539428e-05
Loss after 06392 examples: 0.980 with lr: 1.3331996792301523e-05
Loss after 07192 examples: 0.888 with lr: 1.500267308206362e-05
Loss after 07992 examples: 0.455 with lr: 1.6673349371825715e-05
Loss after 08792 examples: 0.949 with lr: 1.834402566158781e-05
Loss after 09592 examples: 0.782 with lr: 2.0014701951349906e-05
Loss after 10392 examples: 0.267 with lr: 2.1685378241112e-05
Loss after 11192 examples: 1.445 with lr: 2.3356054530874097e-05
Loss after 11992 examples: 0.654 with lr: 2.502673082063619e-05
Loss after 12792 examples: 1.095 wit

100%|██████████| 982/982 [01:21<00:00, 12.03it/s]


Accuracy of the model on the 7854 test samples: 78.22895451864255% with valid loss: 0.8009827233634722
best loss changed to 0.8009827233634722
saving model to /home/ubuntu/workspace/kaist.ir/qa/model/squad_news9.contra
best accuracy changed to 0.7822895451864255
saving model to /home/ubuntu/workspace/kaist.ir/qa/model/squad_news9.contra


 33%|███▎      | 1/3 [48:43<1:37:27, 2923.65s/it]

Loss after 79988 examples: 0.663 with lr: 3.69961982832873e-05
Loss after 80788 examples: 0.437 with lr: 3.681056758442484e-05
Loss after 81588 examples: 0.054 with lr: 3.662493688556239e-05
Loss after 82388 examples: 0.834 with lr: 3.643930618669993e-05
Loss after 83188 examples: 0.459 with lr: 3.625367548783748e-05
Loss after 83988 examples: 0.006 with lr: 3.606804478897503e-05
Loss after 84788 examples: 0.655 with lr: 3.588241409011257e-05
Loss after 85588 examples: 0.932 with lr: 3.569678339125011e-05
Loss after 86388 examples: 0.366 with lr: 3.551115269238766e-05
Loss after 87188 examples: 0.661 with lr: 3.53255219935252e-05
Loss after 87988 examples: 0.163 with lr: 3.5139891294662746e-05
Loss after 88788 examples: 0.206 with lr: 3.4954260595800294e-05
Loss after 89588 examples: 0.210 with lr: 3.4768629896937835e-05
Loss after 90388 examples: 1.317 with lr: 3.458299919807538e-05
Loss after 91188 examples: 0.811 with lr: 3.439736849921293e-05
Loss after 91988 examples: 0.038 with l

100%|██████████| 982/982 [01:21<00:00, 12.04it/s]


Accuracy of the model on the 7854 test samples: 80.10225730005448% with valid loss: 0.7655609007969844
best loss changed to 0.7655609007969844
saving model to /home/ubuntu/workspace/kaist.ir/qa/model/squad_news9.contra
best accuracy changed to 0.8010225730005447
saving model to /home/ubuntu/workspace/kaist.ir/qa/model/squad_news9.contra


 67%|██████▋   | 2/3 [1:37:48<48:56, 2936.26s/it]

Loss after 159984 examples: 0.275 with lr: 1.843312839704179e-05
Loss after 160784 examples: -0.003 with lr: 1.8247497698179334e-05
Loss after 161584 examples: 0.423 with lr: 1.806186699931688e-05
Loss after 162384 examples: 0.175 with lr: 1.7876236300454423e-05
Loss after 163184 examples: 0.015 with lr: 1.769060560159197e-05
Loss after 163984 examples: 0.203 with lr: 1.7504974902729516e-05
Loss after 164784 examples: 0.117 with lr: 1.7319344203867057e-05
Loss after 165584 examples: 0.004 with lr: 1.7133713505004605e-05
Loss after 166384 examples: 0.031 with lr: 1.694808280614215e-05
Loss after 167184 examples: 0.202 with lr: 1.6762452107279694e-05
Loss after 167984 examples: 0.703 with lr: 1.6576821408417238e-05
Loss after 168784 examples: 0.090 with lr: 1.6391190709554783e-05
Loss after 169584 examples: 0.104 with lr: 1.6205560010692327e-05
Loss after 170384 examples: 0.945 with lr: 1.6019929311829872e-05
Loss after 171184 examples: 0.024 with lr: 1.583429861296742e-05
Loss after 171

100%|██████████| 982/982 [01:21<00:00, 12.02it/s]
100%|██████████| 3/3 [2:26:46<00:00, 2935.33s/it]

Accuracy of the model on the 7854 test samples: 79.7479633401222% with valid loss: 0.9962342525576289





In [21]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
accuracy,▁█▇
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅█████████████
loss,▇▄▄▄▃▄▄▂▂▁▁▂▃▃▂▁▄▁▆█▁▁▁▂▁▁▁▁▁▆▂▁▂▁▁▄▂▂▁▁
lr,▂▃▅▇███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁
valid_loss,▂▁█

0,1
accuracy,0.79748
epoch,2.0
loss,0.0129
lr,0.0
valid_loss,0.99623


# Test

In [23]:
train_loader = DataLoader(train_dataset,
                              batch_size=8,
                              shuffle=True,)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CustomizedBertForQuestionAnswering.from_pretrained('klue/bert-base')
model = model.to(device)

In [None]:
t_total = len(train_loader) // config.gradient_accumulation_steps * config.epochs
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay': config.weight_decay},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate, eps=config.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=t_total*config.warmup_proportion, t_total=t_total)

In [24]:
class CustomizedBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, sigma=0.0):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        inputs_embeds = inputs_embeds + sigma * torch.randn_like(inputs_embeds, device=inputs_embeds.device)
        
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class CustomizedBertModel(BertModel):
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.embeddings = CustomizedBertEmbeddings(config)
        self.init_weights()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        sigma=0.01,
    ):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, sigma=sigma
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

        outputs = (sequence_output, pooled_output,) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)


class CustomizedBertForQuestionAnswering(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = CustomizedBertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0])+int(target.size()[0])
        total = torch.cat([source, target], dim=0)
    
        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        l2_distance = ((total0-total1)**2).sum(2)
    
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.nansum(l2_distance.data) / (n_samples**2-n_samples)
            # bandwidth = torch.sum(l2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]

        kernel_val = [torch.exp(-l2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        for val in kernel_val:
            if torch.isnan(val).any():
                print(kernel_val)
        return sum(kernel_val)
    
    def mmd(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        batch_size = int(source.size()[0])
        kernels = self.guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        loss = torch.mean(XX) + torch.mean(YY) - torch.mean(XY) - torch.mean(YX)
        if torch.isnan(loss).any():
            print(loss)
        return loss

    # @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
    # @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
    def forward(
        self,
        input_type=None,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        output_attentions=None,
        output_hidden_states=None,
        beta=0.01,
        sigma=0.01,
    ):

        if input_type is not None:
            outputs = self.bert(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                sigma=sigma,
            )

            sequence_output = outputs[0]

            logits = self.qa_outputs(sequence_output)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1)
            end_logits = end_logits.squeeze(-1)

            outputs = (start_logits, end_logits,) + outputs[2:]
            if start_positions is not None and end_positions is not None:
                # If we are on multi-GPU, split add a dimension
                if len(start_positions.size()) > 1:
                    start_positions = start_positions.squeeze(-1)
                if len(end_positions.size()) > 1:
                    end_positions = end_positions.squeeze(-1)
                # sometimes the start/end positions are outside our model inputs, we ignore these terms
                ignored_index = start_logits.size(1)
                start_positions.clamp_(0, ignored_index)
                end_positions.clamp_(0, ignored_index)

                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                start_loss = loss_fct(start_logits, start_positions)
                end_loss = loss_fct(end_logits, end_positions)

                # find answer, context and question position mask
                a_mask_1 = torch.zeros(token_type_ids.shape[0], token_type_ids.shape[1]+1).to(token_type_ids.device)
                a_mask_1[torch.arange(a_mask_1.shape[0]), start_positions] = 1
                a_mask_1 = a_mask_1.cumsum(dim=1)[:, :-1]
                a_mask_2 = torch.zeros(token_type_ids.shape[0], token_type_ids.shape[1]+1).to(token_type_ids.device)
                a_mask_2[torch.arange(a_mask_2.shape[0]), end_positions+1] = 1
                a_mask_2 = a_mask_2.cumsum(dim=1)[:, :-1]
                a_mask = a_mask_1 * (1 - a_mask_2)
                    
                splits = (input_ids == 102) * torch.arange(input_ids.shape[1], 0, -1).to(input_ids.device)
                _, splits = torch.sort(splits, -1, descending=True)
                splits = splits[:, :2]
                # splits = (input_ids == 102).nonzero()[:, 1].reshape(input_ids.size(0),-1)
                c_mask = (token_type_ids == 1) * attention_mask
                c_mask[torch.arange(c_mask.size(0)), splits[:, 0]] = 0
                c_mask[torch.arange(c_mask.size(0)), splits[:, 1]] = 0
                c_mask = c_mask * (1 - a_mask)

                q_mask = (token_type_ids == 0) * attention_mask
                q_mask[torch.arange(q_mask.size(0)), splits[:, 0]] = 0
                q_mask[:, 0] = 0
                
                a_rep = (sequence_output * a_mask.unsqueeze(-1)).sum(1) / a_mask.sum(-1).unsqueeze(-1)
                cq_mask = ((c_mask + q_mask) > 0) * 1.0
                cq_rep = (sequence_output * cq_mask.unsqueeze(-1)).sum(1) / cq_mask.sum(-1).unsqueeze(-1)
                
                if torch.isnan(cq_rep).any() or torch.isnan(a_rep).any():
                    print(cq_rep)
                    print(a_rep)
                can_loss = -self.mmd(cq_rep, a_rep)
                if torch.isnan(can_loss).any():
                    print(can_loss)
                
                if len((input_type==0).nonzero()[:, 0]) != 0 and len((input_type==1).nonzero()[:, 0]) != 0:
                    a_rep_source = a_rep[(input_type==0).nonzero()[:, 0]].view(-1, a_rep.size(1))
                    a_rep_target = a_rep[(input_type==1).nonzero()[:, 0]].view(-1, a_rep.size(1))
                    cq_rep_source = cq_rep[(input_type==0).nonzero()[:, 0]].view(-1, cq_rep.size(1))
                    cq_rep_target = cq_rep[(input_type==1).nonzero()[:, 0]].view(-1, cq_rep.size(1))

                    if torch.isnan(a_rep_source).any() or torch.isnan(a_rep_target).any() or torch.isnan(cq_rep_source).any() or torch.isnan(cq_rep_target).any():
                        print(a_rep_source)
                        print(a_rep_target)
                        print(cq_rep_source)
                        print(cq_rep_target)
                    can_loss += self.mmd(a_rep_source, a_rep_target) + self.mmd(cq_rep_source, cq_rep_target)
                    if torch.isnan(can_loss).any():
                        print(can_loss)
                
                total_loss = (start_loss + end_loss) / 2 + beta * can_loss
                if torch.isnan(total_loss).any():
                    print(start_loss, end_loss, beta, can_loss)
                outputs = (total_loss,) + outputs

            return outputs
        else:
            outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            sigma=0.0
        )

            sequence_output = outputs[0]

            logits = self.qa_outputs(sequence_output)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1).contiguous()
            end_logits = end_logits.squeeze(-1).contiguous()

            outputs = (start_logits, end_logits,) + outputs[2:]

            total_loss = None
            if start_positions is not None and end_positions is not None:
                # If we are on multi-GPU, split add a dimension
                if len(start_positions.size()) > 1:
                    start_positions = start_positions.squeeze(-1)
                if len(end_positions.size()) > 1:
                    end_positions = end_positions.squeeze(-1)
                # sometimes the start/end positions are outside our model inputs, we ignore these terms
                ignored_index = start_logits.size(1)
                start_positions = start_positions.clamp(0, ignored_index)
                end_positions = end_positions.clamp(0, ignored_index)

                loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
                start_loss = loss_fct(start_logits, start_positions)
                end_loss = loss_fct(end_logits, end_positions)
                total_loss = (start_loss + end_loss) / 2
        return ((total_loss,) + outputs) if total_loss is not None else outputs

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CustomizedBertForQuestionAnswering.from_pretrained('klue/bert-base')
model = model.to(device)

t_total = len(train_loader) // config.gradient_accumulation_steps * config.epochs
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay': config.weight_decay},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate, eps=config.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=t_total*config.warmup_proportion, t_total=t_total)

for idx, d in tqdm(enumerate(train_loader), total=len(train_loader)):
    input_ids = d[0].to(device)
    token_type_ids = d[1].to(device)
    attention_mask = d[2].to(device)
    start_positions = d[3].to(device)
    end_positions = d[4].to(device)
    input_type = d[5].to(device)

    outputs = model(input_ids=input_ids,
                    token_type_ids=token_type_ids,
                    attention_mask=attention_mask,
                    start_positions=start_positions,
                    end_positions=end_positions,
                    input_type=input_type,
                    beta=0.01,
                    sigma=0.01)

    loss = outputs[0]
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    # Step with optimizer
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()

    if idx % 500 == 0:
        print(loss)
    if torch.isnan(loss).any():
        print(input_ids, token_type_ids, attention_mask, start_positions, end_positions, input_type)
        break

Some weights of the model checkpoint at klue/bert-base were not used when initializing CustomizedBertForQuestionAnswering: ['bert.embeddings.position_ids', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing CustomizedBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CustomizedBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Cus

tensor(6.3301, device='cuda:0', grad_fn=<AddBackward0>)


  3%|▎         | 501/15036 [02:22<1:09:10,  3.50it/s]

tensor(2.7735, device='cuda:0', grad_fn=<AddBackward0>)


  7%|▋         | 1001/15036 [04:44<1:06:23,  3.52it/s]

tensor(2.1346, device='cuda:0', grad_fn=<AddBackward0>)


 10%|▉         | 1501/15036 [07:07<1:04:45,  3.48it/s]

tensor(0.8991, device='cuda:0', grad_fn=<AddBackward0>)


 13%|█▎        | 2001/15036 [09:29<1:01:35,  3.53it/s]

tensor(0.7134, device='cuda:0', grad_fn=<AddBackward0>)


 17%|█▋        | 2501/15036 [11:51<59:19,  3.52it/s]

tensor(0.8837, device='cuda:0', grad_fn=<AddBackward0>)


 20%|█▉        | 3001/15036 [14:13<57:00,  3.52it/s]

tensor(0.5410, device='cuda:0', grad_fn=<AddBackward0>)


 23%|██▎       | 3501/15036 [16:35<55:20,  3.47it/s]

tensor(0.1924, device='cuda:0', grad_fn=<AddBackward0>)


 27%|██▋       | 4001/15036 [18:58<52:50,  3.48it/s]

tensor(0.7466, device='cuda:0', grad_fn=<AddBackward0>)


 30%|██▉       | 4501/15036 [21:21<49:53,  3.52it/s]

tensor(1.7122, device='cuda:0', grad_fn=<AddBackward0>)


 33%|███▎      | 5001/15036 [23:44<48:14,  3.47it/s]

tensor(0.9319, device='cuda:0', grad_fn=<AddBackward0>)


 37%|███▋      | 5501/15036 [26:07<45:11,  3.52it/s]

tensor(0.1161, device='cuda:0', grad_fn=<AddBackward0>)


 40%|███▉      | 6001/15036 [28:29<43:07,  3.49it/s]

tensor(0.3111, device='cuda:0', grad_fn=<AddBackward0>)


 43%|████▎     | 6501/15036 [30:53<41:43,  3.41it/s]

tensor(1.2169, device='cuda:0', grad_fn=<AddBackward0>)


 47%|████▋     | 7001/15036 [33:17<38:24,  3.49it/s]

tensor(0.3835, device='cuda:0', grad_fn=<AddBackward0>)


 50%|████▉     | 7501/15036 [35:41<35:58,  3.49it/s]

tensor(0.8966, device='cuda:0', grad_fn=<AddBackward0>)


 53%|█████▎    | 8001/15036 [38:04<33:44,  3.48it/s]

tensor(0.4420, device='cuda:0', grad_fn=<AddBackward0>)


 57%|█████▋    | 8501/15036 [40:27<30:51,  3.53it/s]

tensor(0.8870, device='cuda:0', grad_fn=<AddBackward0>)


 60%|█████▉    | 9001/15036 [42:51<28:37,  3.51it/s]

tensor(1.5651, device='cuda:0', grad_fn=<AddBackward0>)


 63%|██████▎   | 9501/15036 [45:14<26:31,  3.48it/s]

tensor(0.4761, device='cuda:0', grad_fn=<AddBackward0>)


 67%|██████▋   | 10001/15036 [47:37<23:48,  3.52it/s]

tensor(0.4985, device='cuda:0', grad_fn=<AddBackward0>)


 70%|██████▉   | 10501/15036 [49:59<21:40,  3.49it/s]

tensor(0.7667, device='cuda:0', grad_fn=<AddBackward0>)


 73%|███████▎  | 11001/15036 [52:22<19:19,  3.48it/s]

tensor(0.4789, device='cuda:0', grad_fn=<AddBackward0>)


 76%|███████▋  | 11501/15036 [54:48<16:53,  3.49it/s]

tensor(2.1643, device='cuda:0', grad_fn=<AddBackward0>)


 80%|███████▉  | 12001/15036 [57:13<14:29,  3.49it/s]

tensor(0.2914, device='cuda:0', grad_fn=<AddBackward0>)


 83%|████████▎ | 12501/15036 [59:36<12:05,  3.50it/s]

tensor(0.6527, device='cuda:0', grad_fn=<AddBackward0>)


 86%|████████▋ | 13001/15036 [1:01:59<09:39,  3.51it/s]

tensor(0.2810, device='cuda:0', grad_fn=<AddBackward0>)


 90%|████████▉ | 13501/15036 [1:04:21<07:20,  3.49it/s]

tensor(0.5107, device='cuda:0', grad_fn=<AddBackward0>)


 93%|█████████▎| 14001/15036 [1:06:44<04:56,  3.49it/s]

tensor(1.1934, device='cuda:0', grad_fn=<AddBackward0>)


 96%|█████████▋| 14501/15036 [1:09:08<02:31,  3.52it/s]

tensor(1.6701, device='cuda:0', grad_fn=<AddBackward0>)


100%|█████████▉| 15001/15036 [1:11:31<00:09,  3.53it/s]

tensor(0.8176, device='cuda:0', grad_fn=<AddBackward0>)


100%|██████████| 15036/15036 [1:11:41<00:00,  3.50it/s]


In [30]:

t_outputs = model(input_ids=input_ids,
                    token_type_ids=token_type_ids,
                    attention_mask=attention_mask,
                    start_positions=start_positions,
                    end_positions=end_positions,
                    input_type=input_type,
                    beta=0.01,
                    sigma=0.01)

tensor(nan, device='cuda:0', grad_fn=<NllLossBackward0>) tensor(nan, device='cuda:0', grad_fn=<NllLossBackward0>) 0.01 tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
