In [2]:
import os
import time
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np
import pandas as pd
import sklearn


dir_all_data='data/train.tsv'

#超参数设置
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
#从文件中读取数据
data_all = pd.read_csv(dir_all_data, sep='\t')
print(data_all.shape)    #(156060, 4)
print(data_all.keys())   #['PhraseId', 'SentenceId', 'Phrase', 'Sentiment']
print(data_all.head())

(156060, 4)
Index(['PhraseId', 'SentenceId', 'Phrase', 'Sentiment'], dtype='object')
   PhraseId  SentenceId                                             Phrase  \
0         1           1  A series of escapades demonstrating the adage ...   
1         2           1  A series of escapades demonstrating the adage ...   
2         3           1                                           A series   
3         4           1                                                  A   
4         5           1                                             series   

   Sentiment  
0          1  
1          2  
2          2  
3          2  
4          2  


In [4]:
#shuffle、划分验证集、测试集,并保存
idx = np.arange(data_all.shape[0])
seed = 0
np.random.seed(seed)
np.random.shuffle(idx)  

train_size = int(len(idx) * 0.6)
test_size = int(len(idx) * 0.8)

data_all.iloc[idx[:train_size], :].to_csv('data/task2_train.csv', index=False)
data_all.iloc[idx[train_size:test_size], :].to_csv("data/task2_test.csv", index=False)
data_all.iloc[idx[test_size:], :].to_csv("data/task2_dev.csv", index=False)

In [5]:
#使用Torchtext采用声明式方法加载数据
#参考https://blog.csdn.net/JWoswin/article/details/92821752
from torchtext import data
PAD_TOKEN = '<pad>'
TEXT = data.Field(sequential=True, batch_first=True, lower=True, pad_token=PAD_TOKEN)
LABEL = data.Field(sequential=False, batch_first=True, unk_token=None)

In [6]:
#读取数据
datafields = [("PhraseId", None), # 不需要的filed设置为None
              ("SentenceId", None),
              ('Phrase', TEXT),
              ('Sentiment', LABEL)]
train_data = data.TabularDataset(path='data/task2_train.csv', format='csv', fields=datafields)
dev_data = data.TabularDataset(path='data/task2_dev.csv', format='csv', fields=datafields)
test_data = data.TabularDataset(path='data/task2_test.csv', format='csv', fields=datafields)

In [7]:
#构建词典，字符映射到embedding
#TEXT.vocab.vectors 就是词向量
TEXT.build_vocab(train_data, vectors='glove.6B.50d', 
                 unk_init= lambda x:torch.nn.init.uniform_(x, a=-0.25, b=0.25))
LABEL.build_vocab(train_data)
#得到索引，PAD_TOKEN='<pad>'
PAD_INDEX = TEXT.vocab.stoi[PAD_TOKEN]
TEXT.vocab.vectors[PAD_INDEX] = 0.0

In [8]:
print(TEXT.vocab.itos[1510])
print(TEXT.vocab.stoi['bore'])
# 词向量矩阵: TEXT.vocab.vectors
print(TEXT.vocab.vectors.shape)
word_vec = TEXT.vocab.vectors[TEXT.vocab.stoi['bore']]
print(word_vec.shape)
print(word_vec)

succeeds
1486
torch.Size([16473, 50])
torch.Size([50])
tensor([ 0.7493,  0.7730,  0.5915, -0.3801,  0.4761,  1.3279,  0.3476,  0.0737,
        -0.0291, -0.2731, -0.3928, -0.1822, -0.0110, -0.3036, -0.5352, -0.4523,
        -0.8613, -0.0940, -0.3921, -0.3335, -0.6319, -0.2460,  0.3667, -0.9392,
         0.3502, -0.9397, -1.1096,  0.8062,  0.5669, -0.3130,  1.5001, -0.1960,
         0.3081,  0.1727,  0.5624,  0.2619,  0.4756, -0.5688, -0.5013,  0.1903,
         0.0685, -0.0869, -0.1641, -0.2432,  0.3557, -0.1629, -0.1993, -0.1561,
         0.3508, -0.9423])


In [9]:
word_vec = TEXT.vocab.vectors[TEXT.vocab.stoi['<pad>']]
print(word_vec.shape)
print(word_vec)

