In [None]:
%load_ext autoreload
%autoreload 2

In [22]:
import sys
sys.path.append("./..")

In [80]:
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import tensorflow as tf
import numpy as np

#local import
from smallnorb.smallnorb import SmallNORB

In [6]:
class Mask(tf.keras.layers.Layer):
    """
    Mask operation described in 'Dynamic routinig between capsules'.
    
    ...
    
    Methods
    -------
    call(inputs, double_mask)
        mask a capsule layer
        set double_mask for multimnist dataset
    """
    def call(self, inputs, double_mask=None, **kwargs):
        if type(inputs) is list:
            if double_mask:
                inputs, mask1, mask2 = inputs
            else:
                inputs, mask = inputs
        else:  
            x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))        # 2 norm ?
            if double_mask:
                mask1 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,0],num_classes=x.get_shape().as_list()[1])
                mask2 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,1],num_classes=x.get_shape().as_list()[1])
            else:
                mask = tf.keras.backend.one_hot(indices=tf.argmax(x, 1), num_classes=x.get_shape().as_list()[1])

        if double_mask:
            masked1 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask1, -1))
            masked2 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask2, -1))
            return masked1, masked2
        else:
            masked = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask, -1))
            return masked

    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:  
            return tuple([None, input_shape[0][1] * input_shape[0][2]])
        else:  # generation step
            return tuple([None, input_shape[1] * input_shape[2]])

    def get_config(self):
        config = super(Mask, self).get_config()
        return config

In [38]:
def masking_max_norm(u):
    """
    IN:
        u (b, n d) ... capsules
    OUT:
        masked(u)  (b, n, d) where:
        - normalise over dimension d of u
        - keep largest vector in dimension n
        - mask out everything else
    """
    _, n_classes, _ = u.shape
    print(n_classes)
    u_norm = torch.norm(u, dim=2)
    print(u_norm)
    mask = F.one_hot(torch.argmax(u_norm, 1), num_classes=n_classes)
    print(mask)
    return torch.einsum('bnd,bn->bnd', u, mask)

In [39]:
torch.manual_seed(0)
u = torch.rand((1,5,16))
u.size()
u_mask = masking_max_norm(u)
u_mask

5
tensor([[1.9175, 2.4459, 2.0397, 2.3359, 2.6359]])
tensor([[0, 0, 0, 0, 1]])


tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7745, 0.4369, 0.5191, 0.6159, 0.8102, 0.9801, 0.1147, 0.3168,
          0.6965, 0.9143, 0.9351, 0.9412, 0.5995, 0.0652, 0.5460, 0.1872]]])

In [20]:
F.one_hot(torch.tensor(4))

tensor([0, 0, 0, 0, 1])

In [24]:
A = SmallNORB(root = 'data/SmallNORB',train=False,download=True,mode="stereo")

Downloading https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz to data/SmallNORB\raw\smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz


  0%|          | 0/131896188 [00:00<?, ?it/s]

# Extracting data smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat

Downloading https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz to data/SmallNORB\raw\smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz


  0%|          | 0/71052 [00:00<?, ?it/s]

# Extracting data smallnorb-5x46789x9x18x6x2x96x96-training-info.mat

Downloading https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz to data/SmallNORB\raw\smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz


  0%|          | 0/348 [00:00<?, ?it/s]

# Extracting data smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat

Downloading https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz to data/SmallNORB\raw\smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz


  0%|          | 0/130799817 [00:00<?, ?it/s]

# Extracting data smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat

Downloading https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz to data/SmallNORB\raw\smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz


  0%|          | 0/10428 [00:00<?, ?it/s]

# Extracting data smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat

Downloading https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz to data/SmallNORB\raw\smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz


  0%|          | 0/347 [00:00<?, ?it/s]

# Extracting data smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat

Processing...
Done!


In [41]:
a = np.asarray(A[0][0])
b = A[3][2]
print(b, F.one_hot(b,num_classes=5))

tensor(3) tensor([0, 0, 0, 1, 0])


In [45]:
def masking_y_true(u, y_true):
    """
    IN:
        u (b, n d) ... capsules
        y_true (b,)  ... classification value (skalar)
    OUT:
        masked(u)  (b, n, d) where:
        - normalise over dimension d of u
        - keep vector in dimension n with y_true
        - mask out everything else
    """
    _, n_classes, _ = u.shape
    print(n_classes)
    u_norm = torch.norm(u, dim=2)
    print(u_norm)
    mask = F.one_hot(y_true, num_classes=n_classes)
    print(mask)
    return torch.einsum('bnd,bn->bnd', u, mask)

In [71]:
torch.manual_seed(6)
batch_size = 3
u = torch.rand((batch_size,5,16))
y_true = torch.randint(high = 5, size=(batch_size,))
print(y_true.size())
print(y_true)
print(u.size())
u_mask = masking_y_true(u, y_true)
u_mask 

torch.Size([3])
tensor([2, 1, 1])
torch.Size([3, 5, 16])
5
tensor([[2.3595, 2.0025, 2.6382, 2.2739, 2.3331],
        [2.1719, 2.5303, 2.0609, 2.6190, 2.5837],
        [2.1154, 2.3592, 2.5738, 2.2197, 2.0663]])
tensor([[0, 0, 1, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0]])


tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8837, 0.0947, 0.7794, 0.6970, 0.3245, 0.2406, 0.8824, 0.2953,
          0.9455, 0.5017, 0.9711, 0.8482, 0.3557, 0.4164, 0.5201, 0.8180],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.9835, 0.2522, 0.3455, 0.2098, 0.7763, 0.7467, 0.6680, 0.3442,
          0.5228, 0.8239, 0.64

In [86]:
ds_train = SmallNORB(root='data/SmallNORB',train=True, download=True,transform=T.ToTensor(), mode="left")
ds_valid = SmallNORB(root='data/SmallNORB',train=False, download=True,transform=T.ToTensor(),  mode="left")

dl_train = torch.utils.data.DataLoader(ds_train, 
                                       batch_size=3, 
                                       shuffle=True, 
                                       num_workers=4)
dl_valid = torch.utils.data.DataLoader(ds_valid, 
                                       batch_size=3, 
                                       shuffle=True, 
                                       num_workers=4)

In [87]:
# plot train imgs
x, y = next(iter(dl_train))

In [88]:
y.size()

torch.Size([3])

In [90]:
# plot train imgs
x, y = next(iter(dl_train))
torch.manual_seed(6)
batch_size = 3
u = torch.rand((batch_size,5,16))
y_true = y
print(y_true.size())
print(y_true)
print(u.size())
u_mask = masking_y_true(u, y_true)
u_mask 

torch.Size([3])
tensor([0, 3, 2])
torch.Size([3, 5, 16])
5
tensor([[2.3595, 2.0025, 2.6382, 2.2739, 2.3331],
        [2.1719, 2.5303, 2.0609, 2.6190, 2.5837],
        [2.1154, 2.3592, 2.5738, 2.2197, 2.0663]])
tensor([[1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0]])


tensor([[[0.5722, 0.5539, 0.9868, 0.6080, 0.2347, 0.4492, 0.6743, 0.7480,
          0.5601, 0.1674, 0.3333, 0.4648, 0.6332, 0.7692, 0.2147, 0.7815],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.00