In [1]:
import collections
import math
import os
import os.path
import random
import time

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset

In [3]:
'''
@Project ：NLPNewsClassification 
@File    ：data_processing.py
@Author  ：DZY
@Date    ：2025/3/10 11:55 
'''


def tokenizer(samples):
    """
    按照单词分词

    Args:
        samples: 一维列表，每个元素是一个样本

    Returns:
        二维列表，每个子列表是一个样本，子列表中每个元素是一个词元

    Examples:
        samples: ["1 2 3 4","5 6 7 8"]
        Returns: [[1,2,3,4],[5,6,7,8]]

    """
    return [sample.split() for sample in samples]


def sample_truncate_pad(sample_tokens, num_steps, padding_token):
    """
    按num_steps长度截取或填充样本

    Args:
        sample_tokens: 分词后的样本词元列表
        num_steps: 样本要截取或填充的长度
        padding_token: 填充的词元

    Returns:
        截取或填充后的样本词元列表
    """
    sample_length = len(sample_tokens)
    if sample_length > num_steps:
        return sample_tokens[:num_steps]
    else:
        return sample_tokens + [padding_token] * (num_steps - sample_length)


def get_tokens_and_segments(tokens_a, tokens_b=None):
    """
    构造BERT形式的输入序列及片段索引

    Args:
        tokens_a: 样本序列A,一维列表,其中每一个元素是一个词元
        tokens_b: 样本序列B,一维列表,其中每一个元素是一个词元

    Returns:
        BERT输入形式的序列及片段索引

    Examples:
        tokens_a: [1, 2, 3]
        tokens_b: [4, 5, 6]
        Returns:
            tokens: ['<cls>', 1, 2, 3, '<sep>', 4, 5, 6, '<sep>']
            segments: [0, 0, 0, 0, 0, 1, 1, 1, 1]

    """
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

In [4]:
'''
@Project ：NLPNewsClassification 
@File    ：vocabulary.py
@Author  ：DZY
@Date    ：2025/3/10 17:19 
'''


class Vocab:
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        # Flatten a 2D list if needed
        if reserved_tokens is None:
            reserved_tokens = []
        if tokens is None:
            tokens = []
        if tokens and isinstance(tokens[0], list):
            tokens = [token for line in tokens for token in line]
        counter = collections.Counter(tokens)
        # print("counter:",list(counter.items())[:10])
        self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                  reverse=True)
        # print("token_freqs:",list(self.token_freqs)[:10])
        self.idx_to_token = list(
            sorted(set(['<unk>'] + reserved_tokens + [token for token, freq in self.token_freqs if freq >= min_freq])))
        self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if hasattr(indices, '__len__') and len(indices) > 1:
            return [self.idx_to_token[int(index)] for index in indices]
        return self.idx_to_token[indices]

    def create_vocab_txt(self, vocab_txt_relative_path):
        """
        将词表写到txt文件中

        Args:
            vocab_txt_relative_path:

        Returns:
            预训练数据集的词表txt文件
        """
        tokens = []
        indices = []
        for token, idx in self.token_to_idx.items():
            tokens.append(token)
            indices.append(idx)
        with open(vocab_txt_relative_path, 'w') as file:
            for token in tokens:
                # 这里词元后面有一个'\n'，在后续使用词元时要把'\n'去除掉
                file.write(f"{token}\n")

    @property
    def unk(self):
        return self.token_to_idx['<unk>']

In [6]:
'''
@Project ：NLPNewsClassification 
@File    ：attention.py
@Author  ：DZY
@Date    ：2025/3/11 16:12 
'''


def masked_softmax(X, valid_lens):
    def _sequence_mask(X, valid_len, value=0.0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)


class DotProductAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        attention_scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(attention_scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


class MultiHeadAttention(nn.Module):
    def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=use_bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=use_bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=use_bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=use_bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = self.transpose_qkv(self.W_q(queries))
        keys = self.transpose_qkv(self.W_k(keys))
        values = self.transpose_qkv(self.W_v(values))
        # print("queries.shape:",queries.shape,"keys.shape:",keys.shape,"values.shape:",values.shape)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

        output = self.attention(queries, keys, values, valid_lens)
        # print("output.shape:",output.shape)

        output_concat = self.transpose_output(output)
        # print("output_concat.shape:",output_concat.shape)

        result = self.W_o(output_concat)
        # print("result.shape:", result.shape)

        return result

    def transpose_qkv(self, X):
        # (batch_size,num_steps,num_heads,num_hiddens)->(batch_size,num_steps,num_heads,num_hiddens/num_heads)
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        # ->(batch_size,num_heads,num_steps,num_hiddens/num_heads)
        X = X.permute(0, 2, 1, 3)
        # ->(batch_size*num_heads,num_steps,num_hiddens/num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])

    def transpose_output(self, X):
        # (batch_size*num_heads,num_steps,num_hiddens/num_heads)->(batch_size,num_heads,num_steps,num_hiddens/num_heads)
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        # ->(batch_size,num_steps,num_heads,num_hiddens/num_heads)
        X = X.permute(0, 2, 1, 3)
        # ->(batch_size,num_steps,num_hiddens)
        return X.reshape(X.shape[0], X.shape[1], -1)

In [9]:
'''
@Project ：NLPNewsClassification 
@File    ：bert.py
@Author  ：DZY
@Date    ：2025/3/11 21:32 
'''


class BERTEncoder(nn.Module):
    def __init__(self, vocab_size, query_size, key_size, value_size, num_hiddens, normalized_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout, max_len=1000, **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))
        self.blocks = nn.Sequential()
        for i in range(num_layers):
            self.blocks.add_module(f"TransformerEncoderBlock:{i}",
                                   TransformerEncoderBlock(query_size, key_size, value_size, num_hiddens,
                                                           normalized_shape, ffn_num_input, ffn_num_hiddens,
                                                           num_heads, dropout, True))

    def forward(self, tokens, segments, valid_lens):
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for block in self.blocks:
            X = block(X, valid_lens)
        return X


