In [70]:
import numpy as np
import torch
import tensorflow.contrib.keras as kr
import torch.utils.data as Data

def get_labels():
    """
    获取所有类别标签，并转换成id字典形式
    """
    labels = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
    labels_dict = dict(zip(labels, range(len(labels))))
    return labels, labels_dict

def get_vocab(filename):
    """
    获取词表，并转换成ID字典形式
    """
    with open(filename, encoding='utf-8', errors='ignore') as file:
        word = [line.strip() for line in file.readlines()]
        word_dict = dict(zip(word, range(len(word))))
        
    return word, word_dict

def process_file(filename, word_dict, label_dict, maxlen):
    """
    从文件中获取数据并转换成id形式
    """
    context, labels = [],[]
    with open(filename, encoding='utf-8', errors='ignore') as file:
        for line in file:
            try:
                label, content = line.strip().split('\t')
                if content:
                    context.append(list(content))
                    labels.append(label)
            except:
                pass

    data_id, label_id = [], []
    for i in range(len(context)):
        data_id.append([word_dict[x] for x in context[i] if x in word_dict])#将每句话id化
        label_id.append(label_dict[labels[i]])#每句话对应的类别的id

    # # 使用keras提供的pad_sequences来将文本pad为固定长度
    x_pad = kr.preprocessing.sequence.pad_sequences(data_id, maxlen)
    y_pad = kr.utils.to_categorical(label_id, num_classes=len(label_dict))  # 将标签转换为one-hot表示
    return x_pad, y_pad

label, label_dict = get_labels()
word, word_dict = get_vocab('cnews.vocab.txt')

x_train, y_train = process_file('cnews.train.txt', word_dict, label_dict, 600)
x_val, y_val = process_file('cnews.val.txt', word_dict, label_dict, 600)
x_test, y_test = process_file('cnews.test.txt', word_dict, label_dict, 600)

BATCH_SIZE = 128

x_train, y_train = torch.LongTensor(x_train), torch.Tensor(y_train)
x_val, y_val = torch.LongTensor(x_val), torch.Tensor(y_val)

# train_dataset = Data.TensorDataset(x_train, y_train) 
# train_dataloader = Data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# val_dataset = Data.TensorDataset(x_val, y_val)
# val_dataloader = Data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)


def batch_iter(x, y, batch_size=64):
    """生成批次数据"""
    data_len = len(x)
    num_batch = int((data_len - 1) / batch_size) + 1
 
    indices = np.random.permutation(np.arange(data_len))
    x_shuffle = x[indices]
    y_shuffle = y[indices]
 
    for i in range(num_batch):
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]


In [72]:
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F

class TextRNN(nn.Module):
    def __init__(self):
        super(TextRNN, self).__init__()
        self.embedding = nn.Embedding(5000, 64)
        self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=1, bidirectional=False)
        self.f1 = nn.Sequential(nn.Linear(128*600, 128,bias=True),
                                
                                nn.ReLU(),
                                nn.Dropout(0.8)
                                )
        self.f2 = nn.Sequential(nn.Linear(128, 10,bias=True))
        
    def forward(self, x):
        x = self.embedding(x)
        
        
        x,_ = self.rnn(x)
       
        x=x.view(-1,128*600)
        
        
#         x=x[:,-1,:]
        x = self.f1(x)
        x = self.f2(x)
        
        return x
    
def train():
    model = TextRNN().cuda()
    Loss = nn.CrossEntropyLoss()#nn.MultiLabelSoftMarginLoss()
    optimer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimer,
                                                           mode='max', 
                                                           patience=5,
                                                           verbose=True,
                                                           min_lr=1.e-6)
    best_val_acc = 0
    
    
    total_iter=0
    for epoch in range(100):
        
        model.train()
        for step,(batch_x, batch_y) in enumerate(train_dataloader):
            
            
            
            total_iter+=1
            x = batch_x.cuda()
            y = batch_y.cuda().long()
            y_label= torch.argmax(y,1)
            
            out = model(x)
            loss = Loss(out, y_label)
            optimer.zero_grad()
            loss.backward()
            optimer.step()
            
            preds=torch.softmax(out,dim=1)
            acc_train = np.mean((torch.argmax(preds,1) == torch.argmax(y,1)).cpu().numpy())
            
            if step%100==0:
               
                print("epoch:{},step:{},acc_train:{},loss_train:{}".format(epoch,total_iter,acc_train,loss))
            
        acc_val_val=0
        
        
        model.eval()
        for step,(batch_x, batch_y) in enumerate(val_dataloader):
            
            x = batch_x.cuda()
            y = batch_y.cuda().long()
            y_label= torch.argmax(y,1)
            out = model(x)
            
            preds=torch.softmax(out,dim=1)
            acc_val = np.mean((torch.argmax(preds,1) == torch.argmax(y,1)).cpu().numpy())
            acc_val_val+=acc_val
            if acc_val > best_val_acc:
                torch.save(model.state_dict(), './model_params.pkl')
                best_val_acc = acc_val
        acc_val_=acc_val_val/(step+1)
        print("epoch:{},step:{},acc_val:{}".format(epoch,total_iter,acc_val_))
        scheduler.step(acc_val_)
