In [None]:
#%%
import os
os.environ["TMPDIR"] = "./tmp"
os.makedirs("./tmp", exist_ok=True)
import copy
import json
import os
import torch
import logging
import argparse
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor

from tqdm import tqdm
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, Sampler
import wandb
import transformers
from typing import Sequence
import datasets
import shutil
import json
import random


from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel

took_qa_pos = []
took_index = []
sampled_ids = set()

class WeightedRandomSampler(Sampler[int]):
    def __init__(self, weights: Sequence[float], num_samples: int,
                 replacement: bool = False, manual_seed=2147483647) -> None:
        if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
                num_samples <= 0 or num_samples > len(weights):
            raise ValueError("num_samples should be a positive integer "
                             "value less than or equal to len(weights), but got num_samples={}".format(num_samples))
        if not isinstance(replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(replacement))
        global sampled_ids
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples
        self.replacement = False
        self.generator = torch.Generator()
        self.generator.manual_seed(manual_seed)
        self.rand_list = torch.multinomial(self.weights, self.weights.shape[0], self.replacement, generator=self.generator).tolist()
        self.pos = 0
        self.sampled_ids = sampled_ids
        print('weights', list(self.weights))

    def __iter__(self):
        while self.pos < self.num_samples:
            idx = self.rand_list[self.pos]
            self.pos += 1
            self.sampled_ids.add(idx)
            took_qa_pos.append(idx)
            yield idx

    def __len__(self) -> int:
        return self.num_samples

    def update_dynamic_weight(self, new_weights: Sequence[float]):
        if len(new_weights) != len(self.weights):
            raise ValueError("Length of new_weights must match the current weights")

        self.weights = torch.as_tensor(new_weights, dtype=torch.double)

        available_indices = list(set(range(len(self.weights))) - self.sampled_ids)
        available_weights = [self.weights[i] for i in available_indices]

        # Resample taking into account already sampled ids
        new_samples = torch.multinomial(torch.as_tensor(available_weights), len(available_indices), self.replacement, generator=self.generator)
        new_list = [available_indices[i] for i in new_samples.tolist()]
        self.pos = len(self.sampled_ids)
        self.rand_list[self.pos:] = new_list
        assert len(self.rand_list) == len(new_weights)

