<a href="https://colab.research.google.com/github/nan-hk/motional_artifacts_dnn/blob/master/Test_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Different Modalities Dataset
 

1.   RGB as depth dataset without noise
2.   Depth as depth dataset with noise
3.   GT as GT dataset



In [None]:
import zipfile
from google.colab import drive

drive.mount('/content/drive/')
zip_ref = zipfile.ZipFile("/content/drive/My Drive/testdataset_only_depth.zip", 'r')
zip_ref.extractall("/content/tmp")
zip_ref.close()

Mounted at /content/drive/


In [None]:
zip_ref = zipfile.ZipFile("/content/drive/My Drive/testdataset.zip", 'r')
zip_ref.extractall("/content/tmp")
zip_ref.close()

## Res2Net Model

In [None]:
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F
__all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b']


model_urls = {
    'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
    'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
}


class Bottle2neck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'):
        """ Constructor
        Args:
            inplanes: input channel dimensionality
            planes: output channel dimensionality
            stride: conv stride. Replaces pooling layer.
            downsample: None when stride = 1
            baseWidth: basic width of conv3x3
            scale: number of scale.
            type: 'normal': normal set. 'stage': first block of a new stage.
        """
        super(Bottle2neck, self).__init__()

        width = int(math.floor(planes * (baseWidth/64.0)))
        self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width*scale)
        
        if scale == 1:
          self.nums = 1
        else:
          self.nums = scale -1
        if stype == 'stage':
            self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1)
        convs = []
        bns = []
        for i in range(self.nums):
          convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False))
          bns.append(nn.BatchNorm2d(width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)

        self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stype = stype
        self.scale = scale
        self.width  = width

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
          if i==0 or self.stype=='stage':
            sp = spx[i]
          else:
            sp = sp + spx[i]
          sp = self.convs[i](sp)
          sp = self.relu(self.bns[i](sp))
          if i==0:
            out = sp
          else:
            out = torch.cat((out, sp), 1)
        if self.scale != 1 and self.stype=='normal':
          out = torch.cat((out, spx[self.nums]),1)
        elif self.scale != 1 and self.stype=='stage':
          out = torch.cat((out, self.pool(spx[self.nums])),1)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Res2Net(nn.Module):

    def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000):
        self.inplanes = 64
        super(Res2Net, self).__init__()
        self.baseWidth = baseWidth
        self.scale = scale
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, 1, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False)
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.AvgPool2d(kernel_size=stride, stride=stride, 
                    ceil_mode=True, count_include_pad=False),
                nn.Conv2d(self.inplanes, planes * block.expansion, 
                    kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample=downsample, 
                        stype='stage', baseWidth = self.baseWidth, scale=self.scale))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale))

        return nn.Sequential(*layers)

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x0 = self.maxpool(x)
        

        x1 = self.layer1(x0)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        x5 = self.avgpool(x4)
        x6 = x5.view(x5.size(0), -1)
        x7 = self.fc(x6)

        return x7



class Res2Net_Ours(nn.Module):

    def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000):
        self.inplanes = 64
        super(Res2Net_Ours, self).__init__()
        
        self.baseWidth = baseWidth
        self.scale = scale
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, 1, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False)
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
       

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.AvgPool2d(kernel_size=stride, stride=stride, 
                    ceil_mode=True, count_include_pad=False),
                nn.Conv2d(self.inplanes, planes * block.expansion, 
                    kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample=downsample, 
                        stype='stage', baseWidth = self.baseWidth, scale=self.scale))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale))

        return nn.Sequential(*layers)

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x0 = self.maxpool(x)
        

        x1 = self.layer1(x0)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)


        return x0,x1,x2,x3,x4
    
    

