## 0. install , import

In [None]:
!pip install datasets
#!pip install transformers
!pip install git+https://github.com/huggingface/transformers
!pip install wandb --upgrade
!wandb login
import wandb
import os
wandb.init(project="dpr")
os.environ['WANDB_LOG_MODEL'] = 'true' #false by default
os.environ['WANDB_WATCH'] = 'all'


from datasets import load_dataset, load_from_disk, concatenate_datasets
from transformers import AutoTokenizer
import numpy as np
from tqdm.notebook import tqdm,trange
import argparse
import random
import torch
import torch.nn.functional as F
from transformers import BertModel, BertPreTrainedModel, AdamW, TrainingArguments, get_linear_schedule_with_warmup
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)

## 1.parameters

In [None]:
model_checkpoint = "bert-base-multilingual-cased"
seed = 2021


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

class BertEncoder(BertPreTrainedModel):
  def __init__(self, config):
    super(BertEncoder, self).__init__(config)

    self.bert = BertModel(config)
    self.init_weights()
      
  def forward(self, input_ids, 
              attention_mask=None, token_type_ids=None): 
  
      outputs = self.bert(input_ids,
                          attention_mask=attention_mask,
                          token_type_ids=token_type_ids)
      
      pooled_output = outputs[1]

      return pooled_output

seed_everything(seed)
args = TrainingArguments(
    output_dir="/opt/ml/input/dpr",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    dataloader_drop_last  = True,
    dataloader_num_workers  = 4, 
    seed = seed,
    gradient_accumulation_steps  = 2,
)

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)



## 2. dataset

In [None]:
dataset = load_dataset("squad_kor_v1")
mrc_dataset = load_from_disk('/opt/ml/input/data/data/train_dataset')
mrc_dataset = mrc_dataset.remove_columns(['__index_level_0__', 'document_id'])
mrc_dataset_train = mrc_dataset['train'].map(features=dataset['train'].features) #, keep_in_memory=True
mrc_dataset_validation = mrc_dataset['validation'].map(features=dataset['validation'].features) #, keep_in_memory=True
dataset['train'] = concatenate_datasets([mrc_dataset_train, dataset['train']])
dataset['validation'] = concatenate_datasets([mrc_dataset_validation, dataset['validation']])


training_dataset = dataset['train']
q_seqs = tokenizer(training_dataset['question'], padding="max_length", truncation=True, return_tensors='pt')
p_seqs = tokenizer(training_dataset['context'], padding="max_length", truncation=True, return_tensors='pt')
train_dataset = TensorDataset(p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'], 
                        q_seqs['input_ids'], q_seqs['attention_mask'], q_seqs['token_type_ids'])


validate_dataset = dataset['validation']
v_q_seqs = tokenizer(validate_dataset['question'], padding='max_length', truncation=True, return_tensors='pt')
v_p_seqs = tokenizer(validate_dataset['context'], padding='max_length', truncation=True, return_tensors='pt')
valid_dataset = TensorDataset(v_p_seqs['input_ids'], v_p_seqs['attention_mask'], v_p_seqs['token_type_ids'],
                              v_q_seqs['input_ids'], v_q_seqs['attention_mask'], v_q_seqs['token_type_ids'])

In [None]:
training_dataset

In [None]:
validate_dataset

In [None]:
p_encoder = BertEncoder.from_pretrained(model_checkpoint)
q_encoder = BertEncoder.from_pretrained(model_checkpoint)

if torch.cuda.is_available():
    p_encoder.cuda()
    q_encoder.cuda()

print("done")

## 3. Train

In [None]:
# Dataloader
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset,  batch_size=args.per_device_train_batch_size, drop_last = True,shuffle=True)

valid_sampler = RandomSampler(valid_dataset)
valid_loader = DataLoader(valid_dataset,  batch_size=args.per_device_eval_batch_size, drop_last = True) #sampler=valid_sampler,


t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
optimizer = AdamW([
            {'params': p_encoder.parameters()},
            {'params': q_encoder.parameters()}
        ], lr=args.learning_rate, weight_decay=args.weight_decay)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)