class BERTLM(nn.Module):
    def __init__(self, vocab_size, query_size, key_size, value_size, num_hiddens, normalized_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, mlm_in_features, mlm_hiddens, nsp_in_features, nsp_hiddens,
                 dropout, max_len=1000, **kwargs):
        super(BERTLM, self).__init__(**kwargs)
        self.encoder = BERTEncoder(vocab_size, query_size, key_size, value_size, num_hiddens, normalized_shape,
                                   ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=max_len)
        self.mlm = MaskLM(mlm_in_features, mlm_hiddens, vocab_size)
        self.nsp = NextSentencePred(nsp_in_features, nsp_hiddens)

    def forward(self, tokens, segments, valid_lens=None, pred_position=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_position is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_position)
        else:
            mlm_Y_hat = None
        # 下一句子预测只需要<cls>这个特殊标识符的信息
        nsp_Y_hat = self.nsp(encoded_X[:, 0, :])
        return encoded_X, mlm_Y_hat, nsp_Y_hat

In [7]:
'''
@Project ：NLPNewsClassification 
@File    ：encoder.py
@Author  ：DZY
@Date    ：2025/3/11 16:13 
'''


class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))


class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.layernorm = nn.LayerNorm(normalized_shape)
        self.dropout = nn.Dropout(dropout)

    def forward(self, X, Y):
        return self.layernorm(self.dropout(Y) + X)


