In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import ignite
from torchvision import transforms
from torch.utils.data import DataLoader
import ignite
from ignite import metrics, engine
import torch.optim as optim
import torch_xla_py.xla_model as xm
from torchsummary import summary

In [0]:
class Crop(nn.Module):
    def __init__(self, axis = 2, offset = 2):
        super(Crop, self).__init__()
        self.axis = axis
        self.offset = offset
    def __repr__(self):
        return 'Crop(axis=%d, offset=%d)' % (self.axis, self.offset)

    def forward(self, x, ref):
        for axis in range(self.axis, x.dim()):
            ref_size = ref.size(axis)
            indices = torch.arange(self.offset, self.offset + ref_size).long()
            indices = x.data.new().resize_(indices.size()).copy_(indices)
            x = x.index_select(axis, Variable(indices))
        return x


In [0]:
class FCN8s(nn.Module):
    def __init__(self, classes, input_shape = (3, 224, 224)):
        super(FCN8s, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU(),
            nn.Conv2d(64, 64, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU()
            )
        self.maxpool1 = nn.MaxPool2d((2,2), stride = 2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, stride = 1, padding = 1, kernel_size =3),
            nn.ReLU(),
            nn.Conv2d(128, 128, stride = 1, padding = 1, kernel_size =3),
            nn.ReLU()
            )
        self.maxpool2 = nn.MaxPool2d((2,2), stride = 2)
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU(),
            nn.Conv2d(256, 256, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU(),
            nn.Conv2d(256, 256, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU()
            )             
        self.maxpool3 = nn.MaxPool2d((2,2), stride = 2)
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU(),
            nn.Conv2d(512, 512, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU(),
            nn.Conv2d(512, 512, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU()
            )
        self.maxpool4 = nn.MaxPool2d((2,2), stride = 2)
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 512, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU(),
            nn.Conv2d(512, 512, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU(),
            nn.Conv2d(512, 512, stride = 1, padding = 1, kernel_size = 3),
            nn.ReLU()
            )
        self.maxpool5 = nn.MaxPool2d((2,2), stride = 2)
        self.conv6 = nn.Sequential(
            nn.Conv2d(512, 4096, padding = None, kernel_size = 1),
            nn.ReLU(),
            nn.Dropout(0.5)
            )
        self.conv7 = nn.Sequential(
            nn.Conv2d(4096, 4096, padding = None, kernel_size = 7),
            nn.ReLU(),
            nn.Dropout(0.5)
            )
        self.score_7 = nn.Conv2d(4096, classes, kernel_size = 1)
        self.upscore_7 = nn.ConvTranspose2d(classes, classes, kernel_size = 4, stride = 2)
        self.score_pool4 = nn.Conv2d(512, classes, kernel_size = 1)
        self.crop_4 = Crop()
        self.upscore_7_4 = nn.ConvTranspose2d(classes, classes, kernel_size = 4, stride = 2)
        self.score_pool3 = nn.Conv2d(256, classes, kernel_size = 1) 
        self.crop_3 = Crop()
        self.upscore_7_4_3 = nn.ConvTranspose2d(classes, classes, kernel_size = 16, stride = 8)
    def forward(self, input):
        x = self.conv1(input)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x_pool3 = self.maxpool3(x)
        x_score3 = self.score_pool3(x_pool3)
        x = self.conv4(x_pool3)
        x_pool4 = self.maxpool4(x)
        x_score4 = self.score_pool4(x_pool4)
        x = self.conv5(x_pool4)
        x = self.maxpool5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x_score7 = self.score_7(x)
        x_score7 = self.upscore_7(x_score7)
        x_score4 = self.crop_4(x_score4, x_score7)
        x_score_7_4 = x_score4 + x_score7
        x_score_7_4 = self.upscore_7_4(x_score_7_4)
        x_score3 = self.crop_3(x_score3, x_score_7_4)
        x_score_7_4_3 = x_score3 + x_score_7_4
        X = self.upscore_7_4_3(x_score_7_4_3)
        output = F.softmax(X)
        return output

In [0]:
FCN = FCN8s(21)

In [0]:
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(FCN.parameters(), 0.001)

In [0]:
VERSION = "1.5"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

In [0]:

iter = 0
for epoch in range(num_epochs):
    for i, (images, mask) in enumerate(train_loader):
        images = images.requires_grad_()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        iter += 1
        if iter % 500 == 0:
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.requires_grad_()
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)

                correct += (predicted == labels).sum()
            accuracy = 100 * correct / total

            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))