Skip to content

Commit

Permalink
multi-gpu enhance
Browse files Browse the repository at this point in the history
  • Loading branch information
meijieru committed Jun 29, 2017
1 parent 22feea5 commit f9dd5ce
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 41 deletions.
11 changes: 5 additions & 6 deletions crnn_main.py
Expand Up @@ -70,13 +70,10 @@
test_dataset = dataset.lmdbDataset(
root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))

ngpu = int(opt.ngpu)
nh = int(opt.nh)
alphabet = opt.alphabet
nclass = len(alphabet) + 1
nclass = len(opt.alphabet) + 1
nc = 1

converter = utils.strLabelConverter(alphabet)
converter = utils.strLabelConverter(opt.alphabet)
criterion = CTCLoss()


Expand All @@ -89,7 +86,8 @@ def weights_init(m):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)

crnn = crnn.CRNN(opt.imgH, nc, nclass, nh, ngpu)

crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
crnn.apply(weights_init)
if opt.crnn != '':
print('loading pretrained model from %s' % opt.crnn)
Expand All @@ -102,6 +100,7 @@ def weights_init(m):

if opt.cuda:
crnn.cuda()
crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
image = image.cuda()
criterion = criterion.cuda()

Expand Down
2 changes: 1 addition & 1 deletion demo.py
Expand Up @@ -11,7 +11,7 @@
img_path = './data/demo.png'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'

model = crnn.CRNN(32, 1, 37, 256, 1).cuda()
model = crnn.CRNN(32, 1, 37, 256).cuda()
print('loading pretrained model from %s' % model_path)
model.load_state_dict(torch.load(model_path))

Expand Down
34 changes: 13 additions & 21 deletions models/crnn.py
@@ -1,34 +1,29 @@
import torch.nn as nn
import utils


class BidirectionalLSTM(nn.Module):

def __init__(self, nIn, nHidden, nOut, ngpu):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.ngpu = ngpu

self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)

def forward(self, input):
recurrent, _ = utils.data_parallel(
self.rnn, input, self.ngpu) # [T, b, h * 2]

recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = utils.data_parallel(
self.embedding, t_rec, self.ngpu) # [T * b, nOut]

output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)

return output


class CRNN(nn.Module):

def __init__(self, imgH, nc, nclass, nh, ngpu, n_rnn=2, leakyRelu=False):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
super(CRNN, self).__init__()
self.ngpu = ngpu
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

ks = [3, 3, 3, 3, 3, 3, 2]
Expand Down Expand Up @@ -57,31 +52,28 @@ def convRelu(i, batchNormalization=False):
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, True)
convRelu(3)
cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2),
(2, 1),
(0, 1))) # 256x4x16
cnn.add_module('pooling{0}'.format(2),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
convRelu(4, True)
convRelu(5)
cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2),
(2, 1),
(0, 1))) # 512x2x16
cnn.add_module('pooling{0}'.format(3),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
convRelu(6, True) # 512x1x16

self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh, ngpu),
BidirectionalLSTM(nh, nh, nclass, ngpu)
)
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass))

def forward(self, input):
# conv features
conv = utils.data_parallel(self.cnn, input, self.ngpu)
conv = self.cnn(input)
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c]

# rnn features
output = utils.data_parallel(self.rnn, conv, self.ngpu)
output = self.rnn(conv)

return output
13 changes: 0 additions & 13 deletions models/utils.py

This file was deleted.

0 comments on commit f9dd5ce

Please sign in to comment.