In [37]:
import os
import random
import datasets
import numpy as np
import torch
from datasets import load_from_disk
from torch.utils.data import Dataset, DataLoader, RandomSampler, TensorDataset, SequentialSampler
from tqdm import tqdm
import pandas as pd
import pickle
import re
from transformers import AutoTokenizer

In [2]:
def preprocess(text):
    text = re.sub(r"\n", " ", text)
    text = re.sub(r"\\n", " ", text)  # remove newline character
    text = re.sub(r"\s+", " ", text)  # remove continuous spaces
    text = re.sub(r"#", " ", text)
    return text

In [3]:
tokenizer = AutoTokenizer.from_pretrained('klue/bert-base')

In [35]:
def get_tensor_for_dense_negative(
    data_path: str,
    bm25_path: str,
    max_context_seq_length: int,
    max_question_seq_length: int,
    tokenizer,
) -> TensorDataset:
# 0번 idx - > ground truth ( 전처리)
# 1~50 idx -> bm25 negative

    dataset = load_from_disk(data_path).to_pandas()
    
    # ctx = []
    # print(dataset["context"][0])
    # for i in tqdm(range(len(dataset))):
    #     ctx.append(preprocess(dataset["context"][i]))
    dataset["context"] = dataset["context"].apply(preprocess)
    ctx = dataset["context"].to_list()
    q_list = dataset["question"].to_list()
    # 시간이 많이 걸림 -> 미리 처리해둘것
    #print(ctx[0:1])
    with open(bm25_path, "rb") as file:
        elastic_pair = pickle.load(file)
    neg_ctx = []

    num_neg = 2
    for i in tqdm(range(len(dataset))):
        #print(i)
        query = dataset["question"][i]
        ground_truth = ctx[i]
        answer = dataset["answers"][i]["text"][0]
        cnt = num_neg
        idx = 0
        while cnt != 0:
            if ground_truth != elastic_pair[query][idx] and not (answer in elastic_pair[query][idx]):
                # 비슷한 context를 추가하되 정답을 포함하지 않는 문장을 추가한다.
                neg_ctx.append(elastic_pair[query][idx])
                cnt -= 1
            idx += 1
            
            if idx == 200:  # index를 넘어가면 마지막에 추가된 context를 채워줌
                #print('in', cnt)
                while(cnt != 0):
                    neg_ctx.append(elastic_pair[query][-1])
                    cnt -= 1
                    #print(cnt)
                #print(i)
                break
    
 
    q_seqs = tokenizer(
        q_list,
        max_length=max_question_seq_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    p_seqs = tokenizer(
        ctx,
        max_length=max_context_seq_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    print(p_seqs['input_ids'].size())
    np_seqs = tokenizer(
        neg_ctx,
        max_length=max_context_seq_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    max_len = np_seqs['input_ids'].size(-1)
    np_seqs['input_ids'] = np_seqs['input_ids'].view(-1 , num_neg, max_len)
    np_seqs['attention_mask'] = np_seqs['attention_mask'].view(-1, num_neg, max_len)
    np_seqs['token_type_ids'] = np_seqs['token_type_ids'].view(-1, num_neg, max_len)
    
    tensor_dataset = TensorDataset(
        (p_seqs["input_ids"], np_seqs["input_ids"]),
        (p_seqs["attention_mask"],np_seqs["attention_mask"]),
        (p_seqs["token_type_ids"],np_seqs["token_type_ids"]),
        q_seqs["input_ids"],
        q_seqs["attention_mask"],
        q_seqs["token_type_ids"],
    )
    # neg_tensor_dataset = TensorDataset(
    #     np_seqs["input_ids"],
    #     np_seqs["attention_mask"],
    #     np_seqs["token_type_ids"],
    # )

    return tensor_dataset #, neg_tensor_dataset

In [82]:
class InBatchNegativeRandomDataset(Dataset):
    def __init__(self, data_path: str,bm25_path: str,max_context_seq_length: int,max_question_seq_length: int,neg_num,tokenizer):
        preprocess_data = self.preprocess_pos_neg(data_path, bm25_path, max_context_seq_length, max_question_seq_length,neg_num, tokenizer)
        
        self.p_input_ids = preprocess_data[0]
        self.p_attension_mask = preprocess_data[1]
        self.p_token_type_ids = preprocess_data[2]
        
        self.np_input_ids = preprocess_data[3]
        self.np_attension_mask = preprocess_data[4]
        self.np_token_type_ids = preprocess_data[5]
        
        self.q_input_ids = preprocess_data[6]
        self.q_attension_mask = preprocess_data[7]
        self.q_token_type_ids = preprocess_data[8]
    
    def __len__(self):
        return self.p_input_ids.size()[0]
    def __getitem__(self, index):
        return (self.p_input_ids[index],self.p_attension_mask[index],self.p_token_type_ids[index],
        self.np_input_ids[index],self.np_attension_mask[index],self.np_token_type_ids[index],
        self.q_input_ids[index],self.q_attension_mask[index],self.q_token_type_ids[index])

    def preprocess_pos_neg(self, data_path: str,bm25_path: str,max_context_seq_length: int,max_question_seq_length: int, neg_num,tokenizer):
        dataset = load_from_disk(data_path).to_pandas()
        dataset["context"] = dataset["context"].apply(self.preprocess_text)
        ctx = dataset["context"].to_list()
        q_list = dataset["question"].to_list()
        with open(bm25_path, "rb") as file:
            elastic_pair = pickle.load(file)
        neg_ctx = []
        num_neg = neg_num
        for i in tqdm(range(len(dataset))):
            #print(i)
            query = dataset["question"][i]
            ground_truth = ctx[i]
            answer = dataset["answers"][i]["text"][0]
            cnt = num_neg
            idx = 0
            while cnt != 0:
                if ground_truth != elastic_pair[query][idx] and not (answer in elastic_pair[query][idx]):
                    # 비슷한 context를 추가하되 정답을 포함하지 않는 문장을 추가한다.
                    neg_ctx.append(elastic_pair[query][idx])
                    cnt -= 1
                idx += 1
                if idx == 200:  # index를 넘어가면 마지막에 추가된 context를 채워줌
                    #print('in', cnt)
                    while(cnt != 0):
                        neg_ctx.append(elastic_pair[query][-1])
                        cnt -= 1
                    break
        q_seqs = tokenizer(
            q_list,
            max_length=max_question_seq_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        p_seqs = tokenizer(
            ctx,
            max_length=max_context_seq_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        np_seqs = tokenizer(
            neg_ctx,
            max_length=max_context_seq_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        max_len = np_seqs['input_ids'].size(-1)
        np_seqs['input_ids'] = np_seqs['input_ids'].view(-1 , num_neg, max_len)
        np_seqs['attention_mask'] = np_seqs['attention_mask'].view(-1, num_neg, max_len)
        np_seqs['token_type_ids'] = np_seqs['token_type_ids'].view(-1, num_neg, max_len)

        return (p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'], 
        np_seqs['input_ids'], np_seqs['attention_mask'],np_seqs['token_type_ids'],
        q_seqs['input_ids'],q_seqs['attention_mask'],np_seqs['token_type_ids'])
    def preprocess_text(self,text):
        text = re.sub(r"\n", " ", text)
        text = re.sub(r"\\n", " ", text)  # remove newline character
        text = re.sub(r"\s+", " ", text)  # remove continuous spaces
        text = re.sub(r"#", " ", text)
        return text


In [83]:
train_dataset = InBatchNegativeRandom(
    "/opt/ml/data/train_dataset/train",
    '/opt/ml/data/elastic_new_train_500.bin',
    512,
    64,
    5,
    tokenizer
)

100%|██████████| 3952/3952 [00:00<00:00, 43366.16it/s]


In [84]:
# sample size, emd size
# sample size * neg cnt, emb size

len(train_dataset)

3952

In [85]:
p1,p2,p3,n1,n2,n3,q1,q2,q3 = train_dataset[0]


In [86]:
train_dataloader = DataLoader(train_dataset,batch_size=16,drop_last=True)

In [91]:
import random

In [111]:
for batch in train_dataloader:
    print(batch[0].size())
    # print(batch[1])
    # print(batch[2])

    neg_batch_ids = []
    neg_batch_att = []
    neg_batch_tti = []
    neg_random_idx = random.randrange(0,5)
    for i in range(16): # 배치만큼
        neg_batch_ids.append(batch[3][:][i][neg_random_idx].unsqueeze(0))
        neg_batch_att.append(batch[4][:][i][neg_random_idx].unsqueeze(0))
        neg_batch_tti.append(batch[5][:][i][neg_random_idx].unsqueeze(0))
    print(batch[3].size())
    #print(batch[3].size()[-2]) # 랜덤 index 범위
    # print(batch[3][:][2])
    # print(batch[3][:][3])
    neg_batch_ids = torch.cat(neg_batch_ids)
    neg_batch_att = torch.cat(neg_batch_att)
    neg_batch_tti = torch.cat(neg_batch_tti)

    print(neg_batch_ids.size())

    # print(batch[4][:][0])
    # print(batch[5][:][0])

    # print(batch[6])
    # print(batch[7])
    # print(batch[8])
    break

torch.Size([16, 512])
torch.Size([16, 5, 512])
torch.Size([16, 512])


In [110]:
random.randrange(0,10)

2