class TransformerEncoderBlock(nn.Module):
    def __init__(self, query_size, key_size, value_size, num_hiddens, normalized_shape, ffn_num_input, ffn_num_hiddens,
                 num_heads, dropout, use_bias=False, **kwargs):
        super(TransformerEncoderBlock, self).__init__(**kwargs)
        self.multi_head_attention = MultiHeadAttention(query_size, key_size, value_size, num_hiddens, num_heads,
                                                       dropout, use_bias)
        self.add_norm1 = AddNorm(normalized_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.add_norm2 = AddNorm(normalized_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.add_norm1(X, self.multi_head_attention(X, X, X, valid_lens))
        return self.add_norm2(Y, self.ffn(Y))

In [8]:
'''
@Project ：NLPNewsClassification 
@File    ：pretrain_tasks.py
@Author  ：DZY
@Date    ：2025/3/12 11:34 
'''


class MaskLM(nn.Module):
    def __init__(self, mlm_in_features, mlm_hiddens, vocab_size, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(mlm_in_features, mlm_hiddens), nn.ReLU(), nn.LayerNorm(mlm_hiddens),
                                 nn.Linear(mlm_hiddens, vocab_size))

    def forward(self, X, pred_position):
        # pred_position:(batch_size,num_pred_position)
        # X:(batch_size,num_steps,num_hiddens)
        num_pred_position = pred_position.shape[-1]
        batch_size = X.shape[0]
        # pred_position:(batch_size*num_pred_position)
        pred_position = pred_position.reshape(-1)
        batch_indices = torch.arange(0, batch_size)
        batch_indices = torch.repeat_interleave(batch_indices, num_pred_position)
        # 花式索引
        # masked_X:(batch_size*num_pred_position,num_hiddens)
        # 取的是每个掩蔽位置的词向量
        masked_X = X[batch_indices, pred_position]
        # masked_X:(batch_size,num_pred_position,num_hiddens)
        masked_X = masked_X.reshape((batch_size, num_pred_position, -1))

        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat


class NextSentencePred(nn.Module):
    def __init__(self, nsp_in_features, nsp_hiddens, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.mlp = nn.Sequential(nn.Linear(nsp_in_features, nsp_hiddens), nn.Tanh(), nn.Linear(nsp_hiddens, 2))

    def forward(self, X):
        return self.mlp(X)


In [None]:
'''
@Project ：NLPNewsClassification 
@File    ：checkpoints.py
@Author  ：DZY
@Date    ：2025/3/24 19:07 
'''


def save_pretrain_checkpoint(model, optimizer, step, checkpoint_pretrain_info_tuple, checkpoints_relative_path,
                             checkpoint_dir_name,
                             checkpoint_file_name):
    """

    Args:
        model:
        optimizer:
        step:
        checkpoint_pretrain_info_tuple:
        checkpoints_relative_path:
        checkpoint_dir_name:
        checkpoint_file_name:

    Returns:

    """
    script_directory = os.getcwd()
    checkpoint_dir_path = os.path.join(script_directory, checkpoints_relative_path, checkpoint_dir_name)
    if not os.path.exists(checkpoint_dir_path):
        os.makedirs(checkpoint_dir_path)
    # 拼接检查点路径
    checkpoint_file_path = os.path.join(checkpoint_dir_path, checkpoint_file_name)
    torch.save({
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'total_mlm_loss': checkpoint_pretrain_info_tuple[0],
        'total_processed_samples': checkpoint_pretrain_info_tuple[1],
        'cum_time_list': checkpoint_pretrain_info_tuple[2]
    }, checkpoint_file_path)
    print(f"检查点已保存到 {checkpoint_file_path}")


def save_finetuning_model(model, checkpoints_relative_path,
                          checkpoint_dir_name,
                          checkpoint_file_name):
    """

    Args:
        model:
        checkpoints_relative_path:
        checkpoint_dir_name:
        checkpoint_file_name:

    Returns:

    """
    script_directory = os.getcwd()
    checkpoint_dir_path = os.path.join(script_directory, checkpoints_relative_path, checkpoint_dir_name)
    if not os.path.exists(checkpoint_dir_path):
        os.makedirs(checkpoint_dir_path)
    # 拼接检查点路径
    checkpoint_file_path = os.path.join(checkpoint_dir_path, checkpoint_file_name)
    torch.save({
        'model_state_dict': model.state_dict(),
    }, checkpoint_file_path)
    print(f"模型参数已保存到 {checkpoint_file_path}")


def load_checkpoint(model, checkpoint_relative_path, device, optimizer=None, scheduler=None):
    """

    Args:
        model:
        checkpoint_relative_path:
        device:
        optimizer:
        scheduler:

    Returns:

    """
    # 加载检查点，并加载模型参数
    checkpoint = load_pretrained_model_params(model, checkpoint_relative_path, device)
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("加载预训练BERT优化器参数成功")
    if scheduler and checkpoint['scheduler_state_dict']:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        print("加载预训练BERT调度器参数成功")

    # 恢复训练步数
    step = checkpoint['step']
    total_mlm_loss = checkpoint['total_mlm_loss'],
    total_processed_samples = checkpoint['total_processed_samples']
    cum_time_list = checkpoint['cum_time_list']
    print(f"恢复步数: {step} "
          f"恢复当前总损失: {total_mlm_loss} "
          f"恢复当前已处理样本数: {total_processed_samples} "
          f"恢复当前训练迭代时间列表: {cum_time_list}")

    return step, total_mlm_loss, total_processed_samples, cum_time_list


def _process_model_state_dict_keys_name(model_state_dict):
    return {key.replace("module.", ""): value for key, value in model_state_dict.items()}


def load_pretrained_model_params(model, checkpoint_relative_path, device):
    script_directory = os.getcwd()
    checkpoint_path = os.path.join(script_directory, checkpoint_relative_path)
    # 加载检查点
    checkpoint = torch.load(checkpoint_path, map_location=device)
    print(f"加载检查点路径{checkpoint_relative_path}")

    checkpoint['model_state_dict'] = _process_model_state_dict_keys_name(checkpoint['model_state_dict'])
    model.load_state_dict(checkpoint['model_state_dict'])
    print("加载预训练BERT模型参数成功")
    return checkpoint


def save_pretrain_info(pretrain_results_relative_path, pretrain_result):
    script_directory = os.getcwd()
    pretrain_results_path = os.path.join(script_directory, pretrain_results_relative_path)
    with open(pretrain_results_path, 'w') as file:
        file.write(pretrain_result)

In [10]:
'''
@Project ：NLPNewsClassification 
@File    ：environment.py
@Author  ：DZY
@Date    ：2025/3/14 15:41 
'''


def use_cpu():
    """Get the CPU device.

    Defined in :numref:`sec_use_gpu`"""
    return torch.device('cpu')


def use_gpu(i=0):
    """Get a GPU device.

    Defined in :numref:`sec_use_gpu`"""
    return torch.device(f'cuda:{i}')


def num_gpus():
    """Get the number of available GPUs.

    Defined in :numref:`sec_use_gpu`"""
    return torch.cuda.device_count()


def try_gpu(i=0):
    """Return gpu(i) if exists, otherwise return cpu().

    Defined in :numref:`sec_use_gpu`"""
    if num_gpus() >= i + 1:
        return use_gpu(i)
    return use_cpu()


def try_all_gpus():
    """Return all available GPUs, or [cpu(),] if no GPU exists.

    Defined in :numref:`sec_use_gpu`"""
    return [use_gpu(i) for i in range(num_gpus())]


In [None]:
'''
@Project ：NLPNewsClassification 
@File    ：scheduler.py
@Author  ：DZY
@Date    ：2025/3/25 11:35 
'''


class BERTScheduler():
    def __init__(self, optimizer, num_hiddens, warmup_steps, current_step=0):
        """

        Args:
            optimizer: 优化器对象
            num_hiddens: 词向量长度
            warmup_steps: warmup步数
            current_step: 当前epoch（用于恢复训练）

        Returns:

        """
        self.optimizer = optimizer
        self.init_lr = np.power(num_hiddens, -0.5)
        self.warmup_steps = warmup_steps
        self.current_step = current_step
        self.step()

    def get_lr(self):
        if self.current_step < self.warmup_steps:
            # warmup阶段，线性增加学习率
            return self.init_lr * np.power(self.warmup_steps, -1.5) * self.current_step
        else:
            # 线性衰退
            return self.init_lr * np.power(self.current_step, -0.5)

    def step(self):
        """更新学习率"""
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

In [11]:
'''
@Project ：NLPNewsClassification 
@File    ：timer.py
@Author  ：DZY
@Date    ：2025/3/14 17:05 
'''


class Timer():
    def __init__(self):
        self.iter_step_time_list = []

    def start(self):
        self.start_time = time.time()

    def stop(self):
        self.stop_time = time.time()
        self.iter_step_time_list.append(self.stop_time - self.start_time)
        return self.iter_step_time_list[-1]

    def get_time_diff(self):
        return self.iter_step_time_list[-1]

    def get_iter_step_time_list(self):
        return self.iter_step_time_list

    def get_step_time_diff(self, start_step_index, end_step_index):
        return sum(self.iter_step_time_list[start_step_index:end_step_index])

    def get_total_time(self):
        return sum(self.iter_step_time_list)

    def get_avg_time(self):
        return self.get_total_time() / len(self.iter_step_time_list)

    def get_cumulate_time(self):
        """

        Returns:
            返回累积时间

        """
        return np.array(self.iter_step_time_list).cumsum().tolist()


# 将训练时长格式化为 小时:分:秒
def format_duration(seconds):
    # 计算小时、分钟和秒
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = int(seconds % 60)
    # 格式化为 hh:mm:ss
    return f"{hours:02d}:{minutes:02d}:{seconds:02d}"


In [12]:
'''
@Project ：NLPNewsClassification 
@File    ：pretrain_data_create.py
@Author  ：DZY
@Date    ：2025/3/12 16:25 
'''


def _get_inputs(samples_tokens, max_len):
    """
    构造所有样本的BERT形式的输入序列

    Args:
        samples_tokens: 所有样本序列，二维列表，每个子列表是一个样本序列，子列表每个元素是一个词元
        max_len: BERT输入序列最大长度

    Returns:
        所有样本的BERT形式的输入序列及其片段索引
        元组列表：[(样本的BERT形式的输入序列,其片段索引),...]

    """
    inputs = []
    for sample_tokens in samples_tokens:
        # 如果原本输入序列加上<cls>'和'<sep>是否会超过max_len
        if (len(sample_tokens) + 2) > max_len:
            # 超过就丢掉
            continue
        tokens, segments = get_tokens_and_segments(sample_tokens)
        inputs.append((tokens, segments))
    return inputs


def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab):
    """
    用于掩蔽语言模型替换词元

    Args:
        tokens: 用于MLM的输入序列（BERT形式的输入），一维列表，每个元素是一个词元，包含特殊标识符
        candidate_pred_positions: 候选需要替换（预测）的词元位置下标索引（不包括特殊标识符，特殊标识符不被预测）
        num_mlm_preds: 需要替换（预测）的词元数量
        vocab: 词表

    Returns:
        随机替换后的序列
        替换的位置和替换之前的词元组成的元组列表

    """
    # 复制一份MLM输入序列词元列表
    mlm_input_tokens = [token for token in tokens]
    # 预测的词元位置索引和被替换前的词元
    pred_positions_and_labels = []
    # 随机打乱候选替换位置索引
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) > num_mlm_preds:
            break
        masked_tokens = None
        # 80%时间替换为<mask>
        if random.random() < 0.8:
            masked_token = '<mask>'
        else:
            # 10%时间保持不变
            if random.random() < 0.5:
                masked_token = tokens[mlm_pred_position]
            # 10%时间替换为随机词元
            else:
                masked_token = random.choice(vocab.idx_to_token)
        # 替换输入序列中指定位置的词元
        mlm_input_tokens[mlm_pred_position] = masked_token
        # 保存替换的位置和替换之前的词元
        pred_positions_and_labels.append((mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels


def _get_mlm_data(tokens, vocab):
    """
    生成用于掩蔽语言模型的数据

    Args:
        tokens: 用于MLM的输入序列（BERT形式的输入），一维列表，每个元素是一个词元，包含特殊标识符
        vocab: 词表

    Returns:
        替换后的序列中所有词元的索引下标列表
        预测位置的下标列表
        替换前的词元索引下标列表

    """
    # 候选预测位置索引
    candidate_pred_position = []
    for index, token in enumerate(tokens):
        # 特殊标识符不参与词元预测
        if token in ['<cls>', '<sep>']:
            continue
        candidate_pred_position.append(index)
    # MLM中只预测15%的词元
    num_mlm_preds = max(1, round(len(tokens) * 0.15))
    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(tokens, candidate_pred_position, num_mlm_preds,
                                                                      vocab)
    pred_positions_and_labels = sorted(pred_positions_and_labels, key=lambda x: x[0])
    pred_positions = [position_and_label_tuple[0] for position_and_label_tuple in pred_positions_and_labels]
    pred_labels = [position_and_label_tuple[1] for position_and_label_tuple in pred_positions_and_labels]
    return vocab[mlm_input_tokens], pred_positions, vocab[pred_labels]


def _pad_bert_inputs(examples, max_len, vocab):
    max_num_mlm_preds = round(max_len * 0.15)
    all_token_ids, all_segments, valid_lens, = [], [], []
    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
    for (token_ids, pred_positions, mlm_pred_label_ids, segments) in examples:
        all_token_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (
                max_len - len(token_ids)), dtype=torch.long))
        all_segments.append(torch.tensor(segments + [0] * (
                max_len - len(segments)), dtype=torch.long))
        # valid_lens不包括'<pad>'的计数
        valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))
        all_pred_positions.append(torch.tensor(pred_positions + [0] * (
                max_num_mlm_preds - len(pred_positions)), dtype=torch.long))
        # 填充词元的预测将通过乘以0权重在损失中过滤掉
        all_mlm_weights.append(
            torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (
                    max_num_mlm_preds - len(pred_positions)),
                         dtype=torch.float32))
        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (
                max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))
    return (all_token_ids, all_segments, valid_lens, all_pred_positions,
            all_mlm_weights, all_mlm_labels)


