## 1. init model
### 1.1 config model

In [1]:
from pathlib import Path
import numpy as np

In [2]:
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from modeling_rankprompter2 import RankPrompter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

prompter_tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
prompter_config = AutoConfig.from_pretrained("google/umt5-small")
# baichuan
language_model_config = AutoConfig.from_pretrained("baichuan-inc/Baichuan-7B", trust_remote_code=True)
language_model_tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-7B", trust_remote_code=True)
language_model_tokenizer.pad_token_id = language_model_config.pad_token_id

In [3]:
from misc import count_parameters
prompter_config.num_soft_prompt_tokens = 32
prompter_config.llm_dim = language_model_config.hidden_size
prompter = RankPrompter(prompter_config).to(device)
trainable_params, all_param = count_parameters(prompter)
print("prompter trainable params: {:.2f}B || all params: {:.2f}B || trainable%: {:.4f}".format(
        trainable_params / 1e9, all_param / 1e9, 100 * trainable_params / all_param
))

prompter trainable params: 0.31B || all params: 0.31B || trainable%: 100.0000


In [4]:
# language_model_path = "/data2/huggingface_models/baichuan-inc/Baichuan-7B"
language_model_path = "/root/autodl-tmp/Baichuan-7B/"
language_model = AutoModelForCausalLM.from_pretrained(language_model_path, 
                device_map=device, trust_remote_code=True, torch_dtype=torch.bfloat16)
language_model.requires_grad_(False) # fix all model params
trainable_params, all_param = count_parameters(language_model)
print("language model trainable params: {:.2f}B || all params: {:.2f}B || trainable%: {:.4f}".format(
        trainable_params / 1e9, all_param / 1e9, 100 * trainable_params / all_param
))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

language model trainable params: 0.00B || all params: 7.00B || trainable%: 0.0000


In [5]:
# load prompter if needed
# ckpt = torch.load("/root/autodl-tmp/saved_model/prompter/best_ckpt.pt")
# prompter.load_state_dict(ckpt["model"])

In [6]:
prompter.soft_prompt_embeds

AttributeError: 'RankPrompter' object has no attribute 'soft_prompt_embeds'

## 2. Init Dataset
### 2.1 load dataset

In [7]:
import json
from datasets import load_from_disk

document_path = "wikipedia-cn-20230720-documents_10k.json"
qa_path = "wikipedia-cn-20230720_qa-with-retrieval_10k/"

docid2doc = {d["docid"]:d["document"] for d in json.load(open(document_path))}

qa_dataset = load_from_disk(qa_path)

### 2.2 tokenize dataset

In [8]:
def preprocess_dataset(example):
    # 
    num_doc = 20
    doc_max_length = 512
    ques_max_length = 32
    ans_max_length = 128
    #
    pos_docid = example["docid"]
    # put pos_docid in the first place
    docids = [pos_docid] + [docid for docid in example["retrieved_docids"] if docid != pos_docid]
    docs = [docid2doc[docid] for docid in docids[:num_doc]]
    # padding to specific length, make all example have the same shape
    prompter_tokenzied_docs = prompter_tokenizer(docs, padding="max_length", 
                                                truncation=True, max_length=doc_max_length)
    prompter_tokenzied_question = prompter_tokenizer(example["question"], padding="max_length", 
                                                truncation=True, max_length=ques_max_length)
    prompter_tokenzied_answer = prompter_tokenizer(example["answer"], padding="max_length", 
                                                truncation=True, max_length=ans_max_length)
    language_model_tokenzied_question = language_model_tokenizer(example["question"], padding="max_length",
                                                truncation=True, max_length=ques_max_length)
    language_model_tokenzied_answer = language_model_tokenizer(example["answer"], padding="max_length",
                                                truncation=True, max_length=ans_max_length)
    return {"document_input_ids": prompter_tokenzied_docs.input_ids,
            "document_attention_mask": prompter_tokenzied_docs.attention_mask,
            "prompter_question_input_ids": prompter_tokenzied_question.input_ids,
            "prompter_question_attention_mask": prompter_tokenzied_question.attention_mask,
            "prompter_answer_input_ids": prompter_tokenzied_answer.input_ids,
            "prompter_answer_attention_mask": prompter_tokenzied_answer.attention_mask,
            "language_model_question_input_ids": language_model_tokenzied_question.input_ids,
            "language_model_question_attention_mask": language_model_tokenzied_question.attention_mask,
            "language_model_answer_input_ids": language_model_tokenzied_answer.input_ids,
            "language_model_answer_attention_mask": language_model_tokenzied_answer.attention_mask}

