In [1]:
import torch
import numpy as np
import itertools
from torch.distributions.normal import Normal


# SWAG Implementation

#### Remarks to Understand the Code:

###### Inputs for class SWAG
- base, args, kwargs: <br>
base is the base model (default="VGG16") we are going to use. args and kwargs contain parameter of the base.
- no_cov_mat: <br>
When sampling, do we include only the simple SWAG-Diagonal formal. Or add the Low Rank Covariance matrix $DD^T$.
- max_num_models: <br>
Equivalent to variable $K$ in the paper. Maximum number of columns in deviation matrix.
- var_clamp: <br>
it is used for calculating variance in torch.clamp. Equivalent to assigning max(computed_var, 1e-30) to variance. Goal: avoid 0 variance errors.

###### init for SWAG
- n_models: <br>
A parameter of the model which should be saved and restored in the state_dict, but not trained by the optimizer, that's why it is registered as buffer.

- self.params: <br>
The other parameters of the model which will be trained by the optimizer.

- n_models: <br>
Corresponds to $n$ in the pseudo-code description of the algorithm.

###### swag_parameters fct (Not Sure if I understood Correctly !!!)
Basically copy the non-trainable parameters of the model to the list params (and then self.params in SWAG class) without taking into account params with key None. Plus create param for the low rank covariance matrix if we want to include it.

###### sample fct in SWAG
- scale: <br>
scaling constant to obtain a valid distribution. (When used it is stored in args) <br>
(Example in uncertainty.py: `model.sample(scale=args.scale, cov=sample_with_cov)`)

# Quelques questions pour Vincent:
- Lien de leur code: <br>
https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py
- Pourquoi dans def sample_fullrank(self, scale, cov, fullrank), ils ont la variable bool fullrank, par contre ils l'utilisent jamais dans la fonction?
- Scale, c'est bien la valeur de normalisation ? Elle dépend de quoi, Les entrées de l'utilisateur ? 
c'est ce qu'il y a dans uncertainty.py: <br>
`parser.add_argument("--scale", type=float, default=1.0)`
<br>
- Je comprends pas trop la diff entre sample_blockwise et sample_fullrank, mmh ont-ils parlé de ça dans le papier?
- J'ai cherché un peu dans leur code et à chaque fois qu'ils utilisent la fct sample c'est comme ça: <br>
`swag_model.sample(0.0)`
<br>
ça veut dire quoi ahaha?

In [2]:
def swag_parameters(module, params, no_cov_mat=True):
    for name in list(module._parameters.keys()):
        #iterate through parameters of module
        if module._parameters[name] is None:
            continue
        data = module._parameters[name].data
        module._parameters.pop(name)
        module.register_buffer("%s_mean" % name, data.new(data.size()).zero_())
        #Example, non-trainable parameter \theta_0 called: VGG16s_mean
        module.register_buffer("%s_sq_mean" % name, data.new(data.size()).zero_())
        # \theta_{0}^2
        
        if no_cov_mat is False:
            module.register_buffer(
                "%s_cov_mat_sqrt" % name, data.new_empty((0, data.numel())).zero_()
            )
            
        params.append((module, name))
        
        

In [3]:
class SWAG(torch.nn.Module):
    def __init__(
        self, base, no_cov_mat=True, max_num_models=0, var_clamp=1e-30, *args, **kwargs
    ):
        super(SWAG, self).__init__()
        
        self.register_buffer("n_models", torch.zeros([1], dtype=torch.long))
        self.params = list()
        
        self.no_cov_mat = no_cov_mat
        self.max_num_models = max_num_models
        
        self.base = base(*args, **kwargs)
        self.base.apply(
            lambda module: swag_parameters(
                module=module, params=self.params, no_cov_mat=self.no_cov_mat
            )
        )
        
        
    def forward(self, base, *args, **kwargs):
        return self.base(*args, **kwargs)
    
    
    def sample(self, scale=1.0, cov=False, seed=None, block=False, fullrank=True):
        if seed is not None:
            torch.manual_seed(seed)
            #If we need to rerun an experiment we should have a fixed seed. Otherwise torch chooses.
        if not block:
            self.sample_fullrank(scale, cov, fullrank)
        else:
            self.sample_blockwise(scale, cov, fullrank)
        
        
    def sample_fullrank(scale, cov, fullrank):
        pass
    def sample_blockwise(scale, cov, fullrank):
        pass
    
    
    def collect_model(self, base_model):
        for (module, name), base_param in zip(self.params, base_model.parameters()):
            mean = module.__getattr__("%s_mean" % name) #\theta_0
            sq_mean = module.__getattr__("%s_sq_mean" % name) #\theta_0^2
            
            # First Moment
            mean = mean * self.n_models.item() / 
            
            
    def generate_mean_var_covar(self):
        mean_list = []
        var_list = []
        cov_mat_root_list = []
        for module, name in self.params:
            mean = module.__getattr__("%s_mean" % name)
            sq_mean = module.__getattr__("%s_sq_mean" % name)
            cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name)
            
            mean_list.append(mean)
            var_list.append(sq_mean - mean ** 2.0)
            cov_mat_root_list.append(cov_mat_sqrt)
        return mean_list, var_list, cov_mat_root_list
    
    

SyntaxError: invalid syntax (<ipython-input-3-45f18af88dee>, line 47)

# Don't forget:
To sample from SWAG we use the following identity:
\begin{equation}
\tilde{\theta} = \theta_{\text{SWA}} + \frac{1}{\sqrt{2}} \cdot \Sigma_{\text{diag}}^{\frac{1}{2}}
z_1 + \frac{1}{\sqrt{2(K-1)}}\hat{D} z_2, \quad \text{where} \quad z_1 \sim \mathcal{N}(0, I_d), z_2 \sim \mathcal{N}(0, I_K)
\end{equation} 