class PretrainDataset(Dataset):
    def __init__(self, samples, max_len):
        # 原始样本序列按单词分词
        samples_tokens = tokenizer(samples)
        self.vocab = Vocab(samples_tokens, min_freq=5, reserved_tokens=['<cls>', '<sep>', '<pad>', '<mask>'])
        # 获取所有BERT形式的输入序列
        examples = []
        examples.extend(_get_inputs(samples_tokens, max_len))
        examples = [(_get_mlm_data(tokens, self.vocab) + (segments,)) for tokens, segments in examples]
        (self.all_token_ids, self.all_segments, self.valid_lens,
         self.all_pred_positions, self.all_mlm_weights,
         self.all_mlm_labels) = _pad_bert_inputs(examples, max_len, self.vocab)

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.all_pred_positions[idx],
                self.all_mlm_weights[idx], self.all_mlm_labels[idx])

    def __len__(self):
        return len(self.all_token_ids)


def get_dataloader_workers(num_workers):
    """
    使用指定个数进程来读取数据

    Args:
        num_workers: 进程数

    Returns:
        进程数

    """
    return num_workers


def read_pretrain_data(pretrain_sentences_relative_path):
    script_directory = os.getcwd()
    full_path = os.path.join(script_directory, pretrain_sentences_relative_path)
    with open(full_path, 'r') as file:
        samples = file.readlines()
    return samples


