Skip to content

Commit

Permalink
Merge pull request #1 from laubonghaudoi/master
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
alexlimh committed Dec 5, 2018
2 parents 1a945f1 + ea9d752 commit 2bc1114
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ __pycache__/
# data directory
data/
ckpt/
runs/
4 changes: 2 additions & 2 deletions CapsNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def __init__(self, opt):
def forward(self, x):
'''
Args:
`x`: [batch_size, 1, 28, 28] A MNIST sample
`x`: [batch_size, 1, 28, 28] MNIST samples
Return:
`v`: [batch_size, 10, 16] CapsNet outputs, 16D rediction vectors of
`v`: [batch_size, 10, 16] CapsNet outputs, 16D prediction vectors of
10 digit capsules
The dimension transformation procedure of an input tensor in each layer:
Expand Down
8 changes: 4 additions & 4 deletions Decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(self, opt):
def forward(self, v, target):
'''
Args:
v: [batch_size, 10, 16]
target: [batch_size, 10]
`v`: [batch_size, 10, 16]
`target`: [batch_size, 10]
Return:
`reconstruction`: [batch_size, 784]
Expand All @@ -53,8 +53,8 @@ def forward(self, v, target):
assert v_masked.size() == torch.Size([batch_size, 16])

# Forward
v = self.fc1(v_masked)
v = self.fc2(v)
v = F.relu(self.fc1(v_masked))
v = F.relu(self.fc2(v))
reconstruction = torch.sigmoid(self.fc3(v))

assert reconstruction.size() == torch.Size([batch_size, 784])
Expand Down
6 changes: 3 additions & 3 deletions DigitCaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
class DigitCaps(nn.Module):
'''
The `DigitCaps` layer consists of 10 16D capsules. Compared to the traditional
scalar output neurons in fully connected layers(FCN), the `DigitCaps` layer
can be seen as an FCN with 16-dimensional output neurons, where we call
scalar output neurons in fully connected networks(FCN), the `DigitCaps` layer
can be seen as an FCN with ten 16-dimensional output neurons, which we call
these neurons "capsules".
In this layer, we take the input `[1152, 8]` tensor `u` as 1152 [8,] vectors
Expand All @@ -34,7 +34,7 @@ def __init__(self, opt):
The the coupling coefficients `b` [1152, 10] is a temporary variable which
does NOT belong to the layer's parameters. In other words, `b` is not updated
by gradient back-propagations. Instead, we update `b` by Dynamic Routing
in every forward propagation. See docstring of `self.forward` for details.
in every forward propagation. See the docstring of `self.forward` for details.
'''
super(DigitCaps, self).__init__()
self.opt = opt
Expand Down
4 changes: 2 additions & 2 deletions PrimaryCaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class PrimaryCaps(nn.Module):
'''
The `PrimaryCaps` layer consists of 32 capsule units. Each unit takes
the output of the `Conv1` layer, which is a `[256, 20, 20]` feature
tensor (ignoring `batch_size`), and performs a 2D convolution with 8
tensor (omitting `batch_size`), and performs a 2D convolution with 8
output channels, kernel size 9 and stride 2, thus outputing a [8, 6, 6]
tensor. In other words, you can see these 32 capsules as 32 paralleled 2D
convolutional layers. Then we concatenate these 32 capsules' outputs and
Expand All @@ -16,7 +16,7 @@ class PrimaryCaps(nn.Module):
As indicated in Section 4, Page 4 in the paper, *One can see PrimaryCaps
as a Convolution layer with Eq.1 as its block non-linearity.*, outputs of
the `PrimaryCaps` layer are squashed before passing to the next layer.
the `PrimaryCaps` layer are squashed before being passed to the next layer.
Reference: Section 4, Fig. 1
'''
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ As I am busy these days, I might not have time to checkout and fix every issue.

## Requirements

- pytorch 0.2.0
- pytorch 0.4.1
- torchvision
- pytorch-extras (For one-hot vector conversion)
- tensorboard-pytorch
- tqdm

All codes are tested under Python 3.6.3.
All codes are tested under Python 3.6.

## Get Started

Expand Down
26 changes: 14 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
8. `train()` and `test()` in `main.py`
You might find helpful with the paper *Dynamic Routing Between Capsules*
at your hand for referencing.
at your hand for referencing when reading these codes.
"""

import os
Expand Down Expand Up @@ -61,7 +61,8 @@ def train(opt, train_loader, test_loader, model, writer):
assert target.size() == torch.Size([batch_size, 10])

# Use GPU if available
data, target = Variable(data), Variable(target)
with torch.no_grad():
data, target = Variable(data), Variable(target)
if opt.use_cuda & torch.cuda.is_available():
data, target = data.cuda(), target.cuda()

Expand Down Expand Up @@ -113,23 +114,24 @@ def test(opt, test_loader, model, writer, epoch, num_batches):
assert target.size() == torch.Size([batch_size, 10])

# Use GPU if available
data, target = Variable(data, volatile=True), Variable(target)
with torch.no_grad():
data, target = Variable(data), Variable(target)
if opt.use_cuda & torch.cuda.is_available():
data, target = data.cuda(), target.cuda()

# Output predictions
output = model(data)
L, m_loss, r_loss = model.loss(output, target, data)
loss += L
margin_loss += m_loss
recons_loss += r_loss
loss += L.item()
margin_loss += m_loss.item()
recons_loss += r_loss.item()

# Count correct numbers
# norms: [batch_size, 10, 16]
norms = torch.sqrt(torch.sum(output**2, dim=2))
# pred: [batch_size,]
pred = norms.data.max(1, keepdim=True)[1].type(torch.LongTensor)
correct += pred.eq(label.view_as(pred)).cpu().sum()
correct += pred.eq(label.view_as(pred)).cpu().sum().item()

# Visualize reconstructed images of the last batch
recons = model.Decoder(output, target)
Expand All @@ -142,20 +144,20 @@ def test(opt, test_loader, model, writer, epoch, num_batches):
margin_loss /= len(test_loader)
recons_loss /= len(test_loader)
acc = correct / len(test_loader.dataset)
writer.add_scalar('test/loss', loss.item(), step)
writer.add_scalar('test/marginal_loss', margin_loss.item(), step)
writer.add_scalar('test/reconstruction_loss', recons_loss.item(), step)
writer.add_scalar('test/loss', loss, step)
writer.add_scalar('test/marginal_loss', margin_loss, step)
writer.add_scalar('test/reconstruction_loss', recons_loss, step)
writer.add_scalar('test/accuracy', acc, step)

# Print test losses
print('\nTest loss: {:.4f} Marginal loss: {:.4f} Recons loss: {:.4f}'.format(
loss.item(), margin_loss.item(), recons_loss.item()))
loss, margin_loss, recons_loss))
print('Accuracy: {}/{} ({:.0f}%)\n'.format(correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))

# Checkpoint model
torch.save(model, './ckpt/epoch_{}-loss_{:.6f}-acc_{:.6f}.pt'.format(
epoch, loss.item(), acc))
epoch, loss, acc))


if __name__ == "__main__":
Expand Down

0 comments on commit 2bc1114

Please sign in to comment.