In [None]:
import torch
import torch.nn as nn
import sys
sys.path.append('..')
from data_utils import AsocaDataset
import numpy as np

In [None]:
class InvertedResidual(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, expand_ratio=1):
        super().__init__()
        
        self.kernel_size = kernel_size
        self.stride = stride
        
        hidden_dim = in_dim * expand_ratio
        
        self.conv = nn.Sequential(
            nn.Conv3d(in_dim, hidden_dim, kernel_size=1),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True),
            
            nn.Conv3d(hidden_dim, hidden_dim, padding=1, kernel_size=kernel_size, groups=hidden_dim, stride=stride),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True),
            
            nn.Conv3d(hidden_dim, out_dim, kernel_size=1, bias=False),
            nn.BatchNorm3d(out_dim),
        )
    
    def trim_input(self, x):
        _,_,depth, height, width = x.shape
#         d_out = np.floor(( depth - self.kernel_size) / self.stride)+1
#         h_out = np.floor(( height - self.kernel_size) / self.stride)+1
#         w_out = np.floor(( width - self.kernel_size) / self.stride)+1
#         cr_d = int((depth - d_out) / 2)
#         cr_h = int((height - h_out) / 2)
#         cr_w = int((width - w_out) / 2)
        cr_d = depth // 4
        cr_h = height // 4
        cr_w = width // 4
        print(depth, height, width, cr_d, cr_h, cr_h)
        return x[...,cr_d:-cr_d, cr_h:-cr_h, cr_w:-cr_w]
        
    def forward(self, x):
        return self.conv(x)

In [None]:
class MobileNetV2(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        
        layer_params = [
            # t, c, n, s
            [1, 16, 1, 2],
            [6, 24, 1, 2],
            [6, 32, 1, 2],
        ]
        
        self.layers = []
        in_channels = in_dim
        for expand_ratio, out_channels, n_blocks, stride in layer_params:
            self.layers.append(InvertedResidual(
                in_channels, 
                out_channels, 
                expand_ratio=expand_ratio,
                stride=stride))
            in_channels = out_channels
        
        self.model = nn.Sequential(*self.layers)
        
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            print(x.shape)
        return x

In [None]:
ds = AsocaDataset('../dataset/asoca-64.hdf5')

In [None]:
x, t = ds[:5][0].unsqueeze(1), ds[:5][1]

In [None]:
model = MobileNetV2(1, 1)

In [None]:
model(x).shape