In [127]:
import torch 
from torch import nn
from torch.utils.data import DataLoader,Dataset

In [128]:
class CRFLayer(nn.Module):
    def __init__(self, n_tags, n_features):
        super().__init__()
        self.n_tags = n_tags
        self.n_features = n_features
        # 定义模型参数
        self.transitions = nn.Parameter(torch.empty(n_tags, n_tags))
        self.emission_weight = nn.Parameter(torch.empty(n_features, n_tags))
        self.start_transitions = nn.Parameter(torch.empty(n_tags))
        self.end_transitions = nn.Parameter(torch.empty(n_tags))
        self.reset_parameters()
        # 调试信息
        print(f"CRFLayer初始化: n_tags={n_tags}, n_features={n_features}")
        print(f"transitions shape: {self.transitions.shape}")
        
    # 使用（-0.1,0.1）之间的均匀分布初始化参数
    def reset_parameters(self):
        nn.init.uniform_(self.transitions, -0.1, 0.1)
        nn.init.uniform_(self.emission_weight, -0.1, 0.1)
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
    
    # 使用动态规划计算得分
    def compute_score(self, emissions, tags, masks):
        seq_len, batch_size, n_tags = emissions.shape
        # 调试信息
        print(f"compute_score: emissions.shape={emissions.shape}, tags.shape={tags.shape}")
        print(f"transitions.shape={self.transitions.shape}, n_tags={n_tags}")
        # 确保所有标签索引在有效范围内
        clamped_tags = torch.clamp(tags, 0, self.n_tags - 1)
        
        # 处理起始转移分数
        start_indices = torch.clamp(clamped_tags[0], 0, self.start_transitions.size(0) - 1)
        score = self.start_transitions[start_indices] + \
            emissions[0, torch.arange(batch_size), start_indices]
        
        for i in range(1, seq_len):
            # 确保转移矩阵的索引在有效范围内
            prev_tags = torch.clamp(clamped_tags[i-1], 0, self.transitions.size(0) - 1)
            curr_tags = torch.clamp(clamped_tags[i], 0, self.transitions.size(1) - 1)
            
            # 确保批次索引在有效范围内
            batch_indices = torch.clamp(torch.arange(batch_size), 0, emissions.size(1) - 1)
            
            score += self.transitions[prev_tags, curr_tags] * masks[i]
            score += emissions[i, batch_indices, curr_tags] * masks[i]
        
        # 处理结束转移分数
        seq_ends = masks.long().sum(dim=0) - 1
        seq_ends = torch.clamp(seq_ends, 0, seq_len - 1)
        batch_indices = torch.clamp(torch.arange(batch_size), 0, clamped_tags.size(1) - 1)
        last_tags = clamped_tags[seq_ends, batch_indices]
        last_tags = torch.clamp(last_tags, 0, self.end_transitions.size(0) - 1)
        score += self.end_transitions[last_tags]
        return score
    
    # 计算配分函数
    def computer_normalizer(self, emissions, masks):
        seq_len, batch_size, n_tags = emissions.shape
        # 确保起始转移索引在有效范围内
        start_indices = torch.clamp(torch.arange(min(n_tags, self.start_transitions.size(0))), 0, self.start_transitions.size(0) - 1)
        start_trans = self.start_transitions[start_indices]
        score = start_trans + emissions[0]
        
        for i in range(1, seq_len):
            # batch_size * n_tags * 1 [y_{i-1}为某tag的总分]
            broadcast_score = score.unsqueeze(2)
            # batch_size * 1 * n_tags [y_i为某标签的发射分数]
            broadcast_emissions = emissions[i].unsqueeze(1)
            # batch_size * n_tags * n_tags [任意y_{i-1}到y_i的总分]
            next_score = broadcast_score + self.transitions + broadcast_emissions
            # batch_size * n_tags [对y_{i-1}求和]
            next_score = torch.logsumexp(next_score, dim=1)
            # masks为True则更新，否则保留
            score = torch.where(masks[i].unsqueeze(1), next_score, score)
            
        # 确保结束转移索引在有效范围内
        end_indices = torch.clamp(torch.arange(min(n_tags, self.end_transitions.size(0))), 0, self.end_transitions.size(0) - 1)
        end_trans = self.end_transitions[end_indices]
        score += end_trans
        return torch.logsumexp(score, dim=1)
    
    def forward(self, features, tags, masks):
        """
        features: seq_len * batch_size * n_features
        tags/masks: seq_len * batch_size
        """
        _, batch_size, _ = features.size()
        emissions = torch.matmul(features, self.emission_weight)
        masks = masks.to(torch.bool)
        
        # 确保标签索引在模型定义的范围内
        tags = torch.clamp(tags, 0, self.n_tags - 1)
        
        score = self.compute_score(emissions, tags, masks)
        partition = self.computer_normalizer(emissions, masks)
        
        likelihood = score - partition
        return likelihood.sum() / batch_size
    
    def decode(self, features, masks):
        # 与computer_normalizer类似，sum变为max
        emissions = torch.matmul(features, self.emission_weight)
        masks = masks.to(torch.bool)
        
        seq_len, batch_size, n_tags = emissions.shape
        
        # 确保起始转移索引在有效范围内
        start_indices = torch.clamp(torch.arange(min(n_tags, self.start_transitions.size(0))), 0, self.start_transitions.size(0) - 1)
        start_trans = self.start_transitions[start_indices]
        score = start_trans + emissions[0]
        history = []
        
        for i in range(1, seq_len):
            broadcast_score = score.unsqueeze(2)
            broadcast_emission = emissions[i].unsqueeze(1)
            
            next_score = broadcast_score + self.transitions + broadcast_emission
            next_score, indices = next_score.max(dim=1)
            
            score = torch.where(masks[i].unsqueeze(1), next_score, score)
            history.append(indices)
            
        # 确保结束转移索引在有效范围内
        end_indices = torch.clamp(torch.arange(min(n_tags, self.end_transitions.size(0))), 0, self.end_transitions.size(0) - 1)
        end_trans = self.end_transitions[end_indices]
        score += end_trans
        
        seq_ends = masks.long().sum(dim=0) - 1
        best_tags_list = []
        
        for idx in range(batch_size):
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]
            
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())
                
            best_tags.reverse()
            best_tags_list.append(best_tags)
            
        return best_tags_list