torch.Size([50])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])


In [10]:
#构建迭代器
train_iterator = data.BucketIterator(train_data, batch_size=BATCH_SIZE, train=True, shuffle=True, device=DEVICE)
dev_iterator = data.Iterator(dev_data, batch_size=BATCH_SIZE, train=False, sort=False, device=DEVICE) #batch_size应该为len(dev_data) 
test_iterator = data.Iterator(test_data, batch_size=BATCH_SIZE, train=False, sort=False, device=DEVICE)# 在 test_iter , sort一定要设置成 False, 要不然会被 torchtext 搞乱样本顺序

In [11]:
embedding_choice = 'glove'   #  'static'    'non-static'
num_embeddings = len(TEXT.vocab)
embedding_dim = 50
dropout_p = 0.5
hidden_size = 50  #隐藏单元数
num_layers = 2  #层数
vocab_size = len(TEXT.vocab)
label_num = len(LABEL.vocab)
print(vocab_size, label_num)

16473 6


In [28]:
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.embedding_choice = embedding_choice        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        if self.embedding_choice == 'rand':
            self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        if self.embedding_choice == 'glove':
            self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx = PAD_INDEX) \
                                .from_pretrained(TEXT.vocab.vectors, freeze=True)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers,
                            batch_first=True, dropout=dropout_p, bidirectional=True)
        self.dropout = nn.Dropout(dropout_p)    
        self.fc = nn.Linear(hidden_size * 2, label_num)  # 2 for bidirection
        
    def forward(self, x): # (Batch_size, Length) 
        # h_n (num_layers * num_directions, batch, hidden_size) 注意第一维不是Batch_size
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(DEVICE) 
        # c_n (num_layers * num_directions, batch, hidden_size): 
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(DEVICE)
        x = self.embedding(x) #(Batch_size,  Length, Dimention) 
        out, _ = self.lstm(x, (h0, c0)) # (batch_size, Length, hidden_size * 2)  
        out = self.dropout(out)
        out = torch.cat((out[:,0,self.hidden_size:], out[:,-1,:self.hidden_size]), dim=1)
        out = self.fc(out) # (batch_size, label_num)  
        return out 

In [29]:
#构建模型
model = LSTM()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)#创建优化器SGD
criterion = nn.CrossEntropyLoss()   #损失函数
model.to(DEVICE)

LSTM(
  (embedding): Embedding(16473, 50)
  (lstm): LSTM(50, 50, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=100, out_features=6, bias=True)
)

In [32]:
#开始训练
epoch = 1
best_accuracy = 0.0
start_time = time.time()

for i in range(epoch):
    model.train()
    total_loss = 0.0
    accuracy = 0.0
    total_correct = 0.0
    total_data_num = len(train_iterator.dataset)
    steps = 0.0
    for batch in train_iterator:
        steps += 1
        optimizer.zero_grad() #  梯度缓存清零
        batch_text = batch.Phrase
        batch_label = batch.Sentiment
        out = model(batch_text)    #[batch_size, label_num]
        loss = criterion(out, batch_label)
        total_loss += loss.item() 
        loss.backward()
        optimizer.step()        
        correct = (torch.max(out, dim=1)[1] == batch_label).sum()
        total_correct += correct.item()
        if steps % 100 == 0:
            print("Epoch %d_%.3f%%: Training average Loss: %f" 
                  % (i, steps * train_iterator.batch_size * 100 / len(train_iterator.dataset), total_loss / steps))  
    #每个epoch都验证一下
    model.eval()
    total_loss = 0.0
    accuracy = 0.0
    total_correct = 0.0
    total_data_num = len(dev_iterator.dataset)
    steps = 0.0    
    for batch in dev_iterator:
        steps += 1
        batch_text = batch.Phrase
        batch_label = batch.Sentiment
        out = model(batch_text)
        loss = criterion(out, batch_label)
        total_loss += loss.item()
        correct = (torch.max(out, dim=1)[1] == batch_label).sum()
        total_correct += correct.item()
        print("Epoch %d :  Verification average Loss: %f, Verification accuracy: %f%%, Total Time:%f"
          %(i, total_loss / steps, total_correct * 100 / total_data_num, time.time() - start_time))  
    if best_accuracy < total_correct / total_data_num:
        best_accuracy = total_correct / total_data_num 
        torch.save(model, 'model_saved/epoch_%d_accuracy_%f' % (i, total_correct / total_data_num))
        print('Model is saved in model_saved/epoch_%d_accuracy_%f' % (i, total_correct / total_data_num))
    #推荐使用 torch.save(net.state_dict(),path)  net.load_state_dict(torch.load(path)):

