In [1]:
from fastai_bayesian import *

In [2]:
import torch
import torch.nn as nn
from torch.distributions import Uniform, Bernoulli

In [16]:
norm2 = lambda x : (x**2).mean()
neg_entropy = lambda p : p * torch.log(p) + (1-p) * torch.log(1-p)

def get_layer(m,buffer,layer,output="list"):
    """Function which takes a list and a model append the elements"""
    for c in m.children():
        if isinstance(c,layer):
            if isinstance(buffer,list):
              buffer.append(c)
            elif isinstance(buffer,dict):
              i = hex(id(c))
              buffer[i] = c
        get_layer(c,buffer,layer)

In [17]:
class PLU(nn.Module):
  """Probability Linear Unit"""
  def __init__(self):
    super(PLU,self).__init__()
  
  def forward(self,x):
    z = torch.clamp(x,0,1)
    return z    
  
class AutoDropout(nn.Module):
    def __init__(self, dp=0., requires_grad=True):

        super(AutoDropout, self).__init__()

        # We transform the dropout rate to keep rate
        p = 1 - dp
        p = torch.tensor(p)

        self.plu = PLU()

        if requires_grad:
            p = nn.Parameter(p)
            self.register_parameter("p", p)
        else:
            self.register_buffer("p", p)

    def forward(self, x):
        bs, shape = x.shape[0], x.shape[1:]

        # We make sure p is a probability
        p = self.plu(self.p)

        # We sample a mask
        m = Bernoulli(p).sample(shape)

        # Element wise multiplication
        z = x * m

        return z

    def extra_repr(self):
        return 'p={}'.format(
            self.p.item()
        )

In [22]:
class ConcreteDropout(nn.Module):
    def __init__(self,t:float=0.1,dp:float=0.5,requires_grad=True):
        super(ConcreteDropout,self).__init__()
        
        # We first invert the dropout rate to the keeping rate
        p = 1 - dp
        p = torch.tensor(p)
        
        if requires_grad:
            p = nn.Parameter(p)
            self.register_parameter("p",p)
        else:
            self.register_buffer("p",p)
        
        t = torch.tensor(t)
        self.register_buffer("t",t)
        
        self.u = Uniform(0,1)
        
    def forward(self,x):
        bs, shape = x.shape[0], x.shape[1:]
        
        u = self.u.sample(shape)
        p = self.p.expand(shape)
        
        m = torch.sigmoid((torch.log(p) - torch.log(1-p) + torch.log(u) - torch.log(1-u)) / self.t)
        
        m = m[None]
        
        z = m * x
        
        return z

In [23]:
class DropLinear(nn.Module):
    def __init__(self, in_features, out_features, dp_module,dp=0., bias=True, requires_grad=True):
        super(DropLinear, self).__init__()

        self.dp = dp_module(dp=dp,requires_grad=requires_grad)
        self.W = nn.Linear(in_features=in_features,
                           out_features=out_features, bias=bias)
        self.W.weight.data = self.W.weight.data / self.W.weight.data.norm() * (1-dp)

    def forward(self, x):
        z = self.W(x)
        z = self.dp(z)
        return z

In [52]:
from fastai.callbacks.hooks import HookCallback

class KLHook(HookCallback):
    """Hook to register the parameters of the latents during the forward pass to compute the KL term of the VAE"""
    def __init__(self, learn,l:float=1e-2,do_remove:bool=True,recording=False):
        super().__init__(learn)
        
        # First we store all the DropLinears layers to hook them 
        buffer = []
        get_layer(learn.model,buffer,DropLinear)
        if not buffer:
            raise NotImplementedError(f"No {DropLinear} Linear found")
            
        self.modules = buffer
        self.do_remove = do_remove
        
        # We will store the KL of each DropLinear here before summing them
        self.kls = []
        
        self.N = len(learn.data.train_ds)
        self.l = l
        
        self.recording = recording
        
        if recording:
            self.stats = []
            self.loss = []
    
    def on_backward_begin(self,last_loss,**kwargs):
        
        total_kl = 0
        for kl in self.kls:
            total_kl += kl
            
        total_kl /= self.N
            
        total_loss = last_loss + total_kl
        
        if self.recording:
            self.loss.append({"total_kl":total_kl.item(),"last_loss":last_loss.item(),
                            "total_loss":total_loss.item()})
        
        # We empty the buffer of kls
        self.kls = []
        
        return {"last_loss" : total_loss}
        
    def hook(self, m:nn.Module, i, o):
        "Save the latents of the bottleneck"
        p = m.dp.p
        
        W = m.W.weight
        norm_w = norm2(W)
        
        K_out = m.W.out_features
        K_in = m.W.in_features
        
        l = (self.l ** 2) * K_in
        
        kl = l * p * norm_w / 2 + K_out * neg_entropy(p)       
        
        self.kls.append(kl)
        
        if self.recording:
            i = hex(id(m))
            self.stats.append({"dropout":1 - p.item(),"w":norm_w.item(),"module":i})
    
    def plot_stats(self):
        assert self.recording, "Recording mode was off during initialization"
        df = pd.DataFrame(self.stats)
        df.plot()
        
    def plot_losses(self):
        assert self.recording, "Recording mode was off during initialization"
        df = pd.DataFrame(self.loss)
        df.plot()

In [31]:
x = torch.randn(64,4)

In [37]:
dp_module = ConcreteDropout

In [38]:
lin = DropLinear(4,4,dp_module,0.5)

In [42]:
lin.W.out_features

4

In [43]:
from fastai.vision import *

In [44]:
path = untar_data(URLs.MNIST_SAMPLE)

In [47]:
data = ImageDataBunch.from_folder(path)

In [49]:
len(data.train_ds)

12396