In [1]:
import torch
import torch.nn as nn
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
class C3D(nn.Module):

    # input_shape = (1,3,5,512,512)
    def __init__(self, num_classes, pretrained=False):
        super(C3D, self).__init__()

        self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1,0,0))
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))

        self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1,0,0))
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2))

        self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 0, 0))
        self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 0, 0))
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2))

        self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 0, 0))
        self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 0, 0))
        self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2))
        
        self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 0, 0))
        self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 0, 0))
        # self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 3, 3), padding=(1, 1, 1))
        self.adaptpool = nn.AdaptiveAvgPool3d(output_size=(3,3,3)) # 512,3,3,3
        
        self.fc6 = nn.Linear(512*3*3*3, 4096)
        self.fc7 = nn.Linear(4096, 4096)
        self.fc8 = nn.Linear(4096, num_classes)

        self.dropout = nn.Dropout(p=0.5)

        self.relu = nn.ReLU()

        self.__init_weight()

        if pretrained:
            self.__load_pretrained_weights()

    def forward(self, x):

        x = self.relu(self.conv1(x))
        x = self.pool1(x)
        # print(x.shape)
        x = self.relu(self.conv2(x))
        x = self.pool2(x)
        # print(x.shape)
        x = self.relu(self.conv3a(x))
        x = self.relu(self.conv3b(x))
        x = self.pool3(x)
        # print(x.shape)
        x = self.relu(self.conv4a(x))
        x = self.relu(self.conv4b(x))
        x = self.pool4(x)
        # print(x.shape)
        x = self.relu(self.conv5a(x))
        x = self.relu(self.conv5b(x))
        x = self.adaptpool(x)
        # x = self.pool5(x)
        print('after adaptpool x shape : ', x.shape)
        
        x = x.view(-1, 512*x.shape[2]*x.shape[3]*x.shape[4])
        print('after flatten x shape', x.shape)
        x = self.relu(self.fc6(x))
        x = self.dropout(x)
        x = self.relu(self.fc7(x))
        x = self.dropout(x)

        logits = self.fc8(x)
        return logits
    
    def __load_pretrained_weights(self):
        print('Loading Pretrained weights')
        """Initialiaze network."""
        corresp_name = {
                        # Conv1
                        "features.0.weight": "conv1.weight",
                        "features.0.bias": "conv1.bias",
                        # Conv2
                        "features.3.weight": "conv2.weight",
                        "features.3.bias": "conv2.bias",
                        # Conv3a
                        "features.6.weight": "conv3a.weight",
                        "features.6.bias": "conv3a.bias",
                        # Conv3b
                        "features.8.weight": "conv3b.weight",
                        "features.8.bias": "conv3b.bias",
                        # Conv4a
                        "features.11.weight": "conv4a.weight",
                        "features.11.bias": "conv4a.bias",
                        # Conv4b
                        "features.13.weight": "conv4b.weight",
                        "features.13.bias": "conv4b.bias",
                        # Conv5a
                        "features.16.weight": "conv5a.weight",
                        "features.16.bias": "conv5a.bias",
                         # Conv5b
                        "features.18.weight": "conv5b.weight",
                        "features.18.bias": "conv5b.bias",
                        # # # fc6
                        # # "classifier.0.weight": "fc6.weight",
                        # # "classifier.0.bias": "fc6.bias",
                        # # fc7 
                        # "classifier.3.weight": "fc7.weight",
                        # "classifier.3.bias": "fc7.bias",
                        }

        p_dict = torch.load('./c3d-pretrained.pth')
        s_dict = self.state_dict()
        for name in p_dict:
            if name not in corresp_name:
                continue
            s_dict[corresp_name[name]] = p_dict[name]
        self.load_state_dict(s_dict)

    def __init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                # m.weight.data.normal_(0, math.sqrt(2. / n))
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [12]:
inputs = torch.rand(1, 3, 5, 512, 512)
net = C3D(num_classes=6, pretrained=True)

outputs = net.forward(inputs)
# print(outputs.size())

Loading Pretrained weights
after adaptpool x shape :  torch.Size([1, 512, 3, 3, 3])
after flatten x shape torch.Size([1, 13824])


In [13]:
class C3D_3(nn.Module):

    # input_shape = (1,3,5,512,512)
    def __init__(self, num_classes, pretrained=False):
        super(C3D_3, self).__init__()

        self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1,0,0))
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))

        self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1,0,0))
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2))

        self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 0, 0))
        self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 0, 0))
        self.adaptpool = nn.AdaptiveAvgPool3d(output_size=(3,3,3)) # 256,3,3,3
        
        self.fc6 = nn.Linear(256*3*3*3, 4096)
        self.fc7 = nn.Linear(4096, 1024)
        self.fc8 = nn.Linear(1024, num_classes)

        self.dropout = nn.Dropout(p=0.5)

        self.relu = nn.ReLU()

        self.__init_weight()

        if pretrained:
            self.__load_pretrained_weights()

    def forward(self, x):

        x = self.relu(self.conv1(x))
        x = self.pool1(x)
        # print(x.shape)
        x = self.relu(self.conv2(x))
        x = self.pool2(x)
        # print(x.shape)
        x = self.relu(self.conv3a(x))
        x = self.relu(self.conv3b(x))
        x = self.adaptpool(x)
        print('after adaptpool x shape : ', x.shape)
        
        x = x.view(-1, 512*x.shape[2]*x.shape[3]*x.shape[4])
        print('after flatten x shape', x.shape)
        x = self.relu(self.fc6(x))
        x = self.dropout(x)
        x = self.relu(self.fc7(x))
        x = self.dropout(x)

        logits = self.fc8(x)
        return logits
    
    def __load_pretrained_weights(self):
        print('Loading Pretrained weights')
        """Initialiaze network."""
        corresp_name = {
                        # Conv1
                        "features.0.weight": "conv1.weight",
                        "features.0.bias": "conv1.bias",
                        # Conv2
                        "features.3.weight": "conv2.weight",
                        "features.3.bias": "conv2.bias",
                        # Conv3a
                        "features.6.weight": "conv3a.weight",
                        "features.6.bias": "conv3a.bias",
                        # Conv3b
                        "features.8.weight": "conv3b.weight",
                        "features.8.bias": "conv3b.bias",
                        # Conv4a
                        "features.11.weight": "conv4a.weight",
                        "features.11.bias": "conv4a.bias",
                        # Conv4b
                        "features.13.weight": "conv4b.weight",
                        "features.13.bias": "conv4b.bias",
                        # Conv5a
                        "features.16.weight": "conv5a.weight",
                        "features.16.bias": "conv5a.bias",
                         # Conv5b
                        "features.18.weight": "conv5b.weight",
                        "features.18.bias": "conv5b.bias",
                        # # # fc6
                        # # "classifier.0.weight": "fc6.weight",
                        # # "classifier.0.bias": "fc6.bias",
                        # # fc7 
                        # "classifier.3.weight": "fc7.weight",
                        # "classifier.3.bias": "fc7.bias",
                        }

        p_dict = torch.load('./c3d-pretrained.pth')
        s_dict = self.state_dict()
        for name in p_dict:
            if name not in corresp_name:
                continue
            s_dict[corresp_name[name]] = p_dict[name]
        self.load_state_dict(s_dict)

    def __init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                # m.weight.data.normal_(0, math.sqrt(2. / n))
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()