# Prepare

In [None]:
!pip install scikit-learn~=0.24.2
!pip install salib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import numpy as np
import pandas as pd

device = torch.device('cuda')

# Load Model

In [None]:
dataset: ...
model: ...

# Sensitivity Analysis

## SALib

### Import

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('white')
sns.set_context('talk')

#Import the sampling and analysis modules for a Sobol variance-based
#sensitivity analysis
from SALib.sample import saltelli
from SALib.analyze import sobol

### Function def

In [None]:
def get_problem_and_param_values(n=1024):
    problem = {
        'num_vars': 311,
        'names': [
            'Treatment',
            *('Var{}'.format(i) for i in range(1, 10 + 1)),
            *('G{}'.format(i) for i in range(1, 300 + 1))
        ],
        'bounds': [
                (0, 2)
            ] + [
                (0, 1)
            ] * 10 + [
                (0, 2)
            ] * 300
        }
    param_values = saltelli.sample(problem, n, calc_second_order=True)
    for i in (0, *range(11, 300 + 11)):
        param_values[:, i] = param_values[:, i] >= 1
    return problem, param_values


In [None]:
def simulate(param_values):
    import torch
    device = torch.device('cuda')
    x = torch.from_numpy(param_values).float().to(device)
    model.to(device)
    with torch.no_grad():
        y = model(x)
    return y.cpu().numpy().reshape(-1)


### Do

In [None]:
problem, param_values = get_problem_and_param_values()

In [None]:
Y = simulate(param_values)
Si = sobol.analyze(problem, Y)

In [None]:
Si_df = pd.DataFrame({k:Si[k] for k in ['ST','ST_conf','S1','S1_conf']}, index=problem['names'])
Si_df.sort_values(by='ST', ascending=False)[:30]

### Visualize

In [None]:
%matplotlib inline

fig, ax = plt.subplots(1)

indices = Si_df[['S1','ST']]
err = Si_df[['S1_conf','ST_conf']]

indices.plot.bar(yerr=err.values.T,ax=ax)
fig.set_size_inches(160,16)

plt.show()