In [1]:
import math
import re
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AutoModelForMaskedLM, BertConfig, get_linear_schedule_with_warmup
from transformers.models.bert.modeling_bert import BertSelfAttention
from tqdm import tqdm
from typing import Optional, Tuple

In [2]:
# 添加稀疏注意力的实现
class SparseBertSelfAttention(BertSelfAttention):
    def __init__(self, config, position_embedding_type=None):
        super().__init__(config, position_embedding_type=position_embedding_type)
        self.sparsity_threshold = 0.1

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        if encoder_hidden_states is not None:
            mixed_key_layer = self.key(encoder_hidden_states)
            mixed_value_layer = self.value(encoder_hidden_states)
            attention_mask = encoder_attention_mask
        else:
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # 计算注意力分数
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask

        # 实现稀疏注意力
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # 应用稀疏掩码
        mask = (attention_probs > self.sparsity_threshold).float()
        sparse_attention_probs = attention_probs * mask
        # 重新归一化
        sparse_attention_probs = sparse_attention_probs / (sparse_attention_probs.sum(dim=-1, keepdim=True) + 1e-6)

        if head_mask is not None:
            sparse_attention_probs = sparse_attention_probs * head_mask

        sparse_attention_probs = self.dropout(sparse_attention_probs)
        context_layer = torch.matmul(sparse_attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, sparse_attention_probs) if output_attentions else (context_layer,)
        return outputs

def replace_attention_layers(model):
    """替换BERT中的所有注意力层为稀疏注意力层"""
    for layer in model.encoder.layer:
        # 获取原始注意力层的配置
        original_attention = layer.attention.self
        config = BertConfig(
            hidden_size=original_attention.all_head_size,
            num_attention_heads=original_attention.num_attention_heads,
            attention_probs_dropout_prob=original_attention.dropout.p,
            hidden_dropout_prob=0.1
        )

        # 保存原始位置嵌入类型（如果有）
        position_embedding_type = getattr(original_attention, 'position_embedding_type', None)

        # 替换为稀疏注意力层
        layer.attention.self = SparseBertSelfAttention(config, position_embedding_type)

        # 复制原始注意力层的权重
        layer.attention.self.query.weight.data = original_attention.query.weight.data.clone()
        layer.attention.self.query.bias.data = original_attention.query.bias.data.clone()
        layer.attention.self.key.weight.data = original_attention.key.weight.data.clone()
        layer.attention.self.key.bias.data = original_attention.key.bias.data.clone()
        layer.attention.self.value.weight.data = original_attention.value.weight.data.clone()
        layer.attention.self.value.bias.data = original_attention.value.bias.data.clone()

    return model


In [None]:
def clean_text(text):
    """清洗文本数据"""
    # 去除多余的空白字符
    text = ' '.join(text.split())
    # 基本的文本清理
    text = text.strip()

    # 去除特殊字符，但保留基本标点
    import re
    text = re.sub(r'[^\u4e00-\u9fff\u3000-\u303f\uff00-\uff65。，！？；：""''（）《》、]+', ' ', text)

    # 统一全角字符到半角
    text = text.replace('：', ':').replace('；', ';').replace('"', '"').replace('"', '"')
    text = text.replace('！', '!').replace('？', '?').replace('（', '(').replace('）', ')')

    # 去除多余的空格
    text = re.sub(r'\s+', ' ', text)
    # 检查清理后的文本长度
    if len(text) < 50:  # 过滤过短的文本
        return None

    return text

