# Notebook 1: Estimating Variance of Multi-Points

This notebook illustrates how MH Dropout network learns the variance of a multiple target points.


In [None]:
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from models.mhdNetwork import MhdNetwork
from models.mlp import MLP

sns.set_style('ticks')
sns.set_palette("Set1")

#### Helper Functions

In [None]:
def train_loop(X_data, Y_data, model, optimizer, epochs=10, display=100, batch_size=32, hypo_count=None):
    model.train()
    optimizer.zero_grad()
    total_steps = int(np.ceil(len(X_data) / batch_size))
    count = 0
    
    indexes = torch.randperm(Y_data.shape[0])
    X_shuffled = X_data[indexes]
    Y_shuffled = Y_data[indexes]

    for epoch in range(epochs):
        for step in range(total_steps):
            count += 1
            x = X_shuffled[step * batch_size:step * batch_size + batch_size]
            y = Y_shuffled[step * batch_size:step * batch_size + batch_size]
            
            log_loss = train_step(x, y, model, optimizer, hypo_count=hypo_count)
            if count % display == 0:
                print(f"E{epoch + 1}: loss={log_loss}")

def train_step(x, y, model, optimizer, hypo_count=None):
    optimizer.zero_grad()

    if model.__class__.__name__ == 'mhdNetwork':
        outputs, log_loss = model.loss(x, y=y)
        loss = outputs['loss'].mean() 
    elif model.__class__.__name__ == 'MLP':
        loss = model(x,y)[-1].mean()
    else:
        loss = model(x, y=y, hypo_count=hypo_count)['loss']

    log_loss = loss.detach().item()
    loss.backward()
    optimizer.step()
    return log_loss



### 1. Create dataset

In [None]:
# Dataset parameters
inp_dim = 2
out_dim = 2
dataset_sz = 1
outputs_per_input = 5

X_data = torch.normal(mean=0.0, std=torch.ones(dataset_sz, 1, inp_dim)).repeat(1, outputs_per_input, 1)
Y_data = torch.rand(dataset_sz, outputs_per_input, out_dim)

X_flat = torch.flatten(X_data, start_dim=0, end_dim=1)
Y_flat = torch.flatten(Y_data, start_dim=0, end_dim=1)

print(X_data.shape, X_flat.shape)
print(Y_data.shape, Y_flat.shape)

In [None]:
### Visualize labels (Y) with same input
didx = 0 #dataset index

df = pd.DataFrame(Y_data[didx].cpu().numpy(), columns=['y_1','y_2'])
df['Type'] = 'labels'
df['Size'] = 1.0

sns.scatterplot(data=df, x='y_1', y='y_2', hue='Type')

plt.title("Initial State")

### 2. Initialize Models

In [None]:
# Model parameters
hid_dim         = 4
mhd_hid_dim     = 4
dropout_rate    = 0.5
num_layers      = 1

wta_loss        = 'vanilla'
out_act_fn      = 'sigmoid'

models = {}
opts = {}

subset_ratios = [1.0,0.95,0.9,0.75,0.5,0.25,0.01]

for r in subset_ratios:
    assert 1.0 >= r > 0.0

    model_name = str(int(r*100))
    models[model_name] = MhdNetwork(
            inp_dim=inp_dim, 
            hid_dim=hid_dim, 
            out_dim=out_dim, 
            mhd_hid_dim=mhd_hid_dim,
            num_layers=num_layers,
            wta_loss=wta_loss,
            out_act_fn=out_act_fn,
            subset_ratio=r
            )

    #Copy weights so that initial states are the same
    if model_name != '100':
            models[model_name].load_state_dict(models['100'].state_dict())

    opts[model_name] = torch.optim.AdamW(list(models[model_name].parameters()))


# Check parameter counts
param_count = sum(p.numel() for p in models['100'].parameters())
print(param_count)

max_hypos = 2 ** mhd_hid_dim - 1
print(max_hypos)


In [None]:
# Setup benchmark 
mix_components   = max_hypos

# Create benchmark model
models['baseline'] = MLP(
                    inp_dim=inp_dim, 
                    hid_dim=hid_dim, 
                    out_dim=out_dim,
                    num_layers=3, 
                    out_act_fn=out_act_fn, 
                    dropout=0.5
                    )

