In [2]:
from sklearn.model_selection import train_test_split
import os,gc
import cv2
import random
import numpy as np
from PIL import Image, ImageStat
from torch.utils.data import Dataset, DataLoader,random_split
from torch import randperm,unique
from torch._utils import _accumulate
import torchvision.transforms.functional as F
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.optim import lr_scheduler
import time
from skimage import io, transform
import sklearn.metrics as skm
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch import manual_seed,zeros_like,zeros,ones,unique,autograd,device,cuda,cat,save,load,tensor,utils,rand
from torch import sum as torch_sum
manual_seed(17)
random.seed(17)

In [3]:
class LocationAwareConv2d(nn.Conv2d):
    def __init__(self,lb,w,h,in_channels, out_channels, kernel_size, bias=True):
        super().__init__(in_channels, out_channels, kernel_size, bias=bias)
        self.locationBias=lb
        
    def forward(self,inputs):
        if self.locationBias.device != inputs.device:
            self.locationBias=self.locationBias.to(inputs.get_device())
        #b=self.locationBias
        
        return super().forward(inputs)+ self.locationBias#b[0:inputs.shape[2],0:inputs.shape[3],0]+b[0:inputs.shape[2],0:inputs.shape[3],1]+b[0:inputs.shape[2],0:inputs.shape[3],2]


PAD = 0
class NimbroNet2(nn.Module):
  def __init__(self):
    super(NimbroNet2,self).__init__()
    # Encoder 
    resnet_encoder = models.resnet18(pretrained=True)
    del resnet_encoder.avgpool
    del resnet_encoder.fc
    # Model Encoders (Resnet)
    self.resnet_encoder_block1 = nn.Sequential(*list(resnet_encoder.children())[0:5])
    self.resnet_encoder_block2 = nn.Sequential(*list(resnet_encoder.children())[5:6])
    self.resnet_encoder_block3 = nn.Sequential(*list(resnet_encoder.children())[6:7])
    self.resnet_encoder_block4 = nn.Sequential(*list(resnet_encoder.children())[7:8]) # 512

    # Model Decoders
    self.decoder_block_1 = decoder_block(1, 512, 256, 2, 2, 0, 0)
    self.decoder_block_2 = decoder_block(2, 512, 256, 2, 2, 0, 0)
    self.decoder_block_3 = decoder_block(3, 512, 128, 2, 2, 0, 0)
    self.decoder_block_4 = decoder_block(4, 256)

    #intermediate convolutions
    self.conv_en3_de1 = nn.Conv2d(256, 256, kernel_size=1)   
    self.conv_en2_de2 = nn.Conv2d(128, 256, kernel_size=1) 
    self.conv_en1_de3 = nn.Conv2d(64, 128, kernel_size=1) 
    self.locationBias = nn.Parameter(rand(1, 3, 120, 160), requires_grad=True)
    
    #location dependant bias
    self.lbS = LocationAwareConv2d(self.locationBias,120,160,in_channels=256, out_channels=3, kernel_size=1, bias= True)
    self.lbD = LocationAwareConv2d(self.locationBias,120,160,in_channels=256, out_channels=3, kernel_size=1, bias= True)

  def forward(self,x):
    x = self.resnet_encoder_block1(x)
    intermediate_output_1 = self.conv_en1_de3(x)
    x = self.resnet_encoder_block2(x)
    intermediate_output_2 = self.conv_en2_de2(x)
    x = self.resnet_encoder_block3(x)
    intermediate_output_3 = self.conv_en3_de1(x)
    x = self.resnet_encoder_block4(x)
    x = self.decoder_block_1(x)
    decoder_block_2_input = cat((x,intermediate_output_3),1)
    decoder_block_2_output = self.decoder_block_2(decoder_block_2_input)
    decoder_block_3_input = cat((decoder_block_2_output,intermediate_output_2),1)
    decoder_block_3_output = self.decoder_block_3(decoder_block_3_input)
    decoder_block_4_input = cat((decoder_block_3_output,intermediate_output_1),1)
    decoder_block_4_output_0 = self.decoder_block_4(decoder_block_4_input)
    decoder_block_4_output_1 = self.lbS(decoder_block_4_output_0)
    decoder_block_4_output_2 = self.lbD(decoder_block_4_output_0)
    return decoder_block_4_output_1,decoder_block_4_output_2

def decoder_block(block_number, input_channels, output_channels=0,kernel=0,stride=0, padding=0, output_padding=0):
  if block_number == 1:
    block = nn.Sequential(
      nn.ReLU(),
      nn.ConvTranspose2d(input_channels,output_channels,kernel,stride, padding,output_padding)
      )
  elif block_number == 4:
    block = nn.Sequential(
      nn.ReLU(),
      nn.BatchNorm2d(input_channels),
      )
  else:
    block = nn.Sequential(
      nn.ReLU(),
      nn.BatchNorm2d(input_channels),
      nn.ConvTranspose2d(input_channels,output_channels,kernel,stride, padding,output_padding)
      )
  return block