In [1]:
import torch
import torch.nn as nn
from main import run

https://arxiv.org/pdf/1507.06821.pdf
![image.png](rgbdfusion.png)

In [2]:
# Flatten Layer: To flatten a tensor into two-dimension
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, feat):
        return feat.view(feat.size(0), -1)

# example rgbNet: using only rgb images as input to train a model
class rgbNet(nn.Module):
    def __init__(self):
        super(rgbNet, self).__init__()
        self.net = nn.Sequential(
                    nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    Flatten(),
                )

        self.mix_net = nn.Sequential(
                    nn.Linear(6*6*64, 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, 15),
                ) 

    def forward(self, img, dep): 
        feat = self.net(img)
        score = self.mix_net(feat)
        return score

# example depthNet: using only depth image as input to train a model
class depthNet(nn.Module):
    def __init__(self):
        super(depthNet, self).__init__()
        self.net = nn.Sequential(
                    nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    Flatten(),
                )

        self.mix_net = nn.Sequential(
                    nn.Linear(6*6*64, 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, 15),
                ) 

    def forward(self, img, dep): 
        feat = self.net(dep)
        score = self.mix_net(feat)
        return score

# *****************************IMPORTANT******************************
# YOU NEED TO FILL THE CLASS TO FINISH A RGB-D FUSION NETWORK
# NOTICE THAT YOU ONLY NEED TO DEFINE THE NETWORK, AND WE HAVE ALREADY BUILT THE OTHER PARTS(LIKE DATA LOADER, \
# TRAINING CODE ...)
# AFTER FINISHING THE NETWORK, JUST EXCUTE run(rgbdNet) WILL START TO TRAIN AND YOU CAN OBSERVE THE TRAINING PROCESS AND THE\
# ACCURACY ON VALIDATION SET

# AND ALSO YOU CAN RUN run(rgbNet) AND run(depthNet) TO TRAIN ONLY RGB OR DEPTH MODAL. YOU CAN OBSERVE IF THE FUSION \
# GIVE AN ACCURACY BOOST. 

# IF YOU HAVE ANY TROUBLE, YOU CAN REFER TO THE PREVIOUS rgbNet and depthNet
class rgbdNet(nn.Module):
    def __init__(self):
        super(rgbdNet, self).__init__()
        self.rgb_net = nn.Sequential(
                    nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
                    nn.Dropout(p=0.1, inplace=False),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
         
                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
            
                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),

            
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
#                     nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
#                     nn.ReLU(inplace=True),
#                     nn.MaxPool2d(kernel_size=2, stride=2),
                    Flatten(),
                )



        self.dep_net = nn.Sequential(
                #********to fill the network*********#
                    nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
                    nn.Dropout(p=0.1, inplace=False),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
            
                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
            
                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
            
#                     nn.Dropout(p=0.15, inplace=False),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    Flatten(),
            
                #************************************#
                )

        self.mix_net = nn.Sequential(
                #********to fill the network*********#
                    nn.Linear(6*6*64*2, 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, 15),
                #************************************#
                ) 

    def forward(self, img, dep): 
                #********to fill the forward process*********#
        rgb_feat = self.rgb_net(img)
        dep_feat = self.dep_net(dep)
        feat = torch.cat((rgb_feat, dep_feat), dim=-1)
        score = self.mix_net(feat)
            
                #********************************************#
        return score
# ********************************************************************

In [3]:
class Args:
    def __init__(self):
        self.stage='train'
        self.root='../rgbd/'
        self.lr = 0.01
        self.batch_size=64
        self.weight_decay=5e-4
        self.max_epoch=30
        self.exp='./exps/'
        self.resume_path=''

In [4]:
run(net=rgbdNet, args=Args())

Epoch:0	 loss:2.6701, lr:0.01
Accuracy:0.2073

Epoch:1	 loss:1.5422, lr:0.01
Accuracy:0.6238

Epoch:2	 loss:0.4095, lr:0.01
Accuracy:0.7054

Epoch:3	 loss:0.1475, lr:0.01
Accuracy:0.7607

Epoch:4	 loss:0.0862, lr:0.01
Accuracy:0.7784

Epoch:5	 loss:0.0752, lr:0.01
Accuracy:0.6961

Epoch:6	 loss:0.0298, lr:0.01
Accuracy:0.7574

Epoch:7	 loss:0.0174, lr:0.01
Accuracy:0.7548

Epoch:8	 loss:0.0106, lr:0.01
Accuracy:0.7740

Epoch:9	 loss:0.0153, lr:0.01
Accuracy:0.7773

Epoch:10	 loss:0.0015, lr:0.001
Accuracy:0.7890

Epoch:11	 loss:0.0008, lr:0.001
Accuracy:0.7864

Epoch:12	 loss:0.0007, lr:0.001
Accuracy:0.7873

Epoch:13	 loss:0.0006, lr:0.001
Accuracy:0.7875

Epoch:14	 loss:0.0005, lr:0.001
Accuracy:0.7868

Epoch:15	 loss:0.0005, lr:0.001
Accuracy:0.7869

Epoch:16	 loss:0.0005, lr:0.001
Accuracy:0.7875

Epoch:17	 loss:0.0005, lr:0.001
Accuracy:0.7871

Epoch:18	 loss:0.0004, lr:0.001
Accuracy:0.7871

Epoch:19	 loss:0.0004, lr:0.001
Accuracy:0.7879

Epoch:20	 loss:0.0004, lr:0.0001
Accurac

In [5]:
run(rgbNet, args=Args())

Epoch:0	 loss:2.5960, lr:0.01
Accuracy:0.2591

Epoch:1	 loss:1.0183, lr:0.01
Accuracy:0.6303

Epoch:2	 loss:0.2584, lr:0.01
Accuracy:0.5928

Epoch:3	 loss:0.1704, lr:0.01
Accuracy:0.6684

Epoch:4	 loss:0.0606, lr:0.01
Accuracy:0.6468

Epoch:5	 loss:0.0292, lr:0.01
Accuracy:0.6597

Epoch:6	 loss:0.0148, lr:0.01
Accuracy:0.6730

Epoch:7	 loss:0.0042, lr:0.01
Accuracy:0.7402

Epoch:8	 loss:0.0008, lr:0.01
Accuracy:0.7442

Epoch:9	 loss:0.0003, lr:0.01
Accuracy:0.7416

Epoch:10	 loss:0.0002, lr:0.001
Accuracy:0.7420

Epoch:11	 loss:0.0002, lr:0.001
Accuracy:0.7431

Epoch:12	 loss:0.0002, lr:0.001
Accuracy:0.7433

Epoch:13	 loss:0.0002, lr:0.001
Accuracy:0.7433

Epoch:14	 loss:0.0002, lr:0.001
Accuracy:0.7433

Epoch:15	 loss:0.0002, lr:0.001
Accuracy:0.7429

Epoch:16	 loss:0.0002, lr:0.001
Accuracy:0.7433

Epoch:17	 loss:0.0002, lr:0.001
Accuracy:0.7429

Epoch:18	 loss:0.0002, lr:0.001
Accuracy:0.7433

Epoch:19	 loss:0.0002, lr:0.001
Accuracy:0.7437

Epoch:20	 loss:0.0002, lr:0.0001
Accurac

In [None]:
run(depthNet, args=Args())