In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import Parameter

import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from torch.autograd import Variable

get_ipython().run_line_magic('pylab', 'inline')

In [None]:
class ModuleWrapper(nn.Module):
    """Wrapper for nn.Module with support for arbitrary flags and a universal forward pass"""

    def __init__(self):
        super(ModuleWrapper, self).__init__()

    def set_flag(self, flag_name, value):
        setattr(self, flag_name, value)
        for m in self.children():
            if hasattr(m, 'set_flag'):
                m.set_flag(flag_name, value)

    def forward(self, x):
        for module in self.children():
            x = module(x)
        return x


class LinearVariance(ModuleWrapper):
    def __init__(self, in_features, out_features, bias=True):
        super(LinearVariance, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.sigma = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(1, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.sigma.size(1))
        self.sigma.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, x):
        lrt_mean = self.bias
        lrt_std = torch.sqrt_(1e-16 + F.linear(x * x, self.sigma * self.sigma))
        eps = Variable(lrt_std.data.new(lrt_std.size()).normal_())
        return lrt_mean + eps * lrt_std

In [None]:
def get_data(means, variance=[[1, 0], [0, 1]], n=500):
    xs, ys = [], []
    
    for c, mean in enumerate(means):
        x, y = np.random.multivariate_normal(mean, variance, n).T
        data, labels = np.array(list(zip(x, y))), np.zeros(n)+c
        xs.append(data)
        ys.append(labels)
        
    X, y = np.vstack(xs).astype(float32),  np.hstack(ys).astype(long)

    return X, y

Xtr, ytr = get_data([[3, 3], [3, 10], [10, 3], [10, 10]])
Xte, yte = get_data([[3, 3], [3, 10], [10, 3], [10, 10]])
pylab.scatter(Xtr[:, 0], Xtr[:, 1], c=ytr)

In [None]:
model = torch.nn.Sequential(
  torch.nn.Linear(2, 100),
  torch.nn.LeakyReLU(),
    
  torch.nn.Linear(100, 100),
  torch.nn.LeakyReLU(),
    
  torch.nn.Linear(100, 100),
  torch.nn.LeakyReLU(),
    
  torch.nn.Linear(100, 100),
  torch.nn.LeakyReLU(),
  
  LinearVariance(100, 2),
  torch.nn.LeakyReLU(),
  
  torch.nn.Linear(2, 100),
  torch.nn.LeakyReLU(),
  
  torch.nn.Linear(100, 4),
  torch.nn.Softmax()
)

loss_fn = torch.nn.CrossEntropyLoss(size_average=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
for t in range(21000):
    y_pred = model(Variable(torch.from_numpy(Xtr)))
    y_true = Variable(torch.from_numpy(ytr)).long()
    loss = loss_fn(y_pred, y_true)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t % 1000 == 0:
        y_test_sam = model(Variable(torch.from_numpy(Xte)))
        
        y_test_ens = model(Variable(torch.from_numpy(Xte)))
        for i in range(100):
            y_test_ens += model(Variable(torch.from_numpy(Xte))) 
        y_test_ens /= 101.
        
        loss = loss.item()
        acc_sam =  np.mean(y_test_sam.argmax(1).numpy() == yte)
        acc_ens = np.mean(y_test_ens.argmax(1).numpy() == yte)
        
        print('iter %s:' % t, 
              'loss_train = %.3f' % loss, 
              'acc_test_one_sampl = %.3f' % acc_sam,
              'acc_test_ens = %.3f' % acc_ens)