In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torchvision.transforms as transforms
from torchvision import models


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class trimap_to_labels():
    '''
    sample is a tensor of integers with the shape (1, height, width)
    '''

    def __call__(self, sample):
        _, height, width = sample.size()
        output = torch.zeros(3, height, width)
        # output += torch.eq(sample,1)*1.0 ##If the pixel is in the foreground, it is set to 1.0
        # output += torch.eq(sample,3)*0.5 ##If the pixel is in between foreground and background, it is set to 0.5
        # output += torch.eq(sample,2)*0.0 ##If the pixel is in the background, it is set to 0.0
        output[0] = torch.eq(sample, 1)
        output[1] = torch.eq(sample, 2)
        output[2] = torch.eq(sample, 3)

        return output


In [None]:
training_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((128, 128),antialias=True),
        transforms.ColorJitter(brightness=(0.75,1.25),contrast=(1),saturation=(0.75,1.25),hue=(-0.1,0.1))
    ]
)
test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((128, 128),antialias=True),
    ]
)
target_transform = transforms.Compose(
    [
        transforms.PILToTensor(),
        transforms.Resize((128, 128),antialias=True),
        trimap_to_labels(),
    ]
)

training_data = datasets.OxfordIIITPet("../data", "trainval", transform=training_transform, download=True, target_types="segmentation", target_transform=target_transform)

test_data = datasets.OxfordIIITPet("../data", "test", transform=test_transform, download=True, target_types="segmentation", target_transform=target_transform)


In [None]:
index = 10
plt.imshow(training_data[index][0].movedim(0, 2))
plt.show()
plt.imshow(training_data[index][1].movedim(0, 2), alpha=0.5)
plt.show()

In [None]:
class ResBock(nn.Module):
    def __init__(self,in_channels, out_channels,final_activation = nn.ReLU()):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels,out_channels,kernel_size=3, padding = 1),
            nn.BatchNorm2d(out_channels),
        )

        if in_channels == out_channels:
            self.side_convolution = nn.Identity()
        else:
            self.side_convolution = nn.Conv2d(in_channels, out_channels, kernel_size = 1)
        self.final_activation = final_activation
        
    def forward(self,X):
        block_output = self.block(X)
        side_output = self.side_convolution(X)
        return self.final_activation(side_output+block_output)

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_blocks = nn.ModuleList([
            # this is the original U-net model's encoder
            
            # nn.Sequential(
            #     nn.Conv2d(3, 64, 3, padding=1),
            #     nn.ReLU(),
            #     nn.Conv2d(64, 64, 3, padding=1),
            #     nn.ReLU(),
            # ),
            # nn.Sequential(
            #     nn.Conv2d(64, 128, 3, padding=1),
            #     nn.ReLU(),
            #     nn.Conv2d(128, 128, 3, padding=1),
            #     nn.ReLU(),
            # ),
            # nn.Sequential(
            #     nn.Conv2d(128, 256, 3, padding=1),
            #     nn.ReLU(),
            #     nn.Conv2d(256, 256, 3, padding=1),
            #     nn.ReLU(),
            # ),
            # nn.Sequential(
            #     nn.Conv2d(256, 512, 3, padding=1),
            #     nn.ReLU(),
            #     nn.Conv2d(512, 512, 3, padding=1),
            #     nn.ReLU(),
            # ),
            
            ResBock(3,64),
            ResBock(64,128),
            ResBock(128,256),
            ResBock(256,512),
            
        ])

        self.decoder_blocks = nn.ModuleList([
            # this is the original U-net model's decoder
            
            # nn.Sequential(
            #     nn.Conv2d(1024, 512, 3, padding=1),
            #     nn.ReLU(),
            #     nn.Conv2d(512, 512, 3, padding=1),
            #     nn.ReLU(),
            # ),
            # nn.Sequential(
            #     nn.Conv2d(512, 256, 3, padding=1),
            #     nn.ReLU(),
            #     nn.Conv2d(256, 256, 3, padding=1),
            #     nn.ReLU(),
            # ),
            # nn.Sequential(
            #     nn.Conv2d(256, 128, 3, padding=1),
            #     nn.ReLU(),
            #     nn.Conv2d(128, 128, 3, padding=1),
            #     nn.ReLU(),
            # ),
            # nn.Sequential(
            #     nn.Conv2d(128, 64, 3, padding=1),
            #     nn.ReLU(),
            #     nn.Conv2d(64, 3, 3, padding=1),
            #     nn.ReLU(),
            # ),
            ResBock(1024,512),
            ResBock(512,256),
            ResBock(256,128),
            ResBock(128,3),
        ])

        self.latent_block = nn.Sequential(
            # this is the original U-net model's latent block
            
            # nn.Conv2d(512, 1024, 3, padding=1),
            # nn.ReLU(),
            # nn.Conv2d(1024, 1024, 3, padding=1),
            # nn.ReLU(),
            ResBock(512,1024),
        )

        self.pooling = nn.MaxPool2d(2, 2)
        self.up_convolutions = nn.ModuleList([
            # nn.ConvTranspose2d(1024, 512, 2, 2),
            # nn.ConvTranspose2d(512, 256, 2, 2),
            # nn.ConvTranspose2d(256, 128, 2, 2),
            # nn.ConvTranspose2d(128, 64, 2, 2),
            nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.Conv2d(1024, 512, kernel_size=3, padding=1),
            ),
            nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.Conv2d(512, 256, kernel_size=3, padding=1),
            ),
            nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.Conv2d(256, 128, kernel_size=3, padding=1),
            ),
            nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.Conv2d(128, 64, kernel_size=3, padding=1),
            ),
            
            
            
            
        ])
        self.final_activation = nn.Softmax(1)

    def forward(self, X):
        num_blocks = len(self.encoder_blocks)
        encoder_outputs = [None]*num_blocks

        # left side of the U, pooling the network down
        X = self.encoder_blocks[0](X)
        encoder_outputs[0] = X
        X = self.pooling(X)

        for block_index in range(1, num_blocks):
            X = self.encoder_blocks[block_index](X)
            encoder_outputs[block_index] = X
            X = self.pooling(X)

        # bottom of the U
        X = self.latent_block(X)

        # right side of the U using ConvTranspose2d to upsample
        for block_index in range(num_blocks):
            X = self.up_convolutions[block_index](X)
            X = torch.cat((X, encoder_outputs[num_blocks-block_index-1]), 1)
            X = self.decoder_blocks[block_index](X)

        X = self.final_activation(X)
        return X


