In [1]:
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from tqdm import trange

from pprint import pprint

from sklearn.feature_extraction.text import TfidfVectorizer

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    BertModel, RobertaModel,
    BertPreTrainedModel,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments,
)
from datasets import (
    Dataset,
    load_from_disk,
    concatenate_datasets,
)

from typing import List
from torch.utils.data import Sampler

In [2]:
# 난수 고정
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    random.seed(random_seed)
    np.random.seed(random_seed)
    
set_seed(42) # magic number :)

In [3]:
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.7.1].
device:[cuda:0].


## Tokenizer 체크

In [15]:
model_checkpoint = 'klue/bert-base'

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# num_added_toks = tokenizer.add_tokens(['\\n']) # 32000번째로 삽입

In [7]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']

In [None]:
from tqdm import tqdm

unk_count = []
for i in tqdm(range(len(train_dataset))) :
    unk_count.append(tokenizer(train_dataset['context'][i], max_length = True)['input_ids'].count(1)) # 1 == UNK

In [None]:
print(np.median(unk_count))
print(np.mean(unk_count))
print(np.max(unk_count))
print(np.min(unk_count))

## Training


In [4]:
class BertEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(BertEncoder, self).__init__(config)

        self.bert = BertModel(config)
        self.init_weights()
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            token_type_ids=None
        ): 

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[1]
        return pooled_output

In [5]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']

In [6]:
from datasets import Dataset
dataset = pd.read_csv('/opt/ml/data/train_dataset/Aug_Encoder.csv')
dataset = Dataset.from_pandas(dataset)

In [7]:
class CustomSampler(Sampler) :
    def __init__(self, data_source, batch_size) :
        self.data_source = data_source
        self.batch_size = batch_size

    def __iter__(self) :
        n = len(self.data_source)
        index_list = []
        while True :
            out = True
            for i in range(self.batch_size) :
                tmp_data = random.randint(0, n-1)
                index_list.append(tmp_data)
            for f, s in zip(index_list, index_list[1:]) :
                if abs(s-f) <= 2 :
                    out = False
            if out == True :
                break

        while True : # 추가 삽입
            tmp_data = random.randint(0, n-1)
            if (tmp_data not in index_list) and \
                (abs(tmp_data-index_list[-i]) > 2 for i in range(1,self.batch_size+1)) \
            : 
                index_list.append(tmp_data)
            if len(index_list) == n :
                break
        return iter(index_list)

    def __len__(self) :
        return len(self.data_source)