train()





### may reach 90+

epoch:0,step:1,acc_train:0.0,loss_train:2.317220687866211
epoch:0,step:101,acc_train:0.34375,loss_train:1.9952791929244995
epoch:0,step:201,acc_train:0.3125,loss_train:1.7078511714935303
epoch:0,step:301,acc_train:0.46875,loss_train:1.44394052028656
epoch:0,step:401,acc_train:0.71875,loss_train:0.8115702271461487
epoch:0,step:501,acc_train:0.65625,loss_train:0.783383309841156
epoch:0,step:601,acc_train:0.8125,loss_train:0.6549707055091858
epoch:0,step:701,acc_train:0.9375,loss_train:0.36371129751205444
epoch:0,step:801,acc_train:0.875,loss_train:0.4315113127231598
epoch:0,step:901,acc_train:0.8125,loss_train:0.6004424691200256
epoch:0,step:1001,acc_train:0.75,loss_train:0.698325514793396
epoch:0,step:1101,acc_train:0.84375,loss_train:0.6217854022979736
epoch:0,step:1201,acc_train:0.65625,loss_train:0.9829463362693787
epoch:0,step:1301,acc_train:0.78125,loss_train:0.5777592062950134
epoch:0,step:1401,acc_train:0.875,loss_train:0.23794180154800415
epoch:0,step:1501,acc_train:0.90625,loss

epoch:7,step:11842,acc_train:0.9375,loss_train:0.1495477259159088
epoch:7,step:11942,acc_train:0.96875,loss_train:0.11055899411439896
epoch:7,step:12042,acc_train:0.96875,loss_train:0.07442628592252731
epoch:7,step:12142,acc_train:0.9375,loss_train:0.14182725548744202
epoch:7,step:12242,acc_train:0.9375,loss_train:0.0753379538655281
epoch:7,step:12342,acc_train:0.9375,loss_train:0.19853970408439636
epoch:7,step:12442,acc_train:1.0,loss_train:0.028598982840776443
epoch:7,step:12504,acc_val:0.8960987261146497
epoch:8,step:12505,acc_train:0.96875,loss_train:0.10463523864746094
epoch:8,step:12605,acc_train:0.90625,loss_train:0.581418514251709
epoch:8,step:12705,acc_train:1.0,loss_train:0.053128331899642944
epoch:8,step:12805,acc_train:0.9375,loss_train:0.09758573770523071
epoch:8,step:12905,acc_train:1.0,loss_train:0.022051610052585602
epoch:8,step:13005,acc_train:0.96875,loss_train:0.04223306477069855
epoch:8,step:13105,acc_train:0.9375,loss_train:0.13847634196281433
epoch:8,step:13205,ac

KeyboardInterrupt: 

In [None]:
# 测试

    
state_dict = torch.load('model_params.pkl')
model = TextRNN()
model.load_state_dict(state_dict)

