In [95]:
import torch
import torch.nn as nn
from torchvision import models

In [96]:
class SA(nn.Module):
    def __init__(self, i_ch):
        super().__init__()
        self.i_ch = i_ch
        self.o_ch = 1000
        self.q = nn.Conv2d(i_ch, self.o_ch, 1)
        self.k = nn.Conv2d(i_ch, self.o_ch, 1)
        self.v = nn.Conv2d(i_ch, i_ch, 1)

    def forward(self, input):
        b_size, h = input.shape[0], input.shape[2]

        #input is feature map
        query = self.q(input.clone())
        key = self.k(input.clone())
        value = self.v(input.clone())

        query = query.view(b_size, self.o_ch, -1).permute(0,2,1)
        key = key.view(b_size, self.o_ch, -1)
        s = torch.bmm(query, key)

        alpha = torch.sigmoid(s)
        value = value.view(b_size, self.i_ch, -1)
        o = torch.bmm(value, alpha)
        o = o.view(b_size, self.i_ch, h, -1)
        return o

class CA(nn.Module):
    def __init__(self, i_ch):
        super().__init__()
        self.i_ch = i_ch
    def forward(self, input):
        b_size, h = input.shape[0], input.shape[2]
        query = input.clone().view(b_size, self.i_ch, -1) #b,c,h*w
        key = input.clone().view(b_size, self.i_ch, -1)
        value = input.clone().view(b_size, self.i_ch, -1)
        gamma = torch.sigmoid(torch.bmm(query, key.permute(0,2,1))) #b,c,c
        r = torch.bmm(value.permute(0,2,1), gamma) #b,h*w,c
        r = r.permute(0,2,1) #b,c,h*w
        r = r.view(b_size, self.i_ch, h, -1)
        return r

In [135]:
class ResNet50(nn.Module):
    def __init__(self, pretrained=True, num_classes=3, tap=False):
        super().__init__()
        self.num_classes = num_classes #torose, vascular, ulcer
        model = models.resnet50(pretrained=pretrained)
        layers = list(model.children())[:-2] #特徴マップまで
        self.layers = nn.Sequential(*layers)
        self.sa = SA(2048)
        self.ca = CA(2048)
        self.s = torch.tensor([0.], requires_grad=True)
        self.c = torch.tensor([0.], requires_grad=True)
        self._parameters["s"] = self.s
        self._parameters["c"] = self.c
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(4096, self.num_classes, bias=False)

    def forward(self, x):
        x = self.extractor(x)
        x = self.gap(x)
        x = x.view(-1, 4096)
        x = self.fc(x)
        return x

    def fc_w(self):
        return self.fc.weight
    def extractor(self, x):
        feature = self.layers(x)
        o = self.sa(feature)
        r = self.ca(feature)
        s_feature = feature + self.s*o
        c_feature = feature + self.c*o
        cat_feature = torch.cat([s_feature, c_feature], dim=1)
        return cat_feature

In [136]:
model = ResNet50()

In [137]:
x = torch.randn((5,3,512,512))

In [138]:
model(x)

tensor([[-0.2098, -0.4528,  0.0946],
        [-0.0941, -0.5150, -0.0354],
        [-0.1479, -0.4829,  0.0367],
        [-0.1245, -0.4903, -0.0226],
        [-0.1671, -0.6127, -0.0078]], grad_fn=<MmBackward>)

In [139]:
dir(model)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_name',
 '_load_from_state_dict',
 '_load_state_dict_pre_hooks',
 '_modules',
 '_named_members',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_state_dict_hooks',
 '_version',
 'add_module',
 'apply',
 'bfloat16',
 'buffers',
 'c',
 'c',
 'ca',
 'children',
 'cpu',
 'cuda',
 'double',
 'dump_patches',
 'eval',
 'extra_repr',
 'extractor',
 'fc',
 'fc_w',
 'float',
 'forward

In [140]:
for n,p in model.named_parameters():
    print(n)

s
c
layers.0.weight
layers.1.weight
layers.1.bias
layers.4.0.conv1.weight
layers.4.0.bn1.weight
layers.4.0.bn1.bias
layers.4.0.conv2.weight
layers.4.0.bn2.weight
layers.4.0.bn2.bias
layers.4.0.conv3.weight
layers.4.0.bn3.weight
layers.4.0.bn3.bias
layers.4.0.downsample.0.weight
layers.4.0.downsample.1.weight
layers.4.0.downsample.1.bias
layers.4.1.conv1.weight
layers.4.1.bn1.weight
layers.4.1.bn1.bias
layers.4.1.conv2.weight
layers.4.1.bn2.weight
layers.4.1.bn2.bias
layers.4.1.conv3.weight
layers.4.1.bn3.weight
layers.4.1.bn3.bias
layers.4.2.conv1.weight
layers.4.2.bn1.weight
layers.4.2.bn1.bias
layers.4.2.conv2.weight
layers.4.2.bn2.weight
layers.4.2.bn2.bias
layers.4.2.conv3.weight
layers.4.2.bn3.weight
layers.4.2.bn3.bias
layers.5.0.conv1.weight
layers.5.0.bn1.weight
layers.5.0.bn1.bias
layers.5.0.conv2.weight
layers.5.0.bn2.weight
layers.5.0.bn2.bias
layers.5.0.conv3.weight
layers.5.0.bn3.weight
layers.5.0.bn3.bias
layers.5.0.downsample.0.weight
layers.5.0.downsample.1.weight
layer

In [141]:
x = torch.tensor([1,2]).cuda()

In [142]:
model.cuda()

ResNet50(
  (layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 

In [143]:
model.s.is_cuda

True

In [133]:
model._modules

OrderedDict([('layers',
              Sequential(
                (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
                (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): ReLU(inplace=True)
                (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
                (4): Sequential(
                  (0): Bottleneck(
                    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
                    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                    (bn3): BatchNorm2d(256,