In [365]:
import os
import sys
import math
from copy import deepcopy

import numpy as np
import pandas as pd

%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn as nn 
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.datasets import MNIST

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

use_cuda = False
os.environ["CUDA_VISIBLE_DEVICES"]="5"
if torch.cuda.is_available():
    device = torch.cuda.device("cuda:5")
    use_cuda = True

In [538]:
class SteinLinear(nn.Module):
    def __init__(self, in_features, out_features, n_particles=1, bias=True):
        super(SteinLinear, self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.n_particles = n_particles
        
        self.weight = torch.nn.Parameter(torch.Tensor(n_particles, in_features, out_features))
        
        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(n_particles, 1, out_features))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(2))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
            
    def forward(self, X):
        return torch.matmul(X, self.weight) + self.bias

In [546]:
rr = nn.PReLU()
x = torch.Tensor([1,-2,3])
t = torch.Tensor([1, 2, 3])

In [551]:
loss = torch.mean((rr(x) - t) ** 2)

In [552]:
loss.backward()

In [554]:
for name, param in rr.named_parameters():
    print(param.grad)

tensor([ 3.3333])


In [568]:
arc = nn.Sequential(
    SteinLinear(10, 5, 10),
    nn.LeakyReLU(),
    SteinLinear(5, 4, 10),
    nn.LeakyReLU(),
    SteinLinear(4, 1, 10)
)

In [576]:
class SteinNet():
    def __init__(self, arc, train_size):
        '''
        p(y|x,w) = N(y|f(x,w), alpha^-1)
        p0(w) = N(w|0, betta^-1)
        
        k(w, w`) = exp^(-1/h * ||w - w`||^2)
        '''
        
        ### f(x, w)
        self.arc = arc
        if use_cuda:
            self.arc = self.arc.cuda()
            
        ### n_particles
        self.n_particles = arc[0].n_particles
        
        self.train_size = train_size
        
        ### variances from probabilistic model
        self.alpha = torch.tensor(1.)
        self.betta = torch.tensor(1.)
        if use_cuda:
            self.alpha = self.alpha.cuda()
            self.betta = self.betta.cuda()
            
        ### factor from kernel
        self.h = torch.tensor(1.)
        if use_cuda:
            self.h = self.h.cuda()
    
    ### SUM for all j {log p(Dj|w)}
    ### return tensor [n_particles]
    def calc_data_term(self, X, y):
        two = torch.tensor(2.)
        if use_cuda:
            two = two.cuda()
        
        X = X.view(1, *X.shape).expand(self.n_particles, *X.shape)
        y_p = self.arc(X).view(self.n_particles, X.shape[1])
        
        return -(torch.pow(self.alpha, two) / two) * torch.mean(torch.pow(y - y_p, two), dim=1) * self.train_size
    
    ### log p0(w)
    ### return tensor [n_particles] 
    def calc_prior_term(self):
        result = torch.zeros([self.n_particles])
        two = torch.tensor(2.)
        if use_cuda:
            result = result.cuda()
            two = two.cuda()
            
        for name, param in self.arc.named_parameters():
            log_p0 = -(torch.pow(self.betta, two) / two) * torch.pow(param, two)
            result += torch.sum(log_p0.view(param.shape[0], -1), dim=1)
        return result
    
    ### k(w, w)
    ### return tensor [n_particles, n_particles]
    def calc_kernel_term(self):
        one = torch.tensor(1.)
        distances = torch.zeros([self.n_particles, self.n_particles])
        if use_cuda:
            one = one.cuda()
            distances = distances.cuda()
        
        for name, param in self.arc.named_parameters():
            distances += self.pairwise_distances(param.view(self.n_particles, -1), param.view(self.n_particles, -1).)
            
        return torch.exp(-one / self.h * distances)
            
    @staticmethod
    def pairwise_distances(x, y=None):
        '''
        Input: x is a Nxd matrix
               y is an optional Mxd matirx
        Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
                if y is not given then use 'y=x'.
        i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
        '''
        x_norm = (x**2).sum(1).view(-1, 1)
        if y is not None:
            y_t = torch.transpose(y, 0, 1)
            y_norm = (y**2).sum(1).view(1, -1)
        else:
            y_t = torch.transpose(x, 0, 1)
            y_norm = x_norm.view(1, -1)

        dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
        return torch.clamp(dist, 0.0, np.inf)
    
    def calc_objective(self, X, y):
        prior = self.calc_prior_term
        data = self.calc_data_term
        kernel = self.calc_kernel_term
        
        log_term = prior + data
        log_term.backward()
        
        kernel.backward()

In [577]:
#X = torch.Tensor([[1, 2, 3], [2, 2, 2], [4, 3, 2]]).cuda()
X = torch.rand([20, 10]).cuda()
# y = torch.Tensor([1, 1, 1]).cuda()
y = torch.rand([20]).cuda()

In [578]:
s = SteinNet(arc, 1000)

In [579]:
s.calc_prior_term()

tensor([-3.9315, -3.2729, -3.8804, -3.7918, -3.3743, -3.4731, -3.4531,
        -4.1226, -3.4281, -4.1543], device='cuda:0')

In [580]:
s.calc_data_term(X, y)

tensor([-250.0467, -829.3492, -305.4180,  -47.2312,  -53.1171,  -72.4147,
         -61.3859, -127.1481, -103.7411, -913.6259], device='cuda:0')

In [584]:
s.calc_kernel_term()

tensor([[ 1.0000e+00,  4.8717e-08,  1.1981e-07,  1.1502e-07,  4.4647e-06,
          1.2512e-08,  1.3269e-06,  1.5255e-08,  1.8422e-08,  3.2467e-09],
        [ 4.8717e-08,  1.0000e+00,  1.0961e-06,  1.4858e-06,  2.7744e-06,
          2.2724e-06,  6.0164e-07,  6.8032e-08,  9.8728e-07,  5.6944e-06],
        [ 1.1981e-07,  1.0961e-06,  1.0000e+00,  5.9653e-08,  2.0812e-07,
          2.8557e-07,  2.0287e-08,  1.1702e-07,  6.0130e-08,  1.4912e-08],
        [ 1.1502e-07,  1.4858e-06,  5.9653e-08,  1.0000e+00,  2.3602e-07,
          1.0098e-05,  9.4667e-07,  2.9938e-09,  1.8713e-06,  6.7317e-07],
        [ 4.4647e-06,  2.7744e-06,  2.0812e-07,  2.3602e-07,  1.0000e+00,
          9.3492e-07,  7.7692e-07,  7.0353e-06,  3.2348e-07,  1.7565e-09],
        [ 1.2512e-08,  2.2724e-06,  2.8557e-07,  1.0098e-05,  9.3492e-07,
          1.0000e+00,  4.3401e-06,  6.5919e-08,  1.6981e-06,  7.9088e-07],
        [ 1.3269e-06,  6.0164e-07,  2.0287e-08,  9.4667e-07,  7.7692e-07,
          4.3401e-06,  1.0000e+0