In [7]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from base_model import BaseModel

In this notebook, we'll be testing capabilities of the `BaseModel` class. This class can be inherited by your PyTorch model's to give it added abilities of predicting segmentation maps from images. In this example, we will view this using the following two classes:
- A toy model that just returns a random segmentation map given an image (without using it)
- An untrained UNet model that uses the image to predict a segmentation map.

### Toy Example

#### Toy Dataset

In [8]:
class ToyDataset(Dataset):
    def __init__(self, num_samples, num_classes):
        self.num_samples = num_samples
        self.num_classes = num_classes
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):

        # Generate random image
        image = torch.rand(3, 224, 224)
        #Generate corresponding one-hot segmentation mask of num_classes
        mask = torch.randint(0, self.num_classes, (224, 224))
        mask = nn.functional.one_hot(mask, self.num_classes).permute(2, 0, 1).float()

        return image, mask



#### Toy Model

In [6]:
class CustomSegmentationModel(nn.Module, BaseModel):
    def __init__(self, device, num_classes=1):
        nn.Module.__init__(self)
        BaseModel.__init__(self, device)
        self.num_classes = num_classes
        # Your model layers and operations here

    def forward(self, x):
        #Get batch size of x
        batch_size = x.shape[0]

        #Return a tensor of size (batch_size, num_classes, x.shape[2], x.shape[3])
        return torch.randn(batch_size, self.num_classes, x.shape[2], x.shape[3])

In [16]:
num_classes = 3
toy_model = CustomSegmentationModel(device='cpu', num_classes=num_classes)
toy_dataset = ToyDataset(num_samples=10, num_classes=num_classes)
toy_dataloader = DataLoader(toy_dataset, batch_size=2, shuffle=True)

In [17]:
#Get single image and mask
image, mask = next(iter(toy_dataloader))
image = image[0]
mask = mask[0]

In [18]:
predicted_mask = toy_model.predict_image(image)

In [19]:
print(predicted_mask.shape)


(224, 224, 5)


The example above demonstrates the prediction of a single segmentation mask, try changing the num_classes to 5 and see what happens!

If you'd like you can obtain the class of each pixel by running `argmax` on the output of the model.