<a href="https://colab.research.google.com/github/nackjaylor/sydney-innovation-program/blob/main/sip_deep_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Learning (Semantic Segmentation)

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

Let us load a pre-trained model of DeepLabV3 with a ResNet50 backbone.

The ResNet50 backbone has been pretrained on a dataset called ImageNet which is huge. In this way, the network has seen lots of data, and learnt lots of features.

In [None]:
from torchvision import models

model = models.segmentation.deeplabv3_resnet50(pretained=True, progress = True)
model.classifier = models.segmentation.deeplabv3.DeepLabHead(2048, 1)


In [None]:
batch_size = 32
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((64,64))        
])

mask_transform = transforms.Compose([
    transforms.PILToTensor(),
    lambda x: torch.div(x.type(torch.FloatTensor), 2, rounding_mode='trunc'),
    transforms.Resize((64,64))
            
])


train_dataset = datasets.OxfordIIITPet(root='./data/OxfordIITPET', download=True, target_types = "segmentation", transform=img_transform, target_transform=mask_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# train_dataloader = train_dataset
test_dataset = datasets.OxfordIIITPet(root='./data/OxfordIITPET', download=True, split="test", target_types = "segmentation", transform=img_transform, target_transform=mask_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# test_dataloader = test_dataset

Let us look at the segmentation masks!

In [None]:
import numpy as np
train_features, train_labels = next(iter(train_dataloader))

print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0].squeeze()
plt.figure()
f, axarr = plt.subplots(2,1)
axarr[0].imshow(img.T)
axarr[1].imshow(label.T,vmin=0, vmax=1)
plt.show()

print(label.max())

And how does our model work out of the box?

In [None]:
model.eval()

segment_predict = model(train_features)

plt.imshow(segment_predict['out'][0].squeeze().T.detach())


Not very well... Why is this?

Well, a part of the network is still randomly initialised. The segmentation head in particular.

We need to train the network on our data so it sees what we want it to do and it can make some meaningful decisions.

In [None]:
criterion = torch.nn.MSELoss(reduction='mean')
# Specify the optimizer with a lower learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
epochs = 20


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
train_loss_avg = []

model.train()
model.to(device)
print('Training ...')
for epoch in range(epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, mask_batch in train_dataloader:
        
        image_batch = image_batch.to(device)
        
        segment_predict = model(image_batch)
        
        # reconstruction error
        loss = criterion(segment_predict['out'], mask_batch.to(device))
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # one step of the optmizer (using the gradients from backpropagation)
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average error: %f' % (epoch+1, epochs, train_loss_avg[-1]))

In [None]:
torch.save(model.state_dict(),"./segnet")

## Inference

Now let us see how the trained model has performed!



In [None]:
train_features, train_labels = next(iter(test_dataloader))

# print(f"Feature batch shape: {train_features.size()}")
# print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0].squeeze()
plt.figure()
f, axarr = plt.subplots(2,1)
axarr[0].imshow(img.T)
axarr[1].imshow(label.T)
plt.show()

Much better right?

You can train it for longer, on larger images and it will perform much better.

In [None]:
model.eval()


segment_predict = model(train_features.to(device))

plt.imshow(segment_predict['out'][0].squeeze().T.cpu().detach())

## Your own data!

You can go through and upload an image of your pet and see if can be segmented by your network.

Try different photos! What do you notice about framing, lighting, image orientation etc. which affect the results?

In [None]:
from google.colab import files
import cv2
uploaded = files.upload()
filename = "<YOUR_PET_PHOTO_HERE>.png"
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)

In [None]:
test_image = img_transform(img).unsqueeze(0)

In [None]:
model.eval()

segment_predict = model(test_image.to(device))

plt.imshow(segment_predict['out'][0].squeeze().T.cpu().detach())