## 1. init model
### 1.1 config model

In [15]:
from pathlib import Path
import numpy as np

In [1]:
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from modeling_rankprompter import RankPrompter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

prompter_tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
prompter_config = AutoConfig.from_pretrained("google/umt5-small")
# baichuan
language_model_config = AutoConfig.from_pretrained("baichuan-inc/Baichuan-7B", trust_remote_code=True)
language_model_tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-7B", trust_remote_code=True)
language_model_tokenizer.pad_token_id = language_model_config.pad_token_id


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/howard/miniconda3/envs/torch1.13/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
ERROR: /home/howard/miniconda3/envs/torch1.13/bin/python: undefined symbol: cudaRuntimeGetVersion
CUDA SETUP: libcudart.so path is None
CUDA SETUP: Is seems that your cuda installation is not in your path. See https://github.com/TimDettmers/bitsandbytes/issues/85 for more information.
CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 00
CUDA SETUP: Loading binary /home/howard/miniconda3/envs/torch1.13/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so...


  warn("The installed version of bitsandbytes was compiled without GPU support. "
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)


In [2]:
from misc import count_parameters
prompter_config.num_soft_prompt_tokens = 32
prompter_config.llm_dim = language_model_config.hidden_size
prompter = RankPrompter(prompter_config).to(device)
trainable_params, all_param = count_parameters(prompter)
print("prompter trainable params: {:.2f}B || all params: {:.2f}B || trainable%: {:.4f}".format(
        trainable_params / 1e9, all_param / 1e9, 100 * trainable_params / all_param
))

prompter trainable params: 0.44B || all params: 0.44B || trainable%: 100.0000


## 2. Init Dataset
### 2.1 load dataset

In [3]:
import json
from datasets import load_from_disk

document_path = "wikipedia-cn-20230720-documents_10k.json"
qa_path = "wikipedia-cn-20230720_qa-with-retrieval_10k/"

docid2doc = {d["docid"]:d["document"] for d in json.load(open(document_path))}

qa_dataset = load_from_disk(qa_path)



### 2.2 tokenize dataset

In [4]:
def preprocess_dataset(example):
    # 
    num_doc = 20
    doc_max_length = 512
    ques_max_length = 32
    ans_max_length = 128
    #
    pos_docid = example["docid"]
    # put pos_docid in the first place
    docids = [pos_docid] + [docid for docid in example["retrieved_docids"] if docid != pos_docid]
    docs = [docid2doc[docid] for docid in docids[:num_doc]]
    # padding to specific length, make all example have the same shape
    prompter_tokenzied_docs = prompter_tokenizer(docs, padding="max_length", 
                                                truncation=True, max_length=doc_max_length)
    prompter_tokenzied_question = prompter_tokenizer(example["question"], padding="max_length", 
                                                truncation=True, max_length=ques_max_length)
    prompter_tokenzied_answer = prompter_tokenizer(example["answer"], padding="max_length", 
                                                truncation=True, max_length=ans_max_length)
    language_model_tokenzied_question = language_model_tokenizer(example["question"], padding="max_length",
                                                truncation=True, max_length=ques_max_length)
    language_model_tokenzied_answer = language_model_tokenizer(example["answer"], padding="max_length",
                                                truncation=True, max_length=ans_max_length)
    return {"document_input_ids": prompter_tokenzied_docs.input_ids,
            "document_attention_mask": prompter_tokenzied_docs.attention_mask,
            "prompter_question_input_ids": prompter_tokenzied_question.input_ids,
            "prompter_question_attention_mask": prompter_tokenzied_question.attention_mask,
            "prompter_answer_input_ids": prompter_tokenzied_answer.input_ids,
            "prompter_answer_attention_mask": prompter_tokenzied_answer.attention_mask,
            "language_model_question_input_ids": language_model_tokenzied_question.input_ids,
            "language_model_question_attention_mask": language_model_tokenzied_question.attention_mask,
            "language_model_answer_input_ids": language_model_tokenzied_answer.input_ids,
            "language_model_answer_attention_mask": language_model_tokenzied_answer.attention_mask}

In [5]:
tokenized_qa_dataset = qa_dataset.map(preprocess_dataset, 
                                    num_proc=16, 
                                    remove_columns=["retrieved_docids"]).with_format("torch")

                 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-74e2008293004796.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-fb09bec8ab5a6c4a.arrow


  

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-691dc1b90a74e5f1.arrow
Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-add749843132a27f.arrow


  

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-594bfa22304c566d.arrow
Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-919b159bb6e2f7f9.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-6318c30b23dea7c3.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-05a4d046a66b9b2c.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-c19c732dc40e31dc.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-8c42a1ff33f35aaa.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-b0d06cc8ea3330a0.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-76653ed279e8d08f.arrow


  

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-b1f0f132646e8287.arrow
Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-3857c8b0f80994d9.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-8a20cdf5f38e6d7c.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/train/cache-a1cd63a5c784a3b0.arrow


                  

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-bcdcdc7e268e3f5f.arrow
Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-081f458ec29b8c58.arrow


   

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-c9a9809ee3b0ef2d.arrow
Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-08522d06b722d3c2.arrow
Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-581d0217323ab761.arrow


    

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-fe1c929c32aa2f7a.arrow
Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-1f046ef237a8175c.arrow
Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-09d278575afaf9cf.arrow
Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-608fc1f2a7cbe5ec.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-25f49b7325ee34a2.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-2f383080372e8afd.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-d186407a847a120b.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-4282e34bdcacd3c4.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-189ba4d91900efe9.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-9698ba4c6c4c7e57.arrow


 

