### 实现流程：

#### 1. 读取原始数据集（文本集）

#### 2. 文本预处理
* **2.1 清理无用的标点符号**
* **2.2 根据 换行符 \n 分割**
* **2.3 单词 --> 索引 转换**
* **2.4 标签 --> 1， 0 转换**
* **2.5 清理文本太短以及过长的样本**
* **2.6 将单词映射为整型**
* **2.7 设定统一的文本长度，对整个文本数据中的每条评论进行填充或截断**

#### 3. 特征工程
* **3.1 array --> tensor**
* **3.2 将数据集分离成：train, val, test 三部分，比例是： 0.8, 0.1, 0.1**
* **3.3 通过DataLoader按批处理数据**

#### 4. 定义网络模型结构

#### 5. 定义超参数

#### 6. 定义训练函数（训练 + 验证）

#### 7. 定义测试函数

#### 8. 定义预测函数

----------------------------

* **[B站账号： 唐国梁Tommy]** <https://space.bilibili.com/474347248/channel/index>
* **代码+数据集下载，请查看我的B站个人简介**

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

### 1. 加载文本和标签数据

In [2]:
# 读取文本数据
with open("data/reviews.txt", 'r') as file:
    text = file.read()

In [3]:
len(text) # 共33678267个字符

33678267

In [4]:
type(text) # 类型

str

In [5]:
text[:10] # 显示前10个字符

'bromwell h'

In [6]:
# 读取标签数据
with open('data/labels.txt', 'r') as file:
    labels = file.read()

In [7]:
len(labels) # 共225000个字符

225000

In [8]:
type(labels) # 类型

str

In [9]:
labels[:10] # 显示前10个字符

'positive\nn'

### 2 数据 EDA

In [10]:
# 2.1 清理无用的标点符号
from string import punctuation

print("标点符号 : ", punctuation)

标点符号 :  !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~


In [11]:
clean_text = ''.join([char for char in text if char not in punctuation]) # 遍历文本中每一个字符，跳过标点符合

In [12]:
len(clean_text) # 新的文本字符个数

33351075

In [13]:
# 2.2 根据 换行符 \n 分割
clean_text = clean_text.split('\n')

In [14]:
len(clean_text)

25001

In [15]:
clean_text[0]

'bromwell high is a cartoon comedy  it ran at the same time as some other programs about school life  such as  teachers   my   years in the teaching profession lead me to believe that bromwell high  s satire is much closer to reality than is  teachers   the scramble to survive financially  the insightful students who can see right through their pathetic teachers  pomp  the pettiness of the whole situation  all remind me of the schools i knew and their students  when i saw the episode in which a student repeatedly tried to burn down the school  i immediately recalled          at           high  a classic line inspector i  m here to sack one of your teachers  student welcome to bromwell high  i expect that many adults of my age think that bromwell high is far fetched  what a pity that it isn  t   '

In [16]:
# 标签 根据 \n 分割
labels = labels.split('\n')

len(labels)

25001

In [17]:
labels[:5]

['positive', 'negative', 'positive', 'negative', 'positive']

In [18]:
# 2.3 字典： 单词 --> 索引

# 获取所有评论中的每个单词
words = [word.lower() for sentence in clean_text for word in sentence.split(' ')]

In [19]:
words[:10] # 显示前10个单词

['bromwell', 'high', 'is', 'a', 'cartoon', 'comedy', '', 'it', 'ran', 'at']

In [20]:
various_words = list(set(words)) # 筛选出所有评论中不同的单词

In [21]:
various_words.remove('') # 清理空字符

In [22]:
len(various_words) # 不同的单词个数

74072

In [23]:
# 创建字典，格式： 单词 ： 整数

int_word = dict(enumerate(various_words, 1))

In [24]:
int_word

