# 文本相似度实例（单模型单塔模型）模型分别编码两句话再进行分类-交互式

## Step1 导入相关包

In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import torch

  from .autonotebook import tqdm as notebook_tqdm


## Step2 加载数据集

In [2]:

dataset = load_dataset('json', data_files='./data/train_pair_1w.json', split="train") # 如果是加载固定的json文件则用load_dataset
# dataset = DatasetDict.load_from_disk('./data') # 加载的是huggingface的数据集
dataset

Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 10000
})

## Step3 划分数据集

In [3]:
datasets = dataset.train_test_split(test_size=0.2)
datasets

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 2000
    })
})

In [4]:
datasets['train'][0]

{'sentence1': '我？噢，只要他签一个字，我可以给他垫付六百万。',
 'sentence2': '那人问，—代达勒斯拿着，壮鹿马利根说。',
 'label': '0'}

## Step4 数据集预处理

In [5]:
import torch

tokenizer = AutoTokenizer.from_pretrained("D:/pretrained_model/models--hfl--chinese-macbert-base")

def process_function(examples):
    sentences = []
    labels = []
    for sen1, sen2, label in zip(examples["sentence1"], examples["sentence2"], examples["label"]):
        sentences.append(sen1)
        sentences.append(sen2)
        labels.append(1 if int(label) == 1 else 0)
    
    tokenizer_examples = tokenizer(sentences, max_length=250, truncation=True, padding="max_length")
    tokenizer_examples = {k: [v[i : i + 2] for i in range(0, len(v), 2)] for k, v in tokenizer_examples.items()}
    tokenizer_examples['labels'] = labels
    return tokenizer_examples

tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets['train'].column_names)
tokenized_datasets

Map: 100%|██████████| 8000/8000 [00:02<00:00, 2729.11 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 3091.78 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 2000
    })
})

## Step5 创建模型

In [6]:

from transformers import BertForSequenceClassification, BertPreTrainedModel, BertModel
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
from torch.nn import CosineSimilarity, CosineEmbeddingLoss

