Skip to content
Permalink
Browse files

Make Kim CNN ONNX-exportable (#136)

* Kim CNN - only set embedding for corresponding mode

* Kim CNN ONNX Export

* Specify dummy ONNX input size from command line
  • Loading branch information...
tuzhucheng committed Aug 4, 2018
1 parent 82bf90f commit 8a00f9cdcdee682a630693aef8cfc1b956bf4f07
Showing with 37 additions and 19 deletions.
  1. +1 −0 .gitignore
  2. +1 −1 common/evaluators/sst_evaluator.py
  3. +1 −1 common/trainers/sst_trainer.py
  4. +16 −9 kim_cnn/__main__.py
  5. +4 −1 kim_cnn/args.py
  6. +14 −7 kim_cnn/model.py
@@ -12,3 +12,4 @@ text/
kim_cnn/data
.results
.qrel
*.onnx
@@ -13,7 +13,7 @@ def get_scores(self):
total_loss = 0

for batch_idx, batch in enumerate(self.data_loader):
scores = self.model(batch)
scores = self.model(batch.text)
n_dev_correct += (
torch.max(scores, 1)[1].view(batch.label.size()).data == batch.label.data).sum().item()
total_loss += F.cross_entropy(scores, batch.label, size_average=False).item()
@@ -28,7 +28,7 @@ def train_epoch(self, epoch):
self.iterations += 1
self.model.train()
self.optimizer.zero_grad()
scores = self.model(batch)
scores = self.model(batch.text)
n_correct += (torch.max(scores, 1)[1].view(batch.label.size()).data == batch.label.data).sum().item()
n_total += batch.batch_size
train_acc = 100. * n_correct / n_total
@@ -4,6 +4,7 @@

import numpy as np
import torch
import torch.onnx

from common.evaluation import EvaluatorFactory
from common.train import TrainerFactory
@@ -60,11 +61,11 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
if not args.cuda:
args.gpu = -1
if torch.cuda.is_available() and args.cuda:
print("Note: You are using GPU for training")
print('Note: You are using GPU for training')
torch.cuda.set_device(args.gpu)
torch.cuda.manual_seed(args.seed)
if torch.cuda.is_available() and not args.cuda:
print("Warning: You have Cuda but not use it. You are using CPU for training.")
print('Warning: You have Cuda but not use it. You are using CPU for training.')
np.random.seed(args.seed)
random.seed(args.seed)
logger = get_logger()
@@ -83,12 +84,12 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
config.target_class = train_iter.dataset.NUM_CLASSES
config.words_num = len(train_iter.dataset.TEXT_FIELD.vocab)

print("Dataset {} Mode {}".format(args.dataset, args.mode))
print("VOCAB num",len(train_iter.dataset.TEXT_FIELD.vocab))
print("LABEL.target_class:", train_iter.dataset.NUM_CLASSES)
print("Train instance", len(train_iter.dataset))
print("Dev instance", len(dev_iter.dataset))
print("Test instance", len(test_iter.dataset))
print('Dataset {} Mode {}'.format(args.dataset, args.mode))
print('VOCAB num',len(train_iter.dataset.TEXT_FIELD.vocab))
print('LABEL.target_class:', train_iter.dataset.NUM_CLASSES)
print('Train instance', len(train_iter.dataset))
print('Dev instance', len(dev_iter.dataset))
print('Test instance', len(test_iter.dataset))

if args.resume_snapshot:
if args.cuda:
@@ -99,7 +100,7 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
model = KimCNN(config)
if args.cuda:
model.cuda()
print("Shift model to GPU")
print('Shift model to GPU')

parameter = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adadelta(parameter, lr=args.lr, weight_decay=args.weight_decay)
@@ -143,3 +144,9 @@ def evaluate_dataset(split_name, dataset_cls, model, embedding, loader, batch_si
else:
raise ValueError('Unrecognized dataset')

if args.onnx:
device = torch.device('cuda') if torch.cuda.is_available() and args.cuda else torch.device('cpu')
dummy_input = torch.zeros(args.onnx_batch_size, args.onnx_sent_len, dtype=torch.long, device=device)
onnx_filename = 'kimcnn_{}.onnx'.format(args.mode)
torch.onnx.export(model, dummy_input, onnx_filename)
print('Exported model in ONNX format as {}'.format(onnx_filename))
@@ -29,7 +29,10 @@ def get_args():
default=os.path.join(os.pardir, 'Castor-data', 'embeddings', 'word2vec'))
parser.add_argument('--word_vectors_file', help='word vectors filename', default='GoogleNews-vectors-negative300.txt')
parser.add_argument('--trained_model', type=str, default="")
parser.add_argument('--weight_decay',type=float, default=0)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--onnx', action='store_true', default=False, help='Export model in ONNX format')
parser.add_argument('--onnx_batch_size', type=int, default=1024, help='Batch size for ONNX export')
parser.add_argument('--onnx_sent_len', type=int, default=32, help='Sentence length for ONNX export')

args = parser.parse_args()
return args
@@ -14,14 +14,22 @@ def __init__(self, config):
words_dim = config.words_dim
self.mode = config.mode
Ks = 3 # There are three conv nets here
if config.mode == 'multichannel':

input_channel = 1
if config.mode == 'rand':
rand_embed_init = torch.Tensor(words_num, words_dim).uniform_(-0.25, 0.25)
self.embed = nn.Embedding.from_pretrained(rand_embed_init, freeze=False)
elif config.mode == 'static':
self.static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=True)
elif config.mode == 'non-static':
self.non_static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=False)
elif config.mode == 'multichannel':
self.static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=True)
self.non_static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=False)
input_channel = 2
else:
input_channel = 1
rand_embed_init = torch.Tensor(words_num, words_dim).uniform_(-0.25, 0.25)
self.embed = nn.Embedding.from_pretrained(rand_embed_init, freeze=False)
self.static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=True)
self.non_static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=False)
print("Unsupported Mode")
exit()

self.conv1 = nn.Conv2d(input_channel, output_channel, (3, words_dim), padding=(2,0))
self.conv2 = nn.Conv2d(input_channel, output_channel, (4, words_dim), padding=(3,0))
@@ -31,7 +39,6 @@ def __init__(self, config):
self.fc1 = nn.Linear(Ks * output_channel, target_class)

def forward(self, x):
x = x.text
if self.mode == 'rand':
word_input = self.embed(x) # (batch, sent_len, embed_dim)
x = word_input.unsqueeze(1) # (batch, channel_input, sent_len, embed_dim)

0 comments on commit 8a00f9c

Please sign in to comment.
You can’t perform that action at this time.