In [1]:
!nvidia-smi

Sat Oct 30 09:45:49 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  Off  | 00000000:00:05.0 Off |                  Off |
| N/A   41C    P0    38W / 250W |      0MiB / 32510MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from tqdm import trange

from pprint import pprint

from sklearn.feature_extraction.text import TfidfVectorizer

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    BertModel, RobertaModel,
    BertPreTrainedModel,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments,
)
from datasets import (
    Dataset,
    load_from_disk,
    concatenate_datasets,
)

from typing import List

In [3]:
# 난수 고정
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    random.seed(random_seed)
    np.random.seed(random_seed)
    
set_seed(42) # magic number :)

In [4]:
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.7.1].
device:[cuda:0].


# Train

## dpr + sparse embedding (wiki)

In [5]:
# Anwer
from retrieval import SparseRetrieval 

class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        return pooled_output
    
def proprecessing(text):
    new_text = text.replace(r"\n\n", "")
    return new_text


class DenseRetrieval:
    def __init__(self,
        args,
        dataset,
        tokenizer,
        p_encoder,
        q_encoder
    ):
        """
        학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
        """

        self.args = args
        self.train_dataset = dataset

        self.tokenizer = tokenizer
        self.p_encoder = p_encoder
        self.q_encoder = q_encoder
        self.num_neg = 1
        self.prepare_in_batch_negative()
        
    def prepare_in_batch_negative(self,
        train_dataset=None,
        num_neg=1,
        tokenizer=None
    ):
        self.num_neg = num_neg
        
        if train_dataset is None:
            train_dataset = self.train_dataset

        if tokenizer is None:
            tokenizer = self.tokenizer

        batch_size = args.per_device_train_batch_size
        
        # 1. hard-Negative 추가하기
        # sparse embedding -> top1 neg passage
        sparse_retriever =   SparseRetrieval(
                tokenize_fn=tokenizer
            )
        full_ds = concatenate_datasets(
            [
                train_dataset.flatten_indices(),
                # dataset["validation"].flatten_indices(),
            ]
        ) 

        sparse_retriever.get_sparse_embedding()
        df = sparse_retriever.retrieve(full_ds, topk=args.per_device_train_batch_size) # sparse-retriever에서 뽑은 top_k

        p_with_neg = []
        step=0
        
        # CORPUS를 np.array로 변환해줍니다.
        corpus = np.array([proprecessing(example) for example in train_dataset["context"]])
        batch_size = batch_size
        for idx, c in enumerate(train_dataset["context"]):
            if idx % batch_size==0:
                # [step, step+1, .. , step+batch_size]. 
                batch_idxs = np.arange(step*batch_size, min((step+1)*batch_size, len(train_dataset)))
                step +=1
                
            p_neg = list(corpus[batch_idxs])
            
            # sparse-retriever에서 뽑은 top_k indices
            hard_negs = df['context_id'][idx]
            for h in hard_negs:
                if len(p_neg)==batch_size + num_neg: break # 한 배치당 hard_neg num_neg(1)만큼 추가
                
                hard_neg = proprecessing(sparse_retriever.contexts[h]) # sparse-retriever top_k indices -> context
                if hard_neg in p_neg: continue
                p_neg.append(hard_neg)
                
            p_with_neg.extend(p_neg)
                
                
        # 2. (Question, Passage) 데이터셋 만들어주기
        q_seqs = tokenizer(
            train_dataset["question"],
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        p_seqs = tokenizer(
            p_with_neg,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        max_len = p_seqs["input_ids"].size(-1)
        p_seqs["input_ids"] = p_seqs["input_ids"].view(-1, batch_size+num_neg, max_len)
        p_seqs["attention_mask"] = p_seqs["attention_mask"].view(-1, batch_size+num_neg, max_len)
        p_seqs["token_type_ids"] = p_seqs["token_type_ids"].view(-1, batch_size+num_neg, max_len)

        train_dataset = TensorDataset(
            p_seqs["input_ids"], p_seqs["attention_mask"], p_seqs["token_type_ids"], 
            q_seqs["input_ids"], q_seqs["attention_mask"], q_seqs["token_type_ids"]
        )

        self.train_dataloader = DataLoader(
            train_dataset,
            shuffle=True,
            batch_size=self.args.per_device_train_batch_size
        )


    def train(self, args=None, tokenizer = None):
        if args is None:
            args = self.args
        if tokenizer is None :
            tokenizer = self.tokenizer

        train_dataloader = self.train_dataloader

        no_decay = ["bias" ,"LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in self.p_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.p_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
            {"params": [p for n, p in self.q_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.q_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            # eps=args.adam_epsilon
        )

        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
        
        global_step = 0

        self.p_encoder.zero_grad()
        self.q_encoder.zero_grad()
        # torch.cuda.empty_cache()

        train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
        self.q_encoder.train()
        self.p_encoder.train()
        for epoch, _ in enumerate(train_iterator):
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            # loss_value=0 # Accumulation할 때 진행
            losses = 0
            for step, batch in enumerate(epoch_iterator):
                if torch.cuda.is_available():
                    batch = tuple(t.cuda() for t in batch)

                p_inputs = {
                    "input_ids": batch[0].view(args.per_device_train_batch_size * (args.per_device_train_batch_size + self.num_neg), -1),
                    "attention_mask": batch[1].view(args.per_device_train_batch_size * (args.per_device_train_batch_size + self.num_neg), -1),
                    "token_type_ids": batch[2].view(args.per_device_train_batch_size * (args.per_device_train_batch_size + self.num_neg), -1)
                }       
                
                q_inputs = {'input_ids': batch[3],
                            'attention_mask': batch[4],
                            'token_type_ids': batch[5]}

                p_outputs = self.p_encoder(**p_inputs)  # (batch_size+1, emb_dim)
                # p_outputs = torch.transpose(p_outputs.view(batch_size, self.num_neg+1,-1), 1, 2)
                
                q_outputs = self.q_encoder(**q_inputs)  # (batch_size, emb_dim)
                # print(f'p_outputs {p_outputs.shape}')
                # print(f'q_outputs {q_outputs.shape}')
                
                # Calculate similarity score & loss
                sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1))  # (batch_size, emb_dim) x (emb_dim, batch_size+num_neg) = (batch_size, batch_size+num_neg)
                # print(f'sim_scores {sim_scores.shape}')
                
                # target: position of positive samples = diagonal element 
                targets = torch.arange(0, args.per_device_train_batch_size).long()
                
                if torch.cuda.is_available():
                    targets = targets.to('cuda')

                sim_scores = F.log_softmax(sim_scores, dim=1)

                loss = F.nll_loss(sim_scores, targets)
                losses += loss.item()
                if step % 100 == 0 :
                    print(f'{epoch}epoch loss: {losses/(step+1)}') # Accumulation할 경우 주석처리

                
                #################ACCUMULATION###############################
                # loss_value += loss
                # if (step+1) % args.gradient_accumulation_steps == 0 :
                #     optimizer.step()
                #     scheduler.step()
                #     self.q_encoder.zero_grad()
                #     self.p_encoder.zero_grad()
                #     global_step += 1
                #     print(loss_value/args.gradient_accumulation_steps)
                #     loss_value = 0
                ############################################################
                self.q_encoder.zero_grad()
                self.p_encoder.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                global_step += 1
                
                #torch.cuda.empty_cache()
                del p_inputs, q_inputs

        return self.p_encoder, self.q_encoder

In [7]:

args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=1,
    num_train_epochs=10,
    weight_decay=0.01
)

dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']
model_checkpoint = "klue/bert-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
p_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)
q_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)


Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bi