Epoch 0_3.417%: Training average Loss: 1.202789
Epoch 0_6.835%: Training average Loss: 1.175800
Epoch 0_10.252%: Training average Loss: 1.154274
Epoch 0_13.670%: Training average Loss: 1.141255
Epoch 0_17.087%: Training average Loss: 1.125114
Epoch 0_20.505%: Training average Loss: 1.114919
Epoch 0_23.922%: Training average Loss: 1.105873
Epoch 0_27.340%: Training average Loss: 1.100409
Epoch 0_30.757%: Training average Loss: 1.096192
Epoch 0_34.175%: Training average Loss: 1.092510
Epoch 0_37.592%: Training average Loss: 1.084081
Epoch 0_41.009%: Training average Loss: 1.081887
Epoch 0_44.427%: Training average Loss: 1.080616
Epoch 0_47.844%: Training average Loss: 1.077013
Epoch 0_51.262%: Training average Loss: 1.074782
Epoch 0_54.679%: Training average Loss: 1.071123
Epoch 0_58.097%: Training average Loss: 1.069022
Epoch 0_61.514%: Training average Loss: 1.067044
Epoch 0_64.932%: Training average Loss: 1.064643
Epoch 0_68.349%: Training average Loss: 1.061310
Epoch 0_71.767%: Train

Epoch 0 :  Verification average Loss: 0.950529, Verification accuracy: 4.110467%, Total Time:61.630423
Epoch 0 :  Verification average Loss: 0.953767, Verification accuracy: 4.158524%, Total Time:61.636420
Epoch 0 :  Verification average Loss: 0.954569, Verification accuracy: 4.203377%, Total Time:61.645416
Epoch 0 :  Verification average Loss: 0.953856, Verification accuracy: 4.257841%, Total Time:61.651412
Epoch 0 :  Verification average Loss: 0.952905, Verification accuracy: 4.321917%, Total Time:61.659422
Epoch 0 :  Verification average Loss: 0.948841, Verification accuracy: 4.398808%, Total Time:61.666403
Epoch 0 :  Verification average Loss: 0.953411, Verification accuracy: 4.446865%, Total Time:61.674398
Epoch 0 :  Verification average Loss: 0.955123, Verification accuracy: 4.507737%, Total Time:61.680395
Epoch 0 :  Verification average Loss: 0.957554, Verification accuracy: 4.562202%, Total Time:61.686391
Epoch 0 :  Verification average Loss: 0.957296, Verification accuracy: 4.

Epoch 0 :  Verification average Loss: 0.969933, Verification accuracy: 8.858488%, Total Time:62.185537
Epoch 0 :  Verification average Loss: 0.969531, Verification accuracy: 8.919361%, Total Time:62.185537
Epoch 0 :  Verification average Loss: 0.969808, Verification accuracy: 8.973825%, Total Time:62.201161
Epoch 0 :  Verification average Loss: 0.970324, Verification accuracy: 9.025086%, Total Time:62.201161
Epoch 0 :  Verification average Loss: 0.969956, Verification accuracy: 9.095569%, Total Time:62.201161
Epoch 0 :  Verification average Loss: 0.970642, Verification accuracy: 9.159645%, Total Time:62.216783
Epoch 0 :  Verification average Loss: 0.973284, Verification accuracy: 9.214110%, Total Time:62.227189
Epoch 0 :  Verification average Loss: 0.972637, Verification accuracy: 9.278185%, Total Time:62.234191
Epoch 0 :  Verification average Loss: 0.972535, Verification accuracy: 9.342261%, Total Time:62.240181
Epoch 0 :  Verification average Loss: 0.973131, Verification accuracy: 9.

