In [None]:
#INSTALLS
!pip install torch
!pip install pytorch-lightning
!pip install torch_optimizer
!pip install easydict
!pip install torchsummary
!pip install wandb



In [None]:
import numpy as np
import torch.nn as nn
import torch
import pytorch_lightning as pl
from easydict import EasyDict
import torch_optimizer as optim
from torchsummary import summary
import torch.nn.functional as F
from pytorch_lightning.loggers import WandbLogger
import wandb

args = EasyDict({"batch_size": 500,
        "num_epochs": 300,
        "model": "fcae",
        "dataset": "mnist",
        "fcae": {
          "num_caps": 16,
          "caps_dim": 6,
          "feat_dim": 16,

          "optimizer": "radam",
          "lr": 0.01,
          "lr_decay": 0.998,
          "lr_restart_interval": 4000,
          "lr_scheduler": "cosrestarts",

          "weight_decay": 0.0,
          "loss_ll_coeff": 1.0,
          "loss_temp_l1_coeff": 0.01,
          "loss_mse_coeff": 0.0,

          "encoder": {
                    "noise_scale": 4.,
                    "inverse_space_transform": True
                    },
          "decoder": {
                    "alpha_channel": False,
                    "template_nonlin": "sigmoid",
                    "color_nonlin": "sigmoid",
                    "output_size": [40, 40],
                    "lr_coeff": 1.0
          }
        }
      })

logargs = EasyDict({"log":
  {"run_name": None,
  "project": "StackedCapsuleAutoEncoders",
  "team": "mlatberkeley",
  "upload": True,
  "frequency": 2}
  })

logger = WandbLogger(
        project=logargs.log.project,
        name=logargs.log.run_name,
        entity=logargs.log.team,
        config=logargs, offline=not logargs.log.upload)

In [None]:
class CapsuleEncoder(nn.Module):
    def __init__(self, num_caps, caps_dim, output_channels=[32,64], pool_dim=2, latent=8):
        super(CapsuleEncoder, self).__init__()
        self.num_caps = num_caps
        self.caps_dim = caps_dim
        self.cnn = torch.nn.Sequential()
        self.fc = torch.nn.Sequential()
        """
        for i in range(len(output_channels)):
            s = output_channels[i]
            if i == 0:
                self.cnn.add_module("conv_" + str(i), torch.nn.Conv2d(1, s, kernel_size=5, padding=1))
            else:
                self.cnn.add_module("conv_" + str(i), torch.nn.Conv2d(output_channels[i-1],
                                                                      s, kernel_size=2, padding=1))
            self.cnn.add_module("maxpool_" + str(i), torch.nn.MaxPool2d(kernel_size=5))
            self.cnn.add_module("relu_" + str(i), torch.nn.ReLU())
        """

        self.cnn.add_module("conv_0", torch.nn.Conv2d(1, 32, kernel_size=5))
        self.cnn.add_module("maxpool_0", torch.nn.MaxPool2d(kernel_size=3))
        self.cnn.add_module("relu_0", torch.nn.ReLU())

        self.cnn.add_module("conv_2", torch.nn.Conv2d(32, 64, kernel_size=5))
        self.cnn.add_module("maxpool_2", torch.nn.MaxPool2d(kernel_size=3))
        self.cnn.add_module("relu_2", torch.nn.ReLU())
        
        self.cnn.add_module("flatten_0", torch.nn.Flatten(1))

        self.fc.add_module("fc_0", torch.nn.Linear(256, 32))
        self.fc.add_module("tanh_0", torch.nn.Tanh())
        # self.fc.add_module("fc_1", torch.nn.Linear(32, num_caps*caps_dim, bias=True))
        self.fc.add_module("fc_1", torch.nn.Linear(32, latent, bias=True))

    def forward(self, x):
        self.cnn.cuda()
        self.fc.cuda()
        x.cuda()
        x = self.cnn(x)
        return self.fc(x)


