In [None]:
!git clone https://github.com/Riroaki/CapsNet.git

fatal: destination path 'CapsNet' already exists and is not an empty directory.


In [None]:
import torch
from torch import nn

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def squash(x, dim=-1):
    squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * x / (squared_norm.sqrt() + 1e-8)

In [None]:
class Residual_Block(nn.Module):

    def __init__(self, in_dim, mid_dim, out_dim):
      super(Residual_Block,self).__init__()
        # Residual Block
      self.residual_block = nn.Sequential(
                nn.Conv2d(in_dim, mid_dim, kernel_size=3, padding=1),
                nn.ReLU,
                nn.Conv2d(mid_dim, out_dim, kernel_size=3, padding=1),
            )
      self.relu = nn.ReLU()

    def forward(self, x):
       out = self. residual_block(x)  # F(x)
       out = out + x  # F(x) + x
       out = self.relu(out)
       return out

In [None]:
class PrimaryCaps(nn.Module):
    """Primary capsule layer."""

    def __init__(self, num_conv_units, in_channels, out_channels, kernel_size, stride):
        super(PrimaryCaps, self).__init__()

        # Each conv unit stands for a single capsule.
        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels * num_conv_units,
                              kernel_size=kernel_size,
                              stride=stride)
        self.out_channels = out_channels

    def forward(self, x):
        # Shape of x: (batch_size, in_channels, height, weight)
        # Shape of out: out_capsules * (batch_size, out_channels, height, weight)
        out = self.conv(x)
        # Flatten out: (batch_size, out_capsules * height * weight, out_channels)
        batch_size = out.shape[0]
        return squash(out.contiguous().view(batch_size, -1, self.out_channels), dim=-1)

In [None]:
class DigitCaps(nn.Module):
    """Digit capsule layer."""

    def __init__(self, in_dim, in_caps, out_caps, out_dim, num_routing):
        """
        Initialize the layer.
        Args:
            in_dim: 		Dimensionality of each capsule vector.
            in_caps: 		Number of input capsules if digits layer.
            out_caps: 		Number of capsules in the capsule layer
            out_dim: 		Dimensionality, of the output capsule vector.
            num_routing:	Number of iterations during routing algorithm
        """
        super(DigitCaps, self).__init__()
        self.in_dim = in_dim
        self.in_caps = in_caps
        self.out_caps = out_caps
        self.out_dim = out_dim
        self.num_routing = num_routing
        self.device = device
        self.W = nn.Parameter(0.01 * torch.randn(1, out_caps, in_caps, out_dim, in_dim),
                              requires_grad=True)

    def forward(self, x):
        batch_size = x.size(0)
        # (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1)
        x = x.unsqueeze(1).unsqueeze(4)
        # W @ x =
        # (1, out_caps, in_caps, out_dim, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) =
        # (batch_size, out_caps, in_caps, out_dims, 1)
        u_hat = torch.matmul(self.W, x)
        # (batch_size, out_caps, in_caps, out_dim)
        u_hat = u_hat.squeeze(-1)
        # detach u_hat during routing iterations to prevent gradients from flowing
        temp_u_hat = u_hat.detach()

        b = torch.zeros(batch_size, self.out_caps, self.in_caps, 1).to(self.device)
        for route_iter in range(self.num_routing - 1):
            # (batch_size, out_caps, in_caps, 1) -> Softmax along out_caps
            c = b.softmax(dim=1)

            # element-wise multiplication
            # (batch_size, out_caps, in_caps, 1) * (batch_size, in_caps, out_caps, out_dim) ->
            # (batch_size, out_caps, in_caps, out_dim) sum across in_caps ->
            # (batch_size, out_caps, out_dim)
            s = (c * temp_u_hat).sum(dim=2)
            # apply "squashing" non-linearity along out_dim
            v = squash(s)
            # dot product agreement between the current output vj and the prediction uj|i
            # (batch_size, out_caps, in_caps, out_dim) @ (batch_size, out_caps, out_dim, 1)
            # -> (batch_size, out_caps, in_caps, 1)
            uv = torch.matmul(temp_u_hat, v.unsqueeze(-1))
            b += uv

        # last iteration is done on the original u_hat, without the routing weights update
        c = b.softmax(dim=1)
        s = (c * u_hat).sum(dim=2)
        # apply "squashing" non-linearity along out_dim
        v = squash(s)

        return v