In [8]:
# Anwer
class DenseRetrieval:
    def __init__(self,
        args,
        dataset,
        tokenizer,
        p_encoder,
        q_encoder,
        sampler
    ):
        """
        학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
        """

        self.args = args
        self.dataset = dataset

        self.tokenizer = tokenizer
        self.p_encoder = p_encoder
        self.q_encoder = q_encoder
        self.sampler = sampler

    def train(self, args=None, tokenizer = None):
        if args is None:
            args = self.args
        if tokenizer is None :
            tokenizer = self.tokenizer

        for i in (range(len(self.dataset))) :
            if i == 0 :
                q_seqs = tokenizer(
                    self.dataset[i]['question'],
                    padding='max_length',
                    max_length=512,
                    truncation = True,
                    return_tensors='pt'
                )
                p_seqs = tokenizer(
                    self.dataset[i]['context'],
                    truncation = True,
                    stride = 128,
                    padding='max_length',
                    max_length=512,
                    return_overflowing_tokens=True,
                    return_offsets_mapping=True,
                    return_tensors='pt'
                )

                p_seqs.pop('overflow_to_sample_mapping')
                p_seqs.pop('offset_mapping')

                for k in q_seqs.keys() :
                    q_seqs[k] = q_seqs[k].tolist()
                    p_seqs[k] = p_seqs[k].tolist()
                
                initial_p_seqs_length = len(p_seqs['input_ids'])
                for j in range(len(p_seqs['input_ids'])) :
                    q_seqs['input_ids'].append(q_seqs['input_ids'][0])
                    q_seqs['token_type_ids'].append(q_seqs['token_type_ids'][0])
                    q_seqs['attention_mask'].append(q_seqs['attention_mask'][0])
                    p_seqs['input_ids'].append(p_seqs['input_ids'][j])
                    p_seqs['token_type_ids'].append(p_seqs['token_type_ids'][j])
                    p_seqs['attention_mask'].append(p_seqs['attention_mask'][j])

                q_seqs['input_ids'] = q_seqs['input_ids'][1:]
                q_seqs['token_type_ids'] = q_seqs['token_type_ids'][1:]
                q_seqs['attention_mask'] = q_seqs['attention_mask'][1:]
                p_seqs['input_ids'] = p_seqs['input_ids'][initial_p_seqs_length:]
                p_seqs['token_type_ids'] = p_seqs['token_type_ids'][initial_p_seqs_length:]
                p_seqs['attention_mask'] = p_seqs['attention_mask'][initial_p_seqs_length:]

            else :
                tmp_q_seq = tokenizer(
                    self.dataset[i]['question'],
                    padding='max_length',
                    max_length=512,
                    truncation = True,
                    return_tensors='pt')
                tmp_p_seq = tokenizer(
                    self.dataset[i]['context'],
                    truncation = True,
                    stride = 128,
                    padding='max_length',
                    max_length=512,
                    return_overflowing_tokens=True,
                    return_offsets_mapping=True,
                    return_tensors='pt'
                )

                tmp_p_seq.pop('overflow_to_sample_mapping')
                tmp_p_seq.pop('offset_mapping')

                for k in tmp_p_seq.keys() :
                    tmp_p_seq[k] = tmp_p_seq[k].tolist()
                    tmp_q_seq[k] = tmp_q_seq[k].tolist()


                for j in range(len(tmp_p_seq['input_ids'])) :
                    q_seqs['input_ids'].append(tmp_q_seq['input_ids'][0])
                    q_seqs['token_type_ids'].append(tmp_q_seq['token_type_ids'][0])
                    q_seqs['attention_mask'].append(tmp_q_seq['attention_mask'][0])
                    p_seqs['input_ids'].append(tmp_p_seq['input_ids'][j])
                    p_seqs['token_type_ids'].append(tmp_p_seq['token_type_ids'][j])
                    p_seqs['attention_mask'].append(tmp_p_seq['attention_mask'][j])

        for k in q_seqs.keys() :
            q_seqs[k] = torch.tensor(q_seqs[k])
            p_seqs[k] = torch.tensor(p_seqs[k])

        train_dataset = TensorDataset(p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'], 
                        q_seqs['input_ids'], q_seqs['attention_mask'], q_seqs['token_type_ids'])
        sampler = self.sampler(train_dataset, args.per_device_train_batch_size)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.per_device_train_batch_size,
                                      sampler = sampler)

        no_decay = ["bias" ,"LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in self.p_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.p_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
            {"params": [p for n, p in self.q_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.q_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            # eps=args.adam_epsilon
        )

        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
        
        global_step = 0

        self.p_encoder.zero_grad()
        self.q_encoder.zero_grad()
        # torch.cuda.empty_cache()

        train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
        self.q_encoder.train()
        self.p_encoder.train()
        for epoch, _ in enumerate(train_iterator):
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            loss_value=0 # Accumulation할 때 진행
            losses = 0
            for step, batch in enumerate(epoch_iterator):
                if torch.cuda.is_available():
                    batch = tuple(t.cuda() for t in batch)

                p_inputs = {'input_ids': batch[0],
                            'attention_mask': batch[1],
                            'token_type_ids': batch[2]
                            }
                
                q_inputs = {'input_ids': batch[3],
                            'attention_mask': batch[4],
                            'token_type_ids': batch[5]}

                p_outputs = self.p_encoder(**p_inputs)  # (batch_size, emb_dim)
                q_outputs = self.q_encoder(**q_inputs)  # (batch_size, emb_dim)

                # Calculate similarity score & loss
                sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1))  # (batch_size, emb_dim) x (emb_dim, batch_size) = (batch_size, batch_size)

                # target: position of positive samples = diagonal element 
                # targets = torch.arange(0, args.per_device_train_batch_size).long()
                targets = torch.arange(0, len(p_inputs['input_ids'])).long()
                
                if torch.cuda.is_available():
                    targets = targets.to('cuda')
                    
                sim_scores = F.log_softmax(sim_scores, dim=1)

                loss = F.nll_loss(sim_scores, targets)
                losses += loss.item()
                if step % 100 == 0 :
                    print(f'{epoch}epoch loss: {losses/(step+1)}') # Accumulation할 경우 주석처리

                # ################ACCUMULATION###############################
                # if (step+1) % args.gradient_accumulation_steps == 0 :
                #     optimizer.step()
                #     scheduler.step()
                #     self.q_encoder.zero_grad()
                #     self.p_encoder.zero_grad()

                # losses += loss.item()
                # if (step+1) % 64 == 0:
                #     train_loss = losses / 64
                #     print(f'training loss: {train_loss:4.4}')
                #     losses = 0
                # ###########################################################
                self.q_encoder.zero_grad()
                self.p_encoder.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                # global_step += 1
                
                #torch.cuda.empty_cache()
                del p_inputs, q_inputs

        return self.p_encoder, self.q_encoder

