In [1]:
import torch
import torch.nn as nn
from main import run

In [2]:
# Flatten Layer: To flatten a tensor into two-dimension
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, feat):
        return feat.view(feat.size(0), -1)

# example rgbNet: using only rgb images as input to train a model
class rgbNet(nn.Module):
    def __init__(self):
        super(rgbNet, self).__init__()
        self.net = nn.Sequential(
                    nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    Flatten(),
                )

        self.mix_net = nn.Sequential(
                    nn.Linear(6*6*64, 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, 15),
                ) 

    def forward(self, img, dep): 
        feat = self.net(img)
        score = self.mix_net(feat)
        return score

# example depthNet: using only depth image as input to train a model
class depthNet(nn.Module):
    def __init__(self):
        super(depthNet, self).__init__()
        self.net = nn.Sequential(
                    nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    Flatten(),
                )

        self.mix_net = nn.Sequential(
                    nn.Linear(6*6*64, 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, 15),
                ) 

    def forward(self, img, dep): 
        feat = self.net(dep)
        score = self.mix_net(feat)
        return score

# *****************************IMPORTANT******************************
# YOU NEED TO FILL THE CLASS TO FINISH A RGB-D FUSION NETWORK
# NOTICE THAT YOU ONLY NEED TO DEFINE THE NETWORK, AND WE HAVE ALREADY BUILT THE OTHER PARTS(LIKE DATA LOADER, \
# TRAINING CODE ...)
# AFTER FINISHING THE NETWORK, JUST EXCUTE run(rgbdNet) WILL START TO TRAIN AND YOU CAN OBSERVE THE TRAINING PROCESS AND THE\
# ACCURACY ON VALIDATION SET

# AND ALSO YOU CAN RUN run(rgbNet) AND run(depthNet) TO TRAIN ONLY RGB OR DEPTH MODAL. YOU CAN OBSERVE IF THE FUSION \
# GIVE AN ACCURACY BOOST. 

# IF YOU HAVE ANY TROUBLE, YOU CAN REFER TO THE PREVIOUS rgbNet and depthNet
class rgbdNet(nn.Module):
    def __init__(self):
        super(rgbdNet, self).__init__()
        self.rgb_net = nn.Sequential(
                    nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    Flatten(),
                )

        self.dep_net = nn.Sequential(
                    nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=2, stride=2),
                    Flatten(),
                )

        self.mix_net = nn.Sequential(
                    nn.Linear(2*6*6*64, 512),
                    nn.ReLU(inplace=True),
                    nn.Linear(512, 15),
                ) 

    def forward(self, img, dep):
        rgb_feat = self.rgb_net(img)
        dep_feat = self.dep_net(dep)
        feat = torch.cat([rgb_feat, dep_feat],1)
        score = self.mix_net(feat)
        return score
# ********************************************************************

In [3]:
run(rgbdNet)

Epoch:0	 loss:2.6892, lr:0.01
Accuracy:0.1178

Epoch:1	 loss:2.3772, lr:0.01
Accuracy:0.4185

Epoch:2	 loss:1.1307, lr:0.01
Accuracy:0.7170

Epoch:3	 loss:0.3821, lr:0.01
Accuracy:0.7699

Epoch:4	 loss:0.2149, lr:0.01
Accuracy:0.8151

Epoch:5	 loss:0.1062, lr:0.01
Accuracy:0.8509

Epoch:6	 loss:0.0928, lr:0.01
Accuracy:0.8025

Epoch:7	 loss:0.0667, lr:0.01
Accuracy:0.8485

Epoch:8	 loss:0.0462, lr:0.01
Accuracy:0.8818

Epoch:9	 loss:0.0369, lr:0.01
Accuracy:0.8404

Epoch:10	 loss:0.0303, lr:0.001
Accuracy:0.8576

Epoch:11	 loss:0.0092, lr:0.001
Accuracy:0.8530

Epoch:12	 loss:0.0068, lr:0.001
Accuracy:0.8544

Epoch:13	 loss:0.0060, lr:0.001
Accuracy:0.8585

