In [1]:
import argparse
import transformers
import torch
import random
from tqdm.auto import tqdm
from datasets import DatasetDict, load_from_disk, load_metric
import faiss
import json
from sklearn.preprocessing import OneHotEncoder

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import wandb

import logging

# 로그 파일 경로와 파일명
log_file_path = 'log.txt'

# 로거 생성
logger = logging.getLogger('my_logger')
logger.setLevel(logging.DEBUG)

# 파일 핸들러 생성
file_handler = logging.FileHandler(log_file_path)
file_handler.setLevel(logging.DEBUG)
logger.addHandler(file_handler)

# seed 고정
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [51]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, inputs, contexts, targets=[], target_labels=[]):
        self.inputs = inputs
        self.contexts = contexts
        self.targets = targets
        self.target_labels = target_labels

    def __getitem__(self, idx):
        if len(self.targets) == 0:
            return {"input_ids": self.inputs[idx]["input_ids"],
                    "attention_mask": self.inputs[idx]["attention_mask"],
                    "context": self.contexts[idx]}
        else:
            return {"input_ids": self.inputs[idx]["input_ids"],
                    "attention_mask": self.inputs[idx]["attention_mask"],
                    "context": self.contexts[idx],
                    'target_input_ids': self.targets[idx]["input_ids"],
                    'target_attention_mask': self.targets[idx]["attention_mask"],
                    'target_labels': self.target_labels[idx]["input_ids"]
                    }

    def __len__(self):
        return len(self.inputs)

In [52]:
class Dataloader(pl.LightningDataModule):
    def __init__(self, q_model_name, gen_model_name, batch_size, shuffle, train_path, dev_path, test_path, predict_path):
        super().__init__()
        self.q_model_name = q_model_name
        self.gen_model_name = gen_model_name
        self.batch_size = batch_size
        self.shuffle = shuffle

        self.train_path = train_path
        self.dev_path = dev_path
        self.test_path = test_path
        self.predict_path = predict_path

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.predict_dataset = None

        self.query_tokenizer = transformers.AutoTokenizer.from_pretrained(q_model_name)
        self.generation_tokenizer = transformers.AutoTokenizer.from_pretrained(gen_model_name)
        

    def tokenizing(self, data):
        input_data = []
        context_data = []
        target_data = []
        target_label_data = []

        for item in tqdm(data, desc='tokenizing', total=len(data)):
            question = item['question']
            context_data.append(item['question']) #1

            q_outputs = self.query_tokenizer(question, add_special_tokens=True, padding='max_length', truncation=True)
            for key in q_outputs:
                q_outputs[key] = torch.tensor(q_outputs[key], dtype=torch.long)

            try:
                answer = '<s>'+item['answers']['text'][0]+'</s>'
                target_answer = item['answers']['text'][0]+'</s>'
                a_outputs = self.generation_tokenizer(answer, max_length=30, add_special_tokens=True, padding='max_length', truncation=True, return_token_type_ids=False)
                a_target_outputs = self.generation_tokenizer(target_answer, max_length=30, add_special_tokens=True, padding='max_length', truncation=True, return_token_type_ids=False)
                
                for key in a_outputs.keys():
                    a_outputs[key] = torch.tensor(a_outputs[key], dtype=torch.long)
                    target_data.append(a_outputs) #2

                    a_target_outputs[key] = torch.tensor(a_target_outputs[key], dtype=torch.long)
                    target_label_data.append(a_target_outputs) #3
            except:
                pass

            input_data.append(q_outputs) #4

        return input_data, context_data, target_data, target_label_data

    def preprocessing(self, data):
        inputs, contexts, targets, labels = self.tokenizing(data)
        return inputs, contexts, targets, labels

    def setup(self, stage='fit'):
        if stage == 'fit':
            # 학습 데이터와 검증 데이터셋을 호출합니다
            datasets = load_from_disk(self.train_path)
            train_datasets, val_datasets = datasets['train'], datasets['validation']

            # 학습데이터 준비
            train_inputs, train_contexts, train_targets, train_labels = self.preprocessing(train_datasets)

            # 검증데이터 준비
            val_inputs, val_contexts ,val_targets, val_labels = self.preprocessing(val_datasets)

            # train 데이터만 shuffle을 적용해줍니다, 필요하다면 val, test 데이터에도 shuffle을 적용할 수 있습니다
            self.train_dataset = Dataset(train_inputs, train_contexts, train_targets, train_labels)
            self.val_dataset = Dataset(val_inputs, val_contexts, val_targets, val_labels)
        else:
            # 평가데이터 준비
            test_data = load_from_disk(self.test_path)
            test_inputs, test_contexts, test_targets, test_labels = self.preprocessing(test_data['validation'])
            self.test_dataset = Dataset(test_inputs, test_contexts, test_targets, test_labels)

            predict_data = load_from_disk(self.predict_path)
            predict_inputs, predict_contexts, _ , _  = self.preprocessing(predict_data['validation'])
            self.predict_dataset = Dataset(predict_inputs, predict_contexts)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size)

    def predict_dataloader(self):
        return torch.utils.data.DataLoader(self.predict_dataset, batch_size=self.batch_size)