class ImplicitDecoder(nn.Module):
    def __init__(self, z_dim, gf_dim=2, point_dim=2):
        super(ImplicitDecoder, self).__init__()
        self.z_dim = z_dim
        self.point_dim = point_dim
        self.gf_dim = gf_dim
        self.linear_1 = nn.Linear(self.z_dim + self.point_dim, self.gf_dim * 16, bias=True).cuda()
        self.linear_2 = nn.Linear(self.gf_dim * 16, self.gf_dim * 8, bias=True)
        self.linear_3 = nn.Linear(self.gf_dim * 8, self.gf_dim * 4, bias=True)
        self.linear_4 = nn.Linear(self.gf_dim * 4, self.gf_dim * 2, bias=True)
        self.linear_5 = nn.Linear(self.gf_dim * 2, self.gf_dim * 1, bias=True)
        self.linear_6 = nn.Linear(self.gf_dim * 1, 1, bias=True)
        
        nn.init.normal_(self.linear_1.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.linear_1.bias, 0)
        nn.init.normal_(self.linear_2.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.linear_2.bias, 0)
        nn.init.normal_(self.linear_3.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.linear_3.bias, 0)
        nn.init.normal_(self.linear_4.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.linear_4.bias, 0)
        nn.init.normal_(self.linear_5.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.linear_5.bias, 0)
        nn.init.normal_(self.linear_6.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.linear_6.bias, 0)
        
    def forward(self, points, z, is_training=False):
        # zs = z.view(-1, 1, self.z_dim).repeat(1, points.size()[1], 1)
        # change to expand
        # points (N x h*w x 2)
        # z: (N x 224)
        zs = z.unsqueeze(1)
        zs = zs.repeat(1, points.shape[1], 1).cuda()
        points = points.cuda()
        pointz = torch.cat([points, zs], 2).to("cuda")
        self.linear_1 = self.linear_1.cuda()
        self.linear_2 = self.linear_2.cuda()
        self.linear_3 = self.linear_3.cuda()
        self.linear_4 = self.linear_4.cuda()
        self.linear_5 = self.linear_5.cuda()
        self.linear_6 = self.linear_6.cuda()
        l1 = self.linear_1(pointz)
        l1 = F.leaky_relu(l1, negative_slope=0.02, inplace=True)
        l2 = self.linear_2(l1)
        l2 = F.leaky_relu(l2, negative_slope=0.02, inplace=True)
        l3 = self.linear_3(l2)
        l3 = F.leaky_relu(l3, negative_slope=0.02, inplace=True)
        l4 = self.linear_4(l3)
        l4 = F.leaky_relu(l4, negative_slope=0.02, inplace=True)
        l5 = self.linear_5(l4)
        l5 = F.leaky_relu(l5, negative_slope=0.02, inplace=True)
        l6 = self.linear_6(l5)
        l6 = torch.max(torch.min(l6, l6 * 0.01 + 0.99), l6 * 0.01)
        return l6

In [None]:
def to_wandb_im(x, **kwargs):  # TODO: move to utils
    x = x.detach()

    if len(x.shape) == 3:
        # Torch uses C, H, W
        x = x.permute(1, 2, 0)

    if x.shape[-1] == 2:
        # channels = val, alpha
        val = x[..., 0]
        alpha = x[..., 1]

        # convert to RGBA
        x = torch.stack([val]*3 + [alpha], dim=-1)

    return wandb.Image(x.cpu().numpy(), **kwargs)

def rec_to_wandb_im(x, **kwargs):  # TODO: move to utils
    # TODO: unpack reconstruction template components
    return to_wandb_im(x, **kwargs)

