In [1]:
%load_ext autoreload
%autoreload 2

import os
# 确保当前工作目录是项目根目录
os.chdir('Chinese-Text-Classification-Pytorch')

In [3]:
import torch
import numpy as np
from importlib import import_module
import pickle as pkl

# 设置模型配置
dataset = 'THUCNews'
embedding = 'embedding_SougouNews.npz'
model_name = 'TextCNN'

# 动态导入模型配置和定义
x = import_module('models.' + model_name)
config = x.Config(dataset, embedding)

# 加载词汇表
vocab_path = os.path.join(dataset, 'data', 'vocab.pkl')
with open(vocab_path, 'rb') as f:
    vocab = pkl.load(f)

# 初始化模型
config.n_vocab = len(vocab)
model = x.Model(config).to(config.device)

# 检查是否有可用的GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# 加载训练好的模型参数
model_path = 'logs/log_3/model.ckpt'
model.load_state_dict(torch.load(model_path, map_location=device))

# 设置模型为评估模式
model.eval()


Model(
  (embedding): Embedding(4762, 300)
  (convs): ModuleList(
    (0): Conv2d(1, 256, kernel_size=(2, 300), stride=(1, 1))
    (1): Conv2d(1, 256, kernel_size=(3, 300), stride=(1, 1))
    (2): Conv2d(1, 256, kernel_size=(4, 300), stride=(1, 1))
  )
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=768, out_features=10, bias=True)
)

In [6]:
def build_dataset(config, use_word):
    if use_word:
        tokenizer = lambda x: x.split(' ')  # 以空格隔开，word-level
    else:
        tokenizer = lambda x: [y for y in x]  # char-level
    return tokenizer

# 处理单个标题
def process_title(title, tokenizer, vocab, pad_size):
    words_line = []
    token = tokenizer(title)
    seq_len = len(token)
    if pad_size:
        if len(token) < pad_size:
            token.extend([PAD] * (pad_size - len(token)))
        else:
            token = token[:pad_size]
            seq_len = pad_size
    # word to id
    for word in token:
        words_line.append(vocab.get(word, vocab.get(UNK)))
    return [words_line], [seq_len]

# 特殊标记
UNK, PAD = '<UNK>', '<PAD>'

In [8]:

# 设置标题和分词器
title = "东5环海棠公社230-290平2居准现房98折优惠"
tokenizer = build_dataset(config, use_word=False)

# 处理标题
title_seq, title_len = process_title(title, tokenizer, vocab, config.pad_size)

# 转换为Tensor
title_seq = torch.LongTensor(title_seq).to(device)
title_len = torch.LongTensor(title_len).to(device)

# 进行预测
with torch.no_grad():
    outputs = model((title_seq, title_len))
    predicted = torch.max(outputs.data, 1)[1].cpu().numpy()

# 输出预测结果
class_list = [x.strip() for x in open(os.path.join(dataset, 'data', 'class.txt')).readlines()]
predicted_class = class_list[predicted[0]]
print(f'预测类别: {predicted_class}')

预测类别: realty
