In [8]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import sys
from typing import Optional, List, Union, Dict, Tuple
import numpy as np
from dataclasses import dataclass, field
import commentjson

from transformers import (
    HfArgumentParser
)

from datasets import load_dataset
from transformers import BertTokenizerFast, AutoTokenizer
import jittor as jt
from jittor.dataset import Dataset, DataLoader
from model import BertForCL, BertConfig
from dataset import CLDataset
from tool import calc_loss
from tqdm import tqdm
jt.flags.use_cuda = jt.has_cuda

In [9]:

# 读取配置文件
@dataclass
class ModelArguments:
    tokenizer_dir: Optional[str] = field(
        default=None,
        metadata={"help": "The local dir of tokenizer"}
    )
    params_path: Optional[str] = field(
        default=None,
        metadata={"help": "The path of the parameters"}
    )

@dataclass
class DataArguments:
    dataset_path: Optional[str] = field(
        default=None,
        metadata={"help": "The path of the dataset"}
    )
    max_seq_len: int = field(
        default=32,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences logonger"
            "than this will be truncated"
        }
    )
    
@dataclass
class TrainingArguments:
    temperature: int = field(
        default=0.05,
        metadata={"help": "Temperature of Loss function"}
    )
    mode: Optional[str] = field(
        default=None,
        metadata={"help": "The mode of the training. It must be \"supervised\" or \"unsupervised\"."}
    )
    batch_size: int = field(
        default=128,
        metadata={"help": "batch size"}
    )
    epoch: int = field(
        default=1,
        metadata={"help": "epoch"}
    )
    learning_rate: float = field(
        default=1e-5,
        metadata={"help": "learning_rate"}
    )
    def __post_init__(self):
        allowed_mode = ["supervised", "unsupervised"]
        if self.mode not in allowed_mode:
            raise ValueError("mode must be supervised or unsupervised")

parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))

config_path = "../config.jsonc"
with open(config_path, 'r', encoding='utf-8') as f:
    config_data = commentjson.load(f)
model_args, data_args, training_args = parser.parse_dict(config_data)

In [10]:
# 读取数据
if training_args.mode == "supervised":
    dataset = load_dataset("csv", data_files=data_args.dataset_path)
    # print(dataset['train'][0])
elif training_args.mode == "unsupervised":
    dataset = load_dataset("text", data_files=data_args.dataset_path)
    # print(dataset["train"][0])
else:
    raise ValueError("The mode must be \"supervised\" or \"unsupervised\".")



Found cached dataset text (/home/aiuser/.cache/huggingface/datasets/text/default-2b8fc52e575fe36a/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)
100%|██████████| 1/1 [00:00<00:00, 68.89it/s]


In [11]:

# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_dir, local_files_only=True)
# jt.display_memory_info()

training_dataset = CLDataset(dataset, tokenizer, data_args, training_args)
training_dataloader = DataLoader(training_dataset, batch_size=training_args.batch_size)

# jt.display_memory_info()

In [12]:
bert_config = BertConfig()
model = BertForCL(bert_config).cuda()
params_dict = jt.load(model_args.params_path)
params_dict = {k.replace("LayerNorm.gamma", "LayerNorm.weight"): v for k, v in params_dict.items()}
params_dict = {k.replace("LayerNorm.beta", "LayerNorm.bias"): v for k, v in params_dict.items()}
model.load_state_dict(params_dict)
params_dict = None  # 不加这行代码会导致内存泄漏，没有明白原因，可能和全局变量引用导致计算图没有释放有关

In [13]:
optimizer = jt.optim.AdamW(model.parameters(), lr=1e-5)
model.train()

min_loss = float("inf")
best_sd = {}