In [9]:
tokenized_qa_dataset = qa_dataset.map(preprocess_dataset, 
                                    num_proc=16, 
                                    remove_columns=["retrieved_docids"]).with_format("torch")

In [10]:
tokenized_qa_dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'docid', 'document_input_ids', 'document_attention_mask', 'prompter_question_input_ids', 'prompter_question_attention_mask', 'prompter_answer_input_ids', 'prompter_answer_attention_mask', 'language_model_question_input_ids', 'language_model_question_attention_mask', 'language_model_answer_input_ids', 'language_model_answer_attention_mask'],
        num_rows: 55192
    })
    test: Dataset({
        features: ['question', 'answer', 'docid', 'document_input_ids', 'document_attention_mask', 'prompter_question_input_ids', 'prompter_question_attention_mask', 'prompter_answer_input_ids', 'prompter_answer_attention_mask', 'language_model_question_input_ids', 'language_model_question_attention_mask', 'language_model_answer_input_ids', 'language_model_answer_attention_mask'],
        num_rows: 6133
    })
})

### 2.3 init dataloader

In [11]:
from torch.utils.data import DataLoader
from tqdm import tqdm
# Create a DataLoader with the desired batch size
batch_size = 1 # if gradient_accumulation_steps > 1, this is the micro-batch size
gradient_accumulation_steps = 16 # accumulate gradients over n batches

train_dataloader = DataLoader(tokenized_qa_dataset["train"], batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(tokenized_qa_dataset["test"], batch_size=batch_size, shuffle=False)

## 3. train
### 3.1 config optimizer and scheduler

In [12]:
import inspect
from transformers import get_polynomial_decay_schedule_with_warmup
# optimizer config 
learning_rate = 1e-4
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(prompter.parameters(), lr=learning_rate, **extra_args)
print(f"using fused AdamW: {use_fused}")
# scheduler config
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader) // gradient_accumulation_steps
lr_scheduler = get_polynomial_decay_schedule_with_warmup(
    optimizer=optimizer,  # scheduler是针对optimizer的lr的
    lr_end=1e-7,
    power=1, # 当power=1时（默认）等价于linear_schedule_with_warmup
    num_warmup_steps=1000 // gradient_accumulation_steps,
    num_training_steps=num_training_steps)
print(f"num_training_steps: {num_training_steps}")

using fused AdamW: False
num_training_steps: 3449


In [13]:
# just for check lr scheduler, which make scheduler empty, not run this cell when training
# from matplotlib import pyplot as plt
# lst = []
# for _ in range(num_training_steps):
#     lr_scheduler.step()
#     lst.append(lr_scheduler.get_lr())
# plt.plot(lst)

### 3.2 config evaluation

In [14]:
def generate_language_model_labels(soft_prompts, language_model_question_input_ids, language_model_answer_input_ids):
    prompt_labels = torch.zeros(
            soft_prompts.shape[:2], dtype=torch.long, device=device
        ).fill_(-100)
    ques_labels = torch.zeros(
            language_model_question_input_ids.shape[:2], dtype=torch.long, device=device
        ).fill_(-100)
    answer_labels = language_model_answer_input_ids.masked_fill(
            language_model_answer_input_ids == language_model_tokenizer.pad_token_id, -100
        ).to(device)
    language_model_labels = torch.cat([prompt_labels, ques_labels, answer_labels], dim=1)
    return language_model_labels