Loading cached processed dataset at /data/Documents/MyCode/RankPrompter/wikipedia-cn-20230720_qa-with-retrieval_10k/test/cache-53216c8f5accc117.arrow


In [6]:
tokenized_qa_dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'docid', 'document_input_ids', 'document_attention_mask', 'prompter_question_input_ids', 'prompter_question_attention_mask', 'prompter_answer_input_ids', 'prompter_answer_attention_mask', 'language_model_question_input_ids', 'language_model_question_attention_mask', 'language_model_answer_input_ids', 'language_model_answer_attention_mask'],
        num_rows: 55192
    })
    test: Dataset({
        features: ['question', 'answer', 'docid', 'document_input_ids', 'document_attention_mask', 'prompter_question_input_ids', 'prompter_question_attention_mask', 'prompter_answer_input_ids', 'prompter_answer_attention_mask', 'language_model_question_input_ids', 'language_model_question_attention_mask', 'language_model_answer_input_ids', 'language_model_answer_attention_mask'],
        num_rows: 6133
    })
})

### 2.3 init dataloader

In [7]:
from torch.utils.data import DataLoader
from tqdm import tqdm
# Create a DataLoader with the desired batch size
batch_size = 2 # if gradient_accumulation_steps > 1, this is the micro-batch size
gradient_accumulation_steps = 8 # accumulate gradients over n batches

train_dataloader = DataLoader(tokenized_qa_dataset["train"], batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(tokenized_qa_dataset["test"], batch_size=batch_size, shuffle=False)

## 3. train
### 3.1 config optimizer and scheduler

In [8]:
import inspect
from transformers import get_polynomial_decay_schedule_with_warmup
# optimizer config 
learning_rate = 1e-4
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(prompter.parameters(), lr=learning_rate, **extra_args)
print(f"using fused AdamW: {use_fused}")
# scheduler config
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader) // gradient_accumulation_steps
lr_scheduler = get_polynomial_decay_schedule_with_warmup(
    optimizer=optimizer,  # scheduler是针对optimizer的lr的
    lr_end=1e-7,
    power=1, # 当power=1时（默认）等价于linear_schedule_with_warmup
    num_warmup_steps=1000 // gradient_accumulation_steps,
    num_training_steps=num_training_steps)
print(f"num_training_steps: {num_training_steps}")

using fused AdamW: False
num_training_steps: 10348


In [9]:
# just for check lr scheduler, which make scheduler empty, not run this cell when training
# from matplotlib import pyplot as plt
# lst = []
# for _ in range(num_training_steps):
#     lr_scheduler.step()
#     lst.append(lr_scheduler.get_lr())
# plt.plot(lst)

### 3.2 traininig

In [16]:
# helps estimate the loss of the model
@torch.no_grad()
def evaluate_prompter(model, dataloader):
    out = {}
    model.eval()
    losses = []
    for batch in tqdm(dataloader):
        document_input_ids = batch["document_input_ids"].to(device)
        document_attention_mask = batch["document_attention_mask"].to(device)
        prompter_question_input_ids = batch["prompter_question_input_ids"].to(device)
        prompter_question_attention_mask = batch["prompter_question_attention_mask"].to(device)
        prompter_output = prompter(
            document_input_ids=document_input_ids,
            document_attention_mask=document_attention_mask,
            question_input_ids=prompter_question_input_ids,
            question_attention_mask=prompter_question_attention_mask,
        )
        loss = prompter_output.loss.item()
        losses.append(loss)
    out["ranker_val_loss"] = np.mean(losses)
    model.train()
    return out

In [10]:
pbar = tqdm(total=num_training_steps)
step = 0 # total steps = num_training_steps * gradient_accumulation_steps
for epoch in range(num_epochs):
    # Iterate through batches
    for batch in train_dataloader:
        document_input_ids = batch["document_input_ids"].to(device)
        document_attention_mask = batch["document_attention_mask"].to(device)
        prompter_question_input_ids = batch["prompter_question_input_ids"].to(device)
        prompter_question_attention_mask = batch["prompter_question_attention_mask"].to(device)
        prompter_output = prompter(
            document_input_ids=document_input_ids,
            document_attention_mask=document_attention_mask,
            question_input_ids=prompter_question_input_ids,
            question_attention_mask=prompter_question_attention_mask,
        )
        loss = prompter_output.loss / gradient_accumulation_steps
        loss.backward()
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            pbar.update(1)
        step += 1
    

100%|██████████| 10348/10348 [8:51:49<00:00,  3.08s/it] 

In [14]:
model_output_dir = Path("saved_model/ranker")
model_output_dir.mkdir(exist_ok=True, parents=True)
checkpoint = {
    'model': prompter.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': prompter_config,
    'iter_num': step,
    'best_val_loss': None,
    'config': None,
}
print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, out_dir /'ckpt.pt')

saving checkpoint to saved_model/ranker


In [17]:
eval_results = evaluate_prompter(prompter, test_dataloader)

In [18]:
eval_results

{'ranker_val_loss': 0.18672097759533285}

In [21]:
from torch import tensor
from torchmetrics.retrieval import RetrievalRecall
indexes = tensor([0, 0, 0, 1, 1, 1, 1])
preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
target = tensor([False, False, True, False, True, False, True])
r2 = RetrievalRecall(top_k=1)
r2(preds, target, indexes=indexes)

tensor(0.5000)