class FCAE(pl.LightningModule):
    def __init__(self, args: EasyDict, num_caps=16, input_dims=2, latent_dims=8, transform_dims=4, watch=False):
        super(FCAE, self).__init__()
        self.caps_dims = input_dims + latent_dims + transform_dims
        self.encoder = CapsuleEncoder(num_caps, self.caps_dims)
        self.decoder = ImplicitDecoder(latent_dims)
        self.encoder.cuda()
        self.decoder.cuda()
        if watch:
          logger.watch(self.decoder, log='all', log_freq=logargs.log.frequency)
        self.num_caps = num_caps
        self.input_dims = input_dims
        self.latent_dims = latent_dims
        self.tranform_dims = transform_dims
        self.mse = nn.MSELoss().cuda()
        self.lr = args.fcae.lr
        self.args = args
        self.lr_decay = args.fcae.lr_decay
        self.weight_decay = args.fcae.weight_decay

    def forward(self, im):
        batch_size = im.shape[0]
        h = im.shape[2]
        w = im.shape[3]
        im = im.cuda()
        latent = self.encoder(im)
        latent.cuda()
        x, y = np.meshgrid(np.linspace(0, 1, h), np.linspace(0, 1, w))
        pts = np.stack((x, y), axis=2)
        sampling_grid = torch.FloatTensor(pts)
        sampling_grid = sampling_grid.cuda()
        sampling_grid = sampling_grid.view(h*w, 2)
        sampling_grid = sampling_grid.repeat((batch_size, 1, 1))
        pred = self.decoder(sampling_grid, latent)
        res = pred.squeeze(-1).unsqueeze(1).reshape((batch_size, 1, 40, 40))
        return res

    def training_step(self, batch, batch_idx):
        # print(len(batch))
        img, labels = batch
        img = img.cuda()
        pred = self.forward(img)
        # print(pred.shape, img.shape)
        rec_mse = self.mse(pred,img)
        if batch_idx % 250 == 0:
          gt_imgs = [to_wandb_im(img[i], caption='gt_image') for i in range(1)]
          rec_imgs = [rec_to_wandb_im(pred[i], caption='rec_image') for i in range(1)]
          self.logger.experiment.log({
              "imgs": gt_imgs,
              "reconstructions": rec_imgs,
              "train loss": rec_mse
          })
        return rec_mse
    
    def validation_step(self, batch, batch_idx):
        # print(batch[0].shape, batch[1].shape)
        img, labels = batch
        img = img.cuda()
        pred = self.forward(img).cuda()
        # print(pred.shape, img.shape)
        rec_mse = self.mse(pred, img)
        if batch_idx % 250 == 0:
          print(img.shape)
          gt_imgs = [to_wandb_im(img[i], caption='gt_image') for i in range(1)]
          rec_imgs = [rec_to_wandb_im(pred[i], caption='rec_image') for i in range(1)]
          self.logger.experiment.log({
              "imgs": gt_imgs,
              "reconstructions": rec_imgs
          })
        return rec_mse
    
    def configure_optimizers(self):
        param_sets = [
            {'params': self.encoder.parameters()},
            {'params': self.decoder.parameters(), 'lr': self.lr * self.args.fcae.decoder.lr_coeff}
        ]
        opt = optim.RAdam(param_sets, lr=self.lr, weight_decay=self.weight_decay)
        scheduler_step = 'epoch'
        lr_sched = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=self.lr_decay)

        return [opt], [{
            'scheduler': lr_sched,
            'interval': scheduler_step,
            'name': 'fcae'
        }]


In [None]:
from torchvision.datasets import MNIST
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision import transforms

data_path = Path('data')
num_classes = 10
im_channels = 1
image_size = (40,40)
t = transforms.Compose([
            transforms.RandomCrop(size=image_size, pad_if_needed=True),
            transforms.ToTensor()
            ])
dataloader_args = EasyDict(batch_size=args.batch_size, shuffle=False,
                               num_workers=10)
train_dataloader = DataLoader(MNIST(data_path/'mnist', train=True, transform=t, download=True), **dataloader_args)
val_dataloader = DataLoader(MNIST(data_path/'mnist', train=False, transform=t, download=True), **dataloader_args)
model = FCAE(args=args, watch=True)
if torch.cuda.is_available():
    model.cuda()
trainer = pl.Trainer(max_epochs=20, logger=logger)
print("MODEL IS USING", model.device)
# summary(model, (1,40,40))
trainer.fit(model, train_dataloader, val_dataloader)

  cpuset_checked))
[34m[1mwandb[0m: Currently logged in as: [33mmlatberkeley[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name    | Type            | Params
--------------------------------------------
0 | encoder | CapsuleEncoder  | 60.6 K
1 | decoder | ImplicitDecoder | 1.1 K 
2 | mse     | MSELoss         | 0     
--------------------------------------------
61.6 K    Trainable params
0         Non-trainable params
61.6 K    Total params
0.247     Total estimated model params size (MB)


MODEL IS USING cuda:0


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  cpuset_checked))


torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

torch.Size([500, 1, 40, 40])



1