Epoch:14	 loss:0.0055, lr:0.001
Accuracy:0.8583

Epoch:15	 loss:0.0049, lr:0.001
Accuracy:0.8600

Epoch:16	 loss:0.0046, lr:0.001
Accuracy:0.8569

Epoch:17	 loss:0.0043, lr:0.001
Accuracy:0.8583

Epoch:18	 loss:0.0040, lr:0.001
Accuracy:0.8559

Epoch:19	 loss:0.0039, lr:0.001
Accuracy:0.8598



In [7]:
run(rgbNet)

Epoch:0	 loss:2.6965, lr:0.01
Accuracy:0.1178

Epoch:1	 loss:2.6505, lr:0.01
Accuracy:0.1178

Epoch:2	 loss:2.1525, lr:0.01
Accuracy:0.4215

Epoch:3	 loss:0.9394, lr:0.01
Accuracy:0.6416

Epoch:4	 loss:0.3614, lr:0.01
Accuracy:0.7283

Epoch:5	 loss:0.2375, lr:0.01
Accuracy:0.7477

Epoch:6	 loss:0.1373, lr:0.01
Accuracy:0.8088

Epoch:7	 loss:0.0791, lr:0.01
Accuracy:0.7470

Epoch:8	 loss:0.0452, lr:0.01
Accuracy:0.8313

Epoch:9	 loss:0.0365, lr:0.01
Accuracy:0.8199

Epoch:10	 loss:0.0105, lr:0.001
Accuracy:0.8278

Epoch:11	 loss:0.0051, lr:0.001
Accuracy:0.8258

Epoch:12	 loss:0.0043, lr:0.001
Accuracy:0.8206

Epoch:13	 loss:0.0035, lr:0.001
Accuracy:0.8250

Epoch:14	 loss:0.0029, lr:0.001
Accuracy:0.8243

Epoch:15	 loss:0.0027, lr:0.001
Accuracy:0.8215

Epoch:16	 loss:0.0024, lr:0.001
Accuracy:0.8223

Epoch:17	 loss:0.0022, lr:0.001
Accuracy:0.8236

Epoch:18	 loss:0.0021, lr:0.001
Accuracy:0.8208

Epoch:19	 loss:0.0019, lr:0.001
Accuracy:0.8208



In [5]:
torch.cuda.is_available()

True

In [6]:
run(depthNet)

Epoch:0	 loss:2.6797, lr:0.01
Accuracy:0.1178

Epoch:1	 loss:2.4486, lr:0.01
Accuracy:0.3418

Epoch:2	 loss:1.4567, lr:0.01
Accuracy:0.5685

Epoch:3	 loss:0.5814, lr:0.01
Accuracy:0.7322

Epoch:4	 loss:0.3293, lr:0.01
Accuracy:0.8082

Epoch:5	 loss:0.2122, lr:0.01
Accuracy:0.8106

Epoch:6	 loss:0.1611, lr:0.01
Accuracy:0.8184

Epoch:7	 loss:0.1252, lr:0.01
Accuracy:0.7866

Epoch:8	 loss:0.1043, lr:0.01
Accuracy:0.8495

Epoch:9	 loss:0.0636, lr:0.01
Accuracy:0.8496

Epoch:10	 loss:0.0295, lr:0.001
Accuracy:0.8685

Epoch:11	 loss:0.0239, lr:0.001
Accuracy:0.8620

Epoch:12	 loss:0.0222, lr:0.001
Accuracy:0.8617

Epoch:13	 loss:0.0214, lr:0.001
Accuracy:0.8624

Epoch:14	 loss:0.0208, lr:0.001
Accuracy:0.8617

Epoch:15	 loss:0.0201, lr:0.001
Accuracy:0.8618

Epoch:16	 loss:0.0185, lr:0.001
Accuracy:0.8609

Epoch:17	 loss:0.0191, lr:0.001
Accuracy:0.8602

Epoch:18	 loss:0.0189, lr:0.001
Accuracy:0.8606

Epoch:19	 loss:0.0181, lr:0.001
Accuracy:0.8600

