In [3]:
import numpy as np

import torch
from torch import nn
from torch import functional as F

from livelossplot import PlotLosses

import matplotlib.pyplot as plt
plt.style.use('ggplot')

In [4]:
from Inference import BBVI 

In [10]:
from Inference.BBVI import VariationalNetwork

In [None]:
sigma_noise = 0.1
(x_data, y_data) = torch.load('Data/foong_data.pt')

In [44]:
class BoostingModel(nn.Module):
    
    def __init__(self, nComponents, tolerance, input_size, output_size, layer_width, nb_layers, device=None):
        super(BoostingModel, self).__init__()
        
        self.H = layer_width
        self.fixed_components = []
        self.fixed_mixture_probas = torch.tensor([1.])
        self.learnable_proba = None
        self.learnable_component = VariationalNetwork(input_size, output_size, layer_width, nb_layers, device)
        self.nComponents = nComponents
        self.current_nComponents = 1
        self.current_component = self.learnable_component
        self.current_component_index = None
        self.current_hyper_parameters = {'weight_mus': [layer.q_weight_mu for layer in self.current_component.registered_layers],
                                   'weight_rhos': [layer.q_weight_rho for layer in self.current_component.registered_layers],
                                   'bias_mus' : [layer.q_bias_mu for layer in self.current_component.registered_layers],
                                   'bias_rhos': [layer.q_bias_rho for layer in self.current_component.registered_layers]}
        
        self.current_sampled_parameters = {'weight': [layer.weight_sample for layer in self.current_component.registered_layers],
                                          'bias': [layer.bias_sample for layer in self.current_component.registered_layers]}
        
        self.tolerance = tolerance
        self.potential_starting_points = [(0, torch.tensor(np.inf)) for i in range(self.tolerance)]
        self.device = device
        
    def refresh_current_hyper_parameters(self):
        self.current_hyper_parameters = {'weight_mus': [layer.q_weight_mu for layer in self.current_component.registered_layers],
                                   'weight_rhos': [layer.q_weight_rho for layer in self.current_component.registered_layers],
                                   'bias_mus' : [layer.q_bias_mu for layer in self.current_component.registered_layers],
                                   'bias_rhos': [layer.q_bias_rho for layer in self.current_component.registered_layers]}
    
    def refresh_current_sampled_parameters(self):
        self.current_sampled_parameters = {'weight': [layer.weight_sample for layer in self.current_component.registered_layers],
                                          'bias': [layer.bias_sample for layer in self.current_component.registered_layers]}
    
    def forward(self, x):
        out = self.current_component(x)
        return out
    
    def resample_parameters_in_eval(self):
        self.sample_component(last = True)
        self.sample_parameters()
        self.refresh_current_sampled_parameters()
        
    def resample_parameters_in_train(self):
        self.sample_component(last = False)
        self.sample_parameters()
        self.refresh_current_sampled_parameters()
    
    def sample_component(self, last = False):
        
        if len(self.fixed_components) == 0:
            self.current_component = self.current_component
        
        elif not last:
            self.current_component = np.random.choice(self.fixed_components, p = self.fixed_mixture_probas.data.numpy())
        else:
            pi = sigmoid(self.learnable_proba.detach())
            if uniform.sample()< pi:
                self.current_component = self.learnable_component
            else:
                self.sample_component(last = False)
                
    def sample_parameters(self):
        self.current_component.resample_parameters()
        
    def mixture_log_pdf(self):
        pi = sigmoid(self.learnable_proba)
        probs = torch.cat((self.fixed_mixture_probas*(1-pi), pi.unsqueeze(0)))
        #print(probs)
        log_q = []
        for i, component in enumerate(self.fixed_components):
            #component.set_sampled_parameters(self.current_parameters['weight_mus'],
                                     #self.current_parameters['weight_rhos'],
                                     #self.current_parameters['bias_mus'],
                                     #self.current_parameters['bias_rhos'])
            component.set_sampled_parameters(self.current_sampled_parameters['weight'],
                                            self.current_sampled_parameters['bias'])
            log_q.append(component.q_log_pdf())
        #self.learnable_component.set_parameters(self.current_parameters['weight_mus'],
                                     #self.current_parameters['weight_rhos'],
                                     #self.current_parameters['bias_mus'],
                                     #self.current_parameters['bias_rhos'])
        self.learnable_component.set_sampled_parameters(self.current_sampled_parameters['weight'],
                                                       self.current_sampled_parameters['bias'])
        log_q.append(self.learnable_component.q_log_pdf())
        log_q = torch.stack(log_q)
        #signs = log_q.detach().sign()
        log_q = log_q.add(torch.log(probs))
        log_q = torch.logsumexp(log_q, dim = 0)
        """investigate gradient of max"""
        #maximum = log_q.max()
        #log_q = log_q - maximum
        #log_q = torch.exp(log_q)
        #log_q = log_q*probs
        #log_q = torch.log(torch.sum(log_q)) + maximum
        return(log_q)
    
    def compute_mixture_elbo(self, x_data, y_data, sample_size):
        if type(self.learnable_proba) != type(None):
            #print("1")
            
            pi_new = sigmoid(self.learnable_proba)
        
            """Compute expectancy w.r.t old mixture"""
            L_old_mixture = []
            for _ in range(sample_size):
                self.resample_parameters_in_train()
                LQ = self.mixture_log_pdf()
             #   print(LQ.grad_fn)
                y_pred = self.forward(x_data)
                LL = log_norm(y_data, y_pred.t(), torch.tensor(sigma_noise)).sum()
                LP = self.current_component.prior_log_pdf()
                L_old_mixture.append(LQ - LL - LP)
            L_old_mixture = torch.stack(L_old_mixture)
            #print(L_old_mixture.grad_fn)
            L_old_mixture = torch.mean(L_old_mixture)
            #print(L_old_mixture < self.potential_starting_points[-1][1])
            if L_old_mixture < self.potential_starting_points[-1][1]:
                self.potential_starting_points[-1] = (self.current_sampled_parameters, L_old_mixture.detach().data)
                self.potential_starting_points = sorted(self.potential_starting_points, key = lambda x: x[1])
            
            #print(L_old_mixture.grad_fn)
            #print(L_old_mixture)

            """Compute expectancy w.r.t new component"""
            L_new_component = []
            for _ in range(sample_size):
                self.current_component = self.learnable_component
                self.sample_parameters()
                self.refresh_current_sampled_parameters()
                LQ = self.mixture_log_pdf()
                self.sample_component(last = True)
                y_pred = self.forward(x_data)
                LL = log_norm(y_data, y_pred.t(), torch.tensor(sigma_noise)).sum()
                LP = self.current_component.prior_log_pdf()
                L_new_component.append(LQ - LL - LP)
            L_new_component = torch.stack(L_new_component)
            L_new_component = torch.mean(L_new_component)
            #print(L_new_component.grad_fn)
            L = (1-pi_new)*L_old_mixture + pi_new*L_new_component
            #
            return L
        else:
            #print("2")
            L = self.learnable_component.compute_elbo(x_data, y_data, sample_size, sigma_noise, self.device)
            #print(L)
            if L.detach() < self.potential_starting_points[-1][1]:
                self.potential_starting_points[-1] = (self.current_sampled_parameters, L.detach().data)
                self.potential_starting_points = sorted(self.potential_starting_points, key = lambda x: x[1])
                #print(self.potential_starting_points)
                #self.potential_starting_points = self.potential_starting_points[1:]
            return(L)
    
    
    def new_component(self, losses, epsilon, new_pi):
        #print('std losses',torch.std(losses))
        
        if epsilon and len(self.fixed_components)+1<self.nComponents:
        #if torch.std(losses) < epsilon and len(self.components)<self.nComponents:
            self.refresh_current_hyper_parameters()
            print("WE GOT THERE !!!!")
            self.learnable_component.lock_means()
            self.learnable_component.lock_rhos()
            self.fixed_components.append(self.learnable_component)
            self.current_component = self.learnable_component
            self.refresh_current_hyper_parameters()
            self.learnable_component = RegressionModel(self.H)
            #self.learnable_component.set_hyper_parameters(torch.mean(torch.stack([c[0]['weight'] for c in self.potential_starting_points])),
            #                         torch.tensor(,
            #                         torch.mean(torch.stack([c[0]['bias_mus'] for c in self.potential_starting_points])),
            #                         torch.mean(torch.stack([c[0]['bias_rhos'] for c in self.potential_starting_points])))
            #self.learnable_component.lock_means()
            #self.learnable_component.lock_rhos()
            if type(self.learnable_proba) != type(None):
                self.learnable_proba.detach_()
                pi = sigmoid(self.learnable_proba)
                self.fixed_mixture_probas = torch.cat((self.fixed_mixture_probas*(1 - pi), pi.unsqueeze(0)))
            self.learnable_proba = torch.tensor(float(new_pi), requires_grad = True)
            print('NEW COMPONENT OK')
            return 1
