In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torch.optim as optim

# for image
import matplotlib.pyplot as plt
import numpy as np

class CNN_classification(nn.Module):
    def __init__(self, parameters):
        super(CNN_classification, self).__init__()
        self.dataset_name = parameters["dataset"]
        self.use_cuda = parameters["use_cuda"]

        self.kernel_size = 3
        self.pooling_size = 2

        self.input_size, self.output_size = self.return_input_output_size()

        self.cnn_output_size = int(4 ** 2)

        # 28 -> 26(conv) -> 24(conv) -> 12(pool) -> 10(conv) -> 8(conv) -> 4(pool)
        self.cnn_layer_2d = nn.Sequential(
            nn.Conv2d( in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=4, out_channels=16, kernel_size=3 ,stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=16, out_channels=4, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True),

            nn.Conv2d(in_channels=4, out_channels=1, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
        )

        self.fully_connected_layer = nn.Sequential(
            nn.Linear(self.cnn_output_size, 256), 
            nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128), 
            nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.Linear(128, self.output_size),
            nn.ReLU(inplace=True)
        )

        if self.use_cuda:
            self.cnn_layer = self.cnn_layer.cuda()
            self.fully_connected_layer = self.fully_connected_layer.cuda()
    
    def forward(self, input_value):
        x = input_value
        x = self.cnn_layer(x)
        x = x.view(x.size(0), -1)
        x = self.fully_connected_layer(x)
        return x

    def return_input_output_size(self):
        if self.dataset_name == "MNIST":
            input_size = 28
            output_size = 10
        else:
            raise ValueError("Check dataset_name")
        return input_size, output_size
    