In [53]:
def retrieval(q, index, wiki_doc_db, r_num=10):
    q_ = q.detach().cpu().numpy()
    score, relevant_wiki_id = index.search(q_, r_num) # (batch_size, num_retrieval = r_num), (batch_size, num_retrieval = r_num)
    batch_retrieved_docs = []
    for b in range(relevant_wiki_id.shape[0]):
        retrieved_docs = [wiki_doc_db[str(id)] for id in relevant_wiki_id[b]] # ( num_retrieval by *doc_length) : list
        batch_retrieved_docs.append(retrieved_docs)

    return score, batch_retrieved_docs # (batch_size, num_retrieval) : list, dtype:int / (batch_size, num_retrieval by *doc_length) : list, dtype:str

def concat_and_tokenize(context, batch_retrieved_docs, tokenizer):
    '''
    context : (batch_size by *context_length) : list
    batch_retrieved_docs : (batch_size, num_retrieval by *doc_length) : list, dtype:str
    '''
    # concat context and retrieved_docs by batch and tokenizing
    tokenized_context_docs = []
    for c, docs in zip(context, batch_retrieved_docs):
        '''
        c : (*context_length) : list
        docs : (num_retrieval by *doc_length) : list
        '''
        c_ = [c]*len(docs) # (num_retrieval by *context_length) : list

        tokenized = tokenizer(c_, docs, add_special_tokens=True, max_length=512, padding='max_length', truncation=True, return_token_type_ids=False)
        for key in tokenized:
            tokenized[key] = torch.tensor(tokenized[key], dtype=torch.long).to('cuda')
        tokenized_context_docs.append(tokenized)
        
    return tokenized_context_docs # (batch_size, num_retrieval by dict[consists of 'input_ids', 'attention_mask']) : list


