# Weight Uncertainty Autoencoder
_(Requires Python 3, PyTorch 1.0.1, TorchVision 0.2.2)_

**Reference**: _C. Blundell et al,_ [Weight Uncertainty in Neural Networks](https://arxiv.org/abs/1505.05424)

### Libraries
Import torch.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.utils.data as Data
import torchvision

In [None]:
from torch.distributions.normal import Normal

We'll also need numpy and matplotlib.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

### Test CUDA configuration

In [None]:
assert torch.cuda.is_available()
assert torch.cuda.current_device() == 0 
assert torch.cuda.device_count() >= 1
assert torch.cuda.memory_allocated() == 0
assert torch.cuda.memory_cached() == 0

In [None]:
torch.cuda.empty_cache()

## Model
### Weight Uncertainty Layer

In [None]:
class WULayer(nn.Module):
    
    def __init__(self, prev_d, d):
        super(WULayer, self).__init__()
        self.mu = nn.Parameter(torch.zeros(d, prev_d))
        # Initialize 'ro' so that sampled weights are small (sd on output is 0.05).
        self.ro = nn.Parameter(torch.log(torch.exp(1.5*torch.ones(d, prev_d)/np.sqrt(d))-1))
        self.bs = nn.Parameter(torch.zeros(d, 1))
    
    def forward(self, x):
        sigma = torch.log(1 + torch.exp(self.ro))
        weights = torch.randn(self.mu.shape).cuda() * sigma + self.mu
        loss_postr = torch.sum(Normal(self.mu,sigma).log_prob(weights))
        mix_a = Normal(0,0.01).log_prob(weights) + torch.log(torch.tensor([0.7]).cuda())
        mix_b = Normal(0,2).log_prob(weights)   + torch.log(torch.tensor([0.3]).cuda())
        loss_prior = torch.sum(torch.logsumexp(torch.cat([mix_a.view(-1,1), mix_b.view(-1,1)], dim=1),dim=1))
        return torch.mm(weights, x) + self.bs, loss_postr - loss_prior

### Coder

In [None]:
class Coder(nn.Module):

    def __init__(self, i_dim, h_dim, o_dim):
        super(Coder, self).__init__()
        self.layers = nn.ModuleList()

        self.i_dim  = i_dim
        self.h_dims = np.array(h_dim)
        self.o_dim  = o_dim

        prev_d = i_dim
        for d in h_dim:
            self.layers.append(WULayer(prev_d, d))
            prev_d = d
        self.layers.append(WULayer(prev_d, o_dim))
    
    def forward(self, x):
        loss = 0
        for layer in self.layers[:-1]:
            x, l = layer(x)
            x = func.relu(x)
            loss += l
        x, l = self.layers[-1](x)
        loss += l
        return x, loss

### Weight Uncertainty AutoEncoder

In [None]:
class WUAE(nn.Module):
    
    def __init__(self, i_dim, h_dim, l_dim):
        super(WUAE, self).__init__()
        self.encoder = Coder(i_dim, h_dim, l_dim)
        self.decoder = Coder(l_dim, np.flip(h_dim), i_dim)

    def forward(self, x):
        x = torch.t(x)
        z, le = self.encoder(x)
        x, ld = self.decoder(z)
        return torch.t(torch.sigmoid(x)), le+ld

## Training

### Dataset

In [None]:
all_data = np.loadtxt('sc_mouse_binary.txt.gz', delimiter='\t', dtype=int)
# Generate random indices
np.random.seed(10)
idx = np.arange(len(all_data))
np.random.shuffle(idx)
# Take 10% for test
test_idx = idx[:len(all_data)/10,:]
train_idx = idx[len(all_data)/10:,:]
# Generate sets
train_data = all_data[train_idx,:].cuda()
test_data = all_data[test_idx,:].cuda()

### Instantiate WUAE

In [None]:
i_dim = train.shape[1]
h_dim = [1500, 800, 600]
o_dim = 10

wuae = WUAE(i_dim, h_dim, o_dim).cuda()

In [None]:
assert next(wuae.parameters()).is_cuda
assert next(wuae.encoder.parameters()).is_cuda
assert next(wuae.decoder.parameters()).is_cuda

### Training loop

In [None]:
batch_size = 128
n_epochs   = 500

In [None]:
optimizer = torch.optim.Adam(wuae.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,100], gamma=.316)

In [None]:
%matplotlib notebook

# Lists to store training losses
train_rec_loss = []
train_weight_loss = []
test_rec_loss  = []

# Set model to training mode
wuae.train()

loss_text = f.text(0, 0, "Initializing...")

for e in np.arange(n_epochs):
    # Shuffle training data
    np.random.shuffle(epoch_idx)
    batch_idx = np.array_split(epoch_idx, train_data.shape[0]/batch_size)
    N = len(batch_idx)

    rec_loss = 0
    weight_loss = 0
    batch_loss = 0
    for batch_no, idx in enumerate(batch_idx):
        # Input and target data
        x = train_data[idx,:]
        #noisy_x = x.clone().detach().cuda()
        #noise_idx = torch.FloatTensor(noisy_x.shape).uniform_(0,1) < .1
        #noisy_x[noise_idx] = torch.FloatTensor(noisy_x.shape).uniform_(0,1)[noise_idx].cuda()
        # Forward pass of the data through the network
        y, w_loss = wuae(x)
        # Compute the loss
        pi_i = 2**(len(train_batches)-batch_no-1) / (2**len(train_batches)-1)
        loss = nn.functional.binary_cross_entropy(y,x,reduction='sum')
        rec_loss += float(loss) / len(train_samples)
        weight_loss += pi_i * float(w_loss)
        batch_loss += float(loss) + pi_i * float(w_loss)
        # Reset the gradients
        optimizer.zero_grad()
        # Compute gradients
        loss.backward()
        # Update parameters
        optimizer.step()
        del loss

    # End of epoch, compute train & test loss
    train_rec_loss.append(rec_loss)
    train_weight_loss.append(weight_loss)
    out, _ = wuae(test_data)
    loss = float(nn.functional.binary_cross_entropy(out, test_data, reduction='sum')) / len(test_data)
    test_rec_loss.append(float(loss))
    
    print("epoch: {}, train: {:.3f}, test: {:.3f}".format(e+1, rec_loss, float(loss)))


    scheduler.step(batch_loss)


In [None]:
%matplotlib notebook
plt.plot(train_rec_loss)
plt.plot(test_rec_loss)
plt.show()

In [None]:
torch.save(wuae, 'wuae_bsc_wo_dropout.trc')