In [15]:
def model_forward(batch):
    document_input_ids = batch["document_input_ids"].to(device)
    document_attention_mask = batch["document_attention_mask"].to(device)
    prompter_question_input_ids = batch["prompter_question_input_ids"].to(device)
    prompter_question_attention_mask = batch["prompter_question_attention_mask"].to(device)
    prompter_output = prompter(
        document_input_ids=document_input_ids,
        document_attention_mask=document_attention_mask,
        question_input_ids=prompter_question_input_ids,
        question_attention_mask=prompter_question_attention_mask,
    )
    language_model_answer_input_ids = batch["language_model_answer_input_ids"].to(device)
    language_model_ans_embeds = language_model.get_input_embeddings()(language_model_answer_input_ids)
    language_model_question_input_ids = batch["language_model_question_input_ids"].to(device)
    language_model_ques_embeds = language_model.get_input_embeddings()(language_model_question_input_ids)
    language_model_input_embeds = torch.cat([prompter_output.soft_prompts, 
                                             language_model_ques_embeds, language_model_ans_embeds], dim=1)
    language_model_labels = generate_language_model_labels(prompter_output.soft_prompts, 
                                                           language_model_question_input_ids,
                                                           language_model_answer_input_ids)
    language_model_output = language_model(inputs_embeds=language_model_input_embeds, 
                                           labels = language_model_labels)
    return prompter_output, language_model_output

In [16]:
import importlib
import metric
importlib.reload(metric)

# helps estimate the loss of the model
@torch.no_grad()
def evaluate_prompter(prompter, language_model, dataloader, eval_iters=None):
    out = {}
    prompter.eval()
    language_model.eval()
    losses, rank_losses, lm_losses = [], [], []
    generation_metrics = metric.GenerationMetrics(language_model_tokenizer)
    retrieval_metrics = metric.RetrievalMetrics()
    eval_iters = len(dataloader) if eval_iters is None else eval_iters
    step = 0
    for batch in tqdm(dataloader, desc="eval", total=eval_iters):
        prompter_output, language_model_output = model_forward(batch)
        # loss
        loss = (prompter_output.loss + language_model_output.loss).item()
        losses.append(loss)
        rank_losses.append(prompter_output.loss.item())
        lm_losses.append(language_model_output.loss.item())
        # generate metric
        language_model_question_input_ids = batch["language_model_question_input_ids"].to(device)
        language_model_ques_embeds = language_model.get_input_embeddings()(language_model_question_input_ids)
        language_model_input_embeds = torch.cat([prompter_output.soft_prompts, 
                                                 language_model_ques_embeds], dim=1)
        language_model_pred = language_model.generate(inputs_embeds=language_model_input_embeds, max_new_tokens=64)
        generation_metrics.update(language_model_pred.cpu(), batch["language_model_answer_input_ids"])
        # rank metric
        batch_size, num_doc = batch["document_input_ids"].shape[:2]
        rank_preds = prompter_output.logits.cpu().tolist()
        rank_targets = [[True] + [False] * (num_doc - 1) for _ in range(batch_size)]
        retrieval_metrics.update(rank_preds, rank_targets)
        # 
        if step > eval_iters:
            break
        step += 1
        
    out["loss"] = {"val_loss": np.mean(losses), "val_rank_loss": np.mean(rank_losses),
                   "val_lm_loss": np.mean(lm_losses)}
    out["generation"] = generation_metrics.compute()
    out["retrieval"] = retrieval_metrics.compute()
    prompter.train()
    language_model.train()
    return out

In [18]:
model_output_dir = Path("/root/autodl-tmp/saved_model/prompter")
model_output_dir.mkdir(exist_ok=True, parents=True)
log_interval = 1000
eval_iters = 500
total_micro_steps = num_training_steps * gradient_accumulation_steps