In [129]:
class LSTM_CRF(nn.Module):
    def __init__(self,vocab_size,hidden_size,num_layers,dropout,n_tags):
        """
        参数介绍
        vocab_size: 词表大小
        hidden_size: 隐藏层大小
        num_layers: LSTM层数
        dropout: dropout概率
        n_tags: 标签数量
        """
        super(LSTM_CRF, self).__init__()
        #define embedding
        self.embedding=nn.Embedding(vocab_size,hidden_size)
        #define LSTM
        self.lstm=nn.LSTM(hidden_size=hidden_size,num_layers=num_layers,input_size=hidden_size,dropout=dropout,batch_first=False,bidirectional=True)
        self.crf=CRFLayer(n_tags=n_tags,n_features=hidden_size*2)
    def forward(self,input_ids,masks,labels):
        """
        参数解释
        input_ids:输入的id shaperge [batch_size,seq_len]
        masks:输入的mask
        labels:标签
        """
        seq_len=min(input_ids.size(1),masks.size(1),labels.size(1))
        input_ids = input_ids[:, :seq_len]
        masks = masks[:, :seq_len]
        labels = labels[:, :seq_len]
        embed=self.embedding(input_ids)
        embed=torch.transpose(embed,0,1)
        masks=torch.transpose(masks,0,1)
        labels=torch.transpose(labels,0,1)
        hidden_states,_=self.lstm(embed)
        llh=self.crf(hidden_states,labels,masks)
        return llh
    def decode(self,input_ids,masks):
        embed=self.embedding(input_ids)
        embed=torch.transpose(embed,0,1)
        masks=torch.transpose(masks,0,1)
        hidden_states,_=self.lstm(embed)
        return self.crf.decode(hidden_states,masks)

In [130]:
#准备测试数据
class NerDataset(Dataset):
    def __init__(self,sentences,labels,word_to_idx,tag_to_idx):
        self.sentences=sentences
        self.labels=labels
        self.word_to_idx=word_to_idx
        self.tag_to_idx=tag_to_idx
    def __len__(self):
        return len(self.sentences)
    # 在数据预处理时确保标签索引正确
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        label = self.labels[idx]
        
        # 将单词转换为索引
        sent_ids = [self.word_to_idx.get(word, 0) for word in sentence]
        # 确保标签索引在有效范围内
        label_ids = [self.tag_to_idx.get(tag, 0) for tag in label]
        
        # 创建mask
        mask = [1] * len(sent_ids)
        
        return (torch.tensor(sent_ids, dtype=torch.long), 
                torch.tensor(label_ids, dtype=torch.long), 
                torch.tensor(mask, dtype=torch.long))

In [131]:
# 示例数据
training_data = [
    ("the wall street journal reported today that apple corporation made money".split(),
     "B I I I O O O B I O".split()),
    ("georgia tech is a university in georgia".split(),
     "B I O O O O B".split()),
    ("the cpu speed is very fast".split(),
     "O O O O O O".split()),
    ("microsoft is located in washington".split(),
     "B O O O B".split())
]


In [132]:
#进行简单的数据处理
word_to_idx={'<UNK>':0}
tag_to_idx={'B':0,'I':1,'O':2}
for sentence ,tags in training_data:
    for word in sentence:
        if word not in word_to_idx:
            word_to_idx[word]=len(word_to_idx)
    for tag in tags:
        if tag not in tag_to_idx:
            tag_to_idx[tag]=len(tag_to_idx)


print("词汇表大小:", len(word_to_idx))
print("标签表:", tag_to_idx)
print("标签数量:", len(tag_to_idx)) 

