In [None]:
import os
import sys
import torch
from transformers import BertTokenizer,squad_convert_examples_to_features, AutoConfig, AutoModelForQuestionAnswering
from transformers.data.processors.squad import SquadV2Processor
from torch.utils.data import DataLoader, RandomSampler
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
#from ProbeAttention import *

In [None]:
model_dir = 'bert-base-uncased'
model_prefix = 'bert-base-uncased'
data_dir = ''
data_file = 'train-v2.0.json'
max_seq_length = 384
res_size = 3
non_linear = "gelu"
project_dim = 200
layers = 12
hidden_dim = 768
epochs = 5
batch_size = 8
adam_epsilon = 1e-8
max_grad_norm = 0.1
dropout_r = 0.3
lr = 3e-5

In [None]:
# GPU
if torch.cuda.is_available():       
    device = 'cuda'
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = 'cpu'

# Tokenizer
tokenizer = BertTokenizer.from_pretrained(model_prefix)

# Extract examples
processor = SquadV2Processor()
train_examples = processor.get_train_examples(data_dir=data_dir, filename=data_file)

In [None]:
# Extract train features
print("Loading train features")
train_features, train_dataset = squad_convert_examples_to_features(
    examples=train_examples[0:50000],
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    doc_stride=128,
    max_query_length=64,
    is_training=True,
    return_dataset="pt",
    threads=1,
)

In [None]:
# Initialize model
#from ProbeAttention import *

config = AutoConfig.from_pretrained(model_prefix, output_hidden_states = True)
model = AutoModelForQuestionAnswering.from_pretrained(model_prefix, config = config)

# multi-gpu evaluate one at here
model = model.to(device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AdamW
import numpy as np

class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, n_heads=4):
        super(MultiHeadAttention, self).__init__()   
        assert in_dim % n_heads == 0
        self.d = in_dim//n_heads
        self.n_heads = n_heads
        
        self.WQ = nn.Linear(in_dim, self.d * self.n_heads)
        self.WK = nn.Linear(in_dim, self.d * self.n_heads)
        self.WV = nn.Linear(in_dim, self.d * self.n_heads)
        
        self.linear = nn.Linear(self.n_heads * self.d, in_dim)
        self.layer_norm = nn.LayerNorm(in_dim)
        
    def forward(self, h): # (8, 384, 200)
        batch_size = h.shape[0]
        q_s = self.WQ(h).view(batch_size, -1, self.n_heads, self.d).transpose(1, 2) # (8, 4, 384, 50)
        k_s = self.WK(h).view(batch_size, -1, self.n_heads, self.d).transpose(1, 2)
        v_s = self.WV(h).view(batch_size, -1, self.n_heads, self.d).transpose(1, 2) 

        scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(self.d) #(8, 4, 384, 384)
        attn = F.softmax(scores, dim=-1) 
        context = torch.matmul(attn, v_s) #(8, 4, 384, 50)

        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d)
        output = self.linear(context)
        return self.layer_norm(output + h)
        

class Adapter(nn.Module):
    def __init__(self, in_dim, project_dim, non_linear, p = dropout_r, max_seq = max_seq_length):
        super(Adapter, self).__init__()        
        assert non_linear in ["relu","gelu","tanh"]
        
        if (non_linear == "relu"):
            self.non_linear = nn.ReLU()
        elif (non_linear == "gelu"):
            self.non_linear = nn.GELU()
        elif (non_linear == "tanh"):
            self.non_linear = nn.Tanh()
                
        self.project_down = nn.Linear(in_dim, project_dim)
        self.project_up = nn.Linear(project_dim, in_dim)
        self.dropout = nn.Dropout(p=p)
        self.batchnorm = nn.BatchNorm1d(max_seq)
        self.layernorm = nn.LayerNorm(in_dim, max_seq)
        self.attention = MultiHeadAttention(project_dim)
        
    def forward(self, h):
        h = self.project_down(h)
        h = self.batchnorm(h)
        h = self.attention(h)
        h = self.dropout(h)
        h = self.project_up(h)
        h = self.layernorm(h)
        
        return h

class ResAdapter(nn.Module):
    def __init__(self, in_dim, project_dim, non_linear, res_size, max_seq=max_seq_length):
        super(ResAdapter, self).__init__()
        
        self.res_size = res_size
        self.adapter_list = nn.ModuleList([Adapter(in_dim, project_dim, non_linear) for i in range(res_size)])
        
    def forward(self, h_list, h_last=None):
        h = torch.zeros(h_list[0].size()).to(device)
        
        if (h_last != None):
            h = h_last
            
        for i in range(res_size):
            if (i == 0):
                h = self.adapter_list[i](h_list[i]+h) + h_list[i+1]
            elif (i==res_size-1):
                h = self.adapter_list[i](h) + h_list[0]
            else:
                h = self.adapter_list[i](h) + h_list[i+1]
        
        return h