In [9]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=1,
    num_train_epochs=20,
    weight_decay=0.01
)
model_checkpoint = "klue/bert-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# num_added_toks = tokenizer.add_tokens(['\\n']) # 32000번째로 삽입
p_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)
q_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.de

In [10]:
# sampler = CustomSampler()
# Retriever는 아래와 같이 사용할 수 있도록 코드를 짜봅시다.
retriever = DenseRetrieval(
    args=args,
    dataset=dataset,
    tokenizer=tokenizer,
    p_encoder=p_encoder,
    q_encoder=q_encoder,
    sampler = CustomSampler
)
p_encoder, q_encoder = retriever.train()

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

0epoch loss: 18.882301330566406
0epoch loss: 2.747803631681248
0epoch loss: 1.8054619641900433
0epoch loss: 1.4123133636636955
0epoch loss: 1.1848713733812755
0epoch loss: 1.0417346696662655
0epoch loss: 0.9306071474814306
0epoch loss: 0.8390770195624015
0epoch loss: 0.7728449736591047
0epoch loss: 0.7189405193238176
0epoch loss: 0.6735993443199785
0epoch loss: 0.6318205580153833
0epoch loss: 0.6004730541335069
0epoch loss: 0.5747150797251602
0epoch loss: 0.5516593690286893
0epoch loss: 0.5282819543282045
0epoch loss: 0.509971809012273
0epoch loss: 0.49323402389324744
0epoch loss: 0.47776585388611437
0epoch loss: 0.4650395929750753
0epoch loss: 0.4524701598909352
0epoch loss: 0.44089864972736387
0epoch loss: 0.42957448807852655
0epoch loss: 0.419596689456713
0epoch loss: 0.41118150971156925
0epoch loss: 0.4011327781237818
0epoch loss: 0.3935750661385359


