In [1]:
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from tqdm import trange

from pprint import pprint

from sklearn.feature_extraction.text import TfidfVectorizer

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    BertModel, RobertaModel,
    BertPreTrainedModel,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments,
)
from datasets import (
    Dataset,
    load_from_disk,
    concatenate_datasets,
)

from typing import List
from adamp import AdamP

In [2]:
from easydict import EasyDict

CFG = EasyDict()
CFG.passage_stride = 128
CFG.learning_rate = 1e-5
CFG.num_train_epochs = 100
CFG.max_seq_length = 512
CFG.train_batch_size = 16
CFG.eval_batch_size = 16
CFG.focal_gamma = 1.0
CFG.weight_decay = 0.01

In [3]:
# 난수 고정
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    random.seed(random_seed)
    np.random.seed(random_seed)
    
set_seed(42) # magic number :)

In [4]:
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.7.1].
device:[cuda:0].


## Use Custom Loss

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)):
            self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

In [6]:
criterion = FocalLoss(gamma=CFG.focal_gamma)

## Preparing Negative Batching

In [7]:
from torch.utils.data import Sampler

class CustomSampler(Sampler) :
    def __init__(self, data_source, batch_size) :
        self.data_source = data_source
        self.batch_size = batch_size

    def __iter__(self) :
        n = len(self.data_source)
        index_list = []
        while True :
            out = True
            for i in range(self.batch_size) :
                tmp_data = random.randint(0, n-1)
                index_list.append(tmp_data)
            for f, s in zip(index_list, index_list[1:]) :
                if abs(s-f) <= 2 :
                    out = False
            if out == True :
                break

        while True : # 추가 삽입
            tmp_data = random.randint(0, n-1)
            if (tmp_data not in index_list) and \
                (abs(tmp_data-index_list[-i]) > 2 for i in range(1,self.batch_size+1)) \
            : 
                index_list.append(tmp_data)
            if len(index_list) == n :
                break
        return iter(index_list)

    def __len__(self) :
        return len(self.data_source)

## Training

In [8]:
from transformers import AutoModel

class RoBERTaEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super(RoBERTaEncoder, self).__init__(config)

        # self.bert = BertModel(config)
        self.roberta = RobertaModel(config)
        self.init_weights()
      
    def forward(
            self,
            input_ids, 
            attention_mask=None,
            # token_type_ids=None
        ): 

        
        outputs = self.roberta(
            input_ids = input_ids,
            attention_mask=attention_mask,
            # token_type_ids=token_type_ids
        )
        
        pooled_output = outputs[
            "pooler_output"
        ]  # [CLS] token's hidden featrues(hidden state)

        return pooled_output

In [9]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']

In [10]:
class DenseRetrieval:
    def __init__(self,
        args,
        dataset,
        tokenizer,
        p_encoder,
        q_encoder
    ):
        """
        학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
        """

        self.args = args
        self.dataset = dataset

        self.tokenizer = tokenizer
        self.p_encoder = p_encoder
        self.q_encoder = q_encoder

    def train(self, args=None, tokenizer = None):
        if args is None:
            args = self.args
        if tokenizer is None :
            tokenizer = self.tokenizer

        for i in (range(len(self.dataset))) :
            if i == 0 :
                q_seqs = tokenizer(
                    self.dataset[i]['question'],
                    padding='max_length',
                    truncation = True,
                    return_tensors='pt'
                )
                p_seqs = tokenizer(
                    self.dataset[i]['context'],
                    truncation = True,
                    stride = CFG.passage_stride,
                    padding='max_length',
                    return_overflowing_tokens=True,
                    return_offsets_mapping=True,
                    return_tensors='pt'
                )

                p_seqs.pop('overflow_to_sample_mapping')
                p_seqs.pop('offset_mapping')

                for k in q_seqs.keys() :
                    q_seqs[k] = q_seqs[k].tolist()
                    p_seqs[k] = p_seqs[k].tolist()
            else :
                tmp_q_seq = tokenizer(
                    self.dataset[i]['question'],
                    padding='max_length',
                    truncation = True,
                    return_tensors='pt')
                tmp_p_seq = tokenizer(
                    self.dataset[i]['context'],
                    truncation = True,
                    stride = CFG.passage_stride,
                    padding='max_length',
                    return_overflowing_tokens=True,
                    return_offsets_mapping=True,
                    return_tensors='pt'
                )

                tmp_p_seq.pop('overflow_to_sample_mapping')
                tmp_p_seq.pop('offset_mapping')

                for k in tmp_p_seq.keys() :
                    tmp_p_seq[k] = tmp_p_seq[k].tolist()
                    tmp_q_seq[k] = tmp_q_seq[k].tolist()


                for j in range(len(tmp_p_seq['input_ids'])) :
                    q_seqs['input_ids'].append(tmp_q_seq['input_ids'][0])
                    # q_seqs['token_type_ids'].append(tmp_q_seq['token_type_ids'][0])
                    q_seqs['attention_mask'].append(tmp_q_seq['attention_mask'][0])
                    p_seqs['input_ids'].append(tmp_p_seq['input_ids'][j])
                    # p_seqs['token_type_ids'].append(tmp_p_seq['token_type_ids'][j])
                    p_seqs['attention_mask'].append(tmp_p_seq['attention_mask'][j])

        for k in q_seqs.keys() :
            q_seqs[k] = torch.tensor(q_seqs[k])
            p_seqs[k] = torch.tensor(p_seqs[k])

        train_dataset = TensorDataset(
            p_seqs['input_ids'], 
            p_seqs['attention_mask'], 
            # p_seqs['token_type_ids'],
            q_seqs['input_ids'], 
            q_seqs['attention_mask'], 
            # q_seqs['token_type_ids']
            )
        train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size, shuffle = True)

        no_decay = ["bias" ,"LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in self.p_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.p_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
            {"params": [p for n, p in self.q_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.q_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
        ]
        optimizer = AdamP(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            # eps=args.adam_epsilon
        )

        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
        
        global_step = 0

        self.p_encoder.zero_grad()
        self.q_encoder.zero_grad()
        # torch.cuda.empty_cache()

        train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
        self.q_encoder.train()
        self.p_encoder.train()
        for epoch, _ in enumerate(train_iterator):
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            loss_value=0 # Accumulation할 때 진행
            losses = 0
            for step, batch in enumerate(epoch_iterator):
                if torch.cuda.is_available():
                    batch = tuple(t.cuda() for t in batch)

                p_inputs = {'input_ids': batch[0],
                            'attention_mask': batch[1],
                            # 'token_type_ids': batch[2]
                            }
                
                q_inputs = {'input_ids': batch[2],
                            'attention_mask': batch[3],
                            # 'token_type_ids': batch[5]
                            }

                p_outputs = self.p_encoder(**p_inputs)  # (batch_size, emb_dim)
                q_outputs = self.q_encoder(**q_inputs)  # (batch_size, emb_dim)

                # Calculate similarity score & loss
                sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1))  # (batch_size, emb_dim) x (emb_dim, batch_size) = (batch_size, batch_size)

                # target: position of positive samples = diagonal element 
                # targets = torch.arange(0, args.per_device_train_batch_size).long()
                targets = torch.arange(0, len(p_inputs['input_ids'])).long()
                
                if torch.cuda.is_available():
                    targets = targets.to('cuda')
                    
                sim_scores = F.log_softmax(sim_scores, dim=1)

                # loss = F.nll_loss(sim_scores, targets)
                loss = criterion(sim_scores, targets)
                
                losses += loss.item()
                if step % 100 == 0 :
                    print(f'{epoch}epoch loss: {losses/(step+1)}') # Accumulation할 경우 주석처리

                # ################ACCUMULATION###############################
                # if (step+1) % args.gradient_accumulation_steps == 0 :
                #     optimizer.step()
                #     scheduler.step()
                #     self.q_encoder.zero_grad()
                #     self.p_encoder.zero_grad()

                # losses += loss.item()
                # if (step+1) % 64 == 0:
                #     train_loss = losses / 64
                #     print(f'training loss: {train_loss:4.4}')
                #     losses = 0
                # ###########################################################
                self.q_encoder.zero_grad()
                self.p_encoder.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                # global_step += 1
                
                #torch.cuda.empty_cache()
                del p_inputs, q_inputs

        return self.p_encoder, self.q_encoder

