In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt

In [2]:
# 2. 定义RNN模型
class TextRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, n_layers = 1, dropout = 0.5): #n_layers为LSTM的层数
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx = 0) #词向量嵌入，填充索引=0
        self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, batch_first=True, dropout = dropout if n_layers > 1 else 0) #batch_first=True批量优先
        #nn.LSTM的输出样式为(output, (hidden, cell))，其中output为所有时间步的隐藏状态，hidden为最后一个时间步的隐藏状态，cell为最后一个时间步的细胞状态
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, text):
        
        embedded = self.embedding(text)  # 函数会获取到[batch_size, seq_len, embed_dim]
        
        output, (hidden, cell) = self.rnn(embedded) #output: [batch_size, seq_len, hidden_dim];hidden: [n_layers, batch_size, hidden_dim]; cell: [n_layers, batch_size, hidden_dim]
        
        hidden = self.dropout(hidden[-1])  # 取最后一个隐藏层[batch_size, hidden_dim]
        
        return self.fc(hidden)  # 全连接层[batch_size, output_dim]

In [4]:
# 示例预测
def predict(text, model, vocab, label_to_idx, idx_to_label, max_seq_len, cat_dict,device):
    model.eval()
    with torch.no_grad():
        # 预处理文本
        indices = [vocab.get(char, 1) for char in list(text)]
        if len(indices) > max_seq_len:
            indices = indices[:max_seq_len]
        else:
            indices = indices + [0] * (max_seq_len - len(indices))

        tensor = torch.tensor(indices, dtype=torch.long).unsqueeze(0).to(device)  # 添加batch维度并移到GPU

        # 预测
        prediction = model(tensor)

        # 获取概率最高的类别
        pred_idx = prediction.argmax(1).item()
        pred_label = idx_to_label[pred_idx]

        # 获取概率
        probabilities = torch.softmax(prediction, dim=1).squeeze()
        pred_name = next((k for k, v in cat_dict.items() if v == pred_label), None)
        return pred_name, probabilities
    
device = torch.device('cuda')
# 超参数设置
batch_size = 64
embed_dim = 300
hidden_dim = 128
n_layers = 2
dropout = 0.5
n_epoch = 10
lr = 0.001
max_len = 50

    

net = TextRNN(
        vocab_size = 5002,
        embed_dim = 300,
        hidden_dim = 128,
        output_dim = 15,
        n_layers = 2,
        dropout = 0.5
    ).to(device) 
device = torch.device('cuda')
for i in range(5):
    # 测试预测
    sample_title = input("请输入新闻标题：")

    # 加载模型参数
    checkpoint = torch.load('text_rnn.pth', map_location=device)
    net.load_state_dict(checkpoint['model_state_dict'])
    net.eval()

    predicted_name, probabilities = predict(
        sample_title, 
        net, 
        checkpoint['vocab'], 
        checkpoint['label_to_idx'], 
        checkpoint['idx_to_label'], 
        checkpoint['max_len'], 
        checkpoint['cat_dict'], 
        device
    )
    print(f"\n预测示例:")
    print(f"标题: {sample_title}")
    print(f"预测类别: {predicted_name}\t")
    

请输入新闻标题：世俱杯大冷门！大巴黎不敌美洲冠军

预测示例:
标题: 世俱杯大冷门！大巴黎不敌美洲冠军
预测类别: news_sports	
请输入新闻标题：深圳宝安黄田荔枝进入采摘季 预计总产量500吨

预测示例:
标题: 深圳宝安黄田荔枝进入采摘季 预计总产量500吨
预测类别: news_agriculture	
请输入新闻标题：电影《炽热年华》热映 多维看点诠释女性力量勾勒时代群像

预测示例:
标题: 电影《炽热年华》热映 多维看点诠释女性力量勾勒时代群像
预测类别: news_entertainment	
请输入新闻标题：SpaceX星舰爆炸现场震撼，官方回应：人员安全，周边无影响

预测示例:
标题: SpaceX星舰爆炸现场震撼，官方回应：人员安全，周边无影响
预测类别: news_tech	
请输入新闻标题：以色列称已掌控德黑兰领空，伊朗目前的整体损失有多大？

预测示例:
标题: 以色列称已掌控德黑兰领空，伊朗目前的整体损失有多大？
预测类别: news_military	