def res2net50_v1b(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_v1b model.
    Res2Net-50 refers to the Res2Net-50_v1b_26w_4s.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'],map_location='cpu'))
    return model

def res2net101_v1b(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_v1b_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
    return model



def res2net50_v1b_Ours(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_v1b model.
    Res2Net-50 refers to the Res2Net-50_v1b_26w_4s.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net_Ours(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s']))
    return model

def res2net101_v1b_Ours(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_v1b_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net_Ours(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
    return model



def res2net50_v1b_26w_4s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_v1b_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'],map_location='cpu'))
    return model

def res2net101_v1b_26w_4s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_v1b_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
    return model

def res2net152_v1b_26w_4s(pretrained=False, **kwargs):
    """Constructs a Res2Net-50_v1b_26w_4s model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth = 26, scale = 4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s']))
    return model



   
def Res2Net_model(ind=50):
    
    if ind == 50:
        model_base = res2net50_v1b(pretrained=True)
        model      = res2net50_v1b_Ours()

    if ind == 101:
        model_base = res2net101_v1b(pretrained=True)
        model      = res2net101_v1b_Ours()
        
        
    pretrained_dict = model_base.state_dict()
    model_dict      = model.state_dict()
    
    pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict}
    
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    
    return model





if __name__ == '__main__':
    images = torch.rand(1, 3, 352, 352)
    model = res2net50_v1b_26w_4s(pretrained=False)
    model = model

## Salient Object Detection Model

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn import functional as F

def maxpool():
    pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    return pool


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x



#Global Contextual module
class GCM(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(GCM, self).__init__()
        self.relu = nn.ReLU(True)
        self.branch0 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
        )
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
            BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
        )
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
            BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
        )
        self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
        self.conv_res = BasicConv2d(in_channel, out_channel, 1)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)

        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))

        x = self.relu(x_cat + self.conv_res(x))
        return x



###############################################################################

class CIM0(nn.Module):    
    def __init__(self,in_dim, out_dim):
        super(CIM0, self).__init__()
        
        act_fn = nn.ReLU(inplace=True)
        

        self.layer_10 = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1)
        self.layer_20 = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1)   
        
        self.layer_11 = nn.Sequential(nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_dim),act_fn,)        
        self.layer_21 = nn.Sequential(nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_dim),act_fn,)
        
        self.gamma1 = nn.Parameter(torch.zeros(1))
        self.gamma2 = nn.Parameter(torch.zeros(1))
        

        self.layer_ful1 = nn.Sequential(nn.Conv2d(out_dim*2, out_dim, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_dim),act_fn,)
        

    def forward(self, rgb, depth):
        
        ################################
        
        x_rgb = self.layer_10(rgb)
        x_dep = self.layer_20(depth)
        
        rgb_w = nn.Sigmoid()(x_rgb)
        dep_w = nn.Sigmoid()(x_dep)
        
        ##
        x_rgb_w = rgb.mul(dep_w)
        x_dep_w = depth.mul(rgb_w)
        
        x_rgb_r = x_rgb_w + rgb
        x_dep_r = x_dep_w + depth
        
        ## fusion 
        x_rgb_r = self.layer_11(x_rgb_r)
        x_dep_r = self.layer_21(x_dep_r)
        
        
        ful_mul = torch.mul(x_rgb_r, x_dep_r)         
        x_in1   = torch.reshape(x_rgb_r,[x_rgb_r.shape[0],1,x_rgb_r.shape[1],x_rgb_r.shape[2],x_rgb_r.shape[3]])
        x_in2   = torch.reshape(x_dep_r,[x_dep_r.shape[0],1,x_dep_r.shape[1],x_dep_r.shape[2],x_dep_r.shape[3]])
        x_cat   = torch.cat((x_in1, x_in2),dim=1)
        ful_max = x_cat.max(dim=1)[0]
        ful_out = torch.cat((ful_mul,ful_max),dim=1)
        
        out1 = self.layer_ful1(ful_out)
         
        return out1


class CIM(nn.Module):    
    def __init__(self,in_dim, out_dim):
        super(CIM, self).__init__()
        
        act_fn = nn.ReLU(inplace=True)
        
        self.reduc_1 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1), act_fn)
        self.reduc_2 = nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=1), act_fn)
        
        self.layer_10 = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1)
        self.layer_20 = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1)   
        
        self.layer_11 = nn.Sequential(nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_dim),act_fn,)        
        self.layer_21 = nn.Sequential(nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_dim),act_fn,)
        
        self.gamma1 = nn.Parameter(torch.zeros(1))
        self.gamma2 = nn.Parameter(torch.zeros(1))
        

        self.layer_ful1 = nn.Sequential(nn.Conv2d(out_dim*2, out_dim, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_dim),act_fn,)
        self.layer_ful2 = nn.Sequential(nn.Conv2d(out_dim+out_dim//2, out_dim, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_dim),act_fn,)

    def forward(self, rgb, depth, xx):
        
        ################################
        x_rgb = self.reduc_1(rgb)
        x_dep = self.reduc_2(depth)
        
        x_rgb1 = self.layer_10(x_rgb)
        x_dep1 = self.layer_20(x_dep)
        
        rgb_w = nn.Sigmoid()(x_rgb1)
        dep_w = nn.Sigmoid()(x_dep1)
        
        ##
        x_rgb_w = x_rgb.mul(dep_w)
        x_dep_w = x_dep.mul(rgb_w)
        
        x_rgb_r = x_rgb_w + x_rgb
        x_dep_r = x_dep_w + x_dep
        
        ## fusion 
        x_rgb_r = self.layer_11(x_rgb_r)
        x_dep_r = self.layer_21(x_dep_r)
        
        
        ful_mul = torch.mul(x_rgb_r, x_dep_r)         
        x_in1   = torch.reshape(x_rgb_r,[x_rgb_r.shape[0],1,x_rgb_r.shape[1],x_rgb_r.shape[2],x_rgb_r.shape[3]])
        x_in2   = torch.reshape(x_dep_r,[x_dep_r.shape[0],1,x_dep_r.shape[1],x_dep_r.shape[2],x_dep_r.shape[3]])
        x_cat   = torch.cat((x_in1, x_in2),dim=1)
        ful_max = x_cat.max(dim=1)[0]
        ful_out = torch.cat((ful_mul,ful_max),dim=1)
        
        out1 = self.layer_ful1(ful_out)
        out2 = self.layer_ful2(torch.cat([out1,xx],dim=1))
         
        return out2



class MFA(nn.Module):    
    def __init__(self,in_dim):
        super(MFA, self).__init__()
         
        self.relu = nn.ReLU(inplace=True)
        
        self.layer_10 = nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1)
        self.layer_20 = nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1)   
        self.layer_cat1 = nn.Sequential(nn.Conv2d(in_dim*2, in_dim, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(in_dim),)        
        
    def forward(self, x_ful, x1, x2):
        
        ################################
    
        x_ful_1 = x_ful.mul(x1)
        x_ful_2 = x_ful.mul(x2)
        
     
        x_ful_w = self.layer_cat1(torch.cat([x_ful_1, x_ful_2],dim=1))
        out     = self.relu(x_ful + x_ful_w)
        
        return out
    
    

  
   
###############################################################################

class SPNet(nn.Module):
    def __init__(self, channel=32,ind=50):
        super(SPNet, self).__init__()
        
       
        self.relu = nn.ReLU(inplace=True)
        
        self.upsample_2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.upsample_4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        self.upsample_8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
        
        #Backbone model
        #Backbone model
        self.layer_rgb  = Res2Net_model(ind)
        self.layer_dep  = Res2Net_model(ind)
        
        self.layer_dep0 = nn.Conv2d(1, 3, kernel_size=1)
        
        ###############################################
        # funsion encoders #
        ###############################################
        self.fu_0 = CIM0(64, 64)#
        
        self.fu_1 = CIM(256, 128) #MixedFusion_Block_IMfusion
        self.pool_fu_1 = maxpool()
        
        self.fu_2 = CIM(512, 256)
        self.pool_fu_2 = maxpool()
        
        self.fu_3 = CIM(1024, 512)
        self.pool_fu_3 = maxpool()

        self.fu_4 = CIM(2048, 1024)
        self.pool_fu_4 = maxpool()
        
        
        ###############################################
        # decoders #
        ###############################################
        
        ## rgb
        self.rgb_conv_4   = nn.Sequential(BasicConv2d(2048,    256, 3, padding=1),self.relu)
        self.rgb_gcm_4    = GCM(2048,  channel)
        
        self.rgb_conv_3   = nn.Sequential(BasicConv2d(1024+32, 256, 3, padding=1),self.relu)
        self.rgb_gcm_3    = GCM(1024+32,  channel)

        self.rgb_conv_2   = nn.Sequential(BasicConv2d(512+32, 128, 3, padding=1),self.relu)
        self.rgb_gcm_2    = GCM(512+32,  channel)

        self.rgb_conv_1   = nn.Sequential(BasicConv2d(256+32, 128, 3, padding=1),self.relu)
        self.rgb_gcm_1    = GCM(256+32,  channel)

        self.rgb_conv_0   = nn.Sequential(BasicConv2d(64+32, 64, 3, padding=1),self.relu)
        self.rgb_gcm_0    = GCM(64+32,  channel)        
        self.rgb_conv_out = nn.Conv2d(channel, 1, 1)
        
        ## depth
        self.dep_conv_4   = nn.Sequential(BasicConv2d(2048, 256, 3, padding=1),self.relu)
        self.dep_gcm_4    = GCM(2048,  channel)
        
        self.dep_conv_3   = nn.Sequential(BasicConv2d(1024+32, 256, 3, padding=1),self.relu)
        self.dep_gcm_3    = GCM(1024+32,  channel)

        self.dep_conv_2   = nn.Sequential(BasicConv2d(512+32, 128, 3, padding=1),self.relu)
        self.dep_gcm_2    = GCM(512+32,  channel)

        self.dep_conv_1   = nn.Sequential(BasicConv2d(256+32, 128, 3, padding=1),self.relu)
        self.dep_gcm_1    = GCM(256+32,  channel)

        self.dep_conv_0   = nn.Sequential(BasicConv2d(64+32, 64, 3, padding=1),self.relu)
        self.dep_gcm_0    = GCM(64+32,  channel)        
        self.dep_conv_out = nn.Conv2d(channel, 1, 1)
        
        ## fusion
        self.ful_conv_4   = nn.Sequential(BasicConv2d(2048, 256, 3, padding=1),self.relu)
        self.ful_gcm_4    = GCM(1024,  channel)
        
        self.ful_conv_3   = nn.Sequential(BasicConv2d(1024+32*3, 256, 3, padding=1),self.relu)
        self.ful_gcm_3    = GCM(512+32,  channel)

        self.ful_conv_2   = nn.Sequential(BasicConv2d(512+32*3, 128, 3, padding=1),self.relu)
        self.ful_gcm_2    = GCM(256+32,  channel)

        self.ful_conv_1   = nn.Sequential(BasicConv2d(256+32*3, 128, 3, padding=1),self.relu)
        self.ful_gcm_1    = GCM(128+32,  channel)

        self.ful_conv_0   = nn.Sequential(BasicConv2d(128+32*3, 64, 3, padding=1),self.relu)
        self.ful_gcm_0    = GCM(64+32,  channel)        
        self.ful_conv_out = nn.Conv2d(channel, 1, 1)
        
        self.ful_layer4   = MFA(channel)
        self.ful_layer3   = MFA(channel)
        self.ful_layer2   = MFA(channel)
        self.ful_layer1   = MFA(channel)
        self.ful_layer0   = MFA(channel)
        
                

    def forward(self, imgs, depths):
        
        img_0, img_1, img_2, img_3, img_4 = self.layer_rgb(imgs)
        dep_0, dep_1, dep_2, dep_3, dep_4 = self.layer_dep(self.layer_dep0(depths))
        
    
      
        ####################################################
        ## fusion
        ####################################################
        ful_0    = self.fu_0(img_0, dep_0)
        ful_1    = self.fu_1(img_1, dep_1, ful_0)
        ful_2    = self.fu_2(img_2, dep_2, self.pool_fu_1(ful_1))
        ful_3    = self.fu_3(img_3, dep_3, self.pool_fu_2(ful_2))
        ful_4    = self.fu_4(img_4, dep_4, self.pool_fu_3(ful_3))
        
        ####################################################
        ## decoder rgb
        ####################################################        
        #
        x_rgb_42    = self.rgb_gcm_4(img_4)
        
        x_rgb_3_cat = torch.cat([img_3, self.upsample_2(x_rgb_42)], dim=1)
        x_rgb_32    = self.rgb_gcm_3(x_rgb_3_cat)
        
        x_rgb_2_cat = torch.cat([img_2, self.upsample_2(x_rgb_32)], dim=1)
        x_rgb_22    = self.rgb_gcm_2(x_rgb_2_cat)        

        x_rgb_1_cat = torch.cat([img_1, self.upsample_2(x_rgb_22)], dim=1)
        x_rgb_12    = self.rgb_gcm_1(x_rgb_1_cat)     

        x_rgb_0_cat = torch.cat([img_0, x_rgb_12], dim=1)
        x_rgb_02    = self.rgb_gcm_0(x_rgb_0_cat)     
        rgb_out     = self.upsample_4(self.rgb_conv_out(x_rgb_02))
        
        
        ####################################################
        ## decoder depth
        ####################################################        
        #
        x_dep_42    = self.dep_gcm_4(dep_4)
        
        x_dep_3_cat = torch.cat([dep_3, self.upsample_2(x_dep_42)], dim=1)
        x_dep_32    = self.dep_gcm_3(x_dep_3_cat)
        
        x_dep_2_cat = torch.cat([dep_2, self.upsample_2(x_dep_32)], dim=1)
        x_dep_22    = self.dep_gcm_2(x_dep_2_cat)        

        x_dep_1_cat = torch.cat([dep_1, self.upsample_2(x_dep_22)], dim=1)
        x_dep_12    = self.dep_gcm_1(x_dep_1_cat)     

        x_dep_0_cat = torch.cat([dep_0, x_dep_12], dim=1)
        x_dep_02    = self.dep_gcm_0(x_dep_0_cat)     
        dep_out     = self.upsample_4(self.dep_conv_out(x_dep_02))
        

        ####################################################
        ## decoder fusion
        ####################################################        
        #
        x_ful_42    = self.ful_gcm_4(ful_4)
        
        x_ful_3_cat = torch.cat([ful_3, self.ful_layer3(self.upsample_2(x_ful_42),self.upsample_2(x_rgb_42),self.upsample_2(x_dep_42))], dim=1)
        x_ful_32    = self.ful_gcm_3(x_ful_3_cat)
        
        x_ful_2_cat = torch.cat([ful_2, self.ful_layer2(self.upsample_2(x_ful_32),self.upsample_2(x_rgb_32),self.upsample_2(x_dep_32))], dim=1)
        x_ful_22    = self.ful_gcm_2(x_ful_2_cat)        

        x_ful_1_cat = torch.cat([ful_1, self.ful_layer1(self.upsample_2(x_ful_22),self.upsample_2(x_rgb_22),self.upsample_2(x_dep_22))], dim=1)
        x_ful_12    = self.ful_gcm_1(x_ful_1_cat)     

        x_ful_0_cat = torch.cat([ful_0, self.ful_layer0(x_ful_12, x_rgb_12, x_dep_12)], dim=1)
        x_ful_02    = self.ful_gcm_0(x_ful_0_cat)     
        ful_out     = self.upsample_4(self.ful_conv_out(x_ful_02))


        return rgb_out, dep_out, ful_out
    
    

    def _make_agant_layer(self, inplanes, planes):
        layers = nn.Sequential(
            nn.Conv2d(inplanes, planes, kernel_size=1,
                      stride=1, padding=0, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True)
        )
        return layers

    def _make_transpose(self, block, planes, blocks, stride=1):
        upsample = None
        if stride != 1:
            upsample = nn.Sequential(
                nn.ConvTranspose2d(self.inplanes, planes,
                                   kernel_size=2, stride=stride,
                                   padding=0, bias=False),
                nn.BatchNorm2d(planes),
            )
        elif self.inplanes != planes:
            upsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )

        layers = []

        for i in range(1, blocks):
            layers.append(block(self.inplanes, self.inplanes))

        layers.append(block(self.inplanes, planes, stride, upsample))
        self.inplanes = planes

        return nn.Sequential(*layers)

## Data Preprocessing Functions

In [None]:
import os
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms
import random
import numpy as np
from PIL import ImageEnhance

#several data augumentation strategies
def cv_random_flip(img, label,depth):
    flip_flag = random.randint(0, 1)
    # flip_flag2= random.randint(0,1)
    #left right flip
    if flip_flag == 1:
        img = img.transpose(Image.FLIP_LEFT_RIGHT)
        label = label.transpose(Image.FLIP_LEFT_RIGHT)
        depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
    #top bottom flip
    # if flip_flag2==1:
    #     img = img.transpose(Image.FLIP_TOP_BOTTOM)
    #     label = label.transpose(Image.FLIP_TOP_BOTTOM)
    #     depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
    return img, label, depth
def randomCrop(image, label,depth):
    border=30
    image_width = image.size[0]
    image_height = image.size[1]
    crop_win_width = np.random.randint(image_width-border , image_width)
    crop_win_height = np.random.randint(image_height-border , image_height)
    random_region = (
        (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
        (image_height + crop_win_height) >> 1)
    return image.crop(random_region), label.crop(random_region),depth.crop(random_region)
def randomRotation(image,label,depth):
    mode=Image.BICUBIC
    if random.random()>0.8:
        random_angle = np.random.randint(-15, 15)
        image=image.rotate(random_angle, mode)
        label=label.rotate(random_angle, mode)
        depth=depth.rotate(random_angle, mode)
    return image,label,depth
def colorEnhance(image):
    bright_intensity=random.randint(5,15)/10.0
    image=ImageEnhance.Brightness(image).enhance(bright_intensity)
    contrast_intensity=random.randint(5,15)/10.0
    image=ImageEnhance.Contrast(image).enhance(contrast_intensity)
    color_intensity=random.randint(0,20)/10.0
    image=ImageEnhance.Color(image).enhance(color_intensity)
    sharp_intensity=random.randint(0,30)/10.0
    image=ImageEnhance.Sharpness(image).enhance(sharp_intensity)
    return image
def randomGaussian(image, mean=0.1, sigma=0.35):
    def gaussianNoisy(im, mean=mean, sigma=sigma):
        for _i in range(len(im)):
            im[_i] += random.gauss(mean, sigma)
        return im
    img = np.asarray(image)
    width, height = img.shape
    img = gaussianNoisy(img[:].flatten(), mean, sigma)
    img = img.reshape([width, height])
    return Image.fromarray(np.uint8(img))
def randomPeper(img):

    img=np.array(img)
    noiseNum=int(0.0015*img.shape[0]*img.shape[1])
    for i in range(noiseNum):

        randX=random.randint(0,img.shape[0]-1)  

        randY=random.randint(0,img.shape[1]-1)  

        if random.randint(0,1)==0:  

            img[randX,randY]=0  

        else:  

            img[randX,randY]=255 
    return Image.fromarray(img)  

# dataset for training
#The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
#(e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.

class SalObjDataset(data.Dataset):
    def __init__(self, image_root, gt_root,depth_root, trainsize):
        self.trainsize = trainsize
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
        self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
                    or f.endswith('.png')]
        self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
                    or f.endswith('.png')]
        self.images = sorted(self.images)
        self.gts = sorted(self.gts)
        self.depths=sorted(self.depths)
        print('SalObjDat', )
        self.filter_files()
        self.size = len(self.images)
        self.img_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        self.gt_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor()])
        self.depths_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()])

    def __getitem__(self, index):
        image = self.rgb_loader(self.images[index])
        gt = self.binary_loader(self.gts[index])
        depth=self.binary_loader(self.depths[index])
        image,gt,depth =cv_random_flip(image,gt,depth)
        image,gt,depth=randomCrop(image, gt,depth)
        image,gt,depth=randomRotation(image, gt,depth)
        image=colorEnhance(image)
        # gt=randomGaussian(gt)
        gt=randomPeper(gt)
        image = self.img_transform(image)
        gt = self.gt_transform(gt)
        depth=self.depths_transform(depth)
        
        return image, gt, depth

    def filter_files(self):
        print('SalObjDataset', self.images, self.gts)
        assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images)
        images = []
        gts = []
        depths=[]
        for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths):
            img = Image.open(img_path)
            gt = Image.open(gt_path)
            depth= Image.open(depth_path)
            if img.size == gt.size and gt.size==depth.size:
                images.append(img_path)
                gts.append(gt_path)
                depths.append(depth_path)
        self.images = images
        self.gts = gts
        self.depths=depths

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('L')

    def resize(self, img, gt, depth):
        assert img.size == gt.size and gt.size==depth.size
        w, h = img.size
        if h < self.trainsize or w < self.trainsize:
            h = max(h, self.trainsize)
            w = max(w, self.trainsize)
            return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST),depth.resize((w, h), Image.NEAREST)
        else:
            return img, gt, depth

    def __len__(self):
        return self.size


###############################################################################
# 0919
#

class SalObjDataset_var(data.Dataset):
    def __init__(self, image_root, gt_root,depth_root, trainsize):
        
        self.trainsize = trainsize
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
        self.gts    = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
        self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') or f.endswith('.png')]
        self.images = sorted(self.images)
        self.gts    = sorted(self.gts)
        self.depths = sorted(self.depths)
        self.filter_files()
        self.size   = len(self.images)
        
        self.img_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        self.gt_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor()])
        self.depths_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()])

    def __getitem__(self, index):
        
        ## read imag, gt, depth
        image0 = self.rgb_loader(self.images[index])
        gt0    = self.binary_loader(self.gts[index])
        depth0 = self.binary_loader(self.depths[index])
        
        
        ##################################################
        ## out1
        ##################################################
        image,gt,depth = cv_random_flip(image0,gt0,depth0)
        image,gt,depth = randomCrop(image, gt,depth)
        image,gt,depth = randomRotation(image, gt,depth)
        image          = colorEnhance(image)
        gt             = randomPeper(gt)
        image          = self.img_transform(image)
        gt             = self.gt_transform(gt)
        depth          = self.depths_transform(depth)

        ##################################################
        ## out1
        ##################################################
        image2,gt2,depth2 = cv_random_flip(image0,gt0,depth0)
        image2,gt2,depth2 = randomCrop(image2, gt2,depth2)
        image2,gt2,depth2 = randomRotation(image2, gt2,depth2)
        image2          = colorEnhance(image2)
        gt2             = randomPeper(gt2)
        image2          = self.img_transform(image2)
        gt2             = self.gt_transform(gt2)
        depth2          = self.depths_transform(depth2)

        
        return image, gt, depth, image2, gt2, depth2

    def filter_files(self):

        
        assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images)
        images = []
        gts = []
        depths=[]
        for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths):
            img = Image.open(img_path)
            gt = Image.open(gt_path)
            depth= Image.open(depth_path)
            if img.size == gt.size and gt.size==depth.size:
                images.append(img_path)
                gts.append(gt_path)
                depths.append(depth_path)
        self.images = images
        self.gts = gts
        self.depths=depths

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('L')

    def resize(self, img, gt, depth):
        assert img.size == gt.size and gt.size==depth.size
        w, h = img.size
        if h < self.trainsize or w < self.trainsize:
            h = max(h, self.trainsize)
            w = max(w, self.trainsize)
            return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST),depth.resize((w, h), Image.NEAREST)
        else:
            return img, gt, depth

    def __len__(self):
        return self.size



class SalObjDataset_var_unlabel(data.Dataset):
    def __init__(self, image_root, gt_root,depth_root, trainsize):
        
        self.trainsize = trainsize
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.png')]
        self.gts    = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
        self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') or f.endswith('.png')]
        self.images = sorted(self.images)
        self.gts    = sorted(self.gts)
        self.depths = sorted(self.depths)
        self.filter_files()
        self.size   = len(self.images)
        
        self.img_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        self.gt_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor()])
        self.depths_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()])

    def __getitem__(self, index):
        
        ## read imag, gt, depth
        image0 = self.rgb_loader(self.images[index])
        gt0    = self.binary_loader(self.gts[index])
        depth0 = self.binary_loader(self.depths[index])
        
        
        ##################################################
        ## out1
        ##################################################
        image,gt,depth = cv_random_flip(image0,gt0,depth0)
        image,gt,depth = randomCrop(image, gt,depth)
        image,gt,depth = randomRotation(image, gt,depth)
        image          = colorEnhance(image)
        gt             = randomPeper(gt)
        image          = self.img_transform(image)
        gt             = self.gt_transform(gt)
        depth          = self.depths_transform(depth)

        ##################################################
        ## out1
        ##################################################
        image2,gt2,depth2 = cv_random_flip(image0,gt0,depth0)
        image2,gt2,depth2 = randomCrop(image2, gt2,depth2)
        image2,gt2,depth2 = randomRotation(image2, gt2,depth2)
        image2          = colorEnhance(image2)
        gt2             = randomPeper(gt2)
        image2          = self.img_transform(image2)
        gt2             = self.gt_transform(gt2)
        depth2          = self.depths_transform(depth2)

        
        return image, gt, depth, image2, gt2, depth2

    def filter_files(self):

        assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images)
        images = []
        gts = []
        depths=[]
        for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths):
            img = Image.open(img_path)
            gt = Image.open(gt_path)
            depth= Image.open(depth_path)
            if img.size == gt.size and gt.size==depth.size:
                images.append(img_path)
                gts.append(gt_path)
                depths.append(depth_path)
        self.images = images
        self.gts = gts
        self.depths=depths

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('L')

    def resize(self, img, gt, depth):
        assert img.size == gt.size and gt.size==depth.size
        w, h = img.size
        if h < self.trainsize or w < self.trainsize:
            h = max(h, self.trainsize)
            w = max(w, self.trainsize)
            return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST),depth.resize((w, h), Image.NEAREST)
        else:
            return img, gt, depth

    def __len__(self):
        return self.size

#dataloader for training
def get_loader(image_root, gt_root,depth_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=False):
    print(image_root, gt_root, depth_root)
    dataset = SalObjDataset(image_root, gt_root, depth_root,trainsize)
    print(dataset)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batchsize,
                                  shuffle=shuffle,
                                  num_workers=num_workers,
                                  pin_memory=pin_memory)
    return data_loader


#dataloader for training2
## 09-19-2020
def get_loader_var(image_root, gt_root,depth_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=False):

    dataset = SalObjDataset_var(image_root, gt_root, depth_root,trainsize)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batchsize,
                                  shuffle=shuffle,
                                  num_workers=num_workers,
                                  pin_memory=pin_memory)
    return data_loader


def get_loader_var_unlabel(image_root, gt_root,depth_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=False):

    dataset = SalObjDataset_var_unlabel(image_root, gt_root, depth_root,trainsize)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batchsize,
                                  shuffle=shuffle,
                                  num_workers=num_workers,
                                  pin_memory=pin_memory)
    return data_loader


#test dataset and loader
class test_dataset:
    def __init__(self, image_root, gt_root,depth_root, testsize):
        self.testsize = testsize
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
        self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
                       or f.endswith('.png')]
        self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
                    or f.endswith('.png')]
        self.images = sorted(self.images)
        self.gts = sorted(self.gts)
        self.depths=sorted(self.depths)
        self.transform = transforms.Compose([
            transforms.Resize((self.testsize, self.testsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        self.gt_transform = transforms.ToTensor()
        # self.gt_transform = transforms.Compose([
        #     transforms.Resize((self.trainsize, self.trainsize)),
        #     transforms.ToTensor()])
        self.depths_transform = transforms.Compose([transforms.Resize((self.testsize, self.testsize)),transforms.ToTensor()])
        self.size = len(self.images)
        self.index = 0

    def load_data(self):
        image = self.rgb_loader(self.images[self.index])
        image = self.transform(image).unsqueeze(0)
        gt = self.binary_loader(self.gts[self.index])
        depth=self.binary_loader(self.depths[self.index])
        depth=self.depths_transform(depth).unsqueeze(0)
        name = self.images[self.index].split('/')[-1]
        image_for_post=self.rgb_loader(self.images[self.index])
        image_for_post=image_for_post.resize(gt.size)
        if name.endswith('.jpg'):
            name = name.split('.jpg')[0] + '.png'
        self.index += 1
        self.index = self.index % self.size
        return image, gt,depth, name,np.array(image_for_post)

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('L')
    def __len__(self):
        return self.size


## Model Testing

### Test Model for same input dataset

In [None]:
import torch
import torch.nn.functional as F
import sys
import torch.nn as nn
import numpy as np
import os, argparse
import cv2

def test_arguments():
  parser = argparse.ArgumentParser()
  parser.add_argument('--testsize', type=int, default=352, help='testing size')
  parser.add_argument('--gpu_id',   type=str, default='0', help='select gpu id')
  parser.add_argument('--test_path',type=str, default='/content/tmp/testdataset_only_depth/',help='test dataset path')
  return parser.parse_args("")

opt = test_arguments()

dataset_path = opt.test_path

#set device for test
if opt.gpu_id=='0':
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    print('USE GPU 0')
 

#load the model
model = SPNet(32,50)
model.cuda()

model.load_state_dict(torch.load('/content/drive/MyDrive/Checkpoint/path_with_same_input/SPNet_l2_loss.pth'))
model.eval()

#test
test_datasets = ['NJU2K','NLPR', 'DES'] 

test_datasets = ['DES'] 


for dataset in test_datasets:
    save_path = '/content/drive/MyDrive/test_maps/part_with_same_input/' + dataset + '/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
        
    image_root  = str(dataset_path + dataset + '/RGB/')
    gt_root     = str(dataset_path + dataset + '/GT/')
    depth_root  = str(dataset_path + dataset + '/depth/')
    test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize)
    for i in range(test_loader.size):
        image, gt,depth, name, image_for_post = test_loader.load_data()
        gt      = np.asarray(gt, np.float32)
        gt     /= (gt.max() + 1e-8)
        image   = image.cuda()
        depth   = depth.cuda()
        pre_res = model(image,depth)
        res     = pre_res[2]     
        res     = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
        res     = res.sigmoid().data.cpu().numpy().squeeze()
        res     = (res - res.min()) / (res.max() - res.min() + 1e-8)
        
        print('save img to: ',save_path+name)
        cv2.imwrite(save_path+name,res*255)
    print('Test Done!')

USE GPU 0


Downloading: "https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth" to /root/.cache/torch/hub/checkpoints/res2net50_v1b_26w_4s-3cf99910.pth


  0%|          | 0.00/98.4M [00:00<?, ?B/s]



save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1500.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1501.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1502.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1510.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1511.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1512.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1520.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1521.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1522.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1530.png
save img to:  /content/drive/MyDrive/test_maps/part_with_same_input/DES/depth_1531.png
save img to:  /content/drive/MyDrive/test_m

### Test Model with different Input

In [None]:
import torch
import torch.nn.functional as F
import sys
import torch.nn as nn
import numpy as np
import os, argparse
import cv2

def test_arguments():
  parser = argparse.ArgumentParser()
  parser.add_argument('--testsize', type=int, default=352, help='testing size')
  parser.add_argument('--gpu_id',   type=str, default='0', help='select gpu id')
  parser.add_argument('--test_path',type=str, default='/content/tmp/testdataset/',help='test dataset path')
  return parser.parse_args("")

opt = test_arguments()

dataset_path = opt.test_path

#set device for test
if opt.gpu_id=='0':
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    print('USE GPU 0')
 

#load the model
model = SPNet(32,50)
model.cuda()

model.load_state_dict(torch.load('/content/drive/MyDrive/Checkpoint/path_with_different_input/SPNet_l2_loss.pth'))
model.eval()

#test
test_datasets = ['NJU2K','NLPR', 'DES'] 

test_datasets = ['DES'] 


for dataset in test_datasets:
    save_path = '/content/drive/MyDrive/test_maps/path_with_different_input/' + dataset + '/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
        
    image_root  = str(dataset_path + dataset + '/RGB/')
    gt_root     = str(dataset_path + dataset + '/GT/')
    depth_root  = str(dataset_path + dataset + '/depth/')
    test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize)
    for i in range(test_loader.size):
        image, gt,depth, name, image_for_post = test_loader.load_data()
        gt      = np.asarray(gt, np.float32)
        gt     /= (gt.max() + 1e-8)
        image   = image.cuda()
        depth   = depth.cuda()
        pre_res = model(image,depth)
        res     = pre_res[2]     
        res     = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
        res     = res.sigmoid().data.cpu().numpy().squeeze()
        res     = (res - res.min()) / (res.max() - res.min() + 1e-8)
        
        print('save img to: ',save_path+name)
        cv2.imwrite(save_path+name,res*255)
    print('Test Done!')

USE GPU 0




save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1500.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1501.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1502.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1510.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1511.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1512.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1520.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1521.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1522.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1530.png
save img to:  /content/drive/MyDrive/test_maps/path_with_different_input/DES/RGB_1531.png
save img t

#PNG to Nifti File

In [None]:
!pip install SimpleITK

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting SimpleITK
  Downloading SimpleITK-2.1.1.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (48.4 MB)
[K     |████████████████████████████████| 48.4 MB 36 kB/s 
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.1.1.2


In [None]:
import SimpleITK as sitk
import os

dirPth = '/content/drive/MyDrive/test_maps/SPNet_new/DES/'
file_names = [dirPth + f for f in os.listdir(dirPth) if f.endswith('.png')]

#file_names = glob.glob('*.png')
reader = sitk.ImageSeriesReader()
reader.SetFileNames(file_names)
vol = reader.Execute()
sitk.WriteImage(vol, dirPth + 'des_psnrMse_ssim_100.nii.gz')