In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torchvision.transforms as transforms
from torchvision import models


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class trimap_to_labels():
    '''
    sample is a tensor of integers with the shape (1, height, width)
    '''

    def __call__(self, sample):
        _, height, width = sample.size()
        output = torch.zeros(3, height, width)
        # output += torch.eq(sample,1)*1.0 ##If the pixel is in the foreground, it is set to 1.0
        # output += torch.eq(sample,3)*0.5 ##If the pixel is in between foreground and background, it is set to 0.5
        # output += torch.eq(sample,2)*0.0 ##If the pixel is in the background, it is set to 0.0
        output[0] = torch.eq(sample, 1)
        output[1] = torch.eq(sample, 2)
        output[2] = torch.eq(sample, 3)

        return output


In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((128, 128),antialias=True),
    ]
)
target_transform = transforms.Compose(
    [
        transforms.PILToTensor(),
        transforms.Resize((128, 128),antialias=True),
        trimap_to_labels(),
    ]
)

training_data = datasets.OxfordIIITPet("../data", "trainval", transform=transform,
                                       download=True, target_types="segmentation", target_transform=target_transform)


In [None]:
index = 10
plt.imshow(training_data[index][0].movedim(0, 2))
plt.imshow(training_data[index][1].movedim(0, 2), alpha=0.5)


In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(3, 64, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(64, 64, 3, padding=1),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(64, 128, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3, padding=1),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(128, 256, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(256, 256, 3, padding=1),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(256, 512, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(512, 512, 3, padding=1),
                nn.ReLU(),
            ),
        ])

        self.decoder_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(1024, 512, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(512, 512, 3, padding=1),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(512, 256, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(256, 256, 3, padding=1),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(256, 128, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3, padding=1),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(128, 64, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(64, 3, 3, padding=1),
                nn.ReLU(),
            ),

        ])

        self.latent_block = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(1024, 1024, 3, padding=1),
            nn.ReLU(),
        )

        self.pooling = nn.MaxPool2d(2, 2)
        self.up_convolutions = nn.ModuleList([
            nn.ConvTranspose2d(1024, 512, 2, 2),
            nn.ConvTranspose2d(512, 256, 2, 2),
            nn.ConvTranspose2d(256, 128, 2, 2),
            nn.ConvTranspose2d(128, 64, 2, 2),
        ])
        self.final_activation = nn.Softmax(1)

    def forward(self, X):
        num_blocks = len(self.encoder_blocks)
        encoder_outputs = [None]*num_blocks

        # left side of the U, pooling the network down
        X = self.encoder_blocks[0](X)
        encoder_outputs[0] = X
        X = self.pooling(X)

        for block_index in range(1, num_blocks):
            X = self.encoder_blocks[block_index](X)
            encoder_outputs[block_index] = X
            X = self.pooling(X)

        # bottom of the U
        X = self.latent_block(X)

        # right side of the U using ConvTranspose2d to upsample
        for block_index in range(num_blocks):
            X = self.up_convolutions[block_index](X)
            X = torch.cat((X, encoder_outputs[num_blocks-block_index-1]), 1)
            X = self.decoder_blocks[block_index](X)

        X = self.final_activation(X)
        return X


In [None]:
model = UNet().to(device)


In [None]:
pytorch_total_params = sum(p.numel()
                           for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)


In [None]:
with torch.no_grad():
    index = torch.randint(len(training_data),(1,)).item()

    model_output = model(training_data[index][0].unsqueeze(
        0).to(device)).squeeze().cpu()

    fig = plt.figure(figsize=(8, 8))
    fig.add_subplot(4, 4, 1)
    plt.imshow(training_data[index][0].movedim(0, 2))
    plt.axis("off")
    plt.title("image")

    fig.add_subplot(4, 4, 2)
    plt.imshow(training_data[index][1].movedim(0, 2))
    plt.axis("off")
    plt.title("mask")

    fig.add_subplot(4, 4, 3)
    plt.imshow((training_data[index][0]*(training_data[index]
               [1][0]+0.5*training_data[index][1][2])).movedim(0, 2))
    plt.axis("off")
    plt.title("masked")

    fig.add_subplot(4, 4, 4)
    plt.imshow(model_output.movedim(0, 2))
    plt.axis("off")
    plt.title("model output")

    fig.add_subplot(4, 4, 5)
    plt.imshow((training_data[index][0]*(model_output.round()
               [0]+0.5*model_output.round()[2])).movedim(0, 2))
    plt.axis("off")
    plt.title("model masked")

    plt.show()


In [None]:
training_loader = DataLoader(training_data, batch_size=32, shuffle=True)


In [None]:

model.to(device)


In [None]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


In [None]:
for epoch in range(10):
    running_loss = 0.0
    for batch_number, (inputs, labels) in enumerate(training_loader):
        optimizer.zero_grad()
        inputs = inputs.to(device)
        labels = labels.to(device).movedim(1, 3)
        outputs = model(inputs).movedim(1, 3)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_number % 10 == 9:    # print every 10 mini-batches
            print('[%d, %5d] loss: %.7f' %
                  (epoch + 1, batch_number + 1, running_loss / 10))
            running_loss = 0.0


In [None]:
with torch.no_grad():
    index = torch.randint(len(training_data),(1,)).item()

    model_output = model(training_data[index][0].unsqueeze(
        0).to(device)).squeeze().cpu()

    fig = plt.figure(figsize=(8, 8))
    fig.add_subplot(4, 4, 1)
    plt.imshow(training_data[index][0].movedim(0, 2))
    plt.axis("off")
    plt.title("image")

    fig.add_subplot(4, 4, 2)
    plt.imshow(training_data[index][1].movedim(0, 2))
    plt.axis("off")
    plt.title("mask")

    fig.add_subplot(4, 4, 3)
    plt.imshow((training_data[index][0]*(training_data[index]
               [1][0]+0.5*training_data[index][1][2])).movedim(0, 2))
    plt.axis("off")
    plt.title("masked")

    fig.add_subplot(4, 4, 4)
    plt.imshow(model_output.movedim(0, 2))
    plt.axis("off")
    plt.title("model output")

    fig.add_subplot(4, 4, 5)
    plt.imshow((training_data[index][0]*(model_output.round()
               [0]+0.5*model_output.round()[2])).movedim(0, 2))
    plt.axis("off")
    plt.title("model masked")

    plt.show()