## Make Q_Encoder & P_Embedding

In [11]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=CFG.learning_rate,
    per_device_train_batch_size=CFG.train_batch_size,
    per_device_eval_batch_size=CFG.eval_batch_size,
    gradient_accumulation_steps=1,
    num_train_epochs=CFG.num_train_epochs,
    weight_decay=CFG.weight_decay # https://github.com/clovaai/AdamP#usage
)
model_checkpoint = "klue/roberta-base"

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_input_names = ["input_ids", "attention_mask"])
p_encoder = RoBERTaEncoder.from_pretrained(model_checkpoint).to(args.device)
q_encoder = RoBERTaEncoder.from_pretrained(model_checkpoint).to(args.device)

You are using a model of type roberta to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at klue/roberta-base were not used when initializing RoBERTaEncoder: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing RoBERTaEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RoBERTaEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RoBERTaEncoder were not initialized from the model checkpoint at klue/roberta-base and are new

In [12]:
!nvidia-smi

Sat Oct 30 16:35:38 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  Off  | 00000000:00:05.0 Off |                  Off |
| N/A   34C    P0    36W / 250W |   2121MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [13]:
# Retriever는 아래와 같이 사용할 수 있도록 코드를 짜봅시다.
retriever = DenseRetrieval(
    args=args,
    dataset=train_dataset,
    tokenizer=tokenizer,
    p_encoder=p_encoder,
    q_encoder=q_encoder
)
p_encoder, q_encoder = retriever.train()

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

  logpt = F.log_softmax(input)


0epoch loss: 2.716977119445801
0epoch loss: 2.6416165687070032
0epoch loss: 2.050800065792615
0epoch loss: 1.6063865557857526