In [None]:
class CapsNet(nn.Module):
    """Basic implementation of capsule network layer."""

    def __init__(self):
        super(CapsNet, self).__init__()

        # Conv2d layer1
        self.conv1 = nn.Conv2d(1, 32, 9)
        self.relu1 = nn.ReLU(inplace=True)

        # Batch Normalization1
        self.bn1 = nn.BatchNorm2d(32)

        # Conv2d layer2
        self.conv2 = nn.Conv2d(32, 64, 9)
        self.sigmoid2 = nn.Sigmoid()

        # Batch Normalization2
        self.bn2 = nn.BatchNorm2d(64)

        # Conv2d layer3
        self.conv3 = nn.Conv2d(64, 256, 9)
        self.relu3 = nn.ReLU(inplace=True)

        # Primary capsule
        self.primary_caps = PrimaryCaps(num_conv_units=32,
                                        in_channels=256,
                                        out_channels=8,
                                        kernel_size=9,
                                        stride=2)

        # Digit capsule
        self.digit_caps = DigitCaps(in_dim=8,
                                    in_caps=32 * 6 * 6,
                                    out_caps=7,
                                    out_dim=16,
                                    num_routing=3)

        # Reconstruction layer
        self.decoder = nn.Sequential(
            nn.Linear(16 * 7, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 1600),
            nn.ReLU(inplace=True),
            nn.Linear(1600, 2304),
            nn.ReLU(inplace=True),
            nn.Linear(2304, 1936),
            nn.Sigmoid())



    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.bn1(out)
        out = self.sigmoid2(self.conv2(out))
        out = self.bn2(out)
        out = self.relu3(self.conv3(out))
        out = self.primary_caps(out)
        out = self.digit_caps(out)

        # Shape of logits: (batch_size, out_capsules)
        logits = torch.norm(out, dim=-1)
        pred = torch.eye(7).to(device).index_select(dim=0, index=torch.argmax(logits, dim=1))

        # Reconstruction
        batch_size = out.shape[0]
        reconstruction = self.decoder((out * pred.unsqueeze(2)).contiguous().view(batch_size, -1))

        return logits, reconstruction

In [None]:
class CapsuleLoss(nn.Module):
    """Combine margin loss & reconstruction loss of capsule network."""

    def __init__(self, upper_bound=0.9, lower_bound=0.1, lmda=0.5):
        super(CapsuleLoss, self).__init__()
        self.upper = upper_bound
        self.lower = lower_bound
        self.lmda = lmda
        self.reconstruction_loss_scalar = 5e-3
        self.mse = nn.MSELoss(reduction='sum')

    def forward(self, images, labels, logits, reconstructions):
        # Shape of left / right / labels: (batch_size, num_classes)
        left = (self.upper - logits).relu() ** 2  # True negative
        right = (logits - self.lower).relu() ** 2  # False positive
        margin_loss = torch.sum(labels * left) + self.lmda * torch.sum((1 - labels) * right)

        # Reconstruction loss
        reconstruction_loss = self.mse(reconstructions.contiguous().view(images.shape), images)
        self.reconstruction_loss_scalar *= 1.001

        # Combine two losses
        return margin_loss + self.reconstruction_loss_scalar * reconstruction_loss

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torchvision
from torch import optim
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torch.optim import Adam

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def main():
    # Load model
    torch.autograd.set_detect_anomaly(True)
    model = CapsNet().to(device)
    criterion = CapsuleLoss()
    optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)

    # Load data
    transform = transforms.Compose([
        # shift by 2 pixels in either direction with zero padding.
        transforms.Grayscale(num_output_channels=1),
        transforms.RandomCrop(44, padding=1),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    BATCH_SIZE = 128
    trainset = torchvision.datasets.ImageFolder(root="/content/drive/MyDrive/wiset/train", transform=transform)
    train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)
    testset = torchvision.datasets.ImageFolder(root="/content/drive/MyDrive/wiset/test", transform=transform)
    test_loader = DataLoader(testset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)

    # Train
    EPOCHES = 10
    model.train()
    for ep in range(EPOCHES):
        batch_id = 10
        correct, total, total_loss = 0, 0, 0.
        for images, labels in train_loader:
            optimizer.zero_grad()
            images = images.to(device)
            labels = torch.eye(7).index_select(dim=0, index=labels).to(device)
            logits, reconstruction = model(images)

            # Compute loss & accuracy
            loss = criterion(images, labels, logits, reconstruction)
            correct += torch.sum(
                torch.argmax(logits, dim=1) == torch.argmax(labels, dim=1)).item()
            total += len(labels)
            accuracy = correct / total
            total_loss += loss
            loss.backward()
            optimizer.step()
            batch_id += 1
        scheduler.step(ep)
        print('Total loss for epoch {}: {}, Accuracy: {}'.format(ep+1, total_loss, accuracy))

if __name__ == '__main__':
    main()

Total loss for epoch 1: 237858.40625, Accuracy: 0.25679886685552406
Total loss for epoch 2: 288460.65625, Accuracy: 0.3435198300283286
Total loss for epoch 3: 356255.96875, Accuracy: 0.3896954674220963
Total loss for epoch 4: 441090.75, Accuracy: 0.4244334277620397
Total loss for epoch 5: 547184.0625, Accuracy: 0.44975212464589237
Total loss for epoch 6: 679452.5625, Accuracy: 0.47694759206798865
Total loss for epoch 7: 844822.375, Accuracy: 0.4992209631728045
Total loss for epoch 8: 1050425.75, Accuracy: 0.524185552407932
Total loss for epoch 9: 1307027.875, Accuracy: 0.5449362606232294
Total loss for epoch 10: 1627966.875, Accuracy: 0.5717776203966005
