forked from zfang/cnn-text-classification-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
120 lines (102 loc) · 4.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import sys
import torch
import torch.autograd as autograd
import torch.nn.functional as F
from nltk import word_tokenize
import copy
import torchnet.meter as meter
from util import print_time
def train(train_iter, dev_iter, model, args):
if args.cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
steps = 0
model.train()
best_dev_accuracy = 0
best_model = copy.deepcopy(model)
best_epoch = 1
epochTimer = meter.TimeMeter('s')
stepTimer = meter.TimeMeter('s')
for epoch in range(1, args.epochs+1):
epochTimer.reset()
for batch in train_iter:
feature, target = batch.text, batch.label
feature.data.t_(), target.data.sub_(1) # batch first, index align
if args.cuda:
feature, target = feature.cuda(), target.cuda()
optimizer.zero_grad()
logit = model(feature)
loss = F.cross_entropy(logit, target)
loss.backward()
optimizer.step()
model.renorm_fc(args.max_norm)
predictions = torch.max(logit, 1)[1].view(target.size())
steps += 1
if steps % args.log_interval == 0:
corrects = (predictions.data == target.data).sum()
accuracy = 100.0 * corrects/batch.batch_size
sys.stdout.write(
'\rBatch[{}] - loss: {:.6f} acc: {:.4f}%({}/{}), time: {:.2f}{}'.format(
steps, loss.data[0], accuracy, corrects, batch.batch_size, stepTimer.value(), stepTimer.unit))
stepTimer.reset()
if args.save_interval != 0 and steps % args.save_interval == 0:
if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir)
save_prefix = os.path.join(args.save_dir, 'snapshot')
save_path = '{}_steps{}.pt'.format(save_prefix, steps)
torch.save(model, save_path)
print()
print_time('epoch {}'.format(epoch), epochTimer)
dev_accuracy = eval(dev_iter, model, args, print_info=True)
if dev_accuracy > best_dev_accuracy:
best_dev_accuracy = dev_accuracy
best_model = copy.deepcopy(model)
best_epoch = epoch
if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir)
torch.save(best_model, os.path.join(args.save_dir, 'model.pt'))
print("Best epoch:", best_epoch)
print("Best dev accuracy:", best_dev_accuracy)
def eval(data_iter, model, args, print_info=False):
model.eval()
corrects, avg_loss = 0, 0
confusionMeter = meter.ConfusionMeter(len(model.label_itos))
for batch in data_iter:
feature, target = batch.text, batch.label
feature.data.t_(), target.data.sub_(1) # batch first, index align
if args.cuda:
feature, target = feature.cuda(), target.cuda()
logit = model(feature)
loss = F.cross_entropy(logit, target, size_average=False)
predictions = torch.max(logit, 1)[1].view(target.size())
avg_loss += loss.data[0]
corrects += (predictions.data == target.data).sum()
confusionMeter.add(predictions.data, target.data)
size = len(data_iter.dataset)
avg_loss = avg_loss/size
accuracy = 100.0 * corrects/size
model.train()
if print_info:
print('Evaluation - loss: {:.6f} acc: {:.4f}%({}/{})'.format(avg_loss,
accuracy,
corrects,
size))
print("Confusion Matrix\n", confusionMeter.value())
print()
return accuracy
def predict(text, model, args):
assert isinstance(text, str)
model.eval()
text = word_tokenize(text)
text = [[model.vocab_stoi[x] for x in text]]
x = model.tensor_type(text)
x = autograd.Variable(x, volatile=True)
if args.cuda:
x = x.cuda()
#print(x)
output = model(x)
if args.debug:
output = F.softmax(output)
for i, v in enumerate(output[0]):
print("{:8} {:.4f}".format(model.label_itos[i+1], v.data.cpu().numpy()[0]))
_, predicted = torch.max(output, 1)
return model.label_itos[predicted.data[0]+1]