Epoch 0 :  Verification average Loss: 0.976754, Verification accuracy: 13.635344%, Total Time:62.677557
Epoch 0 :  Verification average Loss: 0.977147, Verification accuracy: 13.689809%, Total Time:62.683553
Epoch 0 :  Verification average Loss: 0.976860, Verification accuracy: 13.753885%, Total Time:62.689550
Epoch 0 :  Verification average Loss: 0.976794, Verification accuracy: 13.814757%, Total Time:62.695546
Epoch 0 :  Verification average Loss: 0.976130, Verification accuracy: 13.878833%, Total Time:62.701543
Epoch 0 :  Verification average Loss: 0.975174, Verification accuracy: 13.955724%, Total Time:62.704844
Epoch 0 :  Verification average Loss: 0.974790, Verification accuracy: 14.029411%, Total Time:62.704844
Epoch 0 :  Verification average Loss: 0.974471, Verification accuracy: 14.093487%, Total Time:62.720469
Epoch 0 :  Verification average Loss: 0.975394, Verification accuracy: 14.144747%, Total Time:62.720469
Epoch 0 :  Verification average Loss: 0.974686, Verification acc

Epoch 0 :  Verification average Loss: 0.967563, Verification accuracy: 18.546759%, Total Time:63.161119
Epoch 0 :  Verification average Loss: 0.967199, Verification accuracy: 18.614039%, Total Time:63.168120
Epoch 0 :  Verification average Loss: 0.967217, Verification accuracy: 18.674911%, Total Time:63.175123
Epoch 0 :  Verification average Loss: 0.967731, Verification accuracy: 18.726172%, Total Time:63.181106
Epoch 0 :  Verification average Loss: 0.967332, Verification accuracy: 18.796655%, Total Time:63.187104
Epoch 0 :  Verification average Loss: 0.968170, Verification accuracy: 18.841508%, Total Time:63.193099
Epoch 0 :  Verification average Loss: 0.967990, Verification accuracy: 18.892769%, Total Time:63.199096
Epoch 0 :  Verification average Loss: 0.968803, Verification accuracy: 18.956845%, Total Time:63.205094
Epoch 0 :  Verification average Loss: 0.968899, Verification accuracy: 19.020921%, Total Time:63.206143
Epoch 0 :  Verification average Loss: 0.969549, Verification acc

Epoch 0 :  Verification average Loss: 0.972490, Verification accuracy: 23.118572%, Total Time:63.608265
Epoch 0 :  Verification average Loss: 0.972434, Verification accuracy: 23.185852%, Total Time:63.616279
Epoch 0 :  Verification average Loss: 0.971814, Verification accuracy: 23.262743%, Total Time:63.622270
Epoch 0 :  Verification average Loss: 0.972225, Verification accuracy: 23.323615%, Total Time:63.629252
Epoch 0 :  Verification average Loss: 0.971849, Verification accuracy: 23.384487%, Total Time:63.635249
Epoch 0 :  Verification average Loss: 0.972050, Verification accuracy: 23.442156%, Total Time:63.638270
Epoch 0 :  Verification average Loss: 0.971865, Verification accuracy: 23.496620%, Total Time:63.638270
Epoch 0 :  Verification average Loss: 0.971257, Verification accuracy: 23.570307%, Total Time:63.638270
Epoch 0 :  Verification average Loss: 0.971241, Verification accuracy: 23.631179%, Total Time:63.658051
Epoch 0 :  Verification average Loss: 0.971074, Verification acc

Epoch 0 :  Verification average Loss: 0.972363, Verification accuracy: 27.885817%, Total Time:64.063302
Epoch 0 :  Verification average Loss: 0.972093, Verification accuracy: 27.949893%, Total Time:64.063302
Epoch 0 :  Verification average Loss: 0.972737, Verification accuracy: 27.988338%, Total Time:64.078912
Epoch 0 :  Verification average Loss: 0.972392, Verification accuracy: 28.055618%, Total Time:64.078912
Epoch 0 :  Verification average Loss: 0.972688, Verification accuracy: 28.106879%, Total Time:64.078912
Epoch 0 :  Verification average Loss: 0.972532, Verification accuracy: 28.174158%, Total Time:64.094555
Epoch 0 :  Verification average Loss: 0.972732, Verification accuracy: 28.225419%, Total Time:64.094555
Epoch 0 :  Verification average Loss: 0.972684, Verification accuracy: 28.292699%, Total Time:64.110160
Epoch 0 :  Verification average Loss: 0.972438, Verification accuracy: 28.359978%, Total Time:64.110160
Epoch 0 :  Verification average Loss: 0.971739, Verification acc