# Create optimizer
opts['baseline'] = torch.optim.AdamW(list(models['baseline'].parameters()))

# Check parameter counts
param_count = sum(p.numel() for p in models['baseline'].parameters())
print(param_count)

### 3. Visualize initial state.

In [None]:
size = 0.5

def compute_mean_std_diffs(pred, y):
    '''Computes through population and then averages across dimensions'''
    pred_mean = torch.mean(pred, dim=0)
    y_mean = torch.mean(y, dim=0)

    mean_distance = torch.mean((pred_mean - y_mean) ** 2).item()

    pred_std = torch.std(pred, dim=0)
    y_std = torch.std(y, dim=0)

    std_distance = torch.mean((pred_std - y_std) ** 2).item()

    return mean_distance, std_distance

def compute_error(pred, y):
    '''
    Computes average distance between nearest target of each hypothesis.
    '''
    y_count = y.size(0)
    p_count = pred.size(0)

    pred = pred.unsqueeze(1).repeat([1,y_count,1])
    y = y.unsqueeze(0).repeat([p_count,1,1])

    #Compute distance between each prediction and label
    mse_error = ((pred - y) ** 2).mean(-1)
    
    #Find nearest target of each prediction
    min_error = torch.min(mse_error, dim=1)[0]
    return min_error.mean().item()

def eval_model(model, X_data, Y_data, max_hypos=None):
    model.eval()

    dataset_size = X_data.shape[0]

    metrics = {'error': [], 'mean': [], 'std': []}
    with torch.no_grad():
        for idx in range(dataset_size):
            x = X_data[idx]
            y = Y_data[idx]
            if model.__class__.__name__ == 'mhdNetwork':
                model_outputs = model.sample(x=x, hypo_count=max_hypos)
                if 'pred_sample' in model_outputs:
                    preds = model_outputs['pred_sample']
                else:
                    preds = torch.flatten(model_outputs['hypotheses'], start_dim=0, end_dim=1)    
            elif model.__class__.__name__ in ['MLP']:
                model.train()
                max_hypos = max_hypos if max_hypos is not None else 50
                preds = [model.sample(x=x)['pred_sample'] for _ in range(max_hypos)]
                preds = torch.cat(preds, dim=0)
                
            hypo_error = compute_error(preds, y)
            mean, std = compute_mean_std_diffs(preds, y)

            metrics['error'].append(hypo_error)
            metrics['mean'].append(mean)
            metrics['std'].append(std)

    return metrics

def sample_model(model, X_data, max_hypos, df, size, data_index):
    model.eval()

    #Prepare outputs for graph
    x = X_data[data_index]
    
    if model.__class__.__name__ == 'mhdNetwork':
        with torch.no_grad():
            model_outputs = model.sample(x=x, hypo_count=max_hypos)
        if 'pred_sample' in model_outputs:
            preds = model_outputs['pred_sample']
        else:
            preds = torch.flatten(model_outputs['hypotheses'], start_dim=0, end_dim=1)  
    elif model.__class__.__name__ in ['MLP']:
        model.train()
        with torch.no_grad():
            preds = [model.sample(x=x)['pred_sample'] for _ in range(max_hypos)]
            preds = torch.cat(preds, dim=0)

    #Prepare inputs for graph
    temp_df = pd.DataFrame(preds.cpu().numpy(), columns=['y_1', 'y_2'])
    temp_df['Type'] = 'predictions'
    temp_df['Size'] = size
    return pd.concat([df, temp_df])


In [None]:
metric_container = {}
init_dfs = {}
data_index = 0

for model_name in models.keys():
        print(model_name)
        metrics = eval_model(models[model_name], X_data, Y_data, max_hypos)

        if model_name not in metric_container:
                metric_container[model_name] = []

        metric_container[model_name].append(metrics)

        init_dfs[model_name] = sample_model(
                models[model_name], X_data, max_hypos, df, size, data_index)


## Show initial state for each model

In [None]:
# Create subplots with a different size for the second row
fig, axes = plt.subplots(1, 3, figsize=(12, 4)) #, gridspec_kw={'height_fractions': [2,1]})

ylim_max = 1.0