total_len = len(training_dataset) // training_args.batch_size
for i in range(training_args.epoch):
    for batch_idx, (y1, y2, y3) in enumerate(tqdm(training_dataloader, total=total_len)):
        input_ids1 = y1["input_ids"].squeeze(1)
        token_type_ids1 = y1["token_type_ids"].squeeze(1)
        attention_mask1 = y1["attention_mask"].squeeze(1)
        input_ids2 = y2["input_ids"].squeeze(1)
        token_type_ids2 = y2["token_type_ids"].squeeze(1)
        attention_mask2 = y2["attention_mask"].squeeze(1)

        if training_args.mode == "supervised":
            input_ids3 = y3["input_ids"].squeeze(1)
            token_type_ids3 = y3["token_type_ids"].squeeze(1)
            attention_mask3 = y3["attention_mask"].squeeze(1)

            _, z1 = model(input_ids1, token_type_ids1, attention_mask1)
            _, z2 = model(input_ids2, token_type_ids2, attention_mask2)
            _, z3 = model(input_ids3, token_type_ids3, attention_mask3)

            loss = calc_loss(training_args, z1, z2, z3)
            optimizer.step(loss)
            jt.sync_all()
            jt.gc()

            loss = loss.detach().item()

            if loss < min_loss:
                min_loss = loss
                best_sd = {k: v.clone() for k, v in model.state_dict().items()}
            
        elif training_args.mode == "unsupervised":
            _, z1 = model(input_ids1, token_type_ids1, attention_mask1)
            _, z2 = model(input_ids2, token_type_ids2, attention_mask2)

            loss = calc_loss(training_args, z1, z2)
            optimizer.step(loss)
            jt.sync_all()
            jt.gc()
            
            loss = loss.detach().item()

            if loss < min_loss:
                min_loss = loss
                best_sd = {k: v.clone() for k, v in model.state_dict().items()}
        else:
            raise ValueError("The mode must be \"supervised\" or \"unsupervised\".")

        if batch_idx % 200 == 0:
            print(f"Epoch {i}, batch_idx {batch_idx}, loss={loss}")
    print(batch_idx)

  0%|          | 1/7812 [00:00<1:25:26,  1.52it/s]

Epoch 0, batch_idx 0, loss=0.8145176768302917


  3%|▎         | 201/7812 [01:48<1:08:20,  1.86it/s]

Epoch 0, batch_idx 200, loss=0.7649548649787903


  5%|▌         | 401/7812 [03:36<1:08:19,  1.81it/s]

Epoch 0, batch_idx 400, loss=1.4548181295394897


  8%|▊         | 601/7812 [05:24<1:04:45,  1.86it/s]

Epoch 0, batch_idx 600, loss=0.7639535665512085


 10%|█         | 801/7812 [07:13<1:03:51,  1.83it/s]

Epoch 0, batch_idx 800, loss=0.7581392526626587


 13%|█▎        | 1001/7812 [09:01<1:00:44,  1.87it/s]

Epoch 0, batch_idx 1000, loss=0.7716906070709229


 15%|█▌        | 1201/7812 [10:48<58:46,  1.87it/s]  

Epoch 0, batch_idx 1200, loss=0.7451072335243225


 18%|█▊        | 1401/7812 [12:35<58:31,  1.83it/s]  

Epoch 0, batch_idx 1400, loss=0.7633715867996216


 20%|██        | 1601/7812 [14:22<54:42,  1.89it/s]

Epoch 0, batch_idx 1600, loss=0.7469677925109863


 23%|██▎       | 1801/7812 [16:09<53:57,  1.86it/s]  

Epoch 0, batch_idx 1800, loss=0.7766804695129395


 26%|██▌       | 2001/7812 [17:57<51:53,  1.87it/s]

Epoch 0, batch_idx 2000, loss=0.7625852823257446


 28%|██▊       | 2201/7812 [19:44<49:35,  1.89it/s]

Epoch 0, batch_idx 2200, loss=0.8569320440292358


 31%|███       | 2401/7812 [21:31<47:42,  1.89it/s]

Epoch 0, batch_idx 2400, loss=0.7979599237442017


 33%|███▎      | 2601/7812 [23:17<46:30,  1.87it/s]

Epoch 0, batch_idx 2600, loss=0.7603951692581177


 36%|███▌      | 2801/7812 [25:05<45:55,  1.82it/s]

Epoch 0, batch_idx 2800, loss=0.739495038986206


 38%|███▊      | 3001/7812 [26:52<43:03,  1.86it/s]