In [8]:
# Retriever는 아래와 같이 사용할 수 있도록 코드를 짜봅시다.
retriever = DenseRetrieval(
    args=args,
    dataset=train_dataset,
    tokenizer=tokenizer,
    p_encoder=p_encoder,
    q_encoder=q_encoder
)

p_encoder, q_encoder = retriever.train()

Loading cached processed dataset at /opt/ml/data/train_dataset/train/cache-fbc57aa6e699fb0c.arrow


Lengths of unique contexts : 56737
Embedding pickle load.


Sparse retrieval:  24%|██▍       | 940/3952 [00:00<00:00, 9393.81it/s]

[query exhaustive search] done in 41.005 s


Sparse retrieval: 100%|██████████| 3952/3952 [00:00<00:00, 9277.03it/s]
Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

0epoch loss: 64.46937561035156
0epoch loss: 8.713131009942234
0epoch loss: 6.130762319659713
0epoch loss: 5.19589794593
0epoch loss: 4.687264641026903
0epoch loss: 4.384292032904254
0epoch loss: 4.174694070006765
0epoch loss: 4.01329259015354
0epoch loss: 3.895762870671895
0epoch loss: 3.801125058587992


Epoch:  10%|█         | 1/10 [12:54<1:56:07, 774.19s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

1epoch loss: 2.7169995307922363
1epoch loss: 3.0098952869377515
1epoch loss: 3.025026746057159
1epoch loss: 3.025579998263489
1epoch loss: 3.0181673922740906
1epoch loss: 3.0139089939361083
1epoch loss: 3.0099677301682966
1epoch loss: 3.0050437270149524
1epoch loss: 3.0017942578605052
1epoch loss: 3.000652249460083


Epoch:  20%|██        | 2/10 [25:47<1:43:11, 773.90s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

2epoch loss: 2.918689250946045
2epoch loss: 2.9625078475121223
2epoch loss: 2.9789974452251227
2epoch loss: 2.9853754360414424
2epoch loss: 2.9793721095582195
2epoch loss: 2.9775493316307755
2epoch loss: 2.9755210602739686
2epoch loss: 2.973076890436627
2epoch loss: 2.9709773798858032
2epoch loss: 2.967497061942182


Epoch:  30%|███       | 3/10 [38:39<1:30:13, 773.36s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

3epoch loss: 3.1869192123413086
3epoch loss: 2.94145410603816
3epoch loss: 2.9549368910528533
3epoch loss: 2.952835509151319
3epoch loss: 2.9561024764529487
3epoch loss: 2.9558922292705545
3epoch loss: 2.9519385688515154
3epoch loss: 2.948497179061302
3epoch loss: 2.947234131721373
3epoch loss: 2.9461469253345283


Epoch:  40%|████      | 4/10 [51:32<1:17:19, 773.22s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

4epoch loss: 3.032198905944824
4epoch loss: 2.9408191973620124
4epoch loss: 2.9315267593706427
4epoch loss: 2.93070309898782
4epoch loss: 2.9260703274734
4epoch loss: 2.9266225468374776
4epoch loss: 2.930595846223752
4epoch loss: 2.9261756382042945
4epoch loss: 2.927038513616974
4epoch loss: 2.925711704014938


Epoch:  50%|█████     | 5/10 [1:04:25<1:04:25, 773.06s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

5epoch loss: 2.5492897033691406
5epoch loss: 2.904755717456931
5epoch loss: 2.9007016485603296
5epoch loss: 2.911125889648235
5epoch loss: 2.9096349867204774
5epoch loss: 2.910969939774382
5epoch loss: 2.91470556727265
5epoch loss: 2.9141915585957308
5epoch loss: 2.916206397069676
5epoch loss: 2.9171998162116117


Epoch:  60%|██████    | 6/10 [1:17:17<51:31, 772.89s/it]  




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

6epoch loss: 3.044895648956299
6epoch loss: 2.907477570052194
6epoch loss: 2.915490242972303
6epoch loss: 2.909752833882836
6epoch loss: 2.9103252816378626
6epoch loss: 2.9106368280932338
6epoch loss: 2.9081584229048794
6epoch loss: 2.908284685921227
6epoch loss: 2.9079762284972994
6epoch loss: 2.90667826565733


Epoch:  70%|███████   | 7/10 [1:30:09<38:38, 772.72s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

7epoch loss: 2.8771088123321533
7epoch loss: 2.9034086147157274
7epoch loss: 2.9121520021068514
7epoch loss: 2.9078104979176063
7epoch loss: 2.9052108005989816
7epoch loss: 2.904405474424838
7epoch loss: 2.9054545618333356
7epoch loss: 2.905327877542602
7epoch loss: 2.9054859247100486
7epoch loss: 2.9046683176508488


Epoch:  80%|████████  | 8/10 [1:43:01<25:44, 772.44s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

8epoch loss: 2.8463616371154785
8epoch loss: 2.899140546817591
8epoch loss: 2.901031770516391
8epoch loss: 2.90061831553513
8epoch loss: 2.8992427590481955
8epoch loss: 2.8991127061748694
8epoch loss: 2.899466374947902
8epoch loss: 2.9007860226569946
8epoch loss: 2.899254776565323
8epoch loss: 2.8990515458597064


Epoch:  90%|█████████ | 9/10 [1:55:53<12:52, 772.38s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=988.0, style=ProgressStyle(description_wi…

9epoch loss: 2.7852673530578613
9epoch loss: 2.8585264871616176
9epoch loss: 2.875302117855395
9epoch loss: 2.8783825925022266
9epoch loss: 2.8839872406605176
9epoch loss: 2.8866678788038547
9epoch loss: 2.8888835323828825
9epoch loss: 2.890118754029104
9epoch loss: 2.890539919690098
9epoch loss: 2.891009535032689


Epoch: 100%|██████████| 10/10 [2:08:45<00:00, 772.59s/it]







In [None]:
torch.save(q_encoder, '/opt/ml/models/q_encoder.pt')
torch.save(p_encoder, '/opt/ml/models/p_encoder.pt')