In [60]:
class TrainModel(pl.LightningModule):
    def __init__(self, q_model_name, gen_model_name , wiki_doc_db, gen_tokenizer ,lr, r_num=10):
        super().__init__()
        self.save_hyperparameters()

        self.q_model_name = q_model_name
        self.gen_model_name = gen_model_name
        self.gen_tokenizer = gen_tokenizer
        self.lr = lr
        self.r_num = r_num

        self.index = faiss.read_index('sent_emb.index') # ntotal wikidata, dim : 768
        # self.wiki_db = torch.load('index_to_vector.pt') # (ntotal, emb_dim=768)
        self.wiki_doc_db = wiki_doc_db # 

        self.query_encoder = transformers.AutoModel.from_pretrained(q_model_name, cache_dir='./tmp').to("cuda")
        self.encoder_decoder_generation_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(gen_model_name, cache_dir='./tmp').to("cuda")
        self.softmax = torch.nn.Softmax(dim=1)

        self.loss_func = torch.nn.CrossEntropyLoss(ignore_index=self.encoder_decoder_generation_model.config.pad_token_id, reduction='mean')
    
    def forward(self, **x):
        '''
        x['input_ids'] : (batch_size, max_length),
        x['attention_mask'] : (batch_size, max_length)
        x['context'] : (batch_size) : list
        x['target_input_ids'] : (batch_size, answer_max_length),
        x['target_attention_mask'] : (batch_size, answer_max_length)
        x['target_labels'] : (batch_size, answer_max_length)
        '''

        q = self.query_encoder(**{'input_ids':x['input_ids'].to('cuda'), 'attention_mask':x['attention_mask'].to('cuda')})[1] # pooler_output : (batch_size, 768)
        
        # retrieval
        score, retrieved_docs = retrieval(q, self.index, self.wiki_doc_db, self.r_num) 

        # print('retrieved_docs : \n',len(retrieved_docs))
        # query-document concat and tokenizing
        gen_encoder_inputs = concat_and_tokenize(x['context'], retrieved_docs, self.gen_tokenizer) # (batch_size, num_retrieval by dict[consists of 'input_ids', 'attention_mask']) : list
        
        # generation per batch
        outputs = []
        # print(x['input_ids'].size(0))
        for b in range(x['input_ids'].size(0)):
            # print('gen_input_ids: ',gen_encoder_inputs[b]['input_ids'])
            # print('gen_attention_mask: ',gen_encoder_inputs[b]['attention_mask'])
            # print('decoder_input_ids: ',x['target_input_ids'].size())
            # print('decoder_attention_mask: ',x['target_attention_mask'].size())

            output = self.encoder_decoder_generation_model(gen_encoder_inputs[b]['input_ids'],
                                                            attention_mask=gen_encoder_inputs[b]['attention_mask']
                                                           ,decoder_input_ids=x['target_input_ids'][b].unsqueeze(0).expand(self.r_num,-1).to('cuda'), 
                                                           decoder_attention_mask=x['target_attention_mask'][b].unsqueeze(0).expand(self.r_num,-1).to('cuda'))
            outputs.append(output.logits) # (num_retrieval, answer_max_length, vocab_size)
        outputs = torch.stack(outputs, dim=0) # (batch_size, num_retrieval, answer_max_length, vocab_size)
        
        # weighted mean outputs with score
        scores = self.softmax(torch.tensor(score)[:,:,None,None]) # (batch_size, num_retrieval, 1, 1)
        # (batch_size, num_retrieval, answer_max_length, vocab_size) -> (batch_size, answer_max_length, vocab_size)
        weighted_sum_outputs = torch.sum(scores*outputs.detach().cpu(), dim=1) 

        return weighted_sum_outputs.to('cuda') # (batch_size, answer_max_length, vocab_size)

    def training_step(self, batch, batch_idx):
        x = batch
        y = batch['target_labels'] # (batch_size, answer_max_length)

        logits = self(**x)
        
        logits = logits.swapaxes(1, 2) # (batch_size, vocab_size, answer_max_length)
        loss = self.loss_func(logits, y)
        self.log("train_loss", loss.item())

        return loss
    
    def on_validation_epoch_start(self):
        self.em = 0
        self.count = 0

    def validation_step(self, batch, batch_idx):
        x = batch
        y = batch['target_labels'] # (batch_size, answer_max_length)

        logits = self(**x)
        
        logits = logits.swapaxes(1, 2) # (batch_size, vocab_size, answer_max_length)
        loss = self.loss_func(logits, y)
        self.log("val_loss", loss)

        # em metric
        predicts = torch.argmax(logits, dim=-2)
        for s, pred, y in zip(x['target_labels'], predicts,y):
            try:
                logger.debug(f'x: {s},\npred : {pred},\ny : {y}')
                pred_end = pred.tolist().index(1) # </s> token idx : 1
                y_end = y.tolist().index(1)
                # logger.debug(f'pred : {pred[:pred_end]}, y : {y[:y_end]}')
                if all(pred[:pred_end] == y[:y_end]):
                    self.em += 1
                else:
                    pass
            except:
                pass
        self.count += batch['input_ids'].size(0)

        return loss
    
    
    def on_validation_epoch_end(self):
        self.log("val_em", self.em/self.count)
        logger.debug(f'em : {self.em/self.count}')

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(**x)

    def predict_step(self, batch, batch_idx):
        x = batch
        logits = self(**x)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer
    

In [61]:
config = {"q_model_name":'klue/bert-base',
              "gen_model_name":'gogamza/kobart-base-v2',
              "model_detail" : "v2",

              "batch_size": 8, 
              "shuffle":True,
              "learning_rate":1e-5,
              "rm_num":5,
              "epoch": 10,

              "train_path":'./data/train_dataset', 
              "dev_path":'./data/train_dataset',
              "test_path":'./data/train_dataset', 
              "predict_path":'./data/test_dataset',
              }

In [62]:
dataloader = Dataloader(config["q_model_name"], config["gen_model_name"],config["batch_size"],
                            config["shuffle"], config["train_path"], config["dev_path"],
                            config["test_path"], config["predict_path"])
    
with open('unique_wiki_passages.json', 'r') as f:
    wiki_doc_db = json.load(f)
    
model = TrainModel(config["q_model_name"], config["gen_model_name"], wiki_doc_db, dataloader.generation_tokenizer, config["learning_rate"], config["rm_num"])

early_stop_custom_callback = EarlyStopping(
        "val_loss", patience=3, verbose=True, mode="min"
    )

checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        save_top_k=1,
        dirpath="./",
        filename='_'.join(config["q_model_name"].split()+config["gen_model_name"].split() + config["model_detail"].split()), # model에 따라 변화
        save_weights_only=False,
        verbose=True,
        mode="min",
    )

trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=config["epoch"], callbacks=[checkpoint_callback,early_stop_custom_callback],log_every_n_steps=1) # ,logger=wandb_logger

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [63]:
# 학습
trainer.fit(model=model, datamodule=dataloader)

