# Semantic segmentation experiments

Using vanilla PyTorch

Faisal Qureshi      
faisal.qureshi@ontariotechu.ca

In [None]:
filepath = '../../data/pinyon-jay-bird.jpg'

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import torch

In [None]:
img = Image.open(filepath)
plt.imshow(img);

We need to process this image before we can perform semantic segmentation on it.

In [None]:
import torchvision.transforms as T

In [None]:
transforms = T.Compose([
    T.Resize((256, 256)),
    T.CenterCrop((244, 244)),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])

img_transformed = transforms(img)
print(img_transformed.shape)

Lets load a segmentation model and set it up for inference.

In [None]:
from torchvision import models

In [None]:
fcn = models.segmentation.fcn_resnet101(pretrained=True)

In [None]:
fcn.eval();

We are ready to perform the inference.

In [None]:
output = fcn(img_transformed.unsqueeze(0))
print(output['out'].shape)
print(output['aux'].shape)

The output is a dict.  We can pull out the relevant tensor using `output['out']`.  Note that in this case the output tensor has 21 channels.  This is because this model was trained on 21 classes.

In [None]:
import numpy as np

In [None]:
seg_map = torch.argmax(output['out'].squeeze(), dim=0).detach().cpu().numpy()
print(seg_map.shape)

seg_map_aux = torch.argmax(output['aux'].squeeze(), dim=0).detach().cpu().numpy()
print(seg_map_aux.shape)

In [None]:
def vis_segmentation_map(seg_map, label_colors):
    """
    seg_map is n-by-h-by-w output tensor as seen above.=
    
    classes is a n-by-3 colormap, where n is the 
    number of classes.
    """
    r = np.zeros_like(seg_map)
    g = np.zeros_like(seg_map)
    b = np.zeros_like(seg_map)
    for l in range(0, len(label_colors)):
        idx = seg_map==l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]
    rgb = np.stack([r,g,b], axis=2)
    return rgb

In [None]:
label_colors = np.array([
    # 0=background
    (0, 0, 0),  
    # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
    (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
    # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
    (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
    # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
    (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
    # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
    (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)
])

print(label_colors.shape)

In [None]:
plt.figure()
plt.imshow(vis_segmentation_map(seg_map, label_colors))

In [None]:
plt.figure()
plt.imshow(vis_segmentation_map(seg_map_aux, label_colors))

In [None]:
print(fcn)

In [None]:
p = dict(fcn.named_parameters())
print(p.keys())

In [None]:
import torchviz
torchviz.make_dot(output['out'], params=dict(list(fcn.named_parameters()))).render("fcn_torchviz", format="png")