Epoch 0 :  Verification average Loss: 0.970723, Verification accuracy: 32.659469%, Total Time:64.509518
Epoch 0 :  Verification average Loss: 0.971112, Verification accuracy: 32.710730%, Total Time:64.509518
Epoch 0 :  Verification average Loss: 0.971092, Verification accuracy: 32.774805%, Total Time:64.525141
Epoch 0 :  Verification average Loss: 0.971065, Verification accuracy: 32.838881%, Total Time:64.525141
Epoch 0 :  Verification average Loss: 0.971008, Verification accuracy: 32.899753%, Total Time:64.525141
Epoch 0 :  Verification average Loss: 0.971175, Verification accuracy: 32.960625%, Total Time:64.540765
Epoch 0 :  Verification average Loss: 0.970938, Verification accuracy: 33.024701%, Total Time:64.540765
Epoch 0 :  Verification average Loss: 0.970979, Verification accuracy: 33.082370%, Total Time:64.556403
Epoch 0 :  Verification average Loss: 0.970933, Verification accuracy: 33.143242%, Total Time:64.556403
Epoch 0 :  Verification average Loss: 0.971219, Verification acc

Epoch 0 :  Verification average Loss: 0.967654, Verification accuracy: 37.497197%, Total Time:64.959106
Epoch 0 :  Verification average Loss: 0.967917, Verification accuracy: 37.551661%, Total Time:64.959106
Epoch 0 :  Verification average Loss: 0.967940, Verification accuracy: 37.622145%, Total Time:64.959106
Epoch 0 :  Verification average Loss: 0.967829, Verification accuracy: 37.679813%, Total Time:64.974742
Epoch 0 :  Verification average Loss: 0.967879, Verification accuracy: 37.734277%, Total Time:64.974742
Epoch 0 :  Verification average Loss: 0.967570, Verification accuracy: 37.804761%, Total Time:64.990363
Epoch 0 :  Verification average Loss: 0.967813, Verification accuracy: 37.852818%, Total Time:64.990363
Epoch 0 :  Verification average Loss: 0.967882, Verification accuracy: 37.910486%, Total Time:64.990363
Epoch 0 :  Verification average Loss: 0.967784, Verification accuracy: 37.971358%, Total Time:65.005976
Epoch 0 :  Verification average Loss: 0.967492, Verification acc

Epoch 0 :  Verification average Loss: 0.968318, Verification accuracy: 42.222792%, Total Time:65.412201
Epoch 0 :  Verification average Loss: 0.968632, Verification accuracy: 42.280460%, Total Time:65.427837
Epoch 0 :  Verification average Loss: 0.968407, Verification accuracy: 42.357351%, Total Time:65.427837
Epoch 0 :  Verification average Loss: 0.968089, Verification accuracy: 42.424631%, Total Time:65.427837
Epoch 0 :  Verification average Loss: 0.967841, Verification accuracy: 42.491910%, Total Time:65.443449
Epoch 0 :  Verification average Loss: 0.967997, Verification accuracy: 42.543171%, Total Time:65.443449
Epoch 0 :  Verification average Loss: 0.967941, Verification accuracy: 42.610451%, Total Time:65.459085
Epoch 0 :  Verification average Loss: 0.968104, Verification accuracy: 42.671323%, Total Time:65.459085
Epoch 0 :  Verification average Loss: 0.968034, Verification accuracy: 42.735399%, Total Time:65.459085
Epoch 0 :  Verification average Loss: 0.967860, Verification acc