You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
tokenizing: 100%|██████████| 3952/3952 [00:01<00:00, 3903.02it/s]
tokenizing: 100%|██████████| 240/240 [00:00<00:00, 4030.64it/s]
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                             | Type                         | Params
----------------------------------------------------------------------------------
0 | query_encoder                    | BertModel                    | 110 M 
1 | encoder_decoder_generation_model | BartForConditionalGeneration | 123 M 
2 | softmax                          | Softmax                      |

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:05<00:05,  5.79s/it]



                                                                           

  rank_zero_warn(


Epoch 0:   0%|          | 0/494 [00:00<?, ?it/s] 

KeyboardInterrupt: 

## Debug

In [56]:
with open('unique_wiki_passages.json', 'r') as f:
    wiki_doc_db = json.load(f)

In [42]:
dataloader = Dataloader(config["q_model_name"], config["gen_model_name"],config["batch_size"],
                            config["shuffle"], config["train_path"], config["dev_path"],
                            config["test_path"], config["predict_path"])
dataloader.setup()
dataloader.setup('test')

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
tokenizing: 100%|██████████| 3952/3952 [00:00<00:00, 4011.62it/s]
tokenizing: 100%|██████████| 240/240 [00:00<00:00, 4116.30it/s]
tokenizing: 100%|██████████| 240/240 [00:00<00:00, 4189.97it/s]
tokenizing: 100%|██████████| 600/600 [00:00<00:00, 1893.17it/s]


In [43]:
model = TrainModel(config["q_model_name"], config["gen_model_name"], wiki_doc_db, dataloader.generation_tokenizer, config["learning_rate"], config["rm_num"]).to('cuda')

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.


In [45]:
sample_batch = next(iter(dataloader.train_dataloader()))

In [47]:
model.eval()
with torch.no_grad():
    sample_output = model(**sample_batch)

sample_output

retrieved_docs : 
 8
8
gen_input_ids:  tensor([[14031, 10476, 14318,  ...,     3,     3,     3],
        [14031, 10476, 14318,  ...,     3,     3,     3],
        [14031, 10476, 14318,  ..., 17603, 15833, 15178],
        [14031, 10476, 14318,  ...,     3,     3,     3],
        [14031, 10476, 14318,  ...,     3,     3,     3]], device='cuda:0')
gen_attention_mask:  tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')
decoder_input_ids:  torch.Size([8, 30])
decoder_attention_mask:  torch.Size([8, 30])
gen_input_ids:  tensor([[14061, 11280, 13758,  ...,     3,     3,     3],
        [14061, 11280, 13758,  ...,     3,     3,     3],
        [14061, 11280, 13758,  ..., 15105, 14611, 21032],
        [14061, 11280, 13758,  ...,     3,     3,     3],
        [14061, 11280, 13758,  ...,     3,     3,     3]], device='cuda:0')
gen_attention_mask:  tensor([[1, 1,

tensor([[[  0.3069,  15.0308,  -4.1431,  ...,  -3.3671,  -2.0973,  -2.0289],
         [ -7.4116,   3.9579,  -7.8030,  ...,  -9.1004,  -7.4681,  -5.0446],
         [ -0.3583,  13.2390,  -4.5103,  ...,  -3.6807,  -5.2330,  -1.0779],
         ...,
         [  4.9943,  13.0604,   4.4204,  ...,   5.3715,   3.6483,   2.4233],
         [  5.4812,  13.8295,   4.7953,  ...,   5.5017,   4.0653,   2.6994],
         [  4.7982,  13.5762,   4.1617,  ...,   4.9586,   3.3007,   2.1863]],

        [[ -2.3386,  13.6352,  -7.4798,  ...,  -3.5879,  -4.5814,  -6.5483],
         [ -2.2348,  10.6086,  -2.2230,  ...,  -0.4062,  -4.9052,  -0.0586],
         [ -1.8402,   8.1862,  -3.3572,  ...,  -2.5393,   0.2993,  -4.6089],
         ...,
         [ -1.9475,   5.5597,  -1.4390,  ...,   0.7220,  -3.1223,  -2.9362],
         [ -1.3532,   6.4398,  -0.8055,  ...,   1.2200,  -2.4596,  -2.4568],
         [ -1.4832,   6.9156,  -0.9540,  ...,   1.2626,  -2.8209,  -2.4839]],

        [[  2.5680,  15.5362,  -0.8978,  ...

In [48]:
sample_output.size()

torch.Size([8, 30, 30000])

In [64]:
import torch, gc
gc.collect()

11230

In [65]:
torch.cuda.empty_cache()
gc.collect()

0