词汇表大小: 25
标签表: {'B': 0, 'I': 1, 'O': 2}
标签数量: 3


In [133]:
#创建数据集
dataset=NerDataset(
    [s[0] for s in training_data],
    [s[1] for s in training_data],
    word_to_idx=word_to_idx,
    tag_to_idx=tag_to_idx
)
batch_size=1
train_data_loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [134]:
#设置模型参数
vocab_size=len(word_to_idx)
hidden_size=128
num_layers=1
dropout=0.1
n_tags=len(tag_to_idx)
print(n_tags)


3


In [135]:
#初始化模型
model=LSTM_CRF(
    vocab_size=vocab_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout=dropout,
    n_tags=n_tags
)

#定义优化器
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)


CRFLayer初始化: n_tags=3, n_features=256
transitions shape: torch.Size([3, 3])




In [136]:
#开始训练模型
print("\n开始训练...")
model.train()
for epoch in range(30):
    total_loss = 0
    num_batches = 0
    for batch_idx, batch in enumerate(train_data_loader):
        input_ids, labels, masks = batch
        
        # 确保标签在有效范围内
        labels = torch.clamp(labels, 0, n_tags - 1)
        
        try:
            loss = model(input_ids, masks, labels)
            optimizer.zero_grad()
            (-loss).backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
        except Exception as e:
            print(f"Batch {batch_idx} 出错: {e}")
            continue
    
    if epoch % 10 == 0 and num_batches > 0:
        print(f"Epoch {epoch}, Average Loss: {total_loss/num_batches}")

print("训练完成!")


开始训练...
compute_score: emissions.shape=torch.Size([7, 1, 3]), tags.shape=torch.Size([7, 1])
transitions.shape=torch.Size([3, 3]), n_tags=3
compute_score: emissions.shape=torch.Size([6, 1, 3]), tags.shape=torch.Size([6, 1])
transitions.shape=torch.Size([3, 3]), n_tags=3
compute_score: emissions.shape=torch.Size([10, 1, 3]), tags.shape=torch.Size([10, 1])
transitions.shape=torch.Size([3, 3]), n_tags=3
compute_score: emissions.shape=torch.Size([5, 1, 3]), tags.shape=torch.Size([5, 1])
transitions.shape=torch.Size([3, 3]), n_tags=3
Epoch 0, Average Loss: -7.062752366065979
compute_score: emissions.shape=torch.Size([7, 1, 3]), tags.shape=torch.Size([7, 1])
transitions.shape=torch.Size([3, 3]), n_tags=3
compute_score: emissions.shape=torch.Size([6, 1, 3]), tags.shape=torch.Size([6, 1])
transitions.shape=torch.Size([3, 3]), n_tags=3
compute_score: emissions.shape=torch.Size([10, 1, 3]), tags.shape=torch.Size([10, 1])
transitions.shape=torch.Size([3, 3]), n_tags=3
compute_score: emissions.sha

In [137]:
# 测试模型
model.eval()
test_sentences = [
    "the wall street journal reported today".split(),
    "georgia tech is a university".split()
]

print("\n模型预测结果:")
with torch.no_grad():
    for sentence in test_sentences:
        # 将句子转换为索引
        sent_ids = torch.tensor([[word_to_idx.get(word, 0) for word in sentence]])
        masks = torch.tensor([[1] * len(sentence)])
        
        # 获取预测结果
        predictions = model.decode(sent_ids, masks)
        
        print(f"输入句子: {' '.join(sentence)}")
        print(f"预测标签: {predictions[0]}")
        
        # 将标签索引转换为标签名称
        tag_names = list(tag_to_idx.keys())
        pred_tags = [tag_names[idx] for idx in predictions[0]]
        print(f"标签名称: {pred_tags}")
        print("-" * 50)

# 展示模型参数
print(f"模型参数:")
print(f"词汇表大小: {vocab_size}")
print(f"隐藏层大小: {hidden_size}")
print(f"LSTM层数: {num_layers}")
print(f"标签数量: {n_tags}")
print(f"词汇表: {list(word_to_idx.keys())[:10]}...")  # 显示前10个词汇
print(f"标签表: {list(tag_to_idx.keys())}")


模型预测结果:
输入句子: the wall street journal reported today
预测标签: [0, 1, 1, 1, 2, 2]
标签名称: ['B', 'I', 'I', 'I', 'O', 'O']
--------------------------------------------------
输入句子: georgia tech is a university
预测标签: [0, 1, 2, 2, 2]
标签名称: ['B', 'I', 'O', 'O', 'O']
--------------------------------------------------
模型参数:
词汇表大小: 25
隐藏层大小: 128
LSTM层数: 1
标签数量: 3
词汇表: ['<UNK>', 'the', 'wall', 'street', 'journal', 'reported', 'today', 'that', 'apple', 'corporation']...
标签表: ['B', 'I', 'O']