Epoch 0, batch_idx 3000, loss=0.849277675151825


 41%|████      | 3201/7812 [28:40<40:56,  1.88it/s]

Epoch 0, batch_idx 3200, loss=0.8493638634681702


 44%|████▎     | 3401/7812 [30:27<39:05,  1.88it/s]

Epoch 0, batch_idx 3400, loss=0.8626919984817505


 46%|████▌     | 3601/7812 [32:14<38:01,  1.85it/s]

Epoch 0, batch_idx 3600, loss=0.743420422077179


 49%|████▊     | 3801/7812 [34:02<35:33,  1.88it/s]

Epoch 0, batch_idx 3800, loss=0.7405920624732971


 51%|█████     | 4001/7812 [35:49<34:21,  1.85it/s]

Epoch 0, batch_idx 4000, loss=0.7368035912513733


 54%|█████▍    | 4201/7812 [37:38<32:24,  1.86it/s]

Epoch 0, batch_idx 4200, loss=0.7328453063964844


 56%|█████▋    | 4401/7812 [39:25<30:24,  1.87it/s]

Epoch 0, batch_idx 4400, loss=0.7633770108222961


 59%|█████▉    | 4601/7812 [41:13<28:54,  1.85it/s]

Epoch 0, batch_idx 4600, loss=0.7377058267593384


 61%|██████▏   | 4801/7812 [43:01<26:47,  1.87it/s]

Epoch 0, batch_idx 4800, loss=0.7838277816772461


 64%|██████▍   | 5001/7812 [44:49<25:07,  1.87it/s]

Epoch 0, batch_idx 5000, loss=0.7585068941116333


 67%|██████▋   | 5201/7812 [46:37<25:00,  1.74it/s]

Epoch 0, batch_idx 5200, loss=0.7358042597770691


 69%|██████▉   | 5401/7812 [48:24<21:21,  1.88it/s]

Epoch 0, batch_idx 5400, loss=0.7381942868232727


 72%|███████▏  | 5601/7812 [50:12<19:35,  1.88it/s]

Epoch 0, batch_idx 5600, loss=0.7285133004188538


 74%|███████▍  | 5801/7812 [51:59<17:54,  1.87it/s]

Epoch 0, batch_idx 5800, loss=0.7320502400398254


 77%|███████▋  | 6001/7812 [53:46<16:07,  1.87it/s]

Epoch 0, batch_idx 6000, loss=0.7373573780059814


 79%|███████▉  | 6201/7812 [55:33<14:18,  1.88it/s]

Epoch 0, batch_idx 6200, loss=0.7347224354743958


 82%|████████▏ | 6401/7812 [57:22<12:30,  1.88it/s]

Epoch 0, batch_idx 6400, loss=1.3077831268310547


 84%|████████▍ | 6601/7812 [59:11<11:09,  1.81it/s]

Epoch 0, batch_idx 6600, loss=0.7240882515907288


 87%|████████▋ | 6801/7812 [1:00:58<09:01,  1.87it/s]

Epoch 0, batch_idx 6800, loss=0.7277665138244629


 90%|████████▉ | 7001/7812 [1:02:46<07:21,  1.84it/s]

Epoch 0, batch_idx 7000, loss=0.7971813678741455


 92%|█████████▏| 7201/7812 [1:04:34<05:26,  1.87it/s]

Epoch 0, batch_idx 7200, loss=0.7404295206069946


 95%|█████████▍| 7401/7812 [1:06:21<03:40,  1.86it/s]

Epoch 0, batch_idx 7400, loss=0.7279433012008667


 97%|█████████▋| 7601/7812 [1:08:07<01:54,  1.85it/s]

Epoch 0, batch_idx 7600, loss=0.731085479259491


100%|█████████▉| 7801/7812 [1:09:55<00:05,  1.89it/s]

Epoch 0, batch_idx 7800, loss=0.7291449308395386


7813it [1:10:01,  1.86it/s]                          

7812





In [14]:
print(f"Final Loss: {min_loss}")
jt.save(best_sd,"../ckpt/my_bert/my_bert2.bin")

Final Loss: 0.7179750204086304