{1: 'jgar',
 2: 'centrist',
 3: 'piteously',
 4: 'talmadges',
 5: 'confidential',
 6: 'flickerino',
 7: 'kleine',
 8: 'narrations',
 9: 'woodlanders',
 10: 'personal',
 11: 'hardback',
 12: 'verbatim',
 13: 'disputed',
 14: 'indiania',
 15: 'ney',
 16: 'relatives',
 17: 'tintin',
 18: 'dollying',
 19: 'fuher',
 20: 'diplomat',
 21: 'cork',
 22: 'screamer',
 23: 'basterds',
 24: 'retrieval',
 25: 'colours',
 26: 'ssst',
 27: 'emphasise',
 28: 'humdinger',
 29: 'pleaseee',
 30: 'adherents',
 31: 'worf',
 32: 'altro',
 33: 'uglying',
 34: 'nuovo',
 35: 'deepak',
 36: 'passions',
 37: 'alkie',
 38: 'notle',
 39: 'revisitation',
 40: 'abuses',
 41: 'inquisitive',
 42: 'guarded',
 43: 'pissing',
 44: 'tmtm',
 45: 'abovementioned',
 46: 'feroze',
 47: 'lair',
 48: 'garrard',
 49: 'consort',
 50: 'unquiet',
 51: 'jiggly',
 52: 'lockstock',
 53: 'ktla',
 54: 'mcconnell',
 55: 'yeon',
 56: 'maryln',
 57: 'emu',
 58: 'trim',
 59: 'kombat',
 60: 'mathematical',
 61: 'dogmas',
 62: 'glassed',
 63: 

In [25]:
# 字典，格式： 整数 ： 单词
word_int = {w:int(i) for i, w in int_word.items()}

In [26]:
word_int

{'jgar': 1,
 'centrist': 2,
 'piteously': 3,
 'talmadges': 4,
 'confidential': 5,
 'flickerino': 6,
 'kleine': 7,
 'narrations': 8,
 'woodlanders': 9,
 'personal': 10,
 'hardback': 11,
 'verbatim': 12,
 'disputed': 13,
 'indiania': 14,
 'ney': 15,
 'relatives': 16,
 'tintin': 17,
 'dollying': 18,
 'fuher': 19,
 'diplomat': 20,
 'cork': 21,
 'screamer': 22,
 'basterds': 23,
 'retrieval': 24,
 'colours': 25,
 'ssst': 26,
 'emphasise': 27,
 'humdinger': 28,
 'pleaseee': 29,
 'adherents': 30,
 'worf': 31,
 'altro': 32,
 'uglying': 33,
 'nuovo': 34,
 'deepak': 35,
 'passions': 36,
 'alkie': 37,
 'notle': 38,
 'revisitation': 39,
 'abuses': 40,
 'inquisitive': 41,
 'guarded': 42,
 'pissing': 43,
 'tmtm': 44,
 'abovementioned': 45,
 'feroze': 46,
 'lair': 47,
 'garrard': 48,
 'consort': 49,
 'unquiet': 50,
 'jiggly': 51,
 'lockstock': 52,
 'ktla': 53,
 'mcconnell': 54,
 'yeon': 55,
 'maryln': 56,
 'emu': 57,
 'trim': 58,
 'kombat': 59,
 'mathematical': 60,
 'dogmas': 61,
 'glassed': 62,
 'mic

In [27]:
# 2.4 标签 --> 1， 0 转换
# positive : 1,  negative : 0

label_int = np.array([1 if x == 'positive' else 0 for x in labels])

In [28]:
len(label_int)

25001

In [29]:
from collections import Counter

Counter(label_int)

Counter({1: 12500, 0: 12501})

In [30]:
# 2.5 清理文本太短以及过长的样本

# 统计文本中，每条评论的长度
sentence_length = [len(sentence.split()) for sentence in clean_text]

In [31]:
counts = Counter(sentence_length) # 统计不同长度的评论

In [32]:
# 最小评论长度
min_sen = min(sorted(counts.items()))

In [33]:
min_sen

(0, 1)

In [34]:
# 最大评论长度
max_sen = max(sorted(counts.items()))

In [35]:
max_sen

(2514, 1)

In [36]:
# 获取 min 和 max 对应的索引

min_index = [i for i, length in enumerate(sentence_length) if length == min_sen[0]]

max_index = [i for i, length in enumerate(sentence_length) if length == max_sen[0]]

In [37]:
min_index

[25000]

In [38]:
max_index

[3908]

In [39]:
# 根据索引删除文本中过短或过长的评论

new_text = np.delete(clean_text, min_index)

print("原始文本数量： ", len(clean_text))
print("新文本数量: ", len(new_text))

原始文本数量：  25001
新文本数量:  25000


In [40]:
new_text2 = np.delete(new_text, max_index)

print("原始文本数量： ", len(new_text))
print("新文本数量: ", len(new_text2))

原始文本数量：  25000
新文本数量:  24999


In [41]:
# 同样需要在标签集中根据索引删除对应的标签

new_labels = np.delete(label_int, min_index)

new_labels = np.delete(new_labels, max_index)

print("原始标签数量： ", len(label_int))
print("新标签数量： ", len(new_labels))

原始标签数量：  25001
新标签数量：  24999


In [42]:
new_text2[0]

'bromwell high is a cartoon comedy  it ran at the same time as some other programs about school life  such as  teachers   my   years in the teaching profession lead me to believe that bromwell high  s satire is much closer to reality than is  teachers   the scramble to survive financially  the insightful students who can see right through their pathetic teachers  pomp  the pettiness of the whole situation  all remind me of the schools i knew and their students  when i saw the episode in which a student repeatedly tried to burn down the school  i immediately recalled          at           high  a classic line inspector i  m here to sack one of your teachers  student welcome to bromwell high  i expect that many adults of my age think that bromwell high is far fetched  what a pity that it isn  t   '

In [43]:
# 2.6 将单词映射为整型

text_ints = []
for sentence in new_text2:
    sample = list()
    for word in sentence.split():
        int_value = word_int[word] # 获取到单词对应的键
        sample.append(int_value)
    text_ints.append(sample)

In [44]:
text_ints[0] # 第一条评论

[7267,
 40335,
 68363,
 51508,
 39596,
 39305,
 72091,
 68728,
 5556,
 4157,
 58431,
 40054,
 54028,
 12419,
 53777,
 1651,
 15038,
 21831,
 33552,
 47447,
 54028,
 51827,
 62758,
 65921,
 55270,
 4157,
 17051,
 32040,
 56213,
 20722,
 14740,
 65289,
 61837,
 7267,
 40335,
 64500,
 70199,
 68363,
 65395,
 47644,
 14740,
 40972,
 33569,
 68363,
 51827,
 4157,
 55629,
 14740,
 60871,
 23387,
 4157,
 73423,
 24863,
 60705,
 50611,
 3387,
 35849,
 47980,
 58377,
 7561,
 51827,
 64431,
 4157,
 10577,
 65668,
 4157,
 8857,
 66530,
 8278,
 53415,
 20722,
 65668,
 4157,
 73590,
 18926,
 8943,
 11623,
 58377,
 24863,
 2704,
 18926,
 49266,
 4157,
 6587,
 55270,
 69478,
 51508,
 22784,
 38252,
 223,
 14740,
 71943,
 11433,
 4157,
 21831,
 18926,
 4878,
 20563,
 5556,
 40335,
 51508,
 11612,
 67798,
 10312,
 18926,
 44888,
 60468,
 14740,
 24671,
 7909,
 65668,
 32948,
 51827,
 22784,
 58725,
 14740,
 7267,
 40335,
 18926,
 63100,
 61837,
 58629,
 34074,
 65668,
 62758,
 58401,
 278,
 61837,
 726

In [45]:
len(text_ints) # 总的评论数

24999

In [46]:
# 2.7 设定统一的文本长度，对整个文本数据中的每条评论进行填充或截断
# 设定每条评论固定长度为200个单词，不足的评论用0填充，超过的直接截断

def reset_text(text, seq_len):
    dataset = np.zeros((len(text), seq_len))
    for index, sentence in enumerate(text):
        if len(sentence) < seq_len:
            dataset[index, :len(sentence)] = sentence
        else:
            dataset[index, :] = sentence[:seq_len] # 截断
            
    return dataset

In [47]:
dataset = reset_text(text_ints, seq_len=200)

In [48]:
dataset.shape

(24999, 200)

In [49]:
dataset[0,:]

array([ 7267., 40335., 68363., 51508., 39596., 39305., 72091., 68728.,
        5556.,  4157., 58431., 40054., 54028., 12419., 53777.,  1651.,
       15038., 21831., 33552., 47447., 54028., 51827., 62758., 65921.,
       55270.,  4157., 17051., 32040., 56213., 20722., 14740., 65289.,
       61837.,  7267., 40335., 64500., 70199., 68363., 65395., 47644.,
       14740., 40972., 33569., 68363., 51827.,  4157., 55629., 14740.,
       60871., 23387.,  4157., 73423., 24863., 60705., 50611.,  3387.,
       35849., 47980., 58377.,  7561., 51827., 64431.,  4157., 10577.,
       65668.,  4157.,  8857., 66530.,  8278., 53415., 20722., 65668.,
        4157., 73590., 18926.,  8943., 11623., 58377., 24863.,  2704.,
       18926., 49266.,  4157.,  6587., 55270., 69478., 51508., 22784.,
       38252.,   223., 14740., 71943., 11433.,  4157., 21831., 18926.,
        4878., 20563.,  5556., 40335., 51508., 11612., 67798., 10312.,
       18926., 44888., 60468., 14740., 24671.,  7909., 65668., 32948.,
      

### 3 数据类型转换

In [50]:
type(dataset)

numpy.ndarray

In [51]:
type(label_int)

numpy.ndarray

In [52]:
import torch
import torch.nn as nn

# 3.1 数据类型转换
dataset_tensor = torch.from_numpy(dataset)
label_tensor = torch.from_numpy(new_labels)

In [53]:
dataset_tensor.shape

torch.Size([24999, 200])

In [54]:
label_tensor.shape

torch.Size([24999])

In [55]:
# 3.2 数据分割，train, val, test

# 总样本数
all_samples = len(dataset_tensor)
print("总样本数：",all_samples)

# 设置比例
ratio = 0.8
train_size = int(all_samples * 0.8) # 训练样本数
print("训练样本数：",train_size)

rest_size = all_samples - train_size # 剩余样本数

val_size = int(rest_size * 0.5) # 验证样本数
print("验证样本数：", val_size)

test_size = int(rest_size * 0.5) # 测试样本数
print("测试样本数：", test_size)

总样本数： 24999
训练样本数： 19999
验证样本数： 2500
测试样本数： 2500


In [56]:
# 获取train, val, test 样本

# train
train = dataset_tensor[:train_size]
train_labels = label_tensor[:train_size]

In [57]:
train.shape

torch.Size([19999, 200])

In [58]:
train_labels.shape

torch.Size([19999])

In [59]:
# 剩余样本
rest_samples = dataset_tensor[train_size:]
rest_labels = label_tensor[train_size:]

In [60]:
# val
val = rest_samples[:val_size]
val_labels = rest_labels[:val_size]

In [61]:
val.shape

torch.Size([2500, 200])

In [62]:
val_labels.shape

torch.Size([2500])

In [63]:
# test
test = rest_samples[val_size:]
test_labels = rest_labels[val_size:]

In [64]:
test.shape

torch.Size([2500, 200])

In [65]:
test_labels.shape

torch.Size([2500])

In [66]:
# 3.3 通过DataLoader按批处理数据
from torch.utils.data import TensorDataset, DataLoader

# 对数据进行封装：(评论，标签)
train_dataset = TensorDataset(train, train_labels)
val_dataset = TensorDataset(val, val_labels)
test_dataset = TensorDataset(test, test_labels)

batch_size = 128
# 批处理
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)

In [67]:
# 获取train中的一批数据
data, label = next(iter(train_loader))

In [68]:
data.shape

torch.Size([128, 200])

In [69]:
label.shape

torch.Size([128])

In [70]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

### 4. 定义网络模型结构

In [71]:
class sentiment(nn.Module):
    def __init__(self, input_size, embedding_dim, hidden_dim, output_size, num_layers, dropout=0.5):
        super(sentiment, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.output_size = output_size
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(input_size, embedding_dim) # 词嵌入层
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)
        self.linear = nn.Linear(hidden_dim, output_size)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, hidden):
        '''
        x shape : (batch_size, seq_len, features)
        
        '''
        batch_size = x.size(0) # 获取batch_size
        x = x.long() # 类型转换
        #print('x shape : ', x.shape) # torch.Size([128, 200])
        embeds = self.embedding(x) # 词嵌入表示 
        #print('embeds shape : ', embeds.shape) # torch.Size([128, 200, 300])
        out, hidden = self.lstm(embeds, hidden) # lstm out shape : (batch_size, seq_len, hidden_dim)
        #print('out_1 shape : ', out.shape) # torch.Size([128, 200, 256])
        #print('hidden_0 shape : ', hidden[0].shape) # torch.Size([2, 128, 256])
        #print('hidden_1 shape : ', hidden[1].shape) # torch.Size([2, 128, 256])
        out = out.reshape(-1, self.hidden_dim) # （batch_size * seq_len, hidden_dim）
        #print('out_2 shape : ', out.shape) # torch.Size([25600, 256])
        out = self.linear(out) # 全连接层 
        #print('out_3 shape : ', out.shape) # torch.Size([25600, 1])
        sigmoid_out = self.sigmoid(out) #
        #print('sigmoid_out_1 shape : ', sigmoid_out.shape) # torch.Size([25600, 1])
        sigmoid_out = sigmoid_out.reshape(batch_size, -1)
        #print('sigmoid_out_2 shape : ', sigmoid_out.shape) # torch.Size([128, 200])
        sigmoid_out = sigmoid_out[:, -1] # 获取最后一批的标签
        #print('sigmoid_out_3 shape : ', sigmoid_out.shape) # torch.Size([128])
        return sigmoid_out, hidden
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        #print("weghit :", weight.shape) # torch.Size([74073, 300])
        hidden = (weight.new(self.num_layers, batch_size, self.hidden_dim).zero_().to(device),
                weight.new(self.num_layers, batch_size, self.hidden_dim).zero_().to(device))
        return hidden

In [72]:
# 初始化超参数
input_size = len(word_int) + 1 # 输入（不同的单词个数）
output_size = 1 # 输出
embedding_dim = 400 # 词嵌入维度
hidden_dim = 128 # 隐藏层节点个数
num_layers = 2 # lstm的层数

In [73]:
# 创建模型
model = sentiment(input_size, embedding_dim, hidden_dim, output_size, num_layers)

model

sentiment(
  (embedding): Embedding(74073, 400)
  (lstm): LSTM(400, 128, num_layers=2, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [74]:
criterion = torch.nn.BCELoss() # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # 优化器
num_epochs = 50 # 循环次数

In [75]:
model = model.to(device)

In [76]:
# 定义训练模型
def train(model, device, data_loader, criterion, optimizer, num_epochs, val_loader):
    history = list()
    for epoch in range(num_epochs):
        hs = model.init_hidden(batch_size)
        train_loss = []
        train_correct = 0.0
        model.train()
        for data, target in data_loader:
            data = data.to(device) # 部署到device
            target = target.to(device)
            optimizer.zero_grad() # 梯度置零
            output, hs = model(data, hs) # 模型训练
            hs = tuple([h.data for h in hs])
            #print('output shape : ', output.shape) # torch.Size([128])
            loss = criterion(output, target.float()) # 计算损失
            train_loss.append(loss.item()) # 累计损失
            loss.backward() # 反向传播
            optimizer.step() # 参数更新
            train_correct += torch.sum(output==target) # 比较
            
        # 模型验证
        model.eval()
        hs = model.init_hidden(batch_size)
        val_loss = []
        val_correct = 0.0
        with torch.no_grad():
            for data, target in val_loader:
                data = data.to(device)
                target = target.to(device)
                preds, hs = model(data, hs) # 验证
                hs = tuple([h.data for h in hs])
                loss = criterion(preds, target.float()) # 计算损失
                val_loss.append(loss.item()) # 累计损失
                val_correct += torch.sum(preds==target) # 比较
#             history['val_loss'].append(np.mean(val_loss))
#             history['val_correct'].append(np.mean(val_correct))
#         history['train_loss'].append(np.mean(train_loss))
#         history['train_correct'].append(np.mean(train_correct))
        print(f'Epoch {epoch}/{num_epochs} --- train loss {np.round(np.mean(train_loss), 5)} --- val loss {np.round(np.mean(val_loss),5)}')

In [77]:
train(model, device, train_loader, criterion, optimizer, num_epochs, val_loader)

Epoch 0/50 --- train loss 0.69579 --- val loss 0.69153
Epoch 1/50 --- train loss 0.65006 --- val loss 0.61376
Epoch 2/50 --- train loss 0.50846 --- val loss 0.63338
Epoch 3/50 --- train loss 0.41366 --- val loss 0.573
Epoch 4/50 --- train loss 0.31818 --- val loss 0.55107
Epoch 5/50 --- train loss 0.23296 --- val loss 0.58043
Epoch 6/50 --- train loss 0.19104 --- val loss 0.73331
Epoch 7/50 --- train loss 0.16312 --- val loss 0.69132
Epoch 8/50 --- train loss 0.1371 --- val loss 0.81286
Epoch 9/50 --- train loss 0.12594 --- val loss 0.80081
Epoch 10/50 --- train loss 0.11844 --- val loss 0.89682
Epoch 11/50 --- train loss 0.12253 --- val loss 0.82829
Epoch 12/50 --- train loss 0.12665 --- val loss 0.81739
Epoch 13/50 --- train loss 0.10503 --- val loss 1.06689
Epoch 14/50 --- train loss 0.11575 --- val loss 1.04593
Epoch 15/50 --- train loss 0.09994 --- val loss 0.77949
Epoch 16/50 --- train loss 0.10297 --- val loss 1.07149
Epoch 17/50 --- train loss 0.09307 --- val loss 1.03334
Epoch

In [78]:
# 测试

def test(model, data_loader, device, criterion):
    test_losses = []
    num_correct = 0
    # 初始化隐藏状态
    hs = model.init_hidden(batch_size)
    model.eval()
    for i, dataset in enumerate(data_loader):
        data = dataset[0].to(device) # 部署到device
        target = dataset[1].to(device)
        output, hs = model(data, hs) # 测试
        loss = criterion(output, target.float()) # 计算损失
        pred = torch.round(output) # 将预测值进行四舍五入，转换为0 或 1
        test_losses.append(loss.item()) # 保存损失
        correct_tensor = pred.eq(target.float().view_as(pred)) # 返回一堆True 或 False
        correct = correct_tensor.cpu().numpy()
        result = np.sum(correct)
        num_correct += result
        #print("num correct : ", num_correct)
        print(f'Batch {i}')
        print(f'loss : {np.round(np.mean(loss.item()), 3)}')
        print(f'accuracy : {np.round(result / len(data), 3) * 100} %')
        print()
    print("总的测试损失 test loss : {:.2f}".format(np.mean(test_losses)))
    print("总的测试准确率 test accuracy : {:.2f}".format(np.mean(num_correct / len(data_loader.dataset))))

In [79]:
test(model, test_loader, device, criterion)

Batch 0
loss : 1.749
accuracy : 68.0 %

Batch 1
loss : 1.57
accuracy : 71.1 %

Batch 2
loss : 1.489
accuracy : 72.7 %

Batch 3
loss : 1.456
accuracy : 75.0 %

Batch 4
loss : 1.201
accuracy : 77.3 %

Batch 5
loss : 1.432
accuracy : 73.4 %

Batch 6
loss : 1.409
accuracy : 74.2 %

Batch 7
loss : 1.507
accuracy : 71.1 %

Batch 8
loss : 1.626
accuracy : 70.3 %

Batch 9
loss : 1.331
accuracy : 77.3 %

Batch 10
loss : 1.654
accuracy : 71.89999999999999 %

Batch 11
loss : 1.355
accuracy : 75.0 %

Batch 12
loss : 1.45
accuracy : 72.7 %

Batch 13
loss : 1.737
accuracy : 67.2 %

Batch 14
loss : 1.577
accuracy : 72.7 %

Batch 15
loss : 2.027
accuracy : 65.60000000000001 %

Batch 16
loss : 1.747
accuracy : 67.2 %

Batch 17
loss : 1.482
accuracy : 75.8 %

Batch 18
loss : 1.498
accuracy : 69.5 %

总的测试损失 test loss : 1.54
总的测试准确率 test accuracy : 0.70


### 预测（测试）

In [80]:
# 案例1
text = 'this movie is so amazing. the plot is attractive. and I really like it.'

In [81]:
# 第一步：文本转索引（整数）
from string import punctuation

def converts(text):
    # 去除标点符号
    new_text = ''.join([char for char in text if char not in punctuation])
    print("new text :\n", new_text)
    # 文本映射为索引
    text_ints = [word_int[word.lower()] for word in new_text.split()]
    print("文本映射为索引：\n", text_ints)
    return text_ints

In [82]:
text_ints = converts(text)

new text :
 this movie is so amazing the plot is attractive and I really like it
文本映射为索引：
 [12542, 26331, 68363, 55476, 13980, 4157, 3491, 68363, 25499, 11623, 18926, 8585, 69980, 72091]


In [83]:
text_ints

[12542,
 26331,
 68363,
 55476,
 13980,
 4157,
 3491,
 68363,
 25499,
 11623,
 18926,
 8585,
 69980,
 72091]

In [84]:
# 文本对齐，sequence_length = 200
new_text_ints = reset_text([text_ints], seq_len=200) # 注意这里要添加一个[]，因为，reset_text处理的二维数据

In [85]:
new_text_ints

array([[12542., 26331., 68363., 55476., 13980.,  4157.,  3491., 68363.,
        25499., 11623., 18926.,  8585., 69980., 72091.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.,     0.,     0.,     0.,     0.,     0.,     0., 

In [86]:
new_text_ints.shape

(1, 200)

In [87]:
# numpy --> tensor
text_tensor = torch.from_numpy(new_text_ints)

print(text_tensor.shape)

torch.Size([1, 200])


In [88]:
# 定义预测函数
def predict(model, text_tensor, device):
    batch_size = text_tensor.size(0) # 这里是1
    hs = model.init_hidden(batch_size) # 初始化隐藏状态
    text_tensor = text_tensor.to(device)
    pred, hs = model(text_tensor, hs) # 判断
    print("概率值：", pred.item())
    # 将pred概率值转换为0或1
    pred = torch.round(pred)
    print("类别值：", pred.item())
    # 判断
    if pred.data == 1:
        print("评论正面")
    else:
        print("评论反面")

In [89]:
predict(model, text_tensor, device)

概率值： 0.839598536491394
类别值： 1.0
评论正面