class DualModel(BertPreTrainedModel):

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = BertModel(config)
        self.post_init()
    
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        senA_input_ids, senB_input_ids = input_ids[:, 0], input_ids[:, 1]
        senA_attention_mask, senB_attention_mask = attention_mask[:, 0], attention_mask[:, 1]
        senA_token_type_ids, senB_token_type_ids = token_type_ids[:, 0], token_type_ids[:, 1]

        senA_outputs = self.bert(
            senA_input_ids,
            attention_mask=senA_attention_mask,
            token_type_ids=senA_token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

        senA_pooled_output = senA_outputs[1]

        senB_outputs = self.bert(
            senB_input_ids,
            attention_mask=senB_attention_mask,
            token_type_ids=senB_token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

        senB_pooled_output = senB_outputs[1]

        cos = CosineSimilarity()(senA_pooled_output, senB_pooled_output)

        loss = None
        if labels is not None:
            loss_fct = CosineEmbeddingLoss(0.3)
            loss = loss_fct(senA_pooled_output, senB_pooled_output, labels)
        output = (cos,)
        return ((loss,) + output) if loss is not None else output

model = DualModel.from_pretrained('D:/pretrained_model/models--hfl--chinese-macbert-base')


  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at D:/pretrained_model/models--hfl--chinese-macbert-base were not used when initializing DualModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing DualModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DualModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Step6 创建评估函数

In [7]:
import evaluate

acc_metric = evaluate.load("./metric_accuracy.py")
f1_metirc = evaluate.load("./metric_f1.py")

In [8]:
def eval_metric(eval_predict):
    predictions, labels = eval_predict
    predictions = [int(p > 0.5) for p in predictions]
    labels = [int(l) for l in labels]
    # predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metirc.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc

## Step7 创建TrainingArguments

In [9]:
train_args = TrainingArguments(output_dir="./dual_model",      # 输出文件夹
                               per_device_train_batch_size=4,  # 训练时的batch_size
                               per_device_eval_batch_size=32,   # 验证时的batch_size
                               logging_steps=10,                # log 打印的频率
                               evaluation_strategy="epoch",           # 评估策略
                               save_strategy="epoch",           # 保存策略
                               save_total_limit=3,              # 最大保存数
                               learning_rate=2e-5,              # 学习率
                               weight_decay=0.01,               # weight_decay
                               metric_for_best_model="f1",      # 设定评估指标
                               load_best_model_at_end=True,
                               max_steps=1000
                               )     # 训练完成后加载最优模型

## Step8 创建Trainer

In [10]:
trainer = Trainer(model=model, 
                  args=train_args, 
                  tokenizer=tokenizer,
                  train_dataset=tokenized_datasets["train"], 
                  eval_dataset=tokenized_datasets["test"], 
                  compute_metrics=eval_metric)

## Step9 模型训练

In [11]:
trainer.train()

  0%|          | 0/1000 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  1%|          | 10/1000 [00:03<05:19,  3.10it/s]

{'loss': 0.0142, 'learning_rate': 1.98e-05, 'epoch': 0.01}


  2%|▏         | 20/1000 [00:07<05:12,  3.14it/s]

{'loss': 0.008, 'learning_rate': 1.9600000000000002e-05, 'epoch': 0.01}


  3%|▎         | 30/1000 [00:10<05:04,  3.18it/s]

{'loss': 0.004, 'learning_rate': 1.94e-05, 'epoch': 0.01}


  4%|▍         | 40/1000 [00:13<05:01,  3.18it/s]

{'loss': 0.0017, 'learning_rate': 1.9200000000000003e-05, 'epoch': 0.02}


  5%|▌         | 50/1000 [00:16<05:00,  3.16it/s]

{'loss': 0.0015, 'learning_rate': 1.9e-05, 'epoch': 0.03}


  6%|▌         | 60/1000 [00:19<04:54,  3.19it/s]

{'loss': 0.0009, 'learning_rate': 1.88e-05, 'epoch': 0.03}


  7%|▋         | 70/1000 [00:22<04:53,  3.17it/s]

{'loss': 0.0007, 'learning_rate': 1.86e-05, 'epoch': 0.04}


  8%|▊         | 80/1000 [00:26<04:56,  3.10it/s]

{'loss': 0.001, 'learning_rate': 1.8400000000000003e-05, 'epoch': 0.04}


  9%|▉         | 90/1000 [00:29<05:04,  2.99it/s]

{'loss': 0.001, 'learning_rate': 1.8200000000000002e-05, 'epoch': 0.04}


 10%|█         | 100/1000 [00:32<04:48,  3.12it/s]

{'loss': 0.0006, 'learning_rate': 1.8e-05, 'epoch': 0.05}


 11%|█         | 110/1000 [00:35<04:40,  3.17it/s]

{'loss': 0.0006, 'learning_rate': 1.7800000000000002e-05, 'epoch': 0.06}


 12%|█▏        | 120/1000 [00:38<04:37,  3.17it/s]

{'loss': 0.0008, 'learning_rate': 1.76e-05, 'epoch': 0.06}


 13%|█▎        | 130/1000 [00:42<04:41,  3.09it/s]

{'loss': 0.0011, 'learning_rate': 1.7400000000000003e-05, 'epoch': 0.07}


 14%|█▍        | 140/1000 [00:45<04:31,  3.17it/s]

{'loss': 0.0005, 'learning_rate': 1.72e-05, 'epoch': 0.07}


 15%|█▌        | 150/1000 [00:48<04:28,  3.17it/s]

{'loss': 0.0005, 'learning_rate': 1.7e-05, 'epoch': 0.07}


 16%|█▌        | 160/1000 [00:51<04:22,  3.20it/s]

{'loss': 0.0006, 'learning_rate': 1.6800000000000002e-05, 'epoch': 0.08}


 17%|█▋        | 170/1000 [00:54<04:19,  3.20it/s]

{'loss': 0.0005, 'learning_rate': 1.66e-05, 'epoch': 0.09}


 18%|█▊        | 180/1000 [00:58<04:19,  3.16it/s]

{'loss': 0.0004, 'learning_rate': 1.64e-05, 'epoch': 0.09}


 19%|█▉        | 190/1000 [01:01<04:13,  3.20it/s]

{'loss': 0.0005, 'learning_rate': 1.62e-05, 'epoch': 0.1}


 20%|██        | 200/1000 [01:04<04:12,  3.17it/s]

{'loss': 0.0004, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.1}


 21%|██        | 210/1000 [01:07<04:09,  3.16it/s]

{'loss': 0.0003, 'learning_rate': 1.58e-05, 'epoch': 0.1}


 22%|██▏       | 220/1000 [01:10<04:06,  3.17it/s]

{'loss': 0.0002, 'learning_rate': 1.5600000000000003e-05, 'epoch': 0.11}


 23%|██▎       | 230/1000 [01:13<04:03,  3.17it/s]

{'loss': 0.0003, 'learning_rate': 1.54e-05, 'epoch': 0.12}


 24%|██▍       | 240/1000 [01:16<04:01,  3.15it/s]

{'loss': 0.0003, 'learning_rate': 1.5200000000000002e-05, 'epoch': 0.12}


 25%|██▌       | 250/1000 [01:20<03:56,  3.18it/s]

{'loss': 0.0003, 'learning_rate': 1.5000000000000002e-05, 'epoch': 0.12}


 26%|██▌       | 260/1000 [01:23<03:52,  3.18it/s]

{'loss': 0.0002, 'learning_rate': 1.48e-05, 'epoch': 0.13}


 27%|██▋       | 270/1000 [01:26<03:51,  3.16it/s]

{'loss': 0.0002, 'learning_rate': 1.46e-05, 'epoch': 0.14}


 28%|██▊       | 280/1000 [01:29<03:46,  3.18it/s]

{'loss': 0.0002, 'learning_rate': 1.4400000000000001e-05, 'epoch': 0.14}


 29%|██▉       | 290/1000 [01:32<03:41,  3.20it/s]

{'loss': 0.0002, 'learning_rate': 1.4200000000000001e-05, 'epoch': 0.14}


 30%|███       | 300/1000 [01:35<03:40,  3.18it/s]

{'loss': 0.0002, 'learning_rate': 1.4e-05, 'epoch': 0.15}


 31%|███       | 310/1000 [01:39<03:41,  3.12it/s]

{'loss': 0.0002, 'learning_rate': 1.38e-05, 'epoch': 0.15}


 32%|███▏      | 320/1000 [01:42<03:33,  3.18it/s]

{'loss': 0.0003, 'learning_rate': 1.3600000000000002e-05, 'epoch': 0.16}


 33%|███▎      | 330/1000 [01:45<03:36,  3.09it/s]

{'loss': 0.0002, 'learning_rate': 1.3400000000000002e-05, 'epoch': 0.17}


 34%|███▍      | 340/1000 [01:48<03:28,  3.17it/s]

{'loss': 0.0002, 'learning_rate': 1.3200000000000002e-05, 'epoch': 0.17}


 35%|███▌      | 350/1000 [01:51<03:23,  3.19it/s]

{'loss': 0.0002, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.17}


 36%|███▌      | 360/1000 [01:54<03:20,  3.19it/s]

{'loss': 0.0001, 'learning_rate': 1.2800000000000001e-05, 'epoch': 0.18}


 37%|███▋      | 370/1000 [01:58<03:17,  3.19it/s]

{'loss': 0.0001, 'learning_rate': 1.2600000000000001e-05, 'epoch': 0.18}


 38%|███▊      | 380/1000 [02:01<03:18,  3.12it/s]

{'loss': 0.0001, 'learning_rate': 1.2400000000000002e-05, 'epoch': 0.19}


 39%|███▉      | 390/1000 [02:04<03:10,  3.19it/s]

{'loss': 0.0001, 'learning_rate': 1.22e-05, 'epoch': 0.2}


 40%|████      | 400/1000 [02:07<03:07,  3.20it/s]

{'loss': 0.0001, 'learning_rate': 1.2e-05, 'epoch': 0.2}


 41%|████      | 410/1000 [02:10<03:05,  3.18it/s]

{'loss': 0.0001, 'learning_rate': 1.18e-05, 'epoch': 0.2}


 42%|████▏     | 420/1000 [02:13<03:05,  3.13it/s]

{'loss': 0.0001, 'learning_rate': 1.16e-05, 'epoch': 0.21}


 43%|████▎     | 430/1000 [02:17<02:59,  3.18it/s]

{'loss': 0.0001, 'learning_rate': 1.14e-05, 'epoch': 0.21}


 44%|████▍     | 440/1000 [02:20<02:55,  3.20it/s]

{'loss': 0.0001, 'learning_rate': 1.1200000000000001e-05, 'epoch': 0.22}


 45%|████▌     | 450/1000 [02:23<02:53,  3.18it/s]

{'loss': 0.0001, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.23}


 46%|████▌     | 460/1000 [02:26<02:50,  3.16it/s]

{'loss': 0.0001, 'learning_rate': 1.0800000000000002e-05, 'epoch': 0.23}


 47%|████▋     | 470/1000 [02:29<02:45,  3.20it/s]

{'loss': 0.0001, 'learning_rate': 1.0600000000000002e-05, 'epoch': 0.23}


 48%|████▊     | 480/1000 [02:32<02:42,  3.20it/s]

{'loss': 0.0001, 'learning_rate': 1.04e-05, 'epoch': 0.24}


 49%|████▉     | 490/1000 [02:36<02:42,  3.14it/s]

{'loss': 0.0001, 'learning_rate': 1.02e-05, 'epoch': 0.24}


 50%|█████     | 500/1000 [02:39<02:38,  3.16it/s]

{'loss': 0.0001, 'learning_rate': 1e-05, 'epoch': 0.25}


 51%|█████     | 510/1000 [02:42<02:35,  3.16it/s]

{'loss': 0.0001, 'learning_rate': 9.800000000000001e-06, 'epoch': 0.26}


 52%|█████▏    | 520/1000 [02:45<02:32,  3.15it/s]

{'loss': 0.0001, 'learning_rate': 9.600000000000001e-06, 'epoch': 0.26}


 53%|█████▎    | 530/1000 [02:48<02:27,  3.19it/s]

{'loss': 0.0001, 'learning_rate': 9.4e-06, 'epoch': 0.27}


 54%|█████▍    | 540/1000 [02:52<02:24,  3.18it/s]

{'loss': 0.0001, 'learning_rate': 9.200000000000002e-06, 'epoch': 0.27}


 55%|█████▌    | 550/1000 [02:55<02:20,  3.20it/s]

{'loss': 0.0001, 'learning_rate': 9e-06, 'epoch': 0.28}


 56%|█████▌    | 560/1000 [02:58<02:18,  3.18it/s]

{'loss': 0.0001, 'learning_rate': 8.8e-06, 'epoch': 0.28}


 57%|█████▋    | 570/1000 [03:01<02:20,  3.07it/s]

{'loss': 0.0001, 'learning_rate': 8.6e-06, 'epoch': 0.28}


 58%|█████▊    | 580/1000 [03:04<02:11,  3.21it/s]

{'loss': 0.0001, 'learning_rate': 8.400000000000001e-06, 'epoch': 0.29}


 59%|█████▉    | 590/1000 [03:07<02:08,  3.20it/s]

{'loss': 0.0001, 'learning_rate': 8.2e-06, 'epoch': 0.29}


 60%|██████    | 600/1000 [03:11<02:07,  3.13it/s]

{'loss': 0.0, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.3}


 61%|██████    | 610/1000 [03:14<02:06,  3.09it/s]

{'loss': 0.0001, 'learning_rate': 7.800000000000002e-06, 'epoch': 0.3}


 62%|██████▏   | 620/1000 [03:17<01:59,  3.18it/s]

{'loss': 0.0001, 'learning_rate': 7.600000000000001e-06, 'epoch': 0.31}


 63%|██████▎   | 630/1000 [03:20<01:55,  3.20it/s]

{'loss': 0.0001, 'learning_rate': 7.4e-06, 'epoch': 0.32}


 64%|██████▍   | 640/1000 [03:23<01:54,  3.13it/s]

{'loss': 0.0001, 'learning_rate': 7.2000000000000005e-06, 'epoch': 0.32}


 65%|██████▌   | 650/1000 [03:27<01:50,  3.17it/s]

{'loss': 0.0001, 'learning_rate': 7e-06, 'epoch': 0.33}


 66%|██████▌   | 660/1000 [03:30<01:50,  3.08it/s]

{'loss': 0.0001, 'learning_rate': 6.800000000000001e-06, 'epoch': 0.33}


 67%|██████▋   | 670/1000 [03:33<01:44,  3.16it/s]

{'loss': 0.0001, 'learning_rate': 6.600000000000001e-06, 'epoch': 0.34}


 68%|██████▊   | 680/1000 [03:36<01:40,  3.18it/s]

{'loss': 0.0001, 'learning_rate': 6.4000000000000006e-06, 'epoch': 0.34}


 69%|██████▉   | 690/1000 [03:39<01:37,  3.16it/s]

{'loss': 0.0001, 'learning_rate': 6.200000000000001e-06, 'epoch': 0.34}


 70%|███████   | 700/1000 [03:43<01:37,  3.07it/s]

{'loss': 0.0001, 'learning_rate': 6e-06, 'epoch': 0.35}


 71%|███████   | 710/1000 [03:46<01:32,  3.13it/s]

{'loss': 0.0001, 'learning_rate': 5.8e-06, 'epoch': 0.35}


 72%|███████▏  | 720/1000 [03:49<01:32,  3.03it/s]

{'loss': 0.0001, 'learning_rate': 5.600000000000001e-06, 'epoch': 0.36}


 73%|███████▎  | 730/1000 [03:52<01:25,  3.17it/s]

{'loss': 0.0001, 'learning_rate': 5.400000000000001e-06, 'epoch': 0.36}


 74%|███████▍  | 740/1000 [03:55<01:23,  3.11it/s]

{'loss': 0.0001, 'learning_rate': 5.2e-06, 'epoch': 0.37}


 75%|███████▌  | 750/1000 [03:59<01:23,  3.00it/s]

{'loss': 0.0001, 'learning_rate': 5e-06, 'epoch': 0.38}


 76%|███████▌  | 760/1000 [04:02<01:19,  3.01it/s]

{'loss': 0.0001, 'learning_rate': 4.800000000000001e-06, 'epoch': 0.38}


 77%|███████▋  | 770/1000 [04:05<01:13,  3.12it/s]

{'loss': 0.0001, 'learning_rate': 4.600000000000001e-06, 'epoch': 0.39}


 78%|███████▊  | 780/1000 [04:08<01:09,  3.18it/s]

{'loss': 0.0001, 'learning_rate': 4.4e-06, 'epoch': 0.39}


 79%|███████▉  | 790/1000 [04:12<01:05,  3.20it/s]

{'loss': 0.0001, 'learning_rate': 4.2000000000000004e-06, 'epoch': 0.4}


 80%|████████  | 800/1000 [04:15<00:59,  3.35it/s]

{'loss': 0.0001, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.4}


 81%|████████  | 810/1000 [04:18<00:54,  3.51it/s]

{'loss': 0.0001, 'learning_rate': 3.8000000000000005e-06, 'epoch': 0.41}


 82%|████████▏ | 820/1000 [04:20<00:51,  3.48it/s]

{'loss': 0.0001, 'learning_rate': 3.6000000000000003e-06, 'epoch': 0.41}


 83%|████████▎ | 830/1000 [04:23<00:48,  3.53it/s]

{'loss': 0.0001, 'learning_rate': 3.4000000000000005e-06, 'epoch': 0.41}


 84%|████████▍ | 840/1000 [04:26<00:45,  3.49it/s]

{'loss': 0.0001, 'learning_rate': 3.2000000000000003e-06, 'epoch': 0.42}


 85%|████████▌ | 850/1000 [04:29<00:42,  3.52it/s]

{'loss': 0.0, 'learning_rate': 3e-06, 'epoch': 0.42}


 86%|████████▌ | 860/1000 [04:32<00:39,  3.52it/s]

{'loss': 0.0001, 'learning_rate': 2.8000000000000003e-06, 'epoch': 0.43}


 87%|████████▋ | 870/1000 [04:35<00:36,  3.53it/s]

{'loss': 0.0, 'learning_rate': 2.6e-06, 'epoch': 0.43}


 88%|████████▊ | 880/1000 [04:38<00:33,  3.54it/s]

{'loss': 0.0, 'learning_rate': 2.4000000000000003e-06, 'epoch': 0.44}


 89%|████████▉ | 890/1000 [04:40<00:31,  3.53it/s]

{'loss': 0.0001, 'learning_rate': 2.2e-06, 'epoch': 0.45}


 90%|█████████ | 900/1000 [04:43<00:28,  3.56it/s]

{'loss': 0.0001, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.45}


 91%|█████████ | 910/1000 [04:46<00:25,  3.54it/s]

{'loss': 0.0001, 'learning_rate': 1.8000000000000001e-06, 'epoch': 0.46}


 92%|█████████▏| 920/1000 [04:49<00:22,  3.52it/s]

{'loss': 0.0, 'learning_rate': 1.6000000000000001e-06, 'epoch': 0.46}


 93%|█████████▎| 930/1000 [04:52<00:19,  3.52it/s]

{'loss': 0.0001, 'learning_rate': 1.4000000000000001e-06, 'epoch': 0.47}


 94%|█████████▍| 940/1000 [04:55<00:16,  3.56it/s]

{'loss': 0.0001, 'learning_rate': 1.2000000000000002e-06, 'epoch': 0.47}


 95%|█████████▌| 950/1000 [04:57<00:14,  3.56it/s]

{'loss': 0.0, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.47}


 96%|█████████▌| 960/1000 [05:00<00:11,  3.54it/s]

{'loss': 0.0001, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.48}


 97%|█████████▋| 970/1000 [05:03<00:08,  3.55it/s]

{'loss': 0.0, 'learning_rate': 6.000000000000001e-07, 'epoch': 0.48}


 98%|█████████▊| 980/1000 [05:06<00:05,  3.54it/s]

{'loss': 0.0, 'learning_rate': 4.0000000000000003e-07, 'epoch': 0.49}


 99%|█████████▉| 990/1000 [05:09<00:02,  3.54it/s]

{'loss': 0.0001, 'learning_rate': 2.0000000000000002e-07, 'epoch': 0.49}


100%|██████████| 1000/1000 [05:12<00:00,  3.55it/s]

{'loss': 0.0, 'learning_rate': 0.0, 'epoch': 0.5}


                                                   
100%|██████████| 1000/1000 [05:52<00:00,  3.55it/s]

{'eval_loss': 2.6713312308856985e-06, 'eval_accuracy': 0.4035, 'eval_f1': 0.5749910936943355, 'eval_runtime': 40.6989, 'eval_samples_per_second': 49.141, 'eval_steps_per_second': 1.548, 'epoch': 0.5}


  state_dict = torch.load(best_model_path, map_location="cpu")
100%|██████████| 1000/1000 [05:55<00:00,  2.82it/s]

{'train_runtime': 355.209, 'train_samples_per_second': 11.261, 'train_steps_per_second': 2.815, 'train_loss': 0.0004791009426116943, 'epoch': 0.5}





TrainOutput(global_step=1000, training_loss=0.0004791009426116943, metrics={'train_runtime': 355.209, 'train_samples_per_second': 11.261, 'train_steps_per_second': 2.815, 'train_loss': 0.0004791009426116943, 'epoch': 0.5})

## Step10 模型评估

## Step11 模型预测

In [15]:
class SentenceSimilarityPipeline:

    def __init__(self, model, tokenizer) -> None:
        self.model = model.bert
        self.tokenizer = tokenizer
        self.device = model.device

    def preprocess(self, senA, senB):
        return self.tokenizer([senA, senB], max_length=128, truncation=True, return_tensors="pt", padding=True)

    def predict(self, inputs):
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        return self.model(**inputs)[1]  # [2, 768]

    def postprocess(self, logits):
        cos = CosineSimilarity()(logits[None, 0, :], logits[None,1, :]).squeeze().cpu().item()
        return cos

    def __call__(self, senA, senB, return_vector=False):
        inputs = self.preprocess(senA, senB)
        logits = self.predict(inputs)
        result = self.postprocess(logits)
        if return_vector:
            return result, logits
        else:
            return result

In [16]:
pipe = SentenceSimilarityPipeline(model, tokenizer)

In [17]:
pipe("我喜欢北京", "明天不行", return_vector=True)

(0.9999879598617554,
 tensor([[ 0.9961,  0.9973,  1.0000,  ..., -0.9981, -0.9995, -0.9978],
         [ 0.9951,  0.9973,  1.0000,  ..., -0.9979, -0.9992, -0.9980]],
        device='cuda:0', grad_fn=<TanhBackward0>))