In [None]:
from torchmetrics.aggregation import RunningMean
train_loss = RunningMean(window=log_interval).to(device)
train_loss_rank = RunningMean(window=log_interval).to(device)
train_loss_lm = RunningMean(window=log_interval).to(device)

step = 0 # total steps = num_training_steps * gradient_accumulation_steps
best_val_loss = 1e9
for epoch in range(num_epochs):
    # Iterate through batches
    for batch in train_dataloader:
        prompter_output, language_model_output = model_forward(batch)                                  
        loss = (prompter_output.loss + language_model_output.loss) / gradient_accumulation_steps
        loss.backward()
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
        train_loss_rank.update(prompter_output.loss)
        train_loss_lm.update(language_model_output.loss)
        train_loss.update(loss * gradient_accumulation_steps)
        # log the loss on train set
        if step % log_interval == 0:
            print(f"micro step {step}/{total_micro_steps}:",
                  f"loss {train_loss.compute():.4f}",
                  f"rank-loss {train_loss_rank.compute():.4f}",
                  f"lm-loss {train_loss_lm.compute():.4f}",
                  f"lr {lr_scheduler.get_lr()[0]:.7f}", sep="  ")
        step += 1
        
    eval_results = evaluate_prompter(prompter, language_model, test_dataloader, 
                                     eval_iters=eval_iters)
    if eval_results['loss']['val_loss'] < best_val_loss:
        best_val_loss = eval_results['loss']['val_loss']
        checkpoint = {
            'model': prompter.state_dict(),
            'optimizer': optimizer.state_dict(),
            'model_args': prompter_config,
            'iter_num': step,
            'eval_results': eval_results,
            'config': None,
        }
        print(f"saving checkpoint to {model_output_dir}")
        torch.save(checkpoint, model_output_dir / 'best_ckpt.pt')
    
    



micro step 0/55184:  loss 4.5071  rank-loss 0.7571  lm-loss 3.7500  lr 0.0000000
micro step 1000/55184:  loss 2.6397  rank-loss 0.5537  lm-loss 2.0859  lr 0.0001000
micro step 2000/55184:  loss 1.9683  rank-loss 0.2091  lm-loss 1.7592  lr 0.0000981
micro step 3000/55184:  loss 1.8656  rank-loss 0.2013  lm-loss 1.6642  lr 0.0000963
micro step 4000/55184:  loss 1.8976  rank-loss 0.1988  lm-loss 1.6988  lr 0.0000945
micro step 5000/55184:  loss 1.8869  rank-loss 0.1978  lm-loss 1.6891  lr 0.0000926
micro step 6000/55184:  loss 1.8673  rank-loss 0.1985  lm-loss 1.6688  lr 0.0000908
micro step 7000/55184:  loss 1.8672  rank-loss 0.1970  lm-loss 1.6701  lr 0.0000889
micro step 8000/55184:  loss 1.8177  rank-loss 0.1979  lm-loss 1.6198  lr 0.0000871
micro step 9000/55184:  loss 1.8108  rank-loss 0.1966  lm-loss 1.6142  lr 0.0000853
micro step 10000/55184:  loss 1.8606  rank-loss 0.1969  lm-loss 1.6637  lr 0.0000834
micro step 11000/55184:  loss 1.8072  rank-loss 0.1964  lm-loss 1.6108  lr 0.0

eval:  72%|███████▏  | 358/500 [10:46<04:30,  1.90s/it]

In [37]:
i = 0
for batch in test_dataloader:    
    print("q:", batch["question"])
    print("a:", batch["answer"])
    language_model_question_input_ids = batch["language_model_question_input_ids"].to(device)
    language_model_ques_embeds = language_model.get_input_embeddings()(language_model_question_input_ids)
    language_model_input_embeds = torch.cat([prompter_output.soft_prompts, 
                                             language_model_ques_embeds], dim=1)
    pred = language_model.generate(inputs_embeds=language_model_input_embeds, max_new_tokens=128)
    origin_pred = language_model.generate(inputs_embeds=language_model_ques_embeds, max_new_tokens=128)
    print("prompter a:", language_model_tokenizer.batch_decode(pred.cpu(), skip_special_tokens=True))
    print("origin a:", language_model_tokenizer.batch_decode(origin_pred.cpu(), skip_special_tokens=True))
    if i > 5:
        break
    i += 1
    print()