class HuatuoGPT_data(torch.utils.data.Dataset):
    def __init__(self, config, tokenizer, debug=False):
        self.config = config
        self.tokenizer = tokenizer
        with open(config.data_path) as f:
            self.data_dict = json.load(f)
        self.datacollatorforseq2seq = transformers.DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding=True)
        self.ignore_index = -100
        self.sep = '\n'
        self.sep_ids = self.tokenizer.encode(self.sep,add_special_tokens= False)
        self.roles = ('<问>：','<答>：')
        self.ignore_len = len(self.tokenizer.encode(self.sep + self.roles[1],add_special_tokens= False))
        self.debug = debug

        self.lengths = {k: len(self.data_dict[k]) for k in self.data_dict.keys()}
        self.keys = list(self.data_dict.keys())
        
        # you need to set
        # When you want random sampling, please set the same data priority
        self.data_priority = {'Meidcal_Web_Corpus_en': 32,
                              'Meidcal_Web_Corpus_cn': 32,
                            'Meidcal_Literature_cn': 16,
                            'Meidcal_Literature_en': 16,
                            'Meidcal_Encyclopedia_cn':8,
                            'Meidcal_Encyclopedia_en':8,
                            'Meidcal_Books_cn': 4,
                            'Meidcal_Books_en': 4,
                            'SFT_data': 1}
        
        self.data_epoch = {'Meidcal_Web_Corpus_en': 1,
                              'Meidcal_Web_Corpus_cn': 1,
                            'Meidcal_Literature_cn': 1,
                            'Meidcal_Literature_en': 1,
                            'Meidcal_Encyclopedia_cn': 1,
                            'Meidcal_Encyclopedia_en': 1,
                            'Meidcal_Books_cn': 1,
                            'Meidcal_Books_en': 1,
                            'SFT_data': 3}

        self.weights = []
        self.pos_key = []

        for keyi,key in enumerate(self.keys):
            priority = self.data_priority[key]
            epoch = self.data_epoch[key]
            self.weights += [priority] * int(self.lengths[key]*epoch)
            self.pos_key += [keyi] * int(self.lengths[key]*epoch)

    
    def __getitem__(self, index):
        key = self.keys[self.pos_key[index]]
        sub_index = index % self.lengths[key]
        took_index.append((index, key))
        da = self.preprocess(self.data_dict[key][sub_index])
        da['data_type'] = key
        return da

    def get_data_info(self):
        res = {}
        total = 0
        for k,v in self.data_epoch.items():
            res[k] = self.lengths[k]*v
            total += self.lengths[k]*v
        res['sum'] = total
        return res

    def preprocess(self, data):
        input_ids = []
        labels = []
        if not isinstance(data, list):
            raise ValueError('The data must be a list.')
        for ind, d in enumerate(data):
            if ind % 2 == 1:
                value_ids = self.tokenizer.encode(self.sep + self.roles[1] + d,add_special_tokens= False, max_length=self.config.max_seq_len, truncation=True)
                input_ids += value_ids
                labels += [self.ignore_index] *self.ignore_len + value_ids[self.ignore_len:]
                if len(labels) >= self.config.max_seq_len:
                    if self.debug:
                        print('break max len', len(labels))
                    break
            else:
                pre_str = self.sep if len(input_ids) > 0 else ''
                value_ids = self.tokenizer.encode(pre_str + self.roles[0] + d,add_special_tokens= False, max_length=self.config.max_seq_len, truncation=True)
                input_ids += value_ids

                if len(labels) > 0:
                    labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * (len(value_ids)-1)
                else:
                    labels += [self.ignore_index] * len(value_ids)
        if self.debug and len(data) != 2:
            print('data len more than 2', len(data), len(input_ids))
        input_ids.append(self.tokenizer.eos_token_id)
        labels.append(self.tokenizer.eos_token_id)
        if self.debug:
            print('input_ids',self.tokenizer.decode(input_ids))
            labels = [item if item != self.ignore_index else self.tokenizer.pad_token_id for item in labels]
            # print('labels',self.tokenizer.convert_ids_to_tokens(labels))
            print('labels',self.tokenizer.decode(labels))
            self.debug = False
        if self.debug:
            print('len after preprocess', len(input_ids), len(labels))
        return {'input_ids': input_ids[:self.config.max_seq_len], 'labels': labels[:self.config.max_seq_len]}

    def __len__(self):
        return len(self.weights)

    def sample_num(self):
        return len(self.weights)

    def collate_fn(self, batch):
        return batch


: 

In [2]:
parser = argparse.ArgumentParser(description='Args of Data Preprocess')
debug = False

class Args:
    def __init__(self, data_path="/mnt/c/Users/HOME/Downloads/HuatuoGPT-II/adaption/one_stage_training/train_qa_1p.json",
                 model_path="baichuan-inc/Baichuan2-13B-Base",
                 max_seq_len=4096):
        self.data_path = data_path
        self.model_path = model_path
        self.max_seq_len = max_seq_len

args = Args()

args.train_bsz_per_gpu = 32
args.save_path = '.'.join(os.path.split(args.data_path)[-1].split('.')[:-1])+'_'+os.path.split(args.model_path)[-1]+f'_{args.max_seq_len}_dataset'
print(f'The dataset will save in {args.save_path}')
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = '<PAD>'

#%%
train_dataset = HuatuoGPT_data(args, tokenizer)

sampler = WeightedRandomSampler(train_dataset.weights, num_samples=train_dataset.sample_num(), replacement=False)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_bsz_per_gpu, sampler=sampler, drop_last=False, collate_fn=train_dataset.collate_fn)

