In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./..")

In [None]:
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import tensorflow as tf
import numpy as np

#local import
from smallnorb.smallnorb import SmallNORB

In [None]:
class Mask(tf.keras.layers.Layer):
    """
    Mask operation described in 'Dynamic routinig between capsules'.
    
    ...
    
    Methods
    -------
    call(inputs, double_mask)
        mask a capsule layer
        set double_mask for multimnist dataset
    """
    def call(self, inputs, double_mask=None, **kwargs):
        if type(inputs) is list:
            if double_mask:
                inputs, mask1, mask2 = inputs
            else:
                inputs, mask = inputs
        else:  
            x = tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))        # 2 norm ?
            if double_mask:
                mask1 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,0],num_classes=x.get_shape().as_list()[1])
                mask2 = tf.keras.backend.one_hot(tf.argsort(x,direction='DESCENDING',axis=-1)[...,1],num_classes=x.get_shape().as_list()[1])
            else:
                mask = tf.keras.backend.one_hot(indices=tf.argmax(x, 1), num_classes=x.get_shape().as_list()[1])

        if double_mask:
            masked1 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask1, -1))
            masked2 = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask2, -1))
            return masked1, masked2
        else:
            masked = tf.keras.backend.batch_flatten(inputs * tf.expand_dims(mask, -1))
            return masked

    def compute_output_shape(self, input_shape):
        if type(input_shape[0]) is tuple:  
            return tuple([None, input_shape[0][1] * input_shape[0][2]])
        else:  # generation step
            return tuple([None, input_shape[1] * input_shape[2]])

    def get_config(self):
        config = super(Mask, self).get_config()
        return config

In [None]:
def masking_max_norm(u):
    """
    IN:
        u (b, n d) ... capsules
    OUT:
        masked(u)  (b, n, d) where:
        - normalise over dimension d of u
        - keep largest vector in dimension n
        - mask out everything else
    """
    _, n_classes, _ = u.shape
    print(n_classes)
    u_norm = torch.norm(u, dim=2)
    print(u_norm)
    mask = F.one_hot(torch.argmax(u_norm, 1), num_classes=n_classes)
    print(mask)
    return torch.einsum('bnd,bn->bnd', u, mask)

In [None]:
torch.manual_seed(0)
u = torch.rand((1,5,16))
u.size()
u_mask = masking_max_norm(u)
u_mask

In [None]:
F.one_hot(torch.tensor(4))

In [None]:
A = SmallNORB(root = 'data/SmallNORB',train=False,download=True,mode="stereo")

In [None]:
a = np.asarray(A[0][0])
b = A[3][2]
print(b, F.one_hot(b,num_classes=5))

In [None]:
def masking_y_true(u, y_true):
    """
    IN:
        u (b, n d) ... capsules
        y_true (b,)  ... classification value (skalar)
    OUT:
        masked(u)  (b, n, d) where:
        - normalise over dimension d of u
        - keep vector in dimension n with y_true
        - mask out everything else
    """
    _, n_classes, _ = u.shape
    print(n_classes)
    u_norm = torch.norm(u, dim=2)
    print(u_norm)
    mask = F.one_hot(y_true, num_classes=n_classes)
    print(mask)
    return torch.einsum('bnd,bn->bnd', u, mask)

In [None]:
torch.manual_seed(6)
batch_size = 3
u = torch.rand((batch_size,5,16))
y_true = torch.randint(high = 5, size=(batch_size,))
print(y_true.size())
print(y_true)
print(u.size())
u_mask = masking_y_true(u, y_true)
u_mask 

In [None]:
ds_train = SmallNORB(root='data/SmallNORB',train=True, download=True,transform=T.ToTensor(), mode="left")
ds_valid = SmallNORB(root='data/SmallNORB',train=False, download=True,transform=T.ToTensor(),  mode="left")

dl_train = torch.utils.data.DataLoader(ds_train, 
                                       batch_size=3, 
                                       shuffle=True, 
                                       num_workers=4)
dl_valid = torch.utils.data.DataLoader(ds_valid, 
                                       batch_size=3, 
                                       shuffle=True, 
                                       num_workers=4)

In [None]:
# plot train imgs
x, y = next(iter(dl_train))

In [None]:
y.size()

In [None]:
# plot train imgs
x, y = next(iter(dl_train))
torch.manual_seed(6)
batch_size = 3
u = torch.rand((batch_size,5,16))
y_true = y
print(y_true.size())
print(y_true)
print(u.size())
u_mask = masking_y_true(u, y_true)
u_mask 