def load_pretrain_data(pretrain_data_relative_path, vocab_txt_relative_path, batch_size, max_len):
    num_workers = get_dataloader_workers(0)
    samples = read_pretrain_data(pretrain_data_relative_path)
    pretrain_set = PretrainDataset(samples, max_len)
    pretrain_set.vocab.create_vocab_txt(vocab_txt_relative_path)
    pretrain_iter = torch.utils.data.DataLoader(pretrain_set, batch_size, shuffle=True, num_workers=num_workers)
    return pretrain_iter, pretrain_set.vocab

# batch_size, max_len = 128, 256
# train_iter, vocab = load_pretrain_data(batch_size, max_len)
#
# for (tokens_X, segments_X, valid_lens_X, pred_positions_X, mlm_weights_X,mlm_Y) in train_iter:
#     print(tokens_X.shape, segments_X.shape, valid_lens_X.shape,
#           pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape)
#     break


In [13]:
'''
@Project ：NLPNewsClassification
@File    ：pretrain_bert.py
@Author  ：DZY
@Date    ：2025/3/14 12:38
'''


def load_pretrain_checkpoint(model, optimizer, device, checkpoint_relative_path=None):
    if checkpoint_relative_path is None:
        step = 0
        total_mlm_loss = 0.0
        total_processed_samples = 0.0
        cum_time_list = []
        return step, total_mlm_loss, total_processed_samples, cum_time_list
    step, total_mlm_loss, total_processed_samples, cum_time_list = load_checkpoint(model,
                                                                                   checkpoint_relative_path,
                                                                                   device, optimizer)
    return step, total_mlm_loss, total_processed_samples, cum_time_list


class PretrainBERT():
    def __init__(self, current_step=0, total_mlm_loss=0.0, total_processed_samples=0.0, cum_time_list=None):
        if cum_time_list is None:
            cum_time_list = []
        self.current_step = current_step
        self.total_mlm_loss = total_mlm_loss
        self.total_processed_samples = total_processed_samples
        self.cum_time_list = cum_time_list

    def _get_bert_batch_loss(self, net, loss, vocab_size, tokens_X, segments_X, valid_lens_X, pred_positions_X,
                             mlm_weights_X, mlm_Y):
        _, mlm_Y_hat, _ = net(tokens_X, segments_X, valid_lens_X, pred_positions_X)
        mlm_loss = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)
        mlm_loss = mlm_loss.sum() / (mlm_weights_X.sum() + 1e-8)
        return mlm_loss

    def pretrain(self, pretrain_iter, net, loss, vocab_size, pretrain_optimizer, pretrain_scheduler,
                 num_pretrain_iter_steps,
                 checkpoints_relative_path, devices):
        net = nn.DataParallel(net, device_ids=devices).to(devices[0])
        pretrain_timer = Timer()
        num_pretrain_iter_steps_reached = False
        save_checkpoint_interval = 1000
        pretrain_info_interval = 100

        print("start pretraining...")
        while self.current_step < num_pretrain_iter_steps and not num_pretrain_iter_steps_reached:
            for tokens_X, segments_X, valid_lens_X, pred_positions_X, mlm_weights_X, mlm_Y in pretrain_iter:
                tokens_X = tokens_X.to(devices[0])
                segments_X = segments_X.to(devices[0])
                valid_lens_X = valid_lens_X.to(devices[0])
                pred_positions_X = pred_positions_X.to(devices[0])
                mlm_weights_X = mlm_weights_X.to(devices[0])
                mlm_Y = mlm_Y.to(devices[0])

                pretrain_timer.start()
                pretrain_optimizer.zero_grad()
                mlm_loss = self._get_bert_batch_loss(net, loss, vocab_size, tokens_X, segments_X, valid_lens_X,
                                                     pred_positions_X, mlm_weights_X, mlm_Y)
                mlm_loss.backward()
                pretrain_optimizer.step()

                if pretrain_scheduler is not None:
                    pretrain_scheduler.step()
                pretrain_timer.stop()

                self.total_mlm_loss += mlm_loss.item()
                self.total_processed_samples += tokens_X.shape[0]

                if (self.current_step + 1) % pretrain_info_interval == 0:
                    self.cum_time_list = pretrain_timer.get_cumulate_time()
                    print(
                        f"Iter Steps: {self.current_step + 2 - pretrain_info_interval}-{self.current_step + 1} ---- "
                        f"Avg MLM Loss: {self.total_mlm_loss / (self.current_step + 1):.4f} ---- "
                        f"Cumulative Iter Time: {self.cum_time_list[self.current_step]:.4f} sec")

                if (self.current_step + 1) % save_checkpoint_interval == 0:
                    checkpoint_dir_name = f"checkpoint_step_{self.current_step + 1}"
                    checkpoint_file_name = f"checkpoint_step_{self.current_step + 1}.pth"
                    checkpoint_pretrain_info_tuple = (
                        self.total_mlm_loss, self.total_processed_samples, self.cum_time_list)
                    save_pretrain_checkpoint(net, pretrain_optimizer,
                                             self.current_step + 1, checkpoint_pretrain_info_tuple,
                                             checkpoints_relative_path,
                                             checkpoint_dir_name, checkpoint_file_name)

                self.current_step += 1

                if self.current_step == num_pretrain_iter_steps:
                    num_pretrain_iter_steps_reached = True
                    break

        pretrain_total_time = pretrain_timer.get_total_time()
        print(f"Pretrain BERT Total Time: {format_duration(pretrain_total_time)}")
        print(f"Pretrain BERT Total Avg MLM Loss: {self.total_mlm_loss / self.current_step:.4f}")
        print(f"Pretrain BERT Total Num Processed Samples: {self.total_processed_samples}")
        print(
            f"Pretrain BERT Processing Samples Speed: {self.total_processed_samples / pretrain_total_time:.4f} samples/sec")


