In [1]:
import numpy as np

import torch
from torch import nn
from torch import functional as F

from livelossplot import PlotLosses

import matplotlib.pyplot as plt
plt.style.use('ggplot')

In [2]:
from Inference import BBVI 

In [3]:
from Inference.BBVI import VariationalNetwork

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [5]:
sigma_noise = 0.1
data = torch.load('Data/foong_data.pt')
x_data = data[0].to(device)
y_data = data[1].to(device)
y_data = y_data.unsqueeze(-1)

In [6]:
class MixtureVariationalNetwork(nn.Module):
    def __init__(self, input_size, output_size, layer_width, nb_layers, device=None):
        super(MixtureVariationalNetwork, self).__init__()
        
        self.layer_width = layer_width
        self.input_size = input_size
        self.output_size = output_size
        self.nb_layers = nb_layers
        self.device = device
        
        self.components = []
        self.pi = torch.tensor([], device=device)
        
    def add_component(self, component, proportion):
        #todo check compatibility with other components
        self.components.append(component)
        proportion = proportion.to(self.device)
        self.pi = torch.cat((self.pi*(1-proportion), proportion.unsqueeze(0)))
        
    def sample_parameters(self, M=1):
        D = torch.distributions.multinomial.Multinomial(M, self.pi)
        m = D.sample()
        S = []
        for j in range(len(self.pi)):
            S.append(self.components[j].sample_parameters(int(m[j])))
        return ([torch.cat([c[0][k] for c in S]) for k in range(self.nb_layers)], [torch.cat([c[1][k] for c in S]) for k in range(self.nb_layers)])

    def q_log_pdf(self, layered_w_samples, layered_bias_samples):
        log_q = [c.q_log_pdf(layered_w_samples, layered_bias_samples) for c in self.components]
        return torch.logsumexp(torch.stack(log_q) + torch.log(self.pi).unsqueeze(0).t(), dim=0)

    def prior_log_pdf(self, layered_w_samples, layered_bias_samples):
        log_prior = [c.prior_log_pdf(layered_w_samples, layered_bias_samples) for c in self.components]
        return torch.logsumexp(torch.stack(log_prior) + torch.log(self.pi).unsqueeze(0).t(), dim=0)
    
    def compute_elbo(self, x_data, y_data, n_samples_ELBO, sigma_noise, new_component=None, new_proportion=None):
    
        # sample X^(c)
        (layered_w_samples_XC, layered_bias_samples_XC) = self.sample_parameters(n_samples_ELBO)

        LP_XC = self.prior_log_pdf(layered_w_samples_XC, layered_bias_samples_XC)
        y_pred_XC = self.forward(x_data)
        LL_XC = self._log_norm(y_pred_XC, y_data, torch.tensor(sigma_noise).to(self.device))
        posterior_XC = torch.sum(LP_XC.unsqueeze(-1).unsqueeze(-1) + LL_XC, dim=[1,2])

        qC_log_XC = self.q_log_pdf(layered_w_samples_XC, layered_bias_samples_XC)
        
        if new_component is None:
            return torch.mean(qC_log_XC - posterior_XC)
        
        qN_log_XC = new_component.q_log_pdf(layered_w_samples_XC, layered_bias_samples_XC)
        qCN_log_XC = torch.logsumexp(torch.stack([torch.log(torch.tensor(1.0)-new_proportion) + qC_log_XC, torch.log(new_proportion) + qN_log_XC],dim=0),dim=0)
        
        # sample X_(c+1)
        (layered_w_samples_XN, layered_bias_samples_XN) = new_component.sample_parameters(n_samples_ELBO)
        
        LP_XN = new_component.prior_log_pdf(layered_w_samples_XN, layered_bias_samples_XN)
        y_pred_XN = new_component.forward(x_data)
        LL_XN = new_component._log_norm(y_pred_XN, y_data, torch.tensor(sigma_noise).to(self.device))
        posterior_XN = torch.sum(LP_XN.unsqueeze(-1).unsqueeze(-1) + LL_XN, dim=[1,2])

        qC_log_XN = self.q_log_pdf(layered_w_samples_XN, layered_bias_samples_XN)
        
        qN_log_XN = new_component.q_log_pdf(layered_w_samples_XN, layered_bias_samples_XN)
        qCN_log_XN = torch.logsumexp(torch.stack([torch.log(torch.tensor(1.0)-new_proportion) + qC_log_XN, torch.log(new_proportion) + qN_log_XN],dim=0),dim=0)
        
        L = (torch.tensor(1.0)-new_proportion) * qCN_log_XC.mean() + new_proportion * qCN_log_XN.mean() 
        
        return L
    
    def requires_grad_rhos(self, v = False):
        for k in range(len(self.components)):
            self.components[k].requires_grad_rhos(v)
        
    def requires_grad_mus(self, v = False):
        for k in range(len(self.components)):
            self.components[k].requires_grad_mus(v)
    
    def forward(self, x):
        return torch.cat([self.components[k].forward(x) for k in range(len(self.components))],dim=0)

    def _log_norm(self, x, mu, std):
        return -0.5 * torch.log(2*np.pi*std**2) -(0.5 * (1/(std**2))* (x-mu)**2)
    
    #def KL_log_pdf(self):
    