Epoch:   1%|          | 1/100 [06:33<10:49:34, 393.68s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

1epoch loss: 0.3210725486278534
1epoch loss: 0.3774994014218302
1epoch loss: 0.37678124992853373
1epoch loss: 0.3683483200364335


Epoch:   2%|▏         | 2/100 [13:06<10:42:38, 393.46s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

2epoch loss: 0.26838165521621704
2epoch loss: 0.20917895103408263
2epoch loss: 0.20773002384832842
2epoch loss: 0.201532123666071


Epoch:   3%|▎         | 3/100 [19:40<10:36:06, 393.47s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

3epoch loss: 0.05772494524717331
3epoch loss: 0.13488956147085618
3epoch loss: 0.13965720477147928
3epoch loss: 0.142232786138477


Epoch:   4%|▍         | 4/100 [26:14<10:29:52, 393.67s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

4epoch loss: 0.2454783171415329
4epoch loss: 0.10328603709360935
4epoch loss: 0.11481746848364505
4epoch loss: 0.11276836190071031


Epoch:   5%|▌         | 5/100 [32:47<10:23:11, 393.59s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

5epoch loss: 0.003066950710490346
5epoch loss: 0.08272237741990052
5epoch loss: 0.08512703306012699
5epoch loss: 0.0875371725248572


Epoch:   6%|▌         | 6/100 [39:20<10:16:13, 393.33s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

6epoch loss: 0.2604781985282898
6epoch loss: 0.0658203292604061
6epoch loss: 0.06920803154790334
6epoch loss: 0.06968055638448688


Epoch:   7%|▋         | 7/100 [45:53<10:09:33, 393.26s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

7epoch loss: 0.02321595512330532
7epoch loss: 0.06745205412037891
7epoch loss: 0.05881213590118163
7epoch loss: 0.06172483175257345


Epoch:   8%|▊         | 8/100 [52:26<10:02:52, 393.18s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

8epoch loss: 0.02911366894841194
8epoch loss: 0.06446518910020627
8epoch loss: 0.06441809906362468
8epoch loss: 0.06459617480298


Epoch:   9%|▉         | 9/100 [58:59<9:56:18, 393.17s/it] 




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

9epoch loss: 0.10342981666326523
9epoch loss: 0.06500991094756023
9epoch loss: 0.06009708006173576
9epoch loss: 0.057833198319302184


Epoch:  10%|█         | 10/100 [1:05:32<9:49:34, 393.05s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

10epoch loss: 0.023590171709656715
10epoch loss: 0.07113701856515306
10epoch loss: 0.06283260144893339
10epoch loss: 0.05794860949999718


Epoch:  11%|█         | 11/100 [1:12:05<9:43:01, 393.05s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

11epoch loss: 0.13806118071079254
11epoch loss: 0.065341319778178
11epoch loss: 0.06203590451636345
11epoch loss: 0.051580215159691856


Epoch:  12%|█▏        | 12/100 [1:18:38<9:36:30, 393.07s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

12epoch loss: 0.026282599195837975
12epoch loss: 0.032736700473671394
12epoch loss: 0.03077023222297372
12epoch loss: 0.03494134219967647


Epoch:  13%|█▎        | 13/100 [1:25:12<9:30:24, 393.39s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

13epoch loss: 0.003269219771027565
13epoch loss: 0.04313046728418725
13epoch loss: 0.042574568428197616
13epoch loss: 0.04555706556277188


Epoch:  14%|█▍        | 14/100 [1:31:46<9:23:55, 393.43s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

14epoch loss: 0.036463961005210876
14epoch loss: 0.043623473577945965
14epoch loss: 0.03512643148958398
14epoch loss: 0.0380162429254509


Epoch:  15%|█▌        | 15/100 [1:38:19<9:17:06, 393.26s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

15epoch loss: 0.007860668003559113
15epoch loss: 0.03176009674881594
15epoch loss: 0.030430414718263018
15epoch loss: 0.029576152810864415


Epoch:  16%|█▌        | 16/100 [1:44:52<9:10:29, 393.21s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

16epoch loss: 0.011041329242289066
16epoch loss: 0.03856306123065458
16epoch loss: 0.04055815566891795
16epoch loss: 0.038226855333086736


Epoch:  17%|█▋        | 17/100 [1:51:25<9:03:54, 393.19s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

17epoch loss: 0.09135337918996811
17epoch loss: 0.043603206457371944
17epoch loss: 0.03803762145249511
17epoch loss: 0.03703447482832924


Epoch:  18%|█▊        | 18/100 [1:57:57<8:56:49, 392.80s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

18epoch loss: 0.1166367456316948
18epoch loss: 0.04066101105487214
18epoch loss: 0.03720295392763071
18epoch loss: 0.034889408016194084


Epoch:  19%|█▉        | 19/100 [2:04:29<8:49:56, 392.55s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

19epoch loss: 0.0013978895731270313
19epoch loss: 0.025569363053850477
19epoch loss: 0.02594234134212728
19epoch loss: 0.025251564781447872


Epoch:  20%|██        | 20/100 [2:11:01<8:43:18, 392.48s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

20epoch loss: 0.004771482199430466
20epoch loss: 0.03814449048931607
20epoch loss: 0.03630638531876458
20epoch loss: 0.034355662927565055


Epoch:  21%|██        | 21/100 [2:17:34<8:36:51, 392.55s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

21epoch loss: 0.010472560301423073
21epoch loss: 0.02830311146306593
21epoch loss: 0.02599336092314852
21epoch loss: 0.02304105641903343


Epoch:  22%|██▏       | 22/100 [2:24:06<8:30:07, 392.41s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

22epoch loss: 0.025925004854798317
22epoch loss: 0.030105303575807495
22epoch loss: 0.02932953545736092
22epoch loss: 0.026593871478537548


Epoch:  23%|██▎       | 23/100 [2:30:37<8:23:15, 392.15s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

23epoch loss: 0.02297316864132881
23epoch loss: 0.02199225043627263
23epoch loss: 0.024462290871129673
23epoch loss: 0.024611882817857214


Epoch:  24%|██▍       | 24/100 [2:37:10<8:16:50, 392.24s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

24epoch loss: 0.0006170705310069025
24epoch loss: 0.02273005612531249
24epoch loss: 0.02108563392468921
24epoch loss: 0.0219580365559556


Epoch:  25%|██▌       | 25/100 [2:43:41<8:09:59, 391.99s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

25epoch loss: 0.06525056064128876
25epoch loss: 0.02985947102285653
25epoch loss: 0.025700722488770013
25epoch loss: 0.025806207384750656


Epoch:  26%|██▌       | 26/100 [2:50:13<8:03:27, 391.99s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

26epoch loss: 0.030820634216070175
26epoch loss: 0.02391882090980271
26epoch loss: 0.018659125725252808
26epoch loss: 0.021037229639387985


Epoch:  27%|██▋       | 27/100 [2:56:45<7:56:49, 391.92s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

27epoch loss: 0.042098842561244965
27epoch loss: 0.02386587199282868
27epoch loss: 0.02220067754673826
27epoch loss: 0.018933194619037814


Epoch:  28%|██▊       | 28/100 [3:03:17<7:50:22, 391.98s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

28epoch loss: 0.01680493727326393
28epoch loss: 0.023918411260899237
28epoch loss: 0.021581252708719106
28epoch loss: 0.021515515014401524


Epoch:  29%|██▉       | 29/100 [3:09:48<7:43:37, 391.80s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

29epoch loss: 0.00316160311922431
29epoch loss: 0.01754295358580631
29epoch loss: 0.01948357267700437
29epoch loss: 0.01732876662244841


Epoch:  30%|███       | 30/100 [3:16:20<7:37:06, 391.80s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

30epoch loss: 0.014135805889964104
30epoch loss: 0.015719569373893794
30epoch loss: 0.01945076042471981
30epoch loss: 0.01965357049104502


Epoch:  31%|███       | 31/100 [3:22:52<7:30:24, 391.67s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

31epoch loss: 0.00030492598307318985
31epoch loss: 0.01770533104900096
31epoch loss: 0.01653116059089226
31epoch loss: 0.017619627407271233


Epoch:  32%|███▏      | 32/100 [3:29:24<7:24:08, 391.89s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

32epoch loss: 0.016795819625258446
32epoch loss: 0.0338915457367434
32epoch loss: 0.026677403807076387
32epoch loss: 0.025453687926074053


Epoch:  33%|███▎      | 33/100 [3:35:56<7:17:40, 391.95s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

33epoch loss: 0.02562800422310829
33epoch loss: 0.013356838680816635
33epoch loss: 0.014676398599709699
33epoch loss: 0.013223584488815108


Epoch:  34%|███▍      | 34/100 [3:42:28<7:11:16, 392.07s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

34epoch loss: 0.0018816021038219333
34epoch loss: 0.015671687858257156
34epoch loss: 0.013486950881925076
34epoch loss: 0.014166864440430407


Epoch:  35%|███▌      | 35/100 [3:49:00<7:04:33, 391.89s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

35epoch loss: 0.00012180488556623459
35epoch loss: 0.01747267993832703
35epoch loss: 0.01843520074795261
35epoch loss: 0.01731986768551975


Epoch:  36%|███▌      | 36/100 [3:55:32<6:58:02, 391.91s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

36epoch loss: 0.0036423488054424524
36epoch loss: 0.017589002204523682
36epoch loss: 0.01723858743123367
36epoch loss: 0.015215125511310794


Epoch:  37%|███▋      | 37/100 [4:02:04<6:51:28, 391.88s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

37epoch loss: 0.00020292887347750366
37epoch loss: 0.012790439269004913
37epoch loss: 0.012998564863590136
37epoch loss: 0.015266024902608068


Epoch:  38%|███▊      | 38/100 [4:08:35<6:44:47, 391.73s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

38epoch loss: 0.005850392859429121
38epoch loss: 0.021122217313697348
38epoch loss: 0.017857614028241018
38epoch loss: 0.015515986424700643


Epoch:  39%|███▉      | 39/100 [4:15:07<6:38:13, 391.70s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

39epoch loss: 0.0005988128832541406
39epoch loss: 0.016567947810016234
39epoch loss: 0.015327508608309934
39epoch loss: 0.014988204939241433


Epoch:  40%|████      | 40/100 [4:21:38<6:31:40, 391.67s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

40epoch loss: 0.0011335855815559626
40epoch loss: 0.01128869038728514
40epoch loss: 0.019971937903288172
40epoch loss: 0.018025835962060022


Epoch:  41%|████      | 41/100 [4:28:10<6:25:10, 391.70s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

41epoch loss: 1.4657191059086472e-05
41epoch loss: 0.022926785621661616
41epoch loss: 0.017917615663544266
41epoch loss: 0.01663187005445923


Epoch:  42%|████▏     | 42/100 [4:34:42<6:18:38, 391.69s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

42epoch loss: 0.0002841547830030322
42epoch loss: 0.014188810515028452
42epoch loss: 0.015133233637664204
42epoch loss: 0.01491576288920493


Epoch:  43%|████▎     | 43/100 [4:41:14<6:12:15, 391.86s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

43epoch loss: 9.772866906132549e-05
43epoch loss: 0.008742572760503648
43epoch loss: 0.009557059924616315
43epoch loss: 0.010131404641930412


Epoch:  44%|████▍     | 44/100 [4:47:45<6:05:38, 391.75s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

44epoch loss: 0.0006441058358177543
44epoch loss: 0.015362576322799987
44epoch loss: 0.013057300815611897
44epoch loss: 0.015455061342001131


Epoch:  45%|████▌     | 45/100 [4:54:18<5:59:21, 392.02s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

45epoch loss: 0.0004244135634507984
45epoch loss: 0.00984203664799256
45epoch loss: 0.009599323763657467
45epoch loss: 0.009204212503596786


Epoch:  46%|████▌     | 46/100 [5:00:50<5:52:44, 391.94s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

46epoch loss: 6.200126517796889e-05
46epoch loss: 0.01088186040097789
46epoch loss: 0.011307607693179397
46epoch loss: 0.01081654869404774


Epoch:  47%|████▋     | 47/100 [5:07:23<5:46:23, 392.15s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

47epoch loss: 0.07908079028129578
47epoch loss: 0.01730184223434563
47epoch loss: 0.016269885519659737
47epoch loss: 0.015267857794638271


Epoch:  48%|████▊     | 48/100 [5:13:54<5:39:47, 392.07s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

48epoch loss: 7.246893073897809e-05
48epoch loss: 0.00919905592303225
48epoch loss: 0.006703543560185033
48epoch loss: 0.00904611127252475


Epoch:  49%|████▉     | 49/100 [5:20:26<5:33:12, 392.01s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

49epoch loss: 0.0009528467198833823
49epoch loss: 0.010452120107181532
49epoch loss: 0.011335032112748734
49epoch loss: 0.011054995548816091


Epoch:  50%|█████     | 50/100 [5:26:59<5:26:45, 392.10s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

50epoch loss: 0.0003871243097819388
50epoch loss: 0.011855554418040207
50epoch loss: 0.011382700591165171
50epoch loss: 0.010188994837138327


Epoch:  51%|█████     | 51/100 [5:33:30<5:20:06, 391.97s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

51epoch loss: 2.567153933341615e-05
51epoch loss: 0.011927400167651555
51epoch loss: 0.009711221573438004
51epoch loss: 0.008110484474615725


Epoch:  52%|█████▏    | 52/100 [5:40:02<5:13:33, 391.96s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

52epoch loss: 0.01605701446533203
52epoch loss: 0.011955503253298813
52epoch loss: 0.009066283657684427
52epoch loss: 0.00985516907047611


Epoch:  53%|█████▎    | 53/100 [5:46:34<5:06:56, 391.83s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

53epoch loss: 0.00026130431797355413
53epoch loss: 0.011802887426706593
53epoch loss: 0.010933930340997475
53epoch loss: 0.010606862810718713


Epoch:  54%|█████▍    | 54/100 [5:53:06<5:00:29, 391.95s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

54epoch loss: 5.923998469370417e-05
54epoch loss: 0.004989968764511384
54epoch loss: 0.008209472007652876
54epoch loss: 0.00895313255911857


Epoch:  55%|█████▌    | 55/100 [5:59:39<4:54:10, 392.23s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

55epoch loss: 0.06919576972723007
55epoch loss: 0.010287375541446465
55epoch loss: 0.00882750656340711
55epoch loss: 0.0076010032887965946


Epoch:  56%|█████▌    | 56/100 [6:06:11<4:47:43, 392.36s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

56epoch loss: 6.819409463787451e-05
56epoch loss: 0.008053353067333603
56epoch loss: 0.008123189276496382
56epoch loss: 0.007472632680743493


Epoch:  57%|█████▋    | 57/100 [6:12:44<4:41:09, 392.31s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

57epoch loss: 2.0396073523443192e-05
57epoch loss: 0.009488756412570137
57epoch loss: 0.007608100197507466
57epoch loss: 0.0074577663197371


Epoch:  58%|█████▊    | 58/100 [6:19:16<4:34:35, 392.27s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

58epoch loss: 4.541894304566085e-05
58epoch loss: 0.006537395034186911
58epoch loss: 0.006105007799840393
58epoch loss: 0.006898985149575198


Epoch:  59%|█████▉    | 59/100 [6:25:48<4:28:05, 392.34s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

59epoch loss: 2.9653640012838878e-05
59epoch loss: 0.0053864578319529986
59epoch loss: 0.006259966483058358
59epoch loss: 0.007776770379369099


Epoch:  60%|██████    | 60/100 [6:32:20<4:21:29, 392.23s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

60epoch loss: 0.053082458674907684
60epoch loss: 0.005227444907061477
60epoch loss: 0.004844129596265881
60epoch loss: 0.004795646354709169


Epoch:  61%|██████    | 61/100 [6:38:53<4:14:57, 392.24s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

61epoch loss: 0.06449602544307709
61epoch loss: 0.007194939236524528
61epoch loss: 0.005317101465178537
61epoch loss: 0.005069653610912554


Epoch:  62%|██████▏   | 62/100 [6:45:25<4:08:24, 392.22s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

62epoch loss: 0.0579523965716362
62epoch loss: 0.006969353683297242
62epoch loss: 0.007159838649687967
62epoch loss: 0.006269426860289481


Epoch:  63%|██████▎   | 63/100 [6:51:57<4:01:52, 392.24s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

63epoch loss: 0.24474699795246124
63epoch loss: 0.008938802355556158
63epoch loss: 0.006963536895044123
63epoch loss: 0.007297048788734704


Epoch:  64%|██████▍   | 64/100 [6:58:28<3:55:11, 391.98s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

64epoch loss: 9.067263454198837e-05
64epoch loss: 0.008406040166877043
64epoch loss: 0.007155475178822981
64epoch loss: 0.008506964733006472


Epoch:  65%|██████▌   | 65/100 [7:05:00<3:48:36, 391.91s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

65epoch loss: 0.000315819721436128
65epoch loss: 0.004981185383834293
65epoch loss: 0.004727731357330589
65epoch loss: 0.005248139742187418


Epoch:  66%|██████▌   | 66/100 [7:11:32<3:42:05, 391.92s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

66epoch loss: 1.0944032737825182e-06
66epoch loss: 0.008263939831821002
66epoch loss: 0.006483126965692836
66epoch loss: 0.006021201068642031


Epoch:  67%|██████▋   | 67/100 [7:18:03<3:35:27, 391.74s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

67epoch loss: 1.8414086298434995e-05
67epoch loss: 0.008715332352525397
67epoch loss: 0.00840734484616086
67epoch loss: 0.007916639863499064


Epoch:  68%|██████▊   | 68/100 [7:24:35<3:28:55, 391.75s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

68epoch loss: 0.00021084689069539309
68epoch loss: 0.005313300713152272
68epoch loss: 0.0068494755761530635
68epoch loss: 0.007059814998633344


Epoch:  69%|██████▉   | 69/100 [7:31:07<3:22:24, 391.74s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

69epoch loss: 1.4791738067287952e-05
69epoch loss: 0.003577996580024588
69epoch loss: 0.0051426695458395635
69epoch loss: 0.005437610335544191


Epoch:  70%|███████   | 70/100 [7:37:39<3:15:56, 391.88s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

70epoch loss: 0.0016577471978962421
70epoch loss: 0.008303027703366545
70epoch loss: 0.006184543207011267
70epoch loss: 0.005854246577101963


Epoch:  71%|███████   | 71/100 [7:44:11<3:09:24, 391.86s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

71epoch loss: 4.2631319956853986e-05
71epoch loss: 0.006506302436180132
71epoch loss: 0.005879281729657929
71epoch loss: 0.005072130366931866


Epoch:  72%|███████▏  | 72/100 [7:50:44<3:02:58, 392.08s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

72epoch loss: 0.10696970671415329
72epoch loss: 0.008450325925794274
72epoch loss: 0.008082952556152061
72epoch loss: 0.006784831967240205


Epoch:  73%|███████▎  | 73/100 [7:57:16<2:56:29, 392.19s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

73epoch loss: 5.024424353905488e-06
73epoch loss: 0.008637917582439002
73epoch loss: 0.00796581983786792
73epoch loss: 0.007696585899595946


Epoch:  74%|███████▍  | 74/100 [8:03:48<2:49:58, 392.24s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

74epoch loss: 0.00019752119260374457
74epoch loss: 0.006021690681812756
74epoch loss: 0.005967241790510747
74epoch loss: 0.006807335994899665


Epoch:  75%|███████▌  | 75/100 [8:10:20<2:43:23, 392.15s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

75epoch loss: 0.058174312114715576
75epoch loss: 0.0052567931927952295
75epoch loss: 0.005412396696213486
75epoch loss: 0.004652503596475316


Epoch:  76%|███████▌  | 76/100 [8:16:53<2:36:56, 392.37s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

76epoch loss: 0.00584443612024188
76epoch loss: 0.003260007752369885
76epoch loss: 0.0032090475561499564
76epoch loss: 0.004724140943059095


Epoch:  77%|███████▋  | 77/100 [8:23:26<2:30:26, 392.45s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

77epoch loss: 4.6993777687021066e-06
77epoch loss: 0.004907745071967262
77epoch loss: 0.0036603899204458406
77epoch loss: 0.004195538829005314


Epoch:  78%|███████▊  | 78/100 [8:29:58<2:23:54, 392.47s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

78epoch loss: 0.00043895660201087594
78epoch loss: 0.0048378452647609175
78epoch loss: 0.004087081706456059
78epoch loss: 0.0038978935612041733


Epoch:  79%|███████▉  | 79/100 [8:36:31<2:17:22, 392.50s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

79epoch loss: 8.413815521635115e-06
79epoch loss: 0.008837764442414662
79epoch loss: 0.006414936201621378
79epoch loss: 0.007327343405659087


Epoch:  80%|████████  | 80/100 [8:43:03<2:10:50, 392.52s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

80epoch loss: 4.065473603986902e-06
80epoch loss: 0.004695349390744912
80epoch loss: 0.005438414001258226
80epoch loss: 0.004899344044916391


Epoch:  81%|████████  | 81/100 [8:49:36<2:04:20, 392.63s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

81epoch loss: 0.0001279541611438617
81epoch loss: 0.003409219524215105
81epoch loss: 0.004135931748598027
81epoch loss: 0.005016638495084964


Epoch:  82%|████████▏ | 82/100 [8:56:08<1:57:43, 392.41s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

82epoch loss: 0.0008933261269703507
82epoch loss: 0.007312834923685991
82epoch loss: 0.006971195977305741
82epoch loss: 0.005844089714998921


Epoch:  83%|████████▎ | 83/100 [9:02:40<1:51:09, 392.33s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

83epoch loss: 8.573541094847315e-07
83epoch loss: 0.0027914389467025443
83epoch loss: 0.004708041100868319
83epoch loss: 0.003677679325901369


Epoch:  84%|████████▍ | 84/100 [9:09:13<1:44:36, 392.26s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

84epoch loss: 1.0381048923591152e-05
84epoch loss: 0.0044631264416015
84epoch loss: 0.004547248132760191
84epoch loss: 0.005557277192721983


Epoch:  85%|████████▌ | 85/100 [9:15:46<1:38:07, 392.50s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

85epoch loss: 2.804367795761209e-05
85epoch loss: 0.0045736535451575185
85epoch loss: 0.004782366852076513
85epoch loss: 0.004350011861868328


Epoch:  86%|████████▌ | 86/100 [9:22:18<1:31:33, 392.36s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

86epoch loss: 0.04750896617770195
86epoch loss: 0.003981336318700615
86epoch loss: 0.0039013302545569672
86epoch loss: 0.00407493686500987


Epoch:  87%|████████▋ | 87/100 [9:28:50<1:25:01, 392.43s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

87epoch loss: 0.00019178612274117768
87epoch loss: 0.00802291176809799
87epoch loss: 0.005671443631538525
87epoch loss: 0.004946804559660872


Epoch:  88%|████████▊ | 88/100 [9:35:23<1:18:31, 392.61s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

88epoch loss: 0.0010463525541126728
88epoch loss: 0.003945705155159669
88epoch loss: 0.0032692930080528626
88epoch loss: 0.00356012042101125


Epoch:  89%|████████▉ | 89/100 [9:41:56<1:11:57, 392.54s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

89epoch loss: 0.0014642009045928717
89epoch loss: 0.005770076594321785
89epoch loss: 0.006079706518316592
89epoch loss: 0.005548048883957683


Epoch:  90%|█████████ | 90/100 [9:48:28<1:05:26, 392.60s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

90epoch loss: 0.00010879450564971194
90epoch loss: 0.003893448287951518
90epoch loss: 0.0032996256665291106
90epoch loss: 0.0029439864952381517


Epoch:  91%|█████████ | 91/100 [9:55:00<58:51, 392.40s/it]  




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

91epoch loss: 1.85851713467855e-05
91epoch loss: 0.005160342262680877
91epoch loss: 0.005039740733359826
91epoch loss: 0.00490982525629587


Epoch:  92%|█████████▏| 92/100 [10:01:33<52:19, 392.43s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

92epoch loss: 1.3412449106908753e-07
92epoch loss: 0.0038397092632399795
92epoch loss: 0.003947100778911049
92epoch loss: 0.004256157064166588


Epoch:  93%|█████████▎| 93/100 [10:08:05<45:46, 392.42s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

93epoch loss: 4.028856892546173e-06
93epoch loss: 0.005208615946655264
93epoch loss: 0.004073545734460295
93epoch loss: 0.0031772899612653727


Epoch:  94%|█████████▍| 94/100 [10:14:37<39:14, 392.38s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

94epoch loss: 4.2598872823873535e-05
94epoch loss: 0.004538227673585449
94epoch loss: 0.004106244853936791
94epoch loss: 0.004013447387886293


Epoch:  95%|█████████▌| 95/100 [10:21:10<32:42, 392.43s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

95epoch loss: 1.2969710951438174e-05
95epoch loss: 0.00576222622074879
95epoch loss: 0.003581738068294973
95epoch loss: 0.0033736139279898557


Epoch:  96%|█████████▌| 96/100 [10:27:42<26:09, 392.36s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

96epoch loss: 0.00014291024126578122
96epoch loss: 0.0037267872515753895
96epoch loss: 0.0038018666349518538
96epoch loss: 0.0033083527989438057


Epoch:  97%|█████████▋| 97/100 [10:34:15<19:37, 392.48s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

97epoch loss: 3.050120653824706e-07
97epoch loss: 0.0037341451212865332
97epoch loss: 0.003590364876776183
97epoch loss: 0.0035700307815681


Epoch:  98%|█████████▊| 98/100 [10:40:48<13:05, 392.73s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

98epoch loss: 0.0018786677392199636
98epoch loss: 0.0015858460506675577
98epoch loss: 0.003682142714119382
98epoch loss: 0.0034045530193971157


Epoch:  99%|█████████▉| 99/100 [10:47:21<06:32, 392.67s/it]




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=350.0, style=ProgressStyle(description_wi…

99epoch loss: 0.07664334028959274
99epoch loss: 0.0030382736148791977
99epoch loss: 0.002428999598620668
99epoch loss: 0.0030879192978329577


Epoch: 100%|██████████| 100/100 [10:53:53<00:00, 392.34s/it]







## Validation testing from the wiki corpus

In [14]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [15]:
from preprocess import preprocess
with torch.no_grad() :
    p_encoder.eval()

    p_embs = []
    for p in tqdm(corpus) :
        p = preprocess(p) #TODO: Check whether preprocessing function improves or deteriorates performance
        p = tokenizer(p, padding='max_length', truncation=True, return_tensors='pt').to('cuda')
        p_emb = p_encoder(**p).to('cpu').numpy()
        p_embs.append(p_emb)
p_embs = torch.Tensor(p_embs).squeeze()

HBox(children=(FloatProgress(value=0.0, max=56737.0), HTML(value='')))




In [16]:
import pickle

file_path = '/opt/ml/mrc-level2-nlp-15/custom/passage_embedding_100.bin'
with open(file_path, 'wb') as file :
    pickle.dump(p_embs, file)

In [17]:
p_encoder.cpu()
del p_encoder
torch.cuda.empty_cache()

In [18]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Sun Oct 31 03:46:38 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  Off  | 00000000:00:05.0 Off |                  Off |
| N/A   50C    P0    48W / 250W |   2291MiB / 32510MiB |    100%      Default |
|                               |            

In [19]:
torch.save(q_encoder, '/opt/ml/mrc-level2-nlp-15/custom/q_encoder_100.pt')

## Get Relavant Documnet

In [20]:
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [21]:
tokenizer

PreTrainedTokenizerFast(name_or_path='klue/roberta-base', vocab_size=32000, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'bos_token': '[CLS]', 'eos_token': '[SEP]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [22]:
import pickle
with open(file_path, 'rb') as file :
    p_embs = pickle.load(file)
p_embs = p_embs

In [24]:
q_encoder = torch.load('/opt/ml/mrc-level2-nlp-15/custom/q_encoder_100.pt')

In [25]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

### 1개 확인

In [26]:
query = dataset['validation']['question'][0]
queries = dataset['validation']['question'][:2]

In [27]:
with torch.no_grad() :
    q_encoder.eval()
    q_seqs_val = tokenizer([query], padding="max_length", truncation=True, return_tensors="pt").to(device)
    q_emb = q_encoder(**q_seqs_val).to('cpu')

In [28]:
dot_prod_scores = torch.mm(q_emb, p_embs.T)
dot_prod_scores

tensor([[-2.4953,  0.7282,  2.8620,  ..., -1.0062,  7.3725,  8.9615]])

In [29]:
dot_prod_scores = torch.mm(q_emb, p_embs.T)
rank = torch.argsort(dot_prod_scores, dim=1, descending=True)

In [30]:
rank

tensor([[ 5694, 27301, 22306,  ...,  7364,  4316, 15047]])

### 여러개

In [31]:
with torch.no_grad() :
    q_encoder.eval()
    q_seqs_val = tokenizer(queries, padding="max_length", truncation=True, return_tensors="pt").to(device)
    q_emb = q_encoder(**q_seqs_val).to('cpu')

In [32]:
dot_prod_scores = torch.mm(q_emb, p_embs.T)
sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

In [33]:
scores = sort_result[0]
ranks = sort_result[1]

In [34]:
k = 5
print("[Search query]\n", query, "\n")
print("[Ground truth passage]")
print(dataset['validation']['context'][0], "\n")

for i in range(k):
  print("Top-%d passage with score %.4f" % (i+1, scores[0].squeeze()[i]))
  print(corpus[ranks[0][i]])

[Search query]
 처음으로 부실 경영인에 대한 보상 선고를 받은 회사는? 

[Ground truth passage]
순천여자고등학교 졸업, 1973년 이화여자대학교를 졸업하고 1975년 제17회 사법시험에 합격하여 판사로 임용되었고 대법원 재판연구관, 수원지법 부장판사, 사법연수원 교수, 특허법원 부장판사 등을 거쳐 능력을 인정받았다. 2003년 최종영 대법원장의 지명으로 헌법재판소 재판관을 역임하였다.\n\n경제민주화위원회(위원장 장하성이 소액주주들을 대표해 한보철강 부실대출에 책임이 있는 이철수 전 제일은행장 등 임원 4명을 상대로 제기한 손해배상청구소송에서 서울지방법원 민사합의17부는 1998년 7월 24일에 "한보철강에 부실 대출하여 은행에 막대한 손해를 끼친 점이 인정된다"며 "원고가 배상을 청구한 400억원 전액을 은행에 배상하라"고 하면서 부실 경영인에 대한 최초의 배상 판결을 했다. \n\n2004년 10월 신행정수도의건설을위한특별조치법 위헌 확인 소송에서 9인의 재판관 중 유일하게 각하 견해를 내었다. 소수의견에서 전효숙 재판관은 다수견해의 문제점을 지적하면서 관습헌법 법리를 부정하였다. 전효숙 재판관은 서울대학교 근대법학교육 백주년 기념관에서 열린 강연에서, 국회가 고도의 정치적인 사안을 정치로 풀기보다는 헌법재판소에 무조건 맡겨서 해결하려는 자세는 헌법재판소에게 부담스럽다며 소회를 밝힌 바 있다. 

Top-1 passage with score 13.5171
더불어민주당 홍익표 의원은 2015년 10월 22일 청와대 서별관회의에 제출한 대우조선 관련 문건을 입수해 2016년 7월 4일 공개했다. 문건에는 “대우조선에 5조 원 이상의 부실이 현실화돼 사실 관계 규명을 위해 감리가 필요하다는 문제를 제기했다. 금융감독원이 그간 자발적 소명 기회를 부여했으나 회사(대우조선)는 소명 자료 제출에 소극적”이라고 적혀 있었다. 분식회계 의혹이 있음에도 정부는 회의 일주일 뒤 대우조선의 최대주주인 산업은행은 4조2000억 원 규모의 자금 지

### 연산 시작 - Batch_size:16/epoch:5

In [35]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [36]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 150)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [37]:
correct_length = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

0.7166666666666667


### 연산 시작 - Batch_size:16/epoch:10

In [38]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [42]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 50)

In [43]:
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [44]:
correct_length = []
for i in range(len(cqas)) :
    if cqas['original_context'][i] in cqas['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

0.6166666666666667


In [None]:

cqas.to_csv('/opt/ml/mrc-level2-nlp-15/custom/valid_dpr_b16_e10_t50.csv')

In [None]:
correct_length_50 = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length_50.append(i)
print(len(correct_length_50) / len(dataset['validation']))

0.3625


In [None]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.7416666666666667


In [None]:
correct_length_150 = []
for i in range(len(cqas_150)) :
    if cqas_150['original_context'][i] in cqas_150['context'][i] :
        correct_length_150.append(i)
print(len(correct_length_150) / len(dataset['validation']))

0.7625


In [None]:
correct_length_200 = []
for i in range(len(cqas_200)) :
    if cqas_200['original_context'][i] in cqas_200['context'][i] :
        correct_length_200.append(i)
print(len(correct_length_200) / len(dataset['validation']))

0.825


In [None]:
cqas_200.to_csv('/opt/ml/custom/valid_dpr_200.csv', index = False)

In [None]:
correct_length_300 = []
for i in range(len(cqas_300)) :
    if cqas_300['original_context'][i] in cqas_300['context'][i] :
        correct_length_300.append(i)
print(len(correct_length_200) / len(dataset['validation']))

0.825


In [None]:
cqas.to_csv('/opt/ml/custom/valid_dpr.csv', index = False)

### 연산 시작 - Batch_size:16/epoch:40

In [None]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [None]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 200)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [None]:
correct_length = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

0.85


In [None]:
cqas_50.to_csv('/opt/ml/custom/valid_dpr_b16_e40_t200.csv', index = False)

### 연산 시작 - Batch_size:16/epoch:100

In [None]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [None]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 200)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [None]:
correct_length = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length.append(i)
print(len(correct_length) / len(dataset['validation']))

0.8166666666666667


In [None]:
cqas_50.to_csv('/opt/ml/custom/valid_dpr_b16_e100_t200.csv', index = False)

### 연산 시작 - Batch_size 128/epoch:10


In [None]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [None]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 100)

In [None]:
total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [None]:
correct_length_50 = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length_50.append(i)
print(len(correct_length_50) / len(dataset['validation']))

0.6083333333333333


In [None]:
correct_length_50 = []
for i in range(len(cqas_20)) :
    if cqas_20['original_context'][i] in cqas_20['context'][i] :
        correct_length_20.append(i)
print(len(correct_length_20) / len(dataset['validation']))

### 연산 시작 - Batch_size 128/epoch:5


In [None]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [None]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 50)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_50 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [None]:
correct_length_50 = []
for i in range(len(cqas_50)) :
    if cqas_50['original_context'][i] in cqas_50['context'][i] :
        correct_length_50.append(i)
print(len(correct_length_50) / len(dataset['validation']))

0.30416666666666664


### 연산 시작 - Batch_size 128/epoch:20


In [None]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [None]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 20)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [None]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.5041666666666667


In [None]:
len(correct_length_100)

120

### 연산 시작 - Batch_size 16/epoch:10/Special/NoShuffle


In [None]:
def get_relavant_doc(queries, q_encoder, p_embs, k=1) :
    with torch.no_grad() :
        q_encoder.eval()
        q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device)
        q_emb = q_encoder(**q_seqs_val).to('cpu')
    dot_prod_scores = torch.mm(q_emb, p_embs.T)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i].tolist()[:k])
        result_indices.append(ranks[i].tolist()[:k])
    
    return result_scores, result_indices

In [None]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 100)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [None]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.6916666666666667


### 연산 시작 - Batch_size 16/epoch:10/Special/Shuffle


In [None]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 100)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [None]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.775


### 연산 시작 - Batch_size 16/epoch:100/Special/Shuffle

In [None]:
doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], q_encoder, p_embs, k = 50)

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas_100 = pd.DataFrame(total)

HBox(children=(FloatProgress(value=0.0, description='Dense retrieval: ', max=240.0, style=ProgressStyle(descri…




In [None]:
correct_length_100 = []
for i in range(len(cqas_100)) :
    if cqas_100['original_context'][i] in cqas_100['context'][i] :
        correct_length_100.append(i)
print(len(correct_length_100) / len(dataset['validation']))

0.7458333333333333
