In [1]:
from dpr_train_generate import run_dpr, make_dataset, train, make_dense_embedding, BertEncoder
from retrieval import SparseRetrieval
from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer

from datasets import (
    load_metric,
    load_from_disk,
    Sequence,
    Value,
    Features,
    Dataset,
    DatasetDict,
)
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)

from transformers import (
    HfArgumentParser,
    TrainingArguments,
    set_seed,
)

import os
import json
import torch
import pickle
import torch.nn.functional as F

In [2]:
tokenizer = AutoTokenizer.from_pretrained(
        "klue/bert-base",
        use_fast=True,
    )

data_path = "/opt/ml/data/"
context_path = "wikipedia_documents.json"
with open(os.path.join(data_path, context_path), "r", encoding="utf-8") as f:
    wiki = json.load(f)

context = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

retriever = SparseRetrieval(tokenize_fn=tokenizer)

Lengths of unique contexts : 56737


In [3]:
# negative sampling한 결과
train_dataset, valid_dataset, tokenizer = make_dataset(retriever)


Embedding pickle load.
[query exhaustive search] done in 49.868 s


HBox(children=(FloatProgress(value=0.0, description='Sparse retrieval: ', max=3952.0, style=ProgressStyle(desc…


[query exhaustive search] done in 3.096 s


HBox(children=(FloatProgress(value=0.0, description='Sparse retrieval: ', max=240.0, style=ProgressStyle(descr…




In [12]:
# load pre-trained model on cuda (if available)
p_encoder_p = BertEncoder.from_pretrained("klue/bert-base")
q_encoder_p = BertEncoder.from_pretrained("klue/bert-base")

# 학습 설정
if torch.cuda.is_available():
    p_encoder_p.cuda()
    q_encoder_p.cuda()


args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01
)

train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.per_device_train_batch_size)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertEncoder: ['cls.predictions.bi

In [None]:
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.per_device_train_batch_size)

In [None]:
pickle_name = f"dense_embedding.bin"
q_encoder_name = f"q_encoder0.pt"
p_encoder_name = f"p_encoder0.pt"
emd_path = os.path.join(data_path, pickle_name)
q_model_path = os.path.join("./models/train_dataset", q_encoder_name)
p_model_path = os.path.join("./models/train_dataset", p_encoder_name)

if os.path.isfile(q_model_path):
    q_encoder = torch.load(q_model_path)
    p_encoder = torch.load(p_model_path)
    print("Dense Embedding pickle load.")

Dense Embedding pickle load.


In [14]:
torch.cuda.empty_cache()

In [None]:
for step, batch in enumerate(train_dataloader):
#     print(step)
#     print(len(batch), batch[0].shape, batch[3].shape)
    
    
    if torch.cuda.is_available():
        batch = tuple(t.cuda() for t in batch)
#         targets = targets.cuda()

        p_inputs = {'input_ids': batch[0].view(
                                        args.per_device_train_batch_size*(4+1), -1),
                    'attention_mask': batch[1].view(
                                        args.per_device_train_batch_size*(4+1), -1),
                    'token_type_ids': batch[2].view(
                                        args.per_device_train_batch_size*(4+1), -1)
                    }

        q_inputs = {'input_ids': batch[3],
                    'attention_mask': batch[4],
                    'token_type_ids': batch[5]}
    
    
    
    p_outputs = p_encoder(**p_inputs)  #(batch_size*(num_neg+1), emb_dim)
    q_outputs = q_encoder(**q_inputs)  #(batch_size*, emb_dim)

    
    p_outputs = torch.transpose(p_outputs.view(args.per_device_train_batch_size, 4+1, -1), 1, 2)
    q_outputs = q_outputs.view(args.per_device_train_batch_size, 1, -1)
    
    sim_scores = torch.bmm(q_outputs, p_outputs).squeeze()  #(batch_size, num_neg+1)
#     print(sim_scores.shape)

    sim_scores = F.log_softmax(sim_scores, dim=1)
    preds = torch.argmax(sim_scores, dim=-1)
    
    print(preds)
    if step == 10:
        break

RuntimeError: CUDA out of memory. Tried to allocate 240.00 MiB (GPU 0; 31.75 GiB total capacity; 30.17 GiB already allocated; 199.50 MiB free; 30.42 GiB reserved in total by PyTorch)

In [32]:
sim_scores = torch.bmm(q_outputs, p_outputs).squeeze()  #(batch_size, num_neg+1)
sim_scores.shape

torch.Size([4, 5])

In [34]:
sim_scores = F.log_softmax(sim_scores, dim=1)
preds = torch.argmax(sim_scores, dim=-1)

In [35]:
print(sim_scores)
print(preds)

tensor([[-1.8299e-03, -1.0183e+01, -1.0769e+01, -1.0456e+01, -6.3535e+00],
        [-2.7464e+00, -9.9608e-01, -1.4863e+00, -4.4468e+00, -1.1130e+00],
        [-3.0129e+00, -1.2581e+01, -1.0812e+01, -7.9466e+00, -5.0797e-02],
        [-1.4200e+00, -8.0894e+00, -1.8731e+00, -8.2324e+00, -5.0406e-01]],
       device='cuda:0', grad_fn=<LogSoftmaxBackward>)
tensor([0, 1, 4, 4], device='cuda:0')


In [38]:
targets = torch.zeros(args.per_device_train_batch_size).long()
if torch.cuda.is_available():
    targets = targets.cuda()
loss = F.nll_loss(sim_scores, targets)
print(loss)
matches = 0
matches += (preds == targets).sum()
print(matches)

tensor(1.7953, device='cuda:0', grad_fn=<NllLossBackward>)
tensor(1, device='cuda:0')