In [7]:
class VariationalBoostingOptimizer():
    def __init__(self, mixture, sigma_noise, optimizer, optimizer_params, scheduler=None, scheduler_params=None):
        self.mixture = mixture
        self.sigma_noise = sigma_noise
        self.device = model.device
        
        self.optimizer = optimizer
        if scheduler is None:           
            self.scheduler = None
        else:
            self.scheduler = scheduler
            
    def run(self, data, n_epoch=100, n_iter=1, n_ELBO_samples=1, seed=None, plot=False, savePath=None, xpName=None, networkName=None, saveName=None):
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)

        x_data = data[0].to(self.device)
        y_data = data[1].to(self.device)

        self.optimizer.zero_grad()

        if saveName is not None and savePath is not None:
            liveloss = PlotLosses(fig_path=str(savePath)+str(xpName)+'_'+str(networkName)+'_'+str(saveName))
        else:
            liveloss = PlotLosses()

        for j in range(n_epoch):
            logs = {}
            losses = [None] * n_iter

            for k in range(n_iter):
                self.optimizer.zero_grad()
                loss = self.model.compute_elbo(x_data, y_data, n_ELBO_samples, self.sigma_noise, self.device)
                losses[k] = loss
                loss.backward()
                self.optimizer.step()

            logs['expected_loss'] = torch.stack(losses).mean().detach().clone().cpu().numpy()
            logs['learning rate'] = self.optimizer.param_groups[0]['lr']
            liveloss.update(logs)
            if plot is True:
                liveloss.draw()
            if self.scheduler is not None:
                self.scheduler.step(logs['expected_loss'])
        return self.model

In [8]:
c1 = VariationalNetwork(1, 1, 20, 2, device=device)

In [None]:
optimizer = torch.optim.Adam
optimizer_params = {'lr': 0.05}

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler_params = {'patience': 3, 'factor': 0.8}

In [None]:
voptimizer = BBVI.VariationalOptimizer(model=c1, sigma_noise=0.1, optimizer=optimizer, optimizer_params=optimizer_params, scheduler=scheduler, scheduler_params=scheduler_params)

In [None]:
c1 = voptimizer.run((x_data,y_data), n_epoch=100, n_iter=100, n_ELBO_samples=50, plot=True)

In [9]:
mix = MixtureVariationalNetwork(1, 1, 20, 2, device=device)

In [10]:
mix.add_component(c1, torch.tensor(1.0))

In [11]:
new_component = VariationalNetwork(1, 1, 20, 2, device=device)

In [12]:
a = nn.Parameter(torch.tensor(0.5, requires_grad=True, device=device))

In [13]:
mix.add_component(new_component, torch.tensor(0.5))

In [14]:
x_test = torch.linspace(-2.0, 2.0).unsqueeze(1).to(device)

In [15]:
D = torch.distributions.multinomial.Multinomial(1, mix.pi)

In [16]:
m = D.sample()
S = []

In [17]:
for j in range(len(mix.pi)):
    S.append(mix.components[j].sample_parameters(int(m[j])))

In [26]:
S[1]

([], [])

In [18]:
([torch.cat([c[0][k] for c in S]) for k in range(mix.nb_layers)], [torch.cat([c[1][k] for c in S]) for k in range(mix.nb_layers)])

IndexError: list index out of range

In [None]:
x_test = torch.linspace(-2.0, 2.0).unsqueeze(1).to(device)
fig, ax = plt.subplots()
fig.set_size_inches(11.7, 8.27)
plt.scatter(x_data.cpu(), y_data.cpu())
for _ in range(1000):
    mix.sample_parameters()

    y_test = mix.forward(x_test)
    plt.plot(x_test.detach().cpu().numpy(), y_test.squeeze(0).detach().cpu().numpy(), alpha=0.05, linewidth=1, color='lightblue')

In [None]:
optimizer = torch.optim.Adam
optimizer_params = {'lr': 0.005}

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler_params = {'patience': 3, 'factor': 0.8}

In [None]:
parameters = list(new_component.parameters()) + [a]

In [None]:
vo = optimizer(parameters, **optimizer_params)

In [None]:
s = scheduler(vo, **scheduler_params)

In [None]:
n_iter = 1
n_epoch = 1000
liveloss = PlotLosses()
plot = True

In [None]:
for j in range(n_epoch):
    logs = {}
    losses = [None] * n_iter

    for k in range(n_iter):
        vo.zero_grad()
        loss = mix.compute_elbo(x_data, y_data, 100, .1, new_component=new_component, new_proportion=a)
        losses[k] = loss
        loss.backward()
        vo.step()

    logs['expected_loss'] = torch.stack(losses).mean().detach().clone().cpu().numpy()
    logs['learning rate'] = vo.param_groups[0]['lr']
    liveloss.update(logs)
    if plot is True:
        liveloss.draw()
    if s is not None:
        s.step(logs['expected_loss'])