def load_corpus(max_texts=100000):
    texts = []
    with open('copus.txt', 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            # 清理文本
            cleaned_text = clean_text(line)
            if cleaned_text:
                segments = re.split(r'[。！？!?]+', cleaned_text)
                for segment in segments:
                    segment = segment.strip()
                    if len(segment) >= 50:  # 确保段落长度足够
                        texts.append(segment)

                        if len(texts) >= max_texts:
                            break

            if len(texts) >= max_texts:
                break

    # 对数据进行去重
    texts = list(set(texts))

    filtered_texts = []
    for text in texts:
        # 检查中文字符比例
        chinese_char_count = len(re.findall(r'[\u4e00-\u9fff]', text))
        total_char_count = len(text)
        if chinese_char_count / total_char_count >= 0.7:  # 确保中文字符占比至少70%
            filtered_texts.append(text)

    # 随机打乱数据顺序
    import random
    random.shuffle(filtered_texts)

    return filtered_texts[:max_texts]

In [4]:
class MLMDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.texts = texts
        self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        # 准备MLM输入
        inputs = self.tokenizer(
            self.texts[idx],
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_special_tokens_mask=True
        )

        input_ids = torch.tensor(inputs['input_ids'])
        attention_mask = torch.tensor(inputs['attention_mask'])
        special_tokens_mask = torch.tensor(inputs['special_tokens_mask'])

        # MLM概率为15%
        probability_matrix = torch.full(input_ids.shape, 0.15)
        probability_matrix.masked_fill_(special_tokens_mask.bool(), 0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels = input_ids.clone()

        # 80%替换为[MASK]
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        input_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10%替换为随机token
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        input_ids[indices_random] = random_words[indices_random]

        # 剩余10%保持不变
        labels[~masked_indices] = -100  # 忽略非mask的位置

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

In [None]:
def train_sparse_attention( ):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    """训练稀疏注意力层"""
    config = BertConfig.from_pretrained("huawei-noah/TinyBERT_4L_zh")
    config.model_type = "bert"
    bert_model = BertModel.from_pretrained("huawei-noah/TinyBERT_4L_zh",config=config)
    # 加载tokenizer
    tokenizer = BertTokenizer.from_pretrained("huawei-noah/TinyBERT_4L_zh")
    # 创建MLM模型
    mlm_model = AutoModelForMaskedLM.from_pretrained("huawei-noah/TinyBERT_4L_zh", config=config, trust_remote_code=True)
    # 将预训练模型的bert层替换稀疏注意力版本
    mlm_model.bert = replace_attention_layers(bert_model)
    mlm_model = mlm_model.to(device)

    # 冻结除了注意力层之外的所有参数
    for name, param in mlm_model.named_parameters():
        if 'attention.self' not in name:  # 只训练注意力层
           param.requires_grad = False

    # 只对需要训练的参数创建优化器
    optimizer = torch.optim.AdamW(
        [p for n, p in mlm_model.named_parameters() if 'attention.self' in n],
        lr=2e-4 
    )

    # 加载预训练数据
    pretrain_texts = load_corpus()
    pretrain_dataset = MLMDataset(pretrain_texts, tokenizer)
    pretrain_loader = DataLoader(pretrain_dataset, batch_size=32, shuffle=True)

    # 计算总训练步数
    num_training_steps = len(pretrain_loader) * 5  # 5个epoch
    num_warmup_steps = num_training_steps // 10  # 10%的步数用于warmup

    # 使用带warmup的学习率调度器
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    # 使用MLM loss进行预训练
    mlm_model.train()

    best_loss = float('inf')
    for epoch in range(5):  # 预训练5个epoch
        total_loss = 0
        progress_bar = tqdm(pretrain_loader, desc=f'Pretrain Epoch {epoch + 1}')

        for batch in progress_bar:
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = mlm_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            loss.backward()

            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(
                [p for n, p in mlm_model.named_parameters() if 'attention.self' in n],
                max_norm=1.0
            )

            optimizer.step()
            scheduler.step()  # 更新学习率

            total_loss += loss.item()
            # 显示当前学习率
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'lr': f'{scheduler.get_last_lr()[0]:.2e}'
            })

        avg_loss = total_loss / len(pretrain_loader)
        print(f'Epoch {epoch + 1}:')
        print(f'Average pretrain loss: {avg_loss:.4f}')
        print(f'Learning rate: {scheduler.get_last_lr()[0]:.2e}')

        # 保存最佳模型
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'mlm_model_state_dict': mlm_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),  # 保存调度器状态
                'loss': best_loss,
                'config': config,
            }, 'best_pretrain_model.pt')
            print(f'New best model saved with loss: {best_loss:.4f}')

    # 训练完成后，解冻所有层
    for param in mlm_model.parameters():
        param.requires_grad = True

    # 保存最终的预训练模型
    torch.save({
        'bert_state_dict': mlm_model.bert.state_dict(),
        'config': config,
        'final_loss': avg_loss,
        'scheduler_state_dict': scheduler.state_dict(),  # 保存最终的调度器状态
    }, 'final_pretrained_sparse_bert.pt')
    print('Final pretrained model saved')

    return mlm_model.bert

In [None]:
train_sparse_attention()