train_dataloader_iterator = tqdm(enumerate(train_dataloader))
args.log_step = len(train_dataloader) // 10

from collections import defaultdict
key_nums = defaultdict(int)
args.experiment_name = 'huatuo2_datapre'

wandb.init(project = args.experiment_name, config=args, dir= os.path.join('./train_logs',args.experiment_name))

all_inputs_ids = []
all_labels = []
pad_id = tokenizer.pad_token_id
ignore_index = -100
for batch_cnt, batch in train_dataloader_iterator:
    cur_input = []
    cur_label = []

    for i, da in enumerate(batch):
        key_nums[da['data_type']] += 1
        if len(da['input_ids']) + len(cur_input) <= args.max_seq_len:
            cur_input += da['input_ids']
            cur_label +=  da['labels']
        else:
            if debug:
                print('appended len', len(cur_input), len(cur_label), 'batch', batch_cnt, 'index', i)
            pad_len = args.max_seq_len - len(cur_input)
            cur_input += [pad_id] * pad_len
            cur_label += [ignore_index] * pad_len
            all_inputs_ids.append(cur_input)
            all_labels.append(cur_label)
            cur_input = da['input_ids']
            cur_label =  da['labels']
    if debug:
        print('appended len', len(cur_input), len(cur_label), 'batch', batch_cnt)
    pad_len = args.max_seq_len - len(cur_input)
    cur_input += [pad_id] * pad_len
    cur_label += [ignore_index] * pad_len
    all_inputs_ids.append(cur_input)
    all_labels.append(cur_label)
    assert len(cur_input) == len(cur_label) == args.max_seq_len, f'{len(cur_input)},{len(cur_label)}'

    if batch_cnt % args.log_step == 0:
        logdata = {}
        for key in key_nums:
            logdata[key + '_num'] = key_nums[key]
            if debug:
                print('log', key, key_nums[key])
        wandb.log(logdata)
        key_nums = defaultdict(int)

assert len(all_inputs_ids) == len(all_labels)
print('all_inputs_ids len', len(all_inputs_ids))
save_dataset = datasets.Dataset.from_dict({'input_ids': all_inputs_ids, 'labels':all_labels})
save_dataset.save_to_disk(args.save_path)

table = wandb.Table(columns=["data_priority", "data_epoch","data_num"])
table.add_data(json.dumps(train_dataset.data_priority,ensure_ascii=False,indent=2),json.dumps(train_dataset.data_epoch,ensure_ascii=False,indent=2),json.dumps(train_dataset.get_data_info(),ensure_ascii=False,indent=2))
wandb.log({"data_sample_info": table})
print(json.dumps(train_dataset.data_priority,ensure_ascii=False,indent=2),json.dumps(train_dataset.data_epoch,ensure_ascii=False,indent=2),json.dumps(train_dataset.get_data_info(),ensure_ascii=False,indent=2))