q: ['Marvin Friedrich曾经效力于哪个德甲球队？']
a: ['Marvin Friedrich曾经效力于史浩克04这个德甲球队。']
prompter a: ['Marvin Friedrich曾经效力于德甲球队科隆。他在2009年1月从科隆转会到德甲球队汉堡。他在2010年1月从汉堡转会到德甲球队美因茨。他在2010年12月从美因茨转会到德甲球队科隆。他在2011年1月从科隆转会到德甲球队汉诺威96。他在2011年12月从汉诺威96转会到德甲球队云达不莱梅。他在']
origin a: ['\n拜仁慕尼黑\n\n\n\n 展开全部 拜仁慕尼黑\n\n\n\n 展开全部 拜仁慕尼黑\n\n\n\n 展开全部 拜仁慕尼黑\n\n\n\n 展开全部 拜仁慕尼黑\n\n\n\n 展开全部 拜仁慕尼黑']

q: ['有哪些经典的神学论战案例？']
a: ['马丁·路德在1525年的著作《意志的束缚》，是神学论战的一个典型例子。另一个例子是两册1896年的著作《基督教世界科学与神学论战史》（分上、下卷），由美国外交家，教育家及作家安德鲁·迪克森·怀特所著。']
prompter a: ['经典的神学论战案例包括：亚伯拉罕论战、摩西论战、耶稣论战、保罗论战、奥古斯丁论战、阿奎那论战、加尔文论战、马丁·路德论战、加尔文论战、清教徒论战、加尔文论战、马丁·路德论战、加尔文论战、清教徒论战、加尔文论战、马丁·路德论战、加尔文论战、清教徒论战、加尔文论战']
origin a: ['噢，我明白了，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是想说，你就是']