Epoch 0 :  Verification average Loss: 0.967492, Verification accuracy: 47.079742%, Total Time:65.901048
Epoch 0 :  Verification average Loss: 0.967970, Verification accuracy: 47.121392%, Total Time:65.901048
Epoch 0 :  Verification average Loss: 0.968198, Verification accuracy: 47.172652%, Total Time:65.916657
Epoch 0 :  Verification average Loss: 0.967995, Verification accuracy: 47.243136%, Total Time:65.916657
Epoch 0 :  Verification average Loss: 0.968130, Verification accuracy: 47.291193%, Total Time:65.932282
Epoch 0 :  Verification average Loss: 0.968317, Verification accuracy: 47.336046%, Total Time:65.932282
Epoch 0 :  Verification average Loss: 0.968433, Verification accuracy: 47.403326%, Total Time:65.932282
Epoch 0 :  Verification average Loss: 0.968530, Verification accuracy: 47.460994%, Total Time:65.947906
Epoch 0 :  Verification average Loss: 0.968316, Verification accuracy: 47.531477%, Total Time:65.947906
Epoch 0 :  Verification average Loss: 0.968358, Verification acc

Epoch 0 :  Verification average Loss: 0.968347, Verification accuracy: 51.818153%, Total Time:66.340873
Epoch 0 :  Verification average Loss: 0.968681, Verification accuracy: 51.879025%, Total Time:66.356480
Epoch 0 :  Verification average Loss: 0.968873, Verification accuracy: 51.936693%, Total Time:66.356480
Epoch 0 :  Verification average Loss: 0.968965, Verification accuracy: 51.997565%, Total Time:66.372104
Epoch 0 :  Verification average Loss: 0.969266, Verification accuracy: 52.048826%, Total Time:66.372104
Epoch 0 :  Verification average Loss: 0.969272, Verification accuracy: 52.112902%, Total Time:66.372104
Epoch 0 :  Verification average Loss: 0.969428, Verification accuracy: 52.164162%, Total Time:66.387727
Epoch 0 :  Verification average Loss: 0.969254, Verification accuracy: 52.231442%, Total Time:66.387727
Epoch 0 :  Verification average Loss: 0.969141, Verification accuracy: 52.298722%, Total Time:66.387727
Epoch 0 :  Verification average Loss: 0.969422, Verification acc

Epoch 0 :  Verification average Loss: 0.970516, Verification accuracy: 56.454042%, Total Time:66.820263
Epoch 0 :  Verification average Loss: 0.970482, Verification accuracy: 56.514914%, Total Time:66.828258
Epoch 0 :  Verification average Loss: 0.970390, Verification accuracy: 56.569378%, Total Time:66.834255
Epoch 0 :  Verification average Loss: 0.970205, Verification accuracy: 56.633454%, Total Time:66.841251
Epoch 0 :  Verification average Loss: 0.970013, Verification accuracy: 56.703937%, Total Time:66.846248
Epoch 0 :  Verification average Loss: 0.969872, Verification accuracy: 56.764810%, Total Time:66.853250
Epoch 0 :  Verification average Loss: 0.969625, Verification accuracy: 56.841701%, Total Time:66.860241
Epoch 0 :  Verification average Loss: 0.969500, Verification accuracy: 56.908980%, Total Time:66.866236
Epoch 0 :  Verification average Loss: 0.969366, Verification accuracy: 56.979464%, Total Time:66.871233
Epoch 0 :  Verification average Loss: 0.969770, Verification acc

In [33]:
#测试-重新读取文件
PATH = 'model_saved/epoch_0_accuracy_0.596707'
model = torch.load(PATH)
model.to(DEVICE)
total_loss = 0.0
accuracy = 0.0
total_correct = 0.0
total_data_num = len(test_iterator.dataset)
steps = 0.0    
start_time = time.time()
for batch in test_iterator:
    steps += 1
    batch_text = batch.Phrase
    batch_label = batch.Sentiment
    out = model(batch_text)
    loss = criterion(out, batch_label)
    total_loss += loss.item()
    correct = (torch.max(out, dim=1)[1] == batch_label).sum()
    total_correct += correct.item()
print("Test average Loss: %f, Test accuracy: %f，Total time: %f"
  %(total_loss/steps, total_correct/total_data_num, time.time()-start_time) ) 

Test average Loss: 0.964893, Test accuracy: 0.597796，Total time: 4.077614