In [None]:
batch_size, max_len = 128, 256
pretrain_data_relative_path = "../../dataset/train/pretrain/pretrain_sentences_test.txt"
# pretrain_data_relative_path = "../../dataset/train/pretrain/pretrain_sentences_128.txt"
vocab_txt_relative_path = "../../pretrain_results/vocab.txt"
checkpoints_relative_path = "../../pretrain_results/checkpoints"
# checkpoint_relative_path = "../../pretrain_results/checkpoints/checkpoint_step_10/checkpoint_step_10.pth"

pretrain_iter, vocab = load_pretrain_data(pretrain_data_relative_path, vocab_txt_relative_path,
                                          batch_size, max_len)


In [None]:
'''
@Project ：NLPNewsClassification
@File    ：run_pretrain_bert.py
@Author  ：DZY
@Date    ：2025/3/17 17:07
'''

initial_size = 768
vocab_size = len(vocab)
query_size = initial_size
key_size = initial_size
value_size = initial_size
num_hiddens = initial_size
normalized_shape = [initial_size]
ffn_num_input = initial_size
ffn_num_hiddens = 3072
num_heads = 12
num_layers = 12
mlm_in_features = initial_size
mlm_hiddens = initial_size
nsp_in_features = initial_size
nsp_hiddens = initial_size
dropout = 0.1
lr = 1e-4
num_pretrain_iter_steps = 30

net = BERTLM(vocab_size, query_size, key_size, value_size, num_hiddens, normalized_shape, ffn_num_input,
             ffn_num_hiddens, num_heads, num_layers, mlm_in_features, mlm_hiddens, nsp_in_features, nsp_hiddens,
             dropout, max_len=max_len)

devices = try_all_gpus()
loss = nn.CrossEntropyLoss()

pretrain_optimizer = torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=0.01)
pretrain_scheduler = BERTScheduler(pretrain_optimizer, num_hiddens, warmup_steps=10000)

step, total_mlm_loss, total_processed_samples, cum_time_list = load_pretrain_checkpoint(net,
                                                                                        pretrain_optimizer,
                                                                                        devices[0],
                                                                                        checkpoint_relative_path=None)

pretrainbert = PretrainBERT(step, total_mlm_loss, total_processed_samples, cum_time_list)
pretrainbert.pretrain(pretrain_iter, net, loss, vocab_size, pretrain_optimizer, pretrain_scheduler,
                      num_pretrain_iter_steps,
                      checkpoints_relative_path, devices)


In [None]:
'''
@Project ：NLPNewsClassification 
@File    ：fine_tuning_data_create.py
@Author  ：DZY
@Date    ：2025/3/18 20:19 
'''


def load_fine_tuning_set_data(fine_tuning_set_relative_path, sequence_length):
    script_directory = os.getcwd()
    full_path = os.path.join(script_directory, fine_tuning_set_relative_path)
    with open(full_path, 'r') as file:
        lines = file.readlines()
    samples_tokens = []
    labels = []
    for line in lines:
        data = line.split('\t')
        labels.append(int(data[0]))
        sample_tokens = data[1].split()
        num_sample_tokens = len(sample_tokens)
        if num_sample_tokens < sequence_length:
            samples_tokens.append(sample_tokens)
        else:
            samples_tokens.append(sample_tokens[:sequence_length])
    return samples_tokens, labels


def _create_ft_inputs(samples_tokens, max_len):
    """
    构造所有样本的BERT形式的输入序列

    Args:
        samples_tokens: 所有样本序列，二维列表，每个子列表是一个样本序列，子列表每个元素是一个词元
        max_len: BERT输入序列最大长度

    Returns:
        所有样本的BERT形式的输入序列及其片段索引
        元组列表：[(样本的BERT形式的输入序列,其片段索引),...]

    """
    input_data = []
    for sample_tokens in samples_tokens:
        # 如果原本输入序列加上<cls>'和'<sep>是否会超过max_len
        while (len(sample_tokens) + 2) > max_len:
            # 超过就弹出最后一个词
            sample_tokens.pop()
        tokens, segments = get_tokens_and_segments(sample_tokens)
        input_data.append((tokens, segments))
    return input_data


def _pad_ft_inputs(input_data, max_len, vocab):
    all_token_ids, all_segments, valid_lens, = [], [], []
    for tokens, segments in input_data:
        all_token_ids.append(torch.tensor(vocab[tokens] + [vocab['<pad>']] * (
                max_len - len(tokens)), dtype=torch.long))
        all_segments.append(torch.tensor(segments + [0] * (
                max_len - len(segments)), dtype=torch.long))
        # valid_lens不包括'<pad>'的计数
        valid_lens.append(torch.tensor(len(tokens), dtype=torch.float32))
    return (all_token_ids, all_segments, valid_lens)


class FineTuningDataset(Dataset):
    def __init__(self, dataset, max_len, vocab):
        samples_tokens = dataset[0]
        self.labels = torch.tensor(dataset[1])
        self.vocab = vocab
        self.max_len = max_len
        input_data = _create_ft_inputs(samples_tokens, max_len)
        (self.all_token_ids, self.all_segments, self.valid_lens) = _pad_ft_inputs(input_data, max_len, self.vocab)

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.labels[idx])

    def __len__(self):
        return len(self.all_token_ids)


