In [148]:
import torch
import torch.nn as nn
import torchvision
import torchinfo

In [149]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_org = torchvision.models.vgg16()
model_org = model_org.to(device)

### torchvisionから学習済みモデルをロードする場合の注意点

```python
model = torchvision.models.vgg16()
```

と定義した際にランダムで決まるパラメータがあるので比較の際に2度呼び出さずに1度呼び出したモデルを使って比較する．



In [150]:

batch_size = 1
torchinfo.summary(
    model=model_org,
    input_size=(batch_size, 3, 256, 256),
    col_names=["input_size",
                "output_size"],
    row_settings=["var_names"],
    depth=3 
)

Layer (type (var_name))                  Input Shape               Output Shape
VGG                                      --                        --
├─Sequential (features)                  [1, 3, 256, 256]          [1, 512, 8, 8]
│    └─Conv2d (0)                        [1, 3, 256, 256]          [1, 64, 256, 256]
│    └─ReLU (1)                          [1, 64, 256, 256]         [1, 64, 256, 256]
│    └─Conv2d (2)                        [1, 64, 256, 256]         [1, 64, 256, 256]
│    └─ReLU (3)                          [1, 64, 256, 256]         [1, 64, 256, 256]
│    └─MaxPool2d (4)                     [1, 64, 256, 256]         [1, 64, 128, 128]
│    └─Conv2d (5)                        [1, 64, 128, 128]         [1, 128, 128, 128]
│    └─ReLU (6)                          [1, 128, 128, 128]        [1, 128, 128, 128]
│    └─Conv2d (7)                        [1, 128, 128, 128]        [1, 128, 128, 128]
│    └─ReLU (8)                          [1, 128, 128, 128]        [1, 128, 128, 128]

In [163]:
class Adapter(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(dim)
        self.conv1 = nn.Conv2d(dim, dim, 1)       
        self.bn2 = nn.BatchNorm2d(dim)


    def forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.conv1(out)
        
        out += residual
        out = self.bn2(out)        

        return out

In [164]:
class ReconstructNet(nn.Module):
    def __init__(self):
        super().__init__()
        model = model_org

        self.net_bottom_0 = nn.Sequential(
            model.features[:17]
        )

        self.adapter = Adapter(256)

        self.net_bottom_1 = nn.Sequential(
            model.features[17:],
            model.avgpool
        )

        self.net_top = nn.Sequential(
            model.classifier
        )


    def forward(self, x):
        x = self.net_bottom_0(x)
        x = self.adapter(x)
        x = self.net_bottom_1(x)
        x = torch.flatten(x,1)
        x = self.net_top(x)
        return x

In [165]:
model_new = ReconstructNet()
model_new = model_new.to(device)

In [166]:
torchinfo.summary(
    model=model_new,
    input_size=(batch_size, 3, 256, 256),
    col_names=["input_size",
                "output_size"],
    row_settings=["var_names"],
    depth=3 
)

Layer (type (var_name))                  Input Shape               Output Shape
ReconstructNet                           --                        --
├─Sequential (net_bottom_0)              [1, 3, 256, 256]          [1, 256, 32, 32]
│    └─Sequential (0)                    [1, 3, 256, 256]          [1, 256, 32, 32]
│    │    └─Conv2d (0)                   [1, 3, 256, 256]          [1, 64, 256, 256]
│    │    └─ReLU (1)                     [1, 64, 256, 256]         [1, 64, 256, 256]
│    │    └─Conv2d (2)                   [1, 64, 256, 256]         [1, 64, 256, 256]
│    │    └─ReLU (3)                     [1, 64, 256, 256]         [1, 64, 256, 256]
│    │    └─MaxPool2d (4)                [1, 64, 256, 256]         [1, 64, 128, 128]
│    │    └─Conv2d (5)                   [1, 64, 128, 128]         [1, 128, 128, 128]
│    │    └─ReLU (6)                     [1, 128, 128, 128]        [1, 128, 128, 128]
│    │    └─Conv2d (7)                   [1, 128, 128, 128]        [1, 128, 128, 128]

In [167]:
data = torch.randn(1, 3, 256, 256).to(device)
print(data.shape)
print(type(data))
# data1 = torch.full((1,3,256,256), 2).to(device)
# print(data1.shape)
# print(type(data1))

torch.Size([1, 3, 256, 256])
<class 'torch.Tensor'>


In [168]:
model_org.eval()
model_new.eval()
output_org = model_org(data).max(axis=1)
output_new = model_new(data).max(axis=1)
print(output_org)
print(output_new)
# output_org = model_org(data)
# output_new = model_new(data)
# # print(output_org.shape)
# print(output_new.shape)


torch.return_types.max(
values=tensor([0.2009], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([491], device='cuda:0'))
torch.return_types.max(
values=tensor([0.2667], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([458], device='cuda:0'))


In [147]:
# テンソルの出力のまま比較する場合
# flag = torch.allclose(output_org,output_new, atol=1e-8)
# print(flag)

True