# Subplot 1  
sns.scatterplot(data=init_dfs['100'], x='y_1', y='y_2', hue='Type', size='Size', ax=axes[0], legend=False)
axes[0].set_title("(a) $r$=1.0")
axes[0].set_xlim(0,1.0)
axes[0].set_ylim(0,ylim_max)

# Subplot 2
sns.scatterplot(data=init_dfs['75'], x='y_1', y='y_2', hue='Type', size='Size', ax=axes[1], legend=False)
axes[1].set_title("(b) $r$=0.75")
axes[1].set_xlim(0,1.0)
axes[1].set_ylim(0,ylim_max)
axes[1].set_ylabel("")
axes[1].set_yticks([])

# Subplot 3 
sns.scatterplot(data=init_dfs['baseline'], x='y_1', y='y_2', hue='Type', size='Size', ax=axes[2], legend=False)
axes[2].set_title("(c) baseline")
axes[2].set_xlim(0,1.0)
axes[2].set_ylim(0,ylim_max)
axes[2].set_ylabel("")
axes[2].set_yticks([])

plt.tight_layout()


### 4. Train model

In [None]:
epochs = 1000
display = 500

for model_name in models.keys():
    print("Training model={}.".format(model_name))
    train_loop(X_flat, Y_flat, models[model_name], opts[model_name], epochs=epochs, display=display)


### 5. Visualize outputs

In [None]:
metric_container = {}
steady_dfs = {}
data_index = 0

for model_name in models.keys():

        metrics = eval_model(models[model_name], X_data, Y_data, max_hypos)

        if model_name not in metric_container:
                metric_container[model_name] = []

        metric_container[model_name].append(metrics)

        steady_dfs[model_name] = sample_model(
                models[model_name], X_data, max_hypos, df, size, data_index)


In [None]:
# Create a subplot comparing steady and final state

model_name = 'baseline'

xlim_max = 1.0
ylim_max = 1.0

fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# Steady State:
sns.scatterplot(data=init_dfs[model_name], x='y_1', y='y_2', hue='Type', size='Size', ax=axes[0], legend=False)
axes[0].set_title("(a) Initial State")
axes[0].set_xlim(0,xlim_max)
axes[0].set_ylim(0,ylim_max)

sns.scatterplot(data=steady_dfs[model_name], x='y_1', y='y_2', hue='Type', size='Size', ax=axes[1], legend=False)
axes[1].set_title("(b) Steady State")
axes[1].set_xlim(0,xlim_max)
axes[1].set_ylim(0,ylim_max)

plt.tight_layout()


In [None]:
# Create a subplot comparing steady and final state for mcr model

model_name = '100'

xlim_max = 1.0
ylim_max = 1.0

fig, axes = plt.subplots(1, 2, figsize=(8, 4)) 

# Steady State:
sns.scatterplot(data=init_dfs[model_name], x='y_1', y='y_2', hue='Type', size='Size', ax=axes[0], legend=False)
axes[0].set_title("(a) Initial State")
axes[0].set_xlim(0,xlim_max)
axes[0].set_ylim(0,ylim_max)

sns.scatterplot(data=steady_dfs[model_name], x='y_1', y='y_2', hue='Type', size='Size', ax=axes[1], legend=False)
axes[1].set_title("(b) Steady State")
axes[1].set_xlim(0,xlim_max)
axes[1].set_ylim(0,ylim_max)

plt.tight_layout()


In [None]:
subset_ratios_show = [1.0,0.75,.5,.01]

# Create subplots with a different size for the second row
plot_count = len(subset_ratios_show)
fig_len = plot_count * 3
fig, axes = plt.subplots(1, plot_count, figsize=(fig_len, 4)) #, gridspec_kw={'height_fractions': [2,1]})

for idx, r in enumerate(subset_ratios_show):
    assert 1.0 >= r > 0.0

    model_name = str(int(r*100))

    sns.scatterplot(data=steady_dfs[model_name], 
                    x='y_1', y='y_2', hue='Type', size='Size', ax=axes[idx], legend=False)
    axes[idx].set_title("$r$={:.2f}".format(r))


    if idx != 0:
        axes[idx].set_ylabel("")
        axes[idx].set_yticks([])

plt.tight_layout()




## Experiment from Section 3.2