def create_fine_tuning_iter(batch_size, max_len, sequence_length, vocab, ft_train_set_relative_path,
                            ft_validate_set_relative_path, is_train=True):
    if is_train:
        ft_set_relative_path = ft_train_set_relative_path
    else:
        ft_set_relative_path = ft_validate_set_relative_path
    dataset = load_fine_tuning_set_data(ft_set_relative_path, sequence_length)
    ft_dataset = FineTuningDataset(dataset, max_len, vocab)
    ft_iter = torch.utils.data.DataLoader(ft_dataset, batch_size, shuffle=is_train, num_workers=0)
    return ft_iter

In [None]:
'''
@Project ：NLPNewsClassification 
@File    ：fine_tuning_bert.py
@Author  ：DZY
@Date    ：2025/3/14 12:38 
'''


def _load_pretrained_vocab_txt_file(pretrained_vocab_relative_path):
    """
    加载预训练BERT的词表txt文件

    Returns:
        词元一维列表，每个元素是一个词元

    """
    script_directory = os.getcwd()
    pretrained_vocab_txt_file_absolute_path = os.path.join(script_directory, pretrained_vocab_relative_path)
    with open(pretrained_vocab_txt_file_absolute_path, 'r') as file:
        lines = file.readlines()
    return ' '.join(lines).split()


def _load_pretrained_vocab(pretrained_vocab_relative_path):
    pretrained_vocab_list = _load_pretrained_vocab_txt_file(pretrained_vocab_relative_path)
    pretrained_vocab = Vocab()
    pretrained_vocab.idx_to_token = pretrained_vocab_list
    pretrained_vocab.token_to_idx = {token: idx for idx, token in enumerate(pretrained_vocab.idx_to_token)}
    return pretrained_vocab


def load_pretrained_model(query_size, key_size, value_size, num_hiddens, normalized_shape, ffn_num_input,
                          ffn_num_hiddens, num_heads, num_layers, mlm_in_features, mlm_hiddens, nsp_in_features,
                          nsp_hiddens, dropout, max_len, devices, pretrained_vocab_relative_path,
                          checkpoint_relative_path):
    """

    Args:
        query_size:
        key_size:
        value_size:
        num_hiddens:
        normalized_shape:
        ffn_num_input:
        ffn_num_hiddens:
        num_heads:
        num_layers:
        mlm_in_features:
        mlm_hiddens:
        nsp_in_features:
        nsp_hiddens:
        dropout:
        max_len:
        devices:
        pretrained_vocab_relative_path:
        checkpoint_relative_path:

    Returns:

    """
    pretrained_vocab = _load_pretrained_vocab(pretrained_vocab_relative_path)
    # print("pretrained_vocab",pretrained_vocab.token_to_idx)
    pretrained_bert = BERTLM(len(pretrained_vocab), query_size, key_size, value_size, num_hiddens,
                                  normalized_shape, ffn_num_input,
                                  ffn_num_hiddens, num_heads, num_layers, mlm_in_features, mlm_hiddens, nsp_in_features,
                                  nsp_hiddens,
                                  dropout, max_len=max_len)
    load_pretrained_model_params(pretrained_bert, checkpoint_relative_path, devices[0])
    return pretrained_bert, pretrained_vocab


class BERTClassifier(nn.Module):
    def __init__(self, pretrained_bert, classifier_num_input, classifier_num_hiddens, classifier_num_output):
        super(BERTClassifier, self).__init__()
        self.encoder = pretrained_bert.encoder
        self.classifier = nn.Sequential(nn.Linear(classifier_num_input, classifier_num_hiddens), nn.ReLU(),
                                        nn.Linear(classifier_num_hiddens, classifier_num_output))

    def forward(self, tokens, segments, valid_lens):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        return self.classifier(encoded_X[:, 0, :])


# def accuracy(y_hat, y):
#     """Compute the number of correct predictions.
#
#     Defined in :numref:`sec_utils`"""
#     if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
#         y_hat = d2l.argmax(y_hat, axis=1)
#     cmp = d2l.astype(y_hat, y.dtype) == y
#     return float(d2l.reduce_sum(d2l.astype(cmp, y.dtype)))


# todo
def get_num_correct_preds(y_hat, y):
    """

    Args:
        y_hat:
        y: 这里y必须是标签列表的下标，比如这里分类是14类，标签是[1,2,3]，分类列表是[1,2,3,...,14]，这里y必须是标签对应分类列表的下标[0,1,2]，这样才能和y_hat对应。
        这里简单的做法是把y中每个元素都减一，后面优化代码时，再在数据上做转换

    Returns:

    """
    preds_max_prob_indices = torch.argmax(y_hat, dim=1)
    num_correct_preds = (preds_max_prob_indices == (y - 1)).sum().item()
    return num_correct_preds


def evaluate_accuracy(net, valid_iter, device=None):
    if isinstance(net, nn.Module):
        net.eval()
        if not device:
            device = next(iter(net.parameters())).device
    num_samples = 0
    with torch.no_grad():
        for tokens_X, segments_X, valid_lens_X, labels_Y in valid_iter:
            tokens_X = tokens_X.to(device)
            segments_X = segments_X.to(device)
            valid_lens_X = valid_lens_X.to(device)
            labels_Y = labels_Y.to(device)
            Y_hat = net(tokens_X, segments_X, valid_lens_X)
            num_correct_preds = get_num_correct_preds(Y_hat, labels_Y)
            num_samples += tokens_X.shape[0]
    return round(num_correct_preds / num_samples, 4)