# Start training!
global_step = 0
best_acc = 0.0
best_step = 0


p_encoder.zero_grad()
q_encoder.zero_grad()
torch.cuda.empty_cache()

train_iterator = trange(int(args.num_train_epochs), desc="Epoch")

for _ in tqdm(train_iterator):
    epoch_iterator = train_dataloader
    train_losses = []
    for step, batch in tqdm(enumerate(epoch_iterator)):
      q_encoder.train()
      p_encoder.train()

      if torch.cuda.is_available():
        batch = tuple(t.cuda() for t in batch)

      p_inputs = {'input_ids': batch[0],
                  'attention_mask': batch[1],
                  'token_type_ids': batch[2]
                  }

      q_inputs = {'input_ids': batch[3],
                  'attention_mask': batch[4],
                  'token_type_ids': batch[5]}

      p_outputs = p_encoder(**p_inputs)  # (batch_size, emb_dim)
      q_outputs = q_encoder(**q_inputs)  # (batch_size, emb_dim)


      # Calculate similarity score & loss
      sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1))  # (batch_size, emb_dim) x (emb_dim, batch_size) = (batch_size, batch_size)

      # target: position of positive samples = diagonal element 
      targets = torch.arange(0, sim_scores.shape[0]).long()
      if torch.cuda.is_available():
        targets = targets.to('cuda')

      sim_scores = F.log_softmax(sim_scores, dim=1)

      loss = F.nll_loss(sim_scores, targets)


      loss.backward()
      optimizer.step()
      scheduler.step()
      q_encoder.zero_grad()
      p_encoder.zero_grad()
      global_step += 1


      if global_step % 5000 == 0 :
        with torch.no_grad():
          p_encoder.eval()
          q_encoder.eval()

          valid_loss = 0.0
          valid_acc = 0.0
          avg_valid_loss = 0.0
          eval_correct = 0
          eval_total = 0
          for batch in tqdm(valid_loader):

            batch = tuple(t.cuda() for t in batch)
            p_inputs = {'input_ids': batch[0],
                        'attention_mask': batch[1],
                        'token_type_ids': batch[2]
                        }

            q_inputs = {'input_ids': batch[3],
                        'attention_mask': batch[4],
                        'token_type_ids': batch[5]
                        }
            p_outputs = p_encoder(**p_inputs)
            q_outputs = q_encoder(**q_inputs)

            sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1))
            sim_scores = F.log_softmax(sim_scores, dim=1)
            targets = torch.arange(0, sim_scores.shape[0]).long()
            if torch.cuda.is_available():
              targets = targets.to('cuda')

            predict = torch.argmax(sim_scores, dim=1).long()
            valid_loss= F.nll_loss(sim_scores, targets)

            avg_valid_loss +=valid_loss.item()
            eval_correct += (targets == predict).sum().item()
            eval_total+=sim_scores.shape[0]

          train_loss = loss.item()
          valid_acc =  eval_correct/eval_total
          avg_valid_loss = avg_valid_loss/eval_total

          print( f"train loss:{train_loss} eval acc:{valid_acc} avg_val_loss:{avg_valid_loss}")
          wandb.log({"train_loss":train_loss, 'eval acc': valid_acc, 'avg_val_loss': avg_valid_loss})

          if valid_acc > best_acc:
            best_acc = valid_acc
            best_step = global_step
            print(f"best_acc: {best_acc} ,best_step_saved:{best_step}")
            wandb.log({"best_acc":best_acc, 'best_step_saved loss': best_step})

            torch.save(p_encoder.state_dict(), f"/opt/ml/input/dpr/best_p_model.pth")   
            torch.save(q_encoder.state_dict(), f"/opt/ml/input/dpr/best_q_model.pth")   

      torch.cuda.empty_cache()



print("done")

In [None]:
torch.save(p_encoder.state_dict(), f"/opt/ml/input/dpr/best_p_model_last.pth")   
torch.save(q_encoder.state_dict(), f"/opt/ml/input/dpr/best_q_model_last.pth")   
print("done")
