In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

Device : cuda


In [8]:
# Feature Encoder : Instance Normalization
# Context Encoder : Batch Normalization


class FCEncoder(nn.Module):
    def __init__(self,encoder_type):
        super().__init__()
        self.normalization = nn.InstanceNorm2d if encoder_type == "Feature" else nn.BatchNorm2d
        # resize ?

        # Convolution layer 7X7 (64)
        self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3)

        # Res Net Unit (64)
        # Res Net Unit (64)
        self.Res64 = nn.Sequential(
            self.ResUnit_64(),
            self.ResUnit_64()
        )

        # Res Net Unit (128)
        # Res Net Unit (128)
        self.Res128 = nn.Sequential(
            self.ResUnit_128(),
            self.ResUnit_128()
        )

        # Res Net Unit (192)
        # Res Net Unit (192)
        self.Res192 = nn.Sequential(
            self.ResUnit_192(),
            self.ResUnit_192()
        )

        # Convolution layer 3X3 (256)
        self.conv2 = nn.Conv2d(192, 256, 3, stride=2, padding=1)

        self.encoder = nn.Sequential(
            self.conv1,
            self.normalization(64),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),
            self.Res64,
            self.Res128,
            self.Res192
        )

    def forward(self,x):
        logits = self.encoder(x)
        return logits
    
    def ResUnit_64(self):
        return nn.Sequential(
            nn.Conv2d(64, 64, 1, stride=1, padding=1),
            # self.normalization(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1, padding=1),
        #     self.normalization(64),
        #     nn.ReLU(),
        #     nn.Conv2d(64, 128, 1, stride=1, padding=1),
        #     self.normalization(128),
        #     nn.ReLU()
        )

    def ResUnit_128(self):
        return nn.Sequential(
            nn.Conv2d(128, 128, 1, stride=1, padding=1),
            self.normalization(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            self.normalization(128),
            nn.ReLU(),
            nn.Conv2d(128, 192, 1, stride=1, padding=1),
            self.normalization(192),
            nn.ReLU()
        )

    def ResUnit_192(self):
        return nn.Sequential(
            nn.Conv2d(192, 192, 1, stride=1, padding=1),
            self.normalization(192),
            nn.ReLU(),
            nn.Conv2d(192, 192, 3, stride=1, padding=1),
            self.normalization(192),
            nn.ReLU(),
            nn.Conv2d(192, 256, 1, stride=1, padding=1),
            self.normalization(256),
            nn.ReLU()
        )


In [9]:
model = FCEncoder("Feature").to(device)
print(model)

FCEncoder(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (Res64): Sequential(
    (0): Sequential(
      (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU()
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (5): ReLU()
      (6): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
      (7): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (8): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU()
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

: 