class ResAdapterModel(nn.Module):
    def __init__(self, in_dim, project_dim, non_linear, res_size, max_seq):
        super(ResAdapterModel, self).__init__()
        assert (12 % res_size == 0)
        
        self.res_size = res_size
        self.res_list = nn.ModuleList([ResAdapter(in_dim, project_dim, non_linear, res_size) for i in range(12//res_size)])
        self.linear = nn.Linear(in_dim, 1)
    
    def forward(self, all_h):
        h = torch.zeros(all_h[0].size()).to(device)
        
        h_list = []
        for i in range(12):
            h_list.append(all_h[i])
            if (i%res_size==res_size-1):
                h = self.res_list[i//res_size](h_list, h)
                h_list = []
        return self.linear(h).unsqueeze(0)

In [None]:
# Initialize adaptors
print("Initializing adaptors")
adaptor_s = ResAdapterModel(hidden_dim, project_dim, non_linear, res_size, max_seq_length)
adaptor_e = ResAdapterModel(hidden_dim, project_dim, non_linear, res_size, max_seq_length)

adaptor_s.to(device)
adaptor_e.to(device)

start_optimizer = AdamW(adaptor_s.parameters(), lr=lr, eps=adam_epsilon, correct_bias=False)
end_optimizer = AdamW(adaptor_e.parameters(), lr=lr, eps=adam_epsilon, correct_bias=False)

In [None]:
# minimum loss
min_loss = 1000000000
    
# start & end hidden state
start_hidden={}
end_hidden={}

# create results folder
if not os.path.exists('results'):
        os.mkdir('results')
        
if not os.path.exists('results'+'/'+ model_prefix):
        os.mkdir('results'+'/'+ model_prefix)

In [None]:
# Training epochs
for epoch in range(epochs):
    
    # create each epochs folder    
    epoch_dir = "results"+"/"+ model_prefix + "/"+"epoch_" + str(epoch + 1)
    if not os.path.exists(epoch_dir):
        os.mkdir(epoch_dir)
    
    print("Training epoch: {}".format(epoch+1))
    adaptor_s.train()
    adaptor_e.train()

    # Track epoch loss
    epoch_loss = 0

    # Initialize train data loader
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler = train_sampler, batch_size = batch_size)
    loss_list = []
    
    iter=0
    for batch in tqdm(train_dataloader, desc = "Iteration"):
        
        # Get batch on the right device and prepare input dict
        batch = tuple(t.to(device) for t in batch)

        inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2],
            "start_positions": batch[3],
            "end_positions": batch[4],
        }

        # BERT forward pass
        model.eval()
        with torch.no_grad():
            outputs = model(**inputs)

        # Extract hiddent states
        all_layer_hidden_states = outputs[3][1:] # (layers, batch_size, max_seq_len, hidden_size)
        # Get labels, and update probes for batch
        start_targets = batch[3] # (batch_size)
        end_targets  = batch[4] # (batch_size)
        
        # in: (layers, batch_size, max_seq_len, hidden_size)
        # out: (layers, batch_size, max_seq_len, 1)
#         print(torch.stack(list(all_layer_hidden_states)).size())# ->　torch.Size([12, 8, 384, 768])
        s_scores = adaptor_s(list(all_layer_hidden_states)).squeeze() 
#         print(s_scores.size())
        e_scores = adaptor_e(list(all_layer_hidden_states)).squeeze()      
        
        ignored_index = s_scores.size(1)
        size_of_batch = s_scores.size(0)
        # print(ignored_index) 384
        s_scores.clamp_(0, ignored_index) # (8, 384)
        e_scores.clamp_(0, ignored_index) # (8, 384)
        
        start_loss = nn.CrossEntropyLoss(
            weight = None, ignore_index=ignored_index)(s_scores, start_targets)
        end_loss = nn.CrossEntropyLoss(
            weight = None, ignore_index=ignored_index)(e_scores, end_targets)

        loss = (1.5*start_loss+end_loss)/2.5
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(
            adaptor_s.parameters(), max_grad_norm)
        torch.nn.utils.clip_grad_norm_(
            adaptor_e.parameters(), max_grad_norm)

        start_optimizer.step()
        end_optimizer.step()

        adaptor_s.zero_grad()
        adaptor_e.zero_grad()
        
        batch_loss = float(loss)
        loss_list.append(batch_loss)
        
        if (iter != 0 and iter%100==0):
            plt.plot(loss_list)
            plt.show()
            plt.plot(loss_list,'o')
            plt.show()
        
        epoch_loss+=batch_loss
        iter+=1
        
        # store min loss's hidden state
        if iter % 100 == 0:
            torch.save(adaptor_s.state_dict(), epoch_dir + "/" + "_start_idx_per100")
            torch.save(adaptor_e.state_dict(), epoch_dir + "/" + "_end_idx_per100")
        if batch_loss < min_loss:
            torch.save(adaptor_s.state_dict(), epoch_dir + "/" + "_start_idx")
            torch.save(adaptor_e.state_dict(), epoch_dir + "/" + "_end_idx")
            min_loss = batch_loss        

    print("Epoch loss {}".format(epoch_loss))
    print("Epoch {} complete, saving probes".format(epoch+1))