In [1]:
## Write a standard FCN net model for segmentation


In [10]:
import scipy.misc as misc
import torch
import copy
import torchvision.models as models
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [18]:
class Net(nn.Module):
    def _init_(self, num_classes = 2):
        ## Standard FCN model
            super(Net, self).__init__()
            ## resnet 50 acts as an encoder
            self.Encoder = models.resnet50(pretrained=True)
            # 
            self.PSPScales = [1, 1 / 2, 1 / 4, 1 / 8]
            
            self.PSPLayers = nn.ModuleList() ## decoder layers
            
            for scale in self.PSPScales:
                self.PSPLayers.append(nn.Sequential(nn.Conv2d(2048, 1024, stride = 1, kernel_size = 3, padding = 1, bias = True)))
            
            self.PSPSqueeze = nn.Sequential(
                nn.Conv2d(4096, 512, stride = 1, kernel_size = 1, padding = 0, bias = False),
                nn.BatcNorm2d(512),
                nn.ReLU(),
                nn.Conv2d(512, 512, stride = 1, kernel_size = 3, padding = 0, bias = False),
                nn.BatchNorm2d(512),
                nn.ReLU()
            )
            
            ## skip connection layers
            
            self.SkipConnections = nn.ModuleList()
            self.SkipConnections.append(nn.Sequential(
                nn.Conv2d(1024, 512, stride = 1, kernel_size = 1, padding = 0, bias = False),
                nn.BatchNorm2d(512),
                nn.ReLU()))
            self.SkipConnections.append(nn.Sequential(
                nn.Conv2d(512, 256, stride = 1, kernel_size = 1, padding = 0, bias = False),
                nn.BatchNorm2d(256),
                nn.ReLU()))
            self.SkipConnections.append(nn.Sequential(
                nn.Conv2d(256, 128, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU()))
            
            ## skip squeeze applied
            
            self.SqueezeUpsample = nn.ModuleList()
            self.SqueezeUpsample.append(nn.Sequential(
                nn.Conv2d(1024, 512, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU()))
            self.SqueezeUpsample.append(nn.Sequential(
                nn.Conv2d(256 + 512, 256, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU()))
            self.SqueezeUpsample.append(nn.Sequential(
                nn.Conv2d(256 + 128, 128, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU()))
            
            ## final prediction for region/pixel
            
            self.FinalPrediction= nn.Conv2d(128, NumClasses, stride = 1, kernel_size = 3, padding = 1, bias = False)
            
            ## an attention layer combining the pointer pixel and ROI mask
    def AddAttentionLayer(self):
        self.AttentionLayers = nn.ModuleList()
        self.ROIEncoder = nn.Conv2d(1, 64, stride = 1, kernel_size =3, padding = 1, bias = True)
        self.ROIEncoder.bias.data = torch.zeros(self.ROIEncoder.bias.data.shape)
        self.ROIEncoder.weight.data = torch.zeros(self.ROIEncoder.weight.data.shape)
        
        self.PointerEncoder = nn.Conv2d(1, 64, stride = 1, kernel_size = 3, padding = 1, bias = True)
        self.PointerEncoder.bias.data = torch.zeros(self.PointerEncoder.bias.data.shape)
        self.PointerEncoder.weight.data = torch.zeros(self.PointerEncoder.weight.data.shape)
    
    def forward(self,Images,Pointer,ROI,UseGPU=True):

        #Convert image to pytorch 
        RGBMean = [123.68,116.779,103.939]
        RGBStd = [65,65,65]
        InpImages = torch.autograd.Variable(torch.from_numpy(Images.astype(float)), requires_grad=False).transpose(2,3).transpose(1, 2).type(torch.FloatTensor)

        #Convert ROI mask and pointer point mask into pytorch
        ROImap = torch.autograd.Variable(torch.from_numpy(ROI.astype(np.float)), requires_grad=False).unsqueeze(dim=1).type(torch.FloatTensor)
        Pointermap = torch.autograd.Variable(torch.from_numpy(Pointer.astype(np.float)), requires_grad=False).unsqueeze(dim=1).type(torch.FloatTensor)

        #Normalize image values
        for i in range(len(RGBMean)): InpImages[:, i, :, :]=(InpImages[:, i, :, :]-RGBMean[i])/RGBStd[i] # normalize image values
        x=InpImages
        SkipConFeatures=[] 

        #Run Encoder first layer
        x = self.Encoder.conv1(x)
        x = self.Encoder.bn1(x)

        #Convert ROI mask and pointer map into attention layer and merge with image feature mask
        r = self.ROIEncoder(ROImap) 
        pt = self.PointerEncoder(Pointermap) 
        sp = (x.shape[2], x.shape[3])
        pt = nn.functional.interpolate(pt, size=sp, mode='bilinear')  #
        r = nn.functional.interpolate(r, size=sp, mode='bilinear')  # Resize
        x = x* pt + r # Merge feature mask and attention maps

        #Run remaining encoder layer
        x = self.Encoder.relu(x)
        x = self.Encoder.maxpool(x)
        x = self.Encoder.layer1(x)
        SkipConFeatures.append(x)
        x = self.Encoder.layer2(x)
        SkipConFeatures.append(x)
        x = self.Encoder.layer3(x)
        SkipConFeatures.append(x)

        x = self.Encoder.layer4(x)
        PSPSize=(x.shape[2],x.shape[3]) # Size of the original features map

        PSPFeatures=[] 
        for i,PSPLayer in enumerate(self.PSPLayers): 
            NewSize=(np.array(PSPSize)*self.PSPScales[i]).astype(np.int)
            if NewSize[0] < 1: NewSize[0] = 1
            if NewSize[1] < 1: NewSize[1] = 1

            # print(str(i)+")"+str(NewSize))
            y = nn.functional.interpolate(x, tuple(NewSize), mode='bilinear')
            #print(y.shape)
            y = PSPLayer(y)
            y = nn.functional.interpolate(y, PSPSize, mode='bilinear')

            PSPFeatures.append(y)
        x=torch.cat(PSPFeatures,dim=1)
        x=self.PSPSqueeze(x)
        for i in range(len(self.SkipConnections)):
            sp=(SkipConFeatures[-1-i].shape[2],SkipConFeatures[-1-i].shape[3])
            x=nn.functional.interpolate(x,size=sp,mode='bilinear') #Resize
            x = torch.cat((self.SkipConnections[i](SkipConFeatures[-1-i]),x), dim=1)
            x = self.SqueezeUpsample[i](x)
        # Final prediction
        x = self.FinalPrdiction(x) # Make prediction per pixel
        x = nn.functional.interpolate(x,size=InpImages.shape[2:4],mode='bilinear') # Resize to original image size

        Prob=F.softmax(x,dim=1) 
        tt,Labels=x.max(1) 
        return Prob,Labels