test_demo = ['《时光重返四十二难》恶搞唐增取经一款时下最热门的动画人物：猪猪侠，加上创新的故事背景，震撼的操作快感，成就了这部恶搞新作，现正恶搞上市，玩家们抢先赶快体验快感吧。游戏简介：被时光隧道传送到208年的猪猪侠，必须经历六七四十二难的考验，才能借助柯伊诺尔大钻石的力量，开启时光隧道，重返2008年。在迷糊老师、菲菲公主的帮助下，猪猪侠接受了挑战，开始了这段充满了关心和情谊的旅程。    更多精彩震撼感觉，立即下载该款游戏尽情体验吧。玩家交流才是王道，讯易游戏玩家交流中心 QQ群：6306852-----------------生活要有激情，游戏要玩多彩(多彩游戏)。Colourfulgame (多彩游戏)，让你看看快乐游戏的颜色！精品推荐：1：《钟馗传》大战无头关羽，悲壮的剧情伴随各朝英灵反攻地府！2：《中华群英》将和赵云，项羽，岳飞等猛将作战，穿越各朝代抗击日寇。良品推荐：1：《赌王争霸之斗地主》易飞会在四角恋中会选择谁？是否最终成赌神呢？2：勇者后裔和魔王紧缠一起，前代恩怨《圣火伏魔录》将为您揭示一切。  3：颠覆传统概念，恶搞+非主流？！誓必弄死搞残为止《爆笑飞行棋》。4：《中国象棋残局大师》快棋和人机模式让畅快对弈！一切“多彩游戏”资讯，点击Colourfulgame官网http://www.colourfulgame.com一切“多彩游戏”感言，交流Colourfulgame论坛http://121.33.203.124/forum/【客服邮箱】：xunyiwangluo@126.com">xunyiwangluo@126.com">xunyiwangluo@126.com【客服热线】：020-87588437']

#for i in test_demo:
    #print(i,":",model.predict(i))
help(model.load_state_dict)

In [None]:
# 测试
import torch
import torch.nn as nn

class TextRNN(nn.Module):
    def __init__(self):
        super(TextRNN, self).__init__()
        self.embedding = nn.Embedding(5000, 64)
        self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=1, bidirectional=False)
#         self.f1 = nn.Sequential(nn.Linear(128, 128),
#                                 nn.Dropout(0.8),
#                                 nn.ReLU())
        self.f2 = nn.Sequential(nn.Linear(128, 10),
                                nn.Softmax())
        
    def forward(self, x):
        x = self.embedding(x)
        x,_ = self.rnn(x)
        x = F.dropout(x,p=0.8)
        x = self.f2(x[:,-1,:])
        #x = self.f2(x)
        return x
    
    

model = TextRNN()
torch.save(model.state_dict(), './model_params.pkl')
    

state_dict = torch.load('model_params.pkl')
print(state_dict)

model_new = TextRNN()
model_new.load_state_dict(state_dict)

# test_demo = ['《时光重返四十二难》恶搞唐增取经一款时下最热门的动画人物：猪猪侠，加上创新的故事背景，震撼的操作快感，成就了这部恶搞新作，现正恶搞上市，玩家们抢先赶快体验快感吧。游戏简介：被时光隧道传送到208年的猪猪侠，必须经历六七四十二难的考验，才能借助柯伊诺尔大钻石的力量，开启时光隧道，重返2008年。在迷糊老师、菲菲公主的帮助下，猪猪侠接受了挑战，开始了这段充满了关心和情谊的旅程。    更多精彩震撼感觉，立即下载该款游戏尽情体验吧。玩家交流才是王道，讯易游戏玩家交流中心 QQ群：6306852-----------------生活要有激情，游戏要玩多彩(多彩游戏)。Colourfulgame (多彩游戏)，让你看看快乐游戏的颜色！精品推荐：1：《钟馗传》大战无头关羽，悲壮的剧情伴随各朝英灵反攻地府！2：《中华群英》将和赵云，项羽，岳飞等猛将作战，穿越各朝代抗击日寇。良品推荐：1：《赌王争霸之斗地主》易飞会在四角恋中会选择谁？是否最终成赌神呢？2：勇者后裔和魔王紧缠一起，前代恩怨《圣火伏魔录》将为您揭示一切。  3：颠覆传统概念，恶搞+非主流？！誓必弄死搞残为止《爆笑飞行棋》。4：《中国象棋残局大师》快棋和人机模式让畅快对弈！一切“多彩游戏”资讯，点击Colourfulgame官网http://www.colourfulgame.com一切“多彩游戏”感言，交流Colourfulgame论坛http://121.33.203.124/forum/【客服邮箱】：xunyiwangluo@126.com">xunyiwangluo@126.com">xunyiwangluo@126.com【客服热线】：020-87588437']

# #for i in test_demo:
#     #print(i,":",model.predict(i))
# help(model.load_state_dict)


# torch.save(model.state_dict, './model_params.pkl')