q: ['西安路有哪些社会事业？']
a: ['西安路有大连市口腔医院、大连机车医院、大连市第八中学、大连市朝鲜族基督教堂等社会事业。']
prompter a: ['西安路有西安路小学、西安路中学、西安路幼儿园、西安路社区卫生服务中心等社会事业。西安路小学是西安路街道辖区内的一所公办小学，西安路中学是西安路街道辖区内的一所公办中学。西安路幼儿园是西安路街道辖区内的一所公办幼

In [None]:
batch["document_input_ids"][0]

In [None]:
language_model_output = language_model(inputs_embeds=language_model_input_embeds, 
                                       labels = language_model_labels)

In [None]:
language_model_tokenizer.batch_decode(torch.argmax(language_model_output.logits, dim=2))

In [None]:
docs = prompter_tokenizer.batch_decode(batch["document_input_ids"][0], skip_special_tokens=True)

In [None]:
prompter_output.logits.topk(1)

In [None]:
docs[8]

In [None]:
docs[0]

In [None]:
import importlib
import metric
importlib.reload(metric)
generation_metrics = metric.GenerationMetrics(language_model_tokenizer)
generation_metrics.update(pred.cpu(), batch["language_model_answer_input_ids"])
generation_metrics_scores = generation_metrics.compute()
generation_metrics_scores

In [None]:
eval_results = evaluate_prompter(prompter, language_model, test_dataloader, eval_iters=100)
print(json.dumps(eval_results, indent=2))

In [30]:
# 训练1轮，每次argmax取出一个soft prompt作为输入，

eval_results = evaluate_prompter(prompter, language_model, test_dataloader, eval_iters=100)
print(json.dumps(eval_results, indent=2))

eval: 101it [03:00,  1.78s/it]                         


{
  "loss": {
    "val_loss": 1.790435386931195,
    "val_rank_loss": 0.19559393574794134,
    "val_lm_loss": 1.5948414522058822
  },
  "generation": {
    "accuracy": 0.0,
    "rouge-1": 40.217145098039204,
    "rouge-2": 20.987956862745097,
    "rouge-l": 28.878924509803916,
    "bleu-4": 14.73178823529412
  },
  "retrieval": {
    "HitRate@1": 0.11764705926179886,
    "HitRate@5": 0.36274510622024536,
    "HitRate@10": 0.6666666865348816,
    "MRR": 0.27186647057533264,
    "MAP@1": 0.11764705926179886,
    "MAP@5": 0.20718954503536224,
    "MAP@10": 0.2503190338611603,
    "NDCG@1": 0.11764705926179886,
    "NDCG@5": 0.24587662518024445,
    "NDCG@10": 0.3467034697532654,
    "Recall@1": 0.11764705926179886,
    "Recall@5": 0.36274510622024536,
    "Recall@10": 0.6666666865348816
  }
}


In [19]:
eval_results

{'loss': {'val_loss': 1.7964331930979769,
  'val_rank_loss': 0.1924001278193841,
  'val_lm_loss': 1.6040330653903574},
 'generation': {'accuracy': 0.0,
  'rouge-1': 9.436148573292026,
  'rouge-2': 0.10847696070438612,
  'rouge-l': 6.463228501548997,
  'bleu-4': 0.6537248165661177},
 'retrieval': {'HitRate@1': 0.1159302145242691,
  'HitRate@5': 0.438284695148468,
  'HitRate@10': 0.6991684436798096,
  'MRR': 0.28068479895591736,
  'MAP@1': 0.1159302145242691,
  'MAP@5': 0.22521333396434784,
  'MAP@10': 0.25961822271347046,
  'NDCG@1': 0.1159302145242691,
  'NDCG@5': 0.277713805437088,
  'NDCG@10': 0.3616619110107422,
  'Recall@1': 0.1159302145242691,
  'Recall@5': 0.438284695148468,
  'Recall@10': 0.6991684436798096}}

In [28]:
# last eval
eval_results

{'loss': {'val_loss': 2.8331266724517135,
  'val_rank_loss': 0.19326534395080475,
  'val_lm_loss': 2.639861329080385},
 'generation': {'accuracy': 0.0,
  'rouge-1': 14.224255013859448,
  'rouge-2': 0.9631314691015815,
  'rouge-l': 9.532860215229089,
  'bleu-4': 1.5044573292026742},
 'retrieval': {'HitRate@1': 0.10092940181493759,
  'HitRate@5': 0.4130115807056427,
  'HitRate@10': 0.6972118020057678,
  'MRR': 0.2618274390697479,
  'MAP@1': 0.10092940181493759,
  'MAP@5': 0.20314691960811615,
  'MAP@10': 0.24042972922325134,
  'NDCG@1': 0.10092940181493759,
  'NDCG@5': 0.25463661551475525,
  'NDCG@10': 0.34589165449142456,
  'Recall@1': 0.10092940181493759,
  'Recall@5': 0.4130115807056427,
  'Recall@10': 0.6972118020057678}}

# last train loss

micro step 50000/55184:  loss 1.7886  rank-loss 0.1931  lm-loss 1.5955  lr 0.0000097
micro step 51000/55184:  loss 1.8216  rank-loss 0.1930  lm-loss 1.6286  lr 0.0000078
micro step 52000/55184:  loss 1.7938  rank-loss 0.1927  lm-loss 1.6011  lr 0.0000060
micro step 53000/55184:  loss 1.7911  rank-loss 0.1918  lm-loss 1.5993  lr 0.0000041
micro step 54000/55184:  loss 1.8032  rank-loss 0.1921  lm-loss 1.6111  lr 0.0000023
micro step 55000/55184:  loss 1.8028  rank-loss 0.1927  lm-loss 1.6101  lr 0.0000005