Epoch:   5%|▌         | 1/20 [49:01<15:31:19, 2941.01s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

1epoch loss: 0.7511616349220276
1epoch loss: 0.16321438028117513
1epoch loss: 0.14410351590017106
1epoch loss: 0.1356538531519461
1epoch loss: 0.12847854031590708
1epoch loss: 0.13271724936847915
1epoch loss: 0.12802190902596805
1epoch loss: 0.1296040934319382
1epoch loss: 0.1321448438298585
1epoch loss: 0.1359717246027125
1epoch loss: 0.13792185790834746
1epoch loss: 0.14020746292643907
1epoch loss: 0.13952978648308753
1epoch loss: 0.1402367392770272
1epoch loss: 0.1403175803218403
1epoch loss: 0.13982271013818287
1epoch loss: 0.13844243153770702
1epoch loss: 0.1372594161404947
1epoch loss: 0.13693721700956035
1epoch loss: 0.13726230938535916
1epoch loss: 0.13675700249016076
1epoch loss: 0.13692868274312436
1epoch loss: 0.13826964409903858
1epoch loss: 0.13765156910536266
1epoch loss: 0.13666977752814025
1epoch loss: 0.13681827195988003
1epoch loss: 0.13593342950848566


Epoch:  10%|█         | 2/20 [1:49:04<15:41:54, 3139.72s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

2epoch loss: 0.007763679139316082
2epoch loss: 0.08825566637226596
2epoch loss: 0.08698202444232342
2epoch loss: 0.09124828076237222
2epoch loss: 0.08716737583912201
2epoch loss: 0.08735670023163253
2epoch loss: 0.08903909053254645
2epoch loss: 0.08899323599065202
2epoch loss: 0.09122958182966288
2epoch loss: 0.09238426927950018
2epoch loss: 0.09108896682698435
2epoch loss: 0.09041132712712671
2epoch loss: 0.0915475081301921
2epoch loss: 0.09185384694505602
2epoch loss: 0.09113718145681916
2epoch loss: 0.09032655353261965
2epoch loss: 0.09212168157090589
2epoch loss: 0.0910065274064731
2epoch loss: 0.09078199004754327
2epoch loss: 0.09104163705231604
2epoch loss: 0.09191796127064547
2epoch loss: 0.09195659400834828
2epoch loss: 0.09104120861186132
2epoch loss: 0.09162927383335344
2epoch loss: 0.09131146227044777
2epoch loss: 0.09082713832504305
2epoch loss: 0.09106869321609427


Epoch:  15%|█▌        | 3/20 [2:38:40<14:35:40, 3090.63s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

3epoch loss: 0.04122619330883026
3epoch loss: 0.052865210643581614
3epoch loss: 0.058364454257168484
3epoch loss: 0.06478566712992513
3epoch loss: 0.06682240479546053
3epoch loss: 0.06535700859069743
3epoch loss: 0.06504507717539743
3epoch loss: 0.06541823410249206
3epoch loss: 0.06412726471012972
3epoch loss: 0.062160380532382244
3epoch loss: 0.0632121975494457
3epoch loss: 0.06482717621534682
3epoch loss: 0.06354080632016479
3epoch loss: 0.06465575388410015
3epoch loss: 0.06468484456379119
3epoch loss: 0.06573525933454571
3epoch loss: 0.06722514709698255
3epoch loss: 0.06772213765178256
3epoch loss: 0.06769718890470991
3epoch loss: 0.06857564016399847
3epoch loss: 0.06818672790107565
3epoch loss: 0.06802390167119693
3epoch loss: 0.06794432435769412
3epoch loss: 0.06769889704087645
3epoch loss: 0.0678313069343867
3epoch loss: 0.06796109588438475
3epoch loss: 0.06798575875148855


Epoch:  20%|██        | 4/20 [3:28:09<13:34:25, 3054.09s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

4epoch loss: 0.0026811868883669376
4epoch loss: 0.06647193891557895
4epoch loss: 0.05876003609962674
4epoch loss: 0.056513115543879264
4epoch loss: 0.05755748955478888
4epoch loss: 0.056616972846084086
4epoch loss: 0.054613025192708685
4epoch loss: 0.05575891390274886
4epoch loss: 0.05345471175147036
4epoch loss: 0.05310272766506855
4epoch loss: 0.05484525754019086
4epoch loss: 0.05450681639959144
4epoch loss: 0.05529385116286063
4epoch loss: 0.05569726868754772
4epoch loss: 0.05630592090705967
4epoch loss: 0.05545337929111153
4epoch loss: 0.05565160839530899
4epoch loss: 0.05555988624909183
4epoch loss: 0.05635934032778945
4epoch loss: 0.05596593000125785
4epoch loss: 0.05523817701027147
4epoch loss: 0.055197560983081394
4epoch loss: 0.055486777895183616
4epoch loss: 0.05620977661040418
4epoch loss: 0.056431085458984315
4epoch loss: 0.05697010399658644
4epoch loss: 0.05621083811526379


Epoch:  25%|██▌       | 5/20 [4:45:14<14:41:19, 3525.29s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

5epoch loss: 9.511800453765318e-05
5epoch loss: 0.034927003226557916
5epoch loss: 0.03319913854328439
5epoch loss: 0.03759600826443085
5epoch loss: 0.03728396542605073
5epoch loss: 0.037056101301837015
5epoch loss: 0.037591365234931826
5epoch loss: 0.03881895109191824
5epoch loss: 0.03972232892502529
5epoch loss: 0.04032098132051715
5epoch loss: 0.04035350077395884
5epoch loss: 0.04018796786687718
5epoch loss: 0.0400456768206499
5epoch loss: 0.04010884726292888
5epoch loss: 0.04094672597442079
5epoch loss: 0.041027607561889834
5epoch loss: 0.040838528076758446
5epoch loss: 0.042139763338486694
5epoch loss: 0.042342371162584984
5epoch loss: 0.0427219349482905
5epoch loss: 0.04240123390495646
5epoch loss: 0.042404228715377384
5epoch loss: 0.04160721873893785
5epoch loss: 0.04146077774443725
5epoch loss: 0.04168574307278712
5epoch loss: 0.041329070193053506
5epoch loss: 0.041409686281731126


Epoch:  30%|███       | 6/20 [5:34:15<13:01:43, 3350.23s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

6epoch loss: 0.010956690646708012
6epoch loss: 0.03134032486365085
6epoch loss: 0.03433535057271868
6epoch loss: 0.036942428865844375
6epoch loss: 0.0412732129295606
6epoch loss: 0.03904360281177072
6epoch loss: 0.03771873735572093
6epoch loss: 0.038491236684231336
6epoch loss: 0.03856657928534315
6epoch loss: 0.03875095745298885
6epoch loss: 0.03752509420891173
6epoch loss: 0.03716006972340894
6epoch loss: 0.03763101525414767
6epoch loss: 0.037484534219592705
6epoch loss: 0.03771246675218337
6epoch loss: 0.039480210463538494
6epoch loss: 0.039767279863161524
6epoch loss: 0.03925749414605257
6epoch loss: 0.03807785396251579
6epoch loss: 0.03802188016377943
6epoch loss: 0.03745981442968439
6epoch loss: 0.03710469225097946
6epoch loss: 0.03674464987672148
6epoch loss: 0.036671473058015006
6epoch loss: 0.03647697445232626
6epoch loss: 0.03668966249305325
6epoch loss: 0.03717230566505189


Epoch:  35%|███▌      | 7/20 [6:23:08<11:38:45, 3225.03s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

7epoch loss: 0.05905475839972496
7epoch loss: 0.030461209313801706
7epoch loss: 0.027848273200458438
7epoch loss: 0.030789533991861674
7epoch loss: 0.03176456519152208
7epoch loss: 0.029146606855386448
7epoch loss: 0.031422907286145656
7epoch loss: 0.031695667320118005
7epoch loss: 0.031488692080495796
7epoch loss: 0.032433131691703324
7epoch loss: 0.03250395721673155
7epoch loss: 0.0316516014839841
7epoch loss: 0.03146049086797352
7epoch loss: 0.031083177799117165
7epoch loss: 0.03156151607259298
7epoch loss: 0.031361840209207584
7epoch loss: 0.03111973625604937
7epoch loss: 0.03110459341825073
7epoch loss: 0.03118924599298747
7epoch loss: 0.030382275667096455
7epoch loss: 0.031033565900964153
7epoch loss: 0.030855187353367675
7epoch loss: 0.03086829757010128
7epoch loss: 0.030332968360525278
7epoch loss: 0.03043940640104159
7epoch loss: 0.030685321563340313
7epoch loss: 0.031076278710601653


Epoch:  40%|████      | 8/20 [7:12:30<10:29:11, 3145.94s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

8epoch loss: 0.00014339550398290157
8epoch loss: 0.04908157797357341
8epoch loss: 0.0400167804376724
8epoch loss: 0.03730869235878881
8epoch loss: 0.03792313358509004
8epoch loss: 0.036490290958876195
8epoch loss: 0.03571409058429215
8epoch loss: 0.033984692012289576
8epoch loss: 0.033201059974036363
8epoch loss: 0.033600973192820426
8epoch loss: 0.03342707569110174
8epoch loss: 0.032318708040234276
8epoch loss: 0.03163528212262989
8epoch loss: 0.031133815626995107
8epoch loss: 0.031277343423853805
8epoch loss: 0.03104922738154022
8epoch loss: 0.03125815956484612
8epoch loss: 0.031157990007369307
8epoch loss: 0.03046653608938138
8epoch loss: 0.030337156988176173
8epoch loss: 0.029913958625430226
8epoch loss: 0.030218150120019924
8epoch loss: 0.02995407467267858
8epoch loss: 0.02981848185647147
8epoch loss: 0.03021241654479095
8epoch loss: 0.03000265734064584
8epoch loss: 0.02939313131083644


Epoch:  45%|████▌     | 9/20 [8:01:18<9:24:47, 3080.64s/it] 




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

9epoch loss: 0.0037641157396137714
9epoch loss: 0.013735295104764884
9epoch loss: 0.01957248253741369
9epoch loss: 0.021326236147869656
9epoch loss: 0.021490539780443182
9epoch loss: 0.02062804631762756
9epoch loss: 0.0218541521167494
9epoch loss: 0.023238665784067622
9epoch loss: 0.023492891963572207
9epoch loss: 0.02326914058606494
9epoch loss: 0.02194559172413386
9epoch loss: 0.02228037760050942
9epoch loss: 0.02184717296125668
9epoch loss: 0.021976527814394995
9epoch loss: 0.02183824800253198
9epoch loss: 0.022772401252221908
9epoch loss: 0.02331817228537876
9epoch loss: 0.023303892850271152
9epoch loss: 0.023772291881244006
9epoch loss: 0.024301090720461015
9epoch loss: 0.02408335340400858
9epoch loss: 0.024043801545560784
9epoch loss: 0.02413553380869567
9epoch loss: 0.024147355923794644
9epoch loss: 0.024289763912567666
9epoch loss: 0.024487510842647175
9epoch loss: 0.024060061361205514


Epoch:  50%|█████     | 10/20 [8:50:22<8:26:35, 3039.55s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

10epoch loss: 0.0001587505976203829
10epoch loss: 0.020939867297345087
10epoch loss: 0.017780160216153123
10epoch loss: 0.01911538986027806
10epoch loss: 0.021357800703937183
10epoch loss: 0.024600801380261002
10epoch loss: 0.02518479975528102
10epoch loss: 0.025104808474647126
10epoch loss: 0.026035724526642535
10epoch loss: 0.025693760753416728
10epoch loss: 0.02485208372044
10epoch loss: 0.023982676228610677
10epoch loss: 0.02374673383746138
10epoch loss: 0.02347151284893753
10epoch loss: 0.023598324592667686
10epoch loss: 0.024019248424461043
10epoch loss: 0.024111699679799405
10epoch loss: 0.023743139014124396
10epoch loss: 0.023648302760569916
10epoch loss: 0.023387905805408876
10epoch loss: 0.023068657789718637
10epoch loss: 0.023298608537573574
10epoch loss: 0.02271477852849764
10epoch loss: 0.0230467027540465
10epoch loss: 0.02261174606216441
10epoch loss: 0.022584098364052403
10epoch loss: 0.0224585889831095


Epoch:  55%|█████▌    | 11/20 [9:39:18<7:31:18, 3008.73s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

11epoch loss: 0.010604449547827244
11epoch loss: 0.012376709894919492
11epoch loss: 0.01740743691410999
11epoch loss: 0.01720485625310591
11epoch loss: 0.015610767998487686
11epoch loss: 0.017643817043433534
11epoch loss: 0.0178880312261658
11epoch loss: 0.017668079973911108
11epoch loss: 0.019151421959064056
11epoch loss: 0.02043848443159299
11epoch loss: 0.020568021223607662
11epoch loss: 0.020347540244071177
11epoch loss: 0.020450635887103746
11epoch loss: 0.02021647221674869
11epoch loss: 0.019967820288628795
11epoch loss: 0.01945004501042639
11epoch loss: 0.01977380076846089
11epoch loss: 0.019596298532337814
11epoch loss: 0.019940075834643103
11epoch loss: 0.019565971476871457
11epoch loss: 0.019432439255692175
11epoch loss: 0.019305291377558202
11epoch loss: 0.02007750102832638
11epoch loss: 0.019803326811462865
11epoch loss: 0.020121938963016395
11epoch loss: 0.020592112338416905
11epoch loss: 0.02054204849363453


Epoch:  60%|██████    | 12/20 [10:28:12<6:38:09, 2986.16s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

12epoch loss: 3.0617757147410884e-05
12epoch loss: 0.011195245826975437
12epoch loss: 0.015403448970383196
12epoch loss: 0.014131729053947722
12epoch loss: 0.014754825974663566
12epoch loss: 0.017579102274973293
12epoch loss: 0.017437508741109373
12epoch loss: 0.01841590007289862
12epoch loss: 0.018786859626064217
12epoch loss: 0.018787196943166024
12epoch loss: 0.018124008283940035
12epoch loss: 0.018119894101338983
12epoch loss: 0.017756516943029665
12epoch loss: 0.017952024493405883
12epoch loss: 0.017926700004300713
12epoch loss: 0.018081480602742277
12epoch loss: 0.018271662399859237
12epoch loss: 0.018113428675265726
12epoch loss: 0.01818285728522461
12epoch loss: 0.01863058787251358
12epoch loss: 0.01830331001441495
12epoch loss: 0.018148339574728858
12epoch loss: 0.018056331590307905
12epoch loss: 0.017976091428282374
12epoch loss: 0.01777730662919456
12epoch loss: 0.017545249237043963
12epoch loss: 0.017729764673588947


Epoch:  65%|██████▌   | 13/20 [11:18:23<5:49:16, 2993.72s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

13epoch loss: 1.0758226380858105e-05
13epoch loss: 0.012360822353955
13epoch loss: 0.012778847349845909
13epoch loss: 0.01327108878782546
13epoch loss: 0.014130710084987941
13epoch loss: 0.013442651685552573
13epoch loss: 0.014066294876317172
13epoch loss: 0.01369943343693539
13epoch loss: 0.013400510297625465
13epoch loss: 0.014002785521239687
13epoch loss: 0.01403095510850914
13epoch loss: 0.013890023970200642
13epoch loss: 0.01338647038815667
13epoch loss: 0.01310573599068023
13epoch loss: 0.012721769697180193
13epoch loss: 0.012882444292133076
13epoch loss: 0.013110262378555627
13epoch loss: 0.012802785140431316
13epoch loss: 0.012752863255362698
13epoch loss: 0.012634265637207915
13epoch loss: 0.01323529568965336
13epoch loss: 0.013540852527115753
13epoch loss: 0.013186249077560176
13epoch loss: 0.013121358081113586
13epoch loss: 0.013114542351503593
13epoch loss: 0.013124997965995868
13epoch loss: 0.012837521740548575


Epoch:  70%|███████   | 14/20 [12:07:10<4:57:21, 2973.66s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

14epoch loss: 4.267623808118515e-05
14epoch loss: 0.006842044358073522
14epoch loss: 0.011172034565258096
14epoch loss: 0.009316236779204115
14epoch loss: 0.008883859261830015
14epoch loss: 0.009735961699896491
14epoch loss: 0.010101346796241075
14epoch loss: 0.010137508934959183
14epoch loss: 0.0105509377376564
14epoch loss: 0.011111047717921827
14epoch loss: 0.011486696324110387
14epoch loss: 0.0121854665930578
14epoch loss: 0.012529594987197686
14epoch loss: 0.013059914084309072
14epoch loss: 0.013006377697907131
14epoch loss: 0.013075902743151017
14epoch loss: 0.012996672552296827
14epoch loss: 0.013049495761604609
14epoch loss: 0.013518000964172404
14epoch loss: 0.013852643164384278
14epoch loss: 0.014116879065103914
14epoch loss: 0.01397370300084745
14epoch loss: 0.014101290204707455
14epoch loss: 0.014186770904283014
14epoch loss: 0.014654136625213785
14epoch loss: 0.01470658724332873
14epoch loss: 0.014438172656005695


Epoch:  75%|███████▌  | 15/20 [12:56:26<4:07:21, 2968.26s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

15epoch loss: 8.575472747907043e-06
15epoch loss: 0.00963923910165256
15epoch loss: 0.007877981564541614
15epoch loss: 0.00949252961652641
15epoch loss: 0.009412457601334759
15epoch loss: 0.008459751415736736
15epoch loss: 0.008532374713878245
15epoch loss: 0.008527630422495491
15epoch loss: 0.008284106334759329
15epoch loss: 0.009499655738681303
15epoch loss: 0.009246264278464434
15epoch loss: 0.009259242811441443
15epoch loss: 0.008731722941095109
15epoch loss: 0.009016997744265681
15epoch loss: 0.009606523190659386
15epoch loss: 0.009553988014119065
15epoch loss: 0.00943459554393515
15epoch loss: 0.009471771532356999
15epoch loss: 0.009757023224239323
15epoch loss: 0.009849931392091978
15epoch loss: 0.00974393676246682
15epoch loss: 0.009487583073604498
15epoch loss: 0.009617617460463686
15epoch loss: 0.009740394023628888
15epoch loss: 0.010187842304254244
15epoch loss: 0.01024908027403127
15epoch loss: 0.010360626698222821


Epoch:  80%|████████  | 16/20 [13:45:21<3:17:13, 2958.27s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

16epoch loss: 0.00014006874698679894
16epoch loss: 0.006170652665918899
16epoch loss: 0.011331214248309351
16epoch loss: 0.012542305229582352
16epoch loss: 0.010826803874576433
16epoch loss: 0.010133925452210629
16epoch loss: 0.010706036420256707
16epoch loss: 0.01009544326148267
16epoch loss: 0.01002887838773759
16epoch loss: 0.009836685984852796
16epoch loss: 0.009727800448831174
16epoch loss: 0.009997679780630717
16epoch loss: 0.009710250120860676
16epoch loss: 0.009615120227264823
16epoch loss: 0.009976210518837033
16epoch loss: 0.009860326089622494
16epoch loss: 0.009765597822758728
16epoch loss: 0.009578342363386344
16epoch loss: 0.009753232728520128
16epoch loss: 0.009646931870177034
16epoch loss: 0.010446922988886678
16epoch loss: 0.010568032652136922
16epoch loss: 0.010412313777591716
16epoch loss: 0.010437640065044124
16epoch loss: 0.01020256137439018
16epoch loss: 0.010416382848015958
16epoch loss: 0.010395162323317362


Epoch:  85%|████████▌ | 17/20 [14:34:32<2:27:48, 2956.18s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

17epoch loss: 2.9131449537089793e-06
17epoch loss: 0.018183709642569934
17epoch loss: 0.014833919613672248
17epoch loss: 0.013239795144933937
17epoch loss: 0.010881244543277307
17epoch loss: 0.0125884402417538
17epoch loss: 0.012637715399903911
17epoch loss: 0.011633039949809583
17epoch loss: 0.011062886443210684
17epoch loss: 0.010876724264822514
17epoch loss: 0.010884066435830514
17epoch loss: 0.011880059418295992
17epoch loss: 0.011676458096524865
17epoch loss: 0.012106176420673584
17epoch loss: 0.01171514876283003
17epoch loss: 0.011415878801714043
17epoch loss: 0.011519504943972596
17epoch loss: 0.011575391748706832
17epoch loss: 0.011241770476733573
17epoch loss: 0.011273882274029432
17epoch loss: 0.011200671000575427
17epoch loss: 0.011356978021391973
17epoch loss: 0.011245764199024759
17epoch loss: 0.011109224129158412
17epoch loss: 0.01127489944307123
17epoch loss: 0.011056860801899434
17epoch loss: 0.011287465685729822


Epoch:  90%|█████████ | 18/20 [15:23:38<1:38:26, 2953.22s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

18epoch loss: 6.109276910137851e-06
18epoch loss: 0.011822838373672815
18epoch loss: 0.008052835978643468
18epoch loss: 0.009698969755850993
18epoch loss: 0.009932743235720014
18epoch loss: 0.010677051638307906
18epoch loss: 0.010179665626137629
18epoch loss: 0.010082219929933854
18epoch loss: 0.009838267705593727
18epoch loss: 0.009296356699436785
18epoch loss: 0.009589929846220663
18epoch loss: 0.009551733702201602
18epoch loss: 0.008983928652083294
18epoch loss: 0.00914125485405062
18epoch loss: 0.009034382514421005
18epoch loss: 0.00890624743524677
18epoch loss: 0.008948986403667022
18epoch loss: 0.008903065558291045
18epoch loss: 0.00869791022497675
18epoch loss: 0.008585539964981152
18epoch loss: 0.008557495076888869
18epoch loss: 0.008527439435911211
18epoch loss: 0.008690235737563733
18epoch loss: 0.008632488081525991
18epoch loss: 0.008411602816089563
18epoch loss: 0.008624376329850702
18epoch loss: 0.008701413930790039


Epoch:  95%|█████████▌| 19/20 [16:12:55<49:14, 2954.32s/it]  




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2663.0, style=ProgressStyle(description_w…

19epoch loss: 0.00018654382438398898
19epoch loss: 0.011227520171333633
19epoch loss: 0.011887181312152792
19epoch loss: 0.012592037845984382
19epoch loss: 0.011079731446928594
19epoch loss: 0.010283380245244977
19epoch loss: 0.010067659976776316
19epoch loss: 0.009476075752871002
19epoch loss: 0.008828560373742073
19epoch loss: 0.008281208019345924
19epoch loss: 0.008165714428255729
19epoch loss: 0.009063670515746504
19epoch loss: 0.00929253076349785
19epoch loss: 0.009231522420546428
19epoch loss: 0.009170287471323571
19epoch loss: 0.008929619658945735
19epoch loss: 0.009352687166659409
19epoch loss: 0.009027382925977562
19epoch loss: 0.008893892449996
19epoch loss: 0.008763936685610964
19epoch loss: 0.008917171130896975
19epoch loss: 0.008786972284632195
19epoch loss: 0.00873491856303588
19epoch loss: 0.00873052417237153
19epoch loss: 0.008568023868725193
19epoch loss: 0.00856221605094252
19epoch loss: 0.008484285103997225


Epoch: 100%|██████████| 20/20 [17:02:24<00:00, 3067.21s/it]







In [11]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [12]:
# p_encoder = retriever.p_encoder
# q_encoder = retriever.q_encoder
with torch.no_grad() :
    p_encoder.eval()

    p_embs = []
    for p in tqdm(corpus) :
        p = tokenizer(p, padding='max_length', truncation=True, return_tensors='pt').to('cuda')
        p_emb = p_encoder(**p).to('cpu').numpy()
        p_embs.append(p_emb)
p_embs = torch.Tensor(p_embs).squeeze()

HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))




In [15]:
import pickle

file_path = '/opt/ml/custom/passage_embedding_special_customsample_augmentation_20.bin'
with open(file_path, 'wb') as file :
    pickle.dump(p_embs, file)

In [13]:
p_encoder.cpu()
del p_encoder
torch.cuda.empty_cache()

In [16]:
torch.save(q_encoder, '/opt/ml/custom/q_encoder_special_customsample_augmentation_20.pt')

## Get Relavant Documnet

In [6]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [7]:
model_checkpoint = "klue/bert-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [11]:
import pickle
with open('/opt/ml/custom/passage_embedding_special_shuffle_100.bin', 'rb') as file :
    p_embs = pickle.load(file)
p_embs = p_embs

q_encoder = torch.load('/opt/ml/custom/q_encoder_special_customsample_70.pt')

In [18]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

### 연산 시작 - Augmentation/Batch_size/epoch:20/CustomSampling/512

In [20]:
original_dataset = load_from_disk('/opt/ml/data/train_dataset')

In [37]:
doc_scores, doc_indices = get_relavant_doc(original_dataset['validation']['question'], q_encoder, p_embs, k = 500)

total = []
for idx, example in enumerate(
        tqdm(original_dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [38]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(original_dataset['validation']))

0.9541666666666667


In [39]:
cqas_100.to_csv('/opt/ml/custom/valid_dpr_b16_speical_customsampling_augmentation_e20_t500.csv', index = False)

### 연산 시작 - Batch_size 16/epoch:10/CustomSampling


In [23]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 50)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [24]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.6916666666666667


In [25]:
cqas_100.to_csv('/opt/ml/custom/valid_dpr_b_16_speical_customsampling_e10_t50.csv', index = False)

### 연산 시작 - Batch_size 16/epoch:10/CustomSampling/384


In [21]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 50)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [22]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.7166666666666667


In [23]:
cqas_100.to_csv('/opt/ml/custom/valid_dpr_b_16_speical_customsampling_e70_t50.csv', index = False)