def fine_tuning(ft_train_iter, ft_valid_iter, net, loss, lr, num_epochs, devices, ft_checkpoints_relative_path):
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    ft_optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    ft_timer = Timer()
    current_ft_iter_step = 0
    total_ft_train_loss = 0.0
    total_ft_train_correct_preds = 0
    total_ft_processed_samples = 0
    total_ft_valid_acc = 0.0
    ft_info_interval = 100

    print("start fine tuning...")
    for epoch in range(num_epochs):
        net.train()
        for tokens_X, segments_X, valid_lens_X, labels_Y in ft_train_iter:
            tokens_X = tokens_X.to(devices[0])
            segments_X = segments_X.to(devices[0])
            valid_lens_X = valid_lens_X.to(devices[0])
            labels_Y = labels_Y.to(devices[0])

            ft_timer.start()
            ft_optimizer.zero_grad()
            Y_hat = net(tokens_X, segments_X, valid_lens_X)

            total_ft_train_correct_preds += get_num_correct_preds(Y_hat, labels_Y)

            l = loss(Y_hat, labels_Y)
            l.sum().backward()
            ft_optimizer.step()
            ft_timer.stop()

            total_ft_train_loss += l.sum().item()
            total_ft_processed_samples += tokens_X.shape[0]

            if (current_ft_iter_step + 1) % ft_info_interval == 0:
                cum_time_list = ft_timer.get_cumulate_time()
                print(
                    f"Iter Steps: {current_ft_iter_step + 2 - ft_info_interval}-{current_ft_iter_step + 1} ---- "
                    f"Cumulative Avg Loss: {total_ft_train_loss / (current_ft_iter_step + 1):.4f} ---- "
                    f"Cumulative Correct Train Preds/Cumulative Processed Samples: {total_ft_train_correct_preds}/{total_ft_processed_samples} ---- "
                    f"Cumulative Avg Train Acc: {total_ft_train_correct_preds / total_ft_processed_samples:.4f} ---- "
                    f"Cumulative Iter Time: {cum_time_list[current_ft_iter_step]:.4f} sec")

            current_ft_iter_step += 1

        total_ft_valid_acc = evaluate_accuracy(net, ft_valid_iter)
        print(f"epoch {epoch} Avg Train Acc: {total_ft_train_correct_preds / total_ft_processed_samples:.4f}")
        print(f"epoch {epoch} Avg Valid Acc: {total_ft_valid_acc}")
        
        checkpoint_dir_name = f"checkpoint_epoch_{epoch + 1}"
        checkpoint_file_name = f"checkpoint_epoch_{epoch + 1}.pth"
        save_finetuning_model(net, ft_checkpoints_relative_path, checkpoint_dir_name, checkpoint_file_name)

    ft_total_time = ft_timer.get_total_time()
    print("FT Total Time: ", format_duration(ft_total_time))
    print("FT Total Steps: ", current_ft_iter_step)
    print(f"FT Avg Train Acc: {total_ft_train_correct_preds / total_ft_processed_samples:.4f}")
    print("FT Total Num Processed Samples: ", total_ft_processed_samples)
    print(f"FT Processing Samples Speed: {total_ft_processed_samples / ft_total_time:.4f} samples/sec")
    print("FT Avg Valid Acc: ", total_ft_valid_acc)


In [None]:
'''
@Project ：NLPNewsClassification 
@File    ：run_fine_tuning_bert.py
@Author  ：DZY
@Date    ：2025/3/18 21:11 
'''


max_len = 256
initial_size = 768
query_size = initial_size
key_size = initial_size
value_size = initial_size
num_hiddens = initial_size
normalized_shape = [initial_size]
ffn_num_input = initial_size
ffn_num_hiddens = 3072
num_heads = 12
num_layers = 12
mlm_in_features = initial_size
mlm_hiddens = initial_size
nsp_in_features = initial_size
nsp_hiddens = initial_size
dropout = 0.1
devices = try_all_gpus()

pretrained_vocab_relative_path = "../../pretrain_results/vocab.txt"
checkpoints_relative_path = "../../pretrain_results/checkpoints/checkpoint_step_10/checkpoint_step_10.pth"
ft_train_set_relative_path = "../../dataset/train/fine_tuning/train/fine_tuning_train_set_test.txt"
ft_validate_set_relative_path = "../../dataset/train/fine_tuning/validate/fine_tuning_validate_set_test.txt"
ft_checkpoints_relative_path = "../../fine_tuning_results"

pretrained_bert, pretrained_vocab = load_pretrained_model(query_size, key_size, value_size,
                                                                           num_hiddens, normalized_shape, ffn_num_input,
                                                                           ffn_num_hiddens, num_heads, num_layers,
                                                                           mlm_in_features, mlm_hiddens,
                                                                           nsp_in_features, nsp_hiddens,
                                                                           dropout, max_len, devices,
                                                                           pretrained_vocab_relative_path,
                                                                           checkpoints_relative_path)

batch_size = 32
sequence_length = 256
ft_train_iter = create_fine_tuning_iter(batch_size, max_len, sequence_length, pretrained_vocab,
                                                                ft_train_set_relative_path,
                                                                ft_validate_set_relative_path, is_train=True)
ft_valid_iter = create_fine_tuning_iter(batch_size, max_len, sequence_length, pretrained_vocab,
                                                                ft_train_set_relative_path,
                                                                ft_validate_set_relative_path, is_train=False)

# for (tokens_X, segments_X, valid_lens_X, Y) in ft_train_iter:
#     print(tokens_X.shape, segments_X.shape, valid_lens_X.shape,Y.shape)
#     break


lr, num_epochs = 2e-5, 3
loss = nn.CrossEntropyLoss(reduction='none')
classifier_num_input, classifier_num_hiddens, classifier_num_output = 768, 3072, 14

net = BERTClassifier(pretrained_bert, classifier_num_input, classifier_num_hiddens,
                                      classifier_num_output)
fine_tuning(ft_train_iter, ft_valid_iter, net, loss, lr, num_epochs, devices,
                             ft_checkpoints_relative_path)