#torch.std(losses) < epsilon
        else:
            return 0

In [45]:
nComponents = 3
tolerance = 10
input_size = 1
output_size = 1
layer_width = 20
nb_layers = 2

In [46]:
model = BoostingModel(nComponents, tolerance, input_size, output_size, layer_width, nb_layers)

In [41]:
learning_rate = 0.05
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, factor=0.95,verbose=True)
optimizer.zero_grad()

In [42]:
%matplotlib inline
num_epoch = 300
num_iterations = 100
liveloss = PlotLosses()

In [49]:
n_ELBO_samples = 10
loss = model.compute_mixture_elbo(x_data, y_data, n_ELBO_samples)
loss

tensor(7.5689e+08, grad_fn=<DivBackward0>)

In [50]:
#M = int(.005/learning_rate)+1
weights = []
n_ELBO_samples=5
j = 0
std_steps = 100
stop = False
while stop == False:
#for j in range(num_epoch):
    logs = {}
    losses = [None] * num_iterations
    
    for k in range(num_iterations):
        optimizer.zero_grad()
        #if new_comp_this_epoch:
            #print('before grad', model.learnable_proba)
        loss = model.compute_mixture_elbo(x_data, y_data, n_ELBO_samples)
        losses[k] = loss.detach()
        loss.backward()
        #if new_comp_this_epoch:
        #    print('after grad', model.current_proba_parameter)
        #print(model.components[0].linear1.weight_sample.grad)
        gradients = torch.sum(model.learnable_component.linear1.q_weight_mu.detach()**2)
        weights.append(gradients)
        
        optimizer.step()
        model.refresh_current_hyper_parameters()
        #print(model.current_proba_parameter)
    if j>std_steps and model.current_nComponents < model.nComponents:
#        new_comp = model.new_component(None, epsilon = (optimizer.param_groups[0]['lr'] != learning_rate), new_pi = np.log(1./(4*model.nComponents+3)))
        new_comp = model.new_component(None, epsilon = (torch.std(torch.stack(weights[-std_steps:])) < .1), new_pi = 0.)#np.log(1./(4*model.nComponents+3)))

        #print('STOP', stop)
        if new_comp:
            model.current_nComponents += 1
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)#filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
            optimizer.add_param_group({"params": model.learnable_proba, 'lr': .005})
            #break
            #M+=10
    #if model.nComponents ==3 and (optimizer.param_groups[0]['lr'] != learning_rate):
    #    model.learnable_component.lock_means()
    #    model.learnable_component.lock_rhos()
    #    model.fixed_components.append(model.learnable_component)
    #    pi = model.learnable_proba.detach()
    #    model.fixed_mixture_probas = torch.cat(((1-pi)*model.fixed_mixture_probas, pi.unsqueeze(0)))
    #    break

        #print(model.current_proba_parameter)
    #print(model.mixture_probas)
    #print(torch.std(torch.stack(losses[k-4:k])))
    def f(x):
        if type(x) == type(None):
            return(0.)
        else:
            return(sigmoid(x))
    #f = lambda x: {type(None): 1}.get(type(x), sigmoid(x))
    logs['expected_loss'] = torch.stack(losses).mean().detach().clone().numpy()
    logs['learning rate'] = optimizer.param_groups[0]['lr']
    logs['ncomponents'] = len(model.fixed_components) + 1
    logs['current_proba'] = f(model.learnable_proba)
    logs['gradients_weights'] = gradients
    #iilogs['proba_gradients'] = prob_grad
    #if type(model.current_proba_parameter) != type(None):
    #    logs['current_proba'] = model.current_proba_parameter.detach().data
    lr = optimizer.param_groups[0]['lr']
    #M = int(.005/lr)+1
    #M=5
    #print(M)
    liveloss.update(logs)
    liveloss.draw()
    if j > std_steps:
         weights = weights[-std_steps:]
    #print('epoch', j, 'num_components', len(model.components), 'stop', stop)
    
    scheduler.step(logs['expected_loss'])
    j+=1

KeyboardInterrupt: 