The dataset will save in train_qa_1p_Baichuan2-13B-Base_4096_dataset
weights [tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tensor(32., dtype=torch.float64), tens

0it [00:00, ?it/s][34m[1mwandb[0m: Currently logged in as: [33mlengoctuong23052002[0m ([33mlengoctuong23052002-fpt-retail[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


1786it [01:38, 18.13it/s]


all_inputs_ids len 4938


: 

: 

In [29]:
import numpy as np
for i in took_index:
    print(i)

(205, 'Meidcal_Literature_en')
(143, 'Meidcal_Literature_en')
(182, 'Meidcal_Literature_en')
(28, 'Meidcal_Web_Corpus_en')
(21, 'Meidcal_Web_Corpus_en')
(168, 'Meidcal_Literature_en')
(186, 'Meidcal_Literature_en')
(121, 'Meidcal_Literature_cn')
(38, 'Meidcal_Web_Corpus_en')
(72, 'Meidcal_Web_Corpus_cn')
(98, 'Meidcal_Web_Corpus_cn')
(257, 'Meidcal_Encyclopedia_en')
(76, 'Meidcal_Web_Corpus_cn')
(129, 'Meidcal_Literature_en')
(29, 'Meidcal_Web_Corpus_en')
(68, 'Meidcal_Web_Corpus_cn')
(352, 'Meidcal_Books_cn')
(62, 'Meidcal_Web_Corpus_cn')
(78, 'Meidcal_Web_Corpus_cn')
(55, 'Meidcal_Web_Corpus_cn')
(159, 'Meidcal_Literature_en')
(86, 'Meidcal_Web_Corpus_cn')
(152, 'Meidcal_Literature_en')
(34, 'Meidcal_Web_Corpus_en')
(223, 'Meidcal_Encyclopedia_cn')
(89, 'Meidcal_Web_Corpus_cn')
(50, 'Meidcal_Web_Corpus_cn')
(39, 'Meidcal_Web_Corpus_en')
(170, 'Meidcal_Literature_en')
(266, 'Meidcal_Encyclopedia_en')
(132, 'Meidcal_Literature_en')
(517, 'Meidcal_Books_en')
(94, 'Meidcal_Web_Corpus_cn'

In [30]:
print(tokenizer.decode(all_inputs_ids[1]))

 <问>：APOA5基因与心血管疾病的关联是什么？有何具体机制？
<答>：APOA5基因与心血管疾病的关联是很密切的。APOA5基因位于人类11号染色体上，编码的蛋白质被称为载脂蛋白A-V(apoA-V)，主要在肝脏中表达。apoA-V是血浆甘油三酯水平的重要决定因素，而高血浆甘油三酯水平是冠状动脉疾病的主要风险因素之一。

研究表明，apoA-V通过与LDL-R基因家族受体相互作用，影响脂蛋白代谢。apoA-V是几种脂蛋白分子的组成部分，包括VLDL、HDL和乳糜微粒。APOA5基因与代谢综合征密切相关，由此也可看出其与心血管疾病的关联。

此外，APOA5基因中也包含27个与冠状动脉疾病增加风险有关的SNP（单核苷酸多态性）。

值得一提的是，APOA5基因的发现是通过对人类和小鼠DNA进行比较测序得到的。APOA5基因与其他载脂蛋白基因(APOA1、APOC3、APOA4)位于同一基因簇中，位于人类11号染色体的11q23位置。通过建立两种小鼠模型(APoA5转基因和APoA5敲除)，证实了该基因在血浆甘油三酯水平决定中的重要作用。转基因小鼠的血浆甘油三酯水平较低，而敲除小鼠的血浆甘油三酯水平较高，但两者的血浆胆固醇水平均保持不变。另外，还有一组荷兰研究人员描述了这个相同的基因，将其与肝脏再生的早期阶段相关联，但没有意识到它在血浆甘油三酯水平决定中的重要作用。

蛋白质结构方面，APOA5基因包含4个外显子和3个内含子，基因位于11号染色体的11q23区域，靠近载脂蛋白基因簇。apoA5蛋白属于载脂蛋白A1/A4/E家族，包含2个螺旋卷曲结构域，整体上预测apoA5约有60%的α-螺旋结构。

综上所述，APOA5基因通过调节血浆甘油三酯水平等机制与心血管疾病密切相关。</s> <问>：夏季少吃冷饮以防月经不准
<答>：夏季少吃冷饮确实可以起到一定的作用来预防月经不准。在中医理论中，女性的月经和子宫是寒热之间的平衡状态，而冷饮属于寒性食物，会导致体内寒气增加，从而影响子宫的温度平衡，导致月经不调。因此，夏季少吃冷饮是为了维持体内的热量平衡，有助于月经的规律。

除了少吃冷饮外，还有其他一些方法可以帮助调节月经不准的情况。首先，建议你定期进行适量的运动，例如散步、瑜伽等，有助于提高身体的新陈代谢和血液循环，帮助调节月经。此外，保持良好的生活规律，避免过度劳累和精

In [None]:
index = took_index[0][0]

key = train_dataset.keys[train_dataset.pos_key[index]]
sub_index = index % train_dataset.lengths[key]
da = train_dataset.preprocess(train_dataset.data_dict[key][sub_index])
tokenizer.decode(da['input_ids'])

' <问>：ip给大鼠注射乳清蛋白水解物对其血压和肾脏钠排泄的影响是怎样的？\n<答>：根据目前的研究，我们可以得出关于给大鼠注射乳清蛋白水解物对其血压和肾脏钠排泄的影响的一些结论。这项研究使用具有自发性高血压的大鼠进行实验，并通过腹腔注射乳清蛋白水解物（WPH）来评估其对收缩期动脉血压（SBP）和肾脏钠处理的急性影响。\n\n实验结果显示，在给予0.5 g/kg和1.0 g/kg剂量的WPH后的2小时内，WPH的腹腔注射剂量依赖性降低了高血压大鼠的SBP。与注射0.15 M NaCl（188.5 +/- 9.3 mmHg）相比，注射WPH的SBP分别为176.6 +/- 4.9 mmHg（P = 0.001）和163.8 +/- 5.9 mmHg（P = 0.0018）。\n\n此外，研究还发现，与注射0.15 M NaCl和卡托普利（captopril）治疗的大鼠相比，WPH治疗组的肌酐清除率显著降低（P = 0.0084）。钠的排泄率也明显下降（P = 0.033），而钾的排泄率也显著降低（P = 0.0063）。\n\n需要注意的是，尽管WPH具有体外的血管紧张素转换酶（ACE）抑制活性（IC50 = 0.68 mg/mL），但在注射WPH后，肾小管对钠的重吸收升高，这与ACE抑制的作用相反。\n\n总结来说，该研究结果显示，在给予自发性高血压大鼠注射乳清蛋白水解物后，大鼠的收缩期动脉血压下降，并且肾小管对钠的重吸收增加。</s>'

In [None]:
index = took_index[1][0]

key = train_dataset.keys[train_dataset.pos_key[index]]
sub_index = index % train_dataset.lengths[key]
da = train_dataset.preprocess(train_dataset.data_dict[key][sub_index])
tokenizer.decode(da['input_ids'])

' <问>：什么是自噬，并且它如何影响肿瘤细胞对Apo2L/TRAIL疗法的耐药性？\n<答>：自噬是一种细胞自身进行的清理和再利用过程。在细胞遭受损伤或压力的情况下，自噬可以通过将细胞内部的蛋白质和细胞器降解为基本分子来提供能量和营养。这种过程可以帮助细胞应对压力和维持正常功能。\n\n对于Apo2L/TRAIL疗法，它能够选择性地诱导肿瘤细胞凋亡而不伤害正常细胞，因此备受关注。然而，耐药性是该疗法的一个严重限制因素。研究发现，自噬是在凋亡诱导刺激（包括Apo2L/TRAIL）作用下几乎始终被激活的细胞清理和维持过程。与凋亡不同，自噬可能导致细胞死亡或存活，具体取决于环境因素。\n\n研究已经发现了自噬调控Apo2L/TRAIL诱导凋亡的多种分子机制。此外，自噬是否完成（即完整的自噬通路）可能决定癌细胞的命运，即细胞存活或死亡。因此，针对自噬的治疗策略被视为克服Apo2L/TRAIL耐药性的一种有吸引力的方法。\n\n总结来说，自噬是一种细胞清理和再利用的过程，在Apo2L/TRAIL疗法中起到重要的调节作用。了解自噬与凋亡之间的相互作用及其调控机制，对于预测和克服肿瘤细胞对Apo2L/TRAIL疗法的耐药性具有重要意义。</s>'

In [25]:
logdata

{'SFT_data_num': 2}