In [None]:
model = UNet().to(device)


In [None]:
pytorch_total_params = sum(p.numel()
                           for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)


In [None]:
model.eval()
with torch.no_grad():
    index = torch.randint(len(training_data),(1,)).item()

    model_output = model(training_data[index][0].unsqueeze(
        0).to(device)).squeeze().cpu()

    fig = plt.figure(figsize=(8, 8))
    fig.add_subplot(4, 4, 1)
    plt.imshow(training_data[index][0].movedim(0, 2))
    plt.axis("off")
    plt.title("image")

    fig.add_subplot(4, 4, 2)
    plt.imshow(training_data[index][1].movedim(0, 2))
    plt.axis("off")
    plt.title("mask")

    fig.add_subplot(4, 4, 3)
    plt.imshow((training_data[index][0]*(training_data[index]
               [1][0]+0.5*training_data[index][1][2])).movedim(0, 2))
    plt.axis("off")
    plt.title("masked")

    fig.add_subplot(4, 4, 4)
    plt.imshow(model_output.movedim(0, 2))
    plt.axis("off")
    plt.title("model output")

    fig.add_subplot(4, 4, 5)
    

    object_pixels = nn.functional.one_hot(model_output.argmax(0),3)[:,:,0]
    edge_pixels = 0.5*nn.functional.one_hot(model_output.argmax(0),3)[:,:,2]
    mask = object_pixels + edge_pixels
    plt.imshow((training_data[index][0]*mask).movedim(0,2))
    plt.axis("off")
    plt.title("model masked")

    plt.show()

_= model.train()


In [None]:
training_loader = DataLoader(training_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8, shuffle=True)

In [None]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)


In [None]:
num_epochs = 10
num_mini_batches = len(training_loader)
train_losses = [None]*num_mini_batches*num_epochs
test_losses = [None]*num_epochs


for epoch in range(num_epochs):
    running_loss = 0.0
    #train one epoch
    for batch_number, (inputs, labels) in enumerate(training_loader):
        optimizer.zero_grad()
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        
        
        running_loss += loss.item()
        train_losses[epoch*num_mini_batches + batch_number] = loss.item()
        
        del loss
        
        if batch_number  % 10 == 9:    # print every 10 mini-batches
            print('[%d, %5d] loss: %.7f' %
                  (epoch, batch_number+1, running_loss / 10))
            running_loss = 0.0
    with torch.no_grad():
        test_loss = 0.0
        test_iterator = iter(test_loader)
        for i in range(len(test_loader)//10):
            inputs, labels = next(test_iterator)
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            test_loss += loss_fn(outputs, labels).item()
            del inputs
            del labels
            del outputs 
        test_loss = test_loss / (len(test_loader)//10)
        
        print('Epoch %d test loss = %.7f' % (epoch, test_loss))
        test_losses[epoch] = test_loss
        
        del test_iterator
    if torch.cuda.is_available():
     torch.cuda.empty_cache()
    

In [None]:


x_train = [*range(num_epochs*num_mini_batches)]
x_test = [*range(num_mini_batches,(num_epochs+1)*num_mini_batches,num_mini_batches)]
plt.plot(x_train,train_losses,label = "training loss")
plt.title("loss")

plt.plot(x_test,test_losses,label = "test loss")

plt.show()

In [None]:
model.eval()
with torch.no_grad():
    index = torch.randint(len(test_data),(1,)).item()

    model_output = model(test_data[index][0].unsqueeze(
        0).to(device)).squeeze().cpu()

    fig = plt.figure(figsize=(8, 8))
    fig.add_subplot(4, 4, 1)
    plt.imshow(test_data[index][0].movedim(0, 2))
    plt.axis("off")
    plt.title("image")

    fig.add_subplot(4, 4, 2)
    plt.imshow(test_data[index][1].movedim(0, 2))
    plt.axis("off")
    plt.title("mask")

    fig.add_subplot(4, 4, 3)
    plt.imshow((test_data[index][0]*(test_data[index]
               [1][0]+0.5*test_data[index][1][2])).movedim(0, 2))
    plt.axis("off")
    plt.title("masked")

    fig.add_subplot(4, 4, 4)
    plt.imshow(model_output.movedim(0, 2))
    plt.axis("off")
    plt.title("model output")

    fig.add_subplot(4, 4, 5)
    

    object_pixels = nn.functional.one_hot(model_output.argmax(0),3)[:,:,0]
    edge_pixels = 0.5*nn.functional.one_hot(model_output.argmax(0),3)[:,:,2]
    mask = object_pixels + edge_pixels
    plt.imshow((test_data[index][0]*mask).movedim(0,2))
    plt.axis("off")
    plt.title("model masked")

    plt.show()

_= model.train()


In [None]:
torch.save(model.state_dict(),"../model_name")