# Interpretability-framework - Uncertainty Tutorial 2

In [7]:
import torch
from torch.nn import Sequential
from torch.nn import Softmax

from interpretability_framework import modules
from pytorch_fcn.fcn32s import FCN32s

import data_utils
from data_utils import DataConfig
import matplotlib.pyplot as plt

### Step 1:
Load pre-trained FCN32 model

In [8]:
fully_conf_net = FCN32s()
fully_conf_net.load_state_dict(torch.load(FCN32s.download()))

[/home/fabian/data/models/pytorch/fcn32s_from_caffe.pth] Checking md5 (8acf386d722dc3484625964cbe2aba49)


### Step 2:
Build and prepare ensemble

In [9]:
ensemble = Sequential(
    modules.MeanEnsemble(fully_conf_net, 20),
    Softmax(dim=1)
)

ensemble.eval()

Sequential(
  (0): MeanEnsemble(
    (inner): FCN32s(
      (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(100, 100))
      (relu1_1): ReLU(inplace)
      (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu1_2): ReLU(inplace)
      (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
      (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu2_1): ReLU(inplace)
      (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu2_2): ReLU(inplace)
      (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
      (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu3_1): ReLU(inplace)
      (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu3_2): ReLU(inplace)
      (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), paddi

### Step 3:

Load (and show) image

In [10]:
img = data_utils.get_example_from_path("../data/fcn_example.jpg", DataConfig.FCN32)
    
plt.imshow(img.squeeze(dim=0).permute(1, 2, 0))
plt.show()

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


In [None]:
pred = ensemble(img)
plt.imshow(pred.argmax(dim=1).squeeze(dim=0).detach().numpy())
plt.show()