In [None]:
def run_multiple_experiments(
        exp_count,
        inp_dim,
        out_dim,
        outputs_per_input,
        dataset_sz,
        hid_dim,
        num_layers,
        wta_loss,
        subset_ratios,
        epochs, 
):      
        max_hypos = 2 ** hid_dim
        out_act_fn      = 'sigmoid'
        
        metric_container = {}
        
        for exp_idx in range(exp_count):
                print("Starting experiment {}/{}.".format(exp_idx + 1, exp_count))

                X_data = torch.normal(mean=0.0, std=torch.ones(dataset_sz, 1, inp_dim)).repeat(1, outputs_per_input, 1)
                Y_data = torch.rand(dataset_sz, outputs_per_input, out_dim)
                X_flat = torch.flatten(X_data, start_dim=0, end_dim=1)
                Y_flat = torch.flatten(Y_data, start_dim=0, end_dim=1)

                models = {}
                opts = {}

                #Baseline model
                models['baseline'] = MLP(
                                inp_dim=inp_dim, 
                                hid_dim=hid_dim, 
                                out_dim=out_dim,
                                num_layers=num_layers, 
                                out_act_fn=out_act_fn, 
                                dropout=0.5
                                )
                
                # Create optimizer
                opts['baseline'] = torch.optim.AdamW(list(models['baseline'].parameters()))


                for r in subset_ratios:
                        assert 1.0 >= r > 0.0
                        model_name = str(int(r*100))
                        
                        models[model_name] = MhdNetwork(
                                inp_dim=inp_dim, 
                                hid_dim=hid_dim, 
                                out_dim=out_dim, 
                                num_layers=num_layers,
                                wta_loss=wta_loss,
                                out_act_fn=out_act_fn,
                                subset_ratio=r
                                )

                        #Copy weights so that initial states are the same
                        if model_name != '100':
                                models[model_name].load_state_dict(models['100'].state_dict())

                        opts[model_name] = torch.optim.AdamW(list(models[model_name].parameters()))

                
                for model_name in models.keys():

                        train_loop(X_flat, Y_flat, models[model_name], opts[model_name], epochs=epochs, display=10000000)

                        metrics = eval_model(models[model_name], X_data, Y_data)

                        if model_name not in metric_container:
                                metric_container[model_name] = {}

                        for metric_name, metric_values in metrics.items():
                                if metric_name not in metric_container[model_name]:
                                        metric_container[model_name][metric_name] = []

                                metric_container[model_name][metric_name] += metric_values

        return metric_container

In [None]:
metric_container = run_multiple_experiments(
        exp_count=30,
        inp_dim=2,
        out_dim=2,
        outputs_per_input=5,
        dataset_sz=1,
        hid_dim=4,
        num_layers=3,
        wta_loss='vanilla',
        subset_ratios=[1.0,0.95,0.9,0.85,0.8,0.75,0.7,0.65,0.6,0.55,0.5,0.45,0.4,0.35,0.3,0.25,0.2,0.15,0.1,0.05,0.01],
        epochs=10000, 
)

In [None]:
graph_data = []
for model_name, metrics in metric_container.items():
    error = sum(metrics['error']) / len(metrics['error'])
    mean = sum(metrics['mean']) / len(metrics['mean'])
    std = sum(metrics['std']) / len(metrics['std'])

    print(model_name, np.mean(metrics['std']), np.std(metrics['std']))

    #if model_name not in ['1','10']:
    if model_name != 'baseline':
        subset_ratio = int(model_name) / 100
        for v in metrics['std']:
            graph_data.append({'subset ratio (r)': subset_ratio, 'SDD': v, 'group': 'Stochastic WTA'})
    else:
        for subset_ratio in [0.01,0.5,1.0]:
            for v in metrics['std']:
                graph_data.append({'subset ratio (r)': subset_ratio, 'SDD': v, 'group': 'MC dropout'})


In [None]:

sns.set_style('ticks')
sns.set_palette("Set1")

graph_df = pd.DataFrame(graph_data)
ax = sns.lineplot(graph_df, x='subset ratio (r)', y='SDD', hue='group', errorbar=('ci', 95)) 
plt.yscale('log')
plt.ylabel('log SSD')
ax.lines[0].set_linestyle("--")

sns.despine()
plt.legend(title=None, frameon=True)

