In [3]:

from segmentation_models_pytorch.encoders._base import EncoderMixin
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch
import torchxrayvision as xrv

from models.adaptor import Adaptor

In [111]:
class _ResNetAE(nn.Module):
    def __init__(self, adaptor, downblock, upblock, num_layers, n_classes):
        super(_ResNetAE, self).__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_downlayer(downblock, 64, num_layers[0])
        self.layer2 = self._make_downlayer(downblock, 128, num_layers[1], stride=2)
        self.layer3 = self._make_downlayer(downblock, 256, num_layers[2], stride=2)
        self.layer4 = self._make_downlayer(downblock, 128, num_layers[3], stride=6)

        self.adaptor = adaptor
        
        self.uplayer0 = nn.ConvTranspose2d(768, 512, kernel_size=3, stride=1, bias=False)
        
        self.uplayer1 = self._make_up_block(upblock, 128, num_layers[3], stride=6)
        self.uplayer2 = self._make_up_block(upblock, 64, num_layers[2], stride=2)
        self.uplayer3 = self._make_up_block(upblock, 32, num_layers[1], stride=2)
        self.uplayer4 = self._make_up_block(upblock, 16, num_layers[0], stride=2)

        upsample = nn.Sequential(
            nn.ConvTranspose2d(self.in_channels, 64, kernel_size=1, stride=2, bias=False, output_padding=1),
            nn.BatchNorm2d(64),
        )
        self.uplayer_top = DeconvBottleneck(self.in_channels, 64, 1, 2, upsample)

        self.conv1_1 = nn.ConvTranspose2d(64, n_classes, kernel_size=1, stride=1, bias=False)
        
    def _make_downlayer(self, block, init_channels, num_layer, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != init_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, init_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(init_channels * block.expansion),
            )
        layers = []
        layers.append(block(self.in_channels, init_channels, stride, downsample))
        self.in_channels = init_channels * block.expansion
        for i in range(1, num_layer):
            layers.append(block(self.in_channels, init_channels))

        return nn.Sequential(*layers)

    def _make_up_block(self, block, init_channels, num_layer, stride=1):
        # upsample = None
        # # expansion = block.expansion
        # if stride != 1 or self.in_channels != init_channels * 2:
        #     upsample = nn.Sequential(
        #         nn.ConvTranspose2d(self.in_channels, init_channels * 2, kernel_size=1, stride=stride, bias=False, output_padding=1),
        #         nn.BatchNorm2d(init_channels * 2),
        #     )
        # layers = []
        # for i in range(1, num_layer):
        #     layers.append(block(self.in_channels, init_channels, 4))

        # layers.append(block(self.in_channels, init_channels, 2, stride, upsample))
        # self.in_channels = init_channels * 2
        # return nn.Sequential(*layers)
        up_block = _UpBlock(self.in_channels, block, init_channels, num_layer, stride)
        self.in_channels = up_block.in_channels
        return up_block
    
    def encode(self, x, check_resolution=True):
        if check_resolution and hasattr(self, 'weights_metadata'):
            resolution = self.weights_metadata['resolution']
            if (x.shape[2] != resolution) | (x.shape[3] != resolution):
                raise ValueError("Input size ({}x{}) is not the native resolution ({}x{}) for this model. Set check_resolution=False on the encode function to override this error.".format(x.shape[2], x.shape[3], resolution, resolution))

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = feat1 = self.layer1(x)
        x = feat2 = self.layer2(x)
        x = feat3 = self.layer3(x)
        x = feat4 = self.layer4(x)
        return x, [feat1, feat2, feat3, feat4]

    def fusion(self, x):
        x = self.adaptor(x)
        return x
    
    def upsample(self, x):
        x = self.uplayer0(x)
        return x
    
    def decode(self, x, skips, image_size=[1, 1, 224, 224]):
        x = self.uplayer1(x, skips[3])
        x = self.uplayer2(x, skips[2])
        x = self.uplayer3(x, skips[1])
        x = self.uplayer4(x, skips[0])
        x = self.uplayer_top(x, skips)

        x = self.conv1_1(x, output_size=image_size)
        return x

    def get_global_features(self, x):
        return torch.flatten(x, start_dim=2).permute((0, 2, 1)).mean(1) 
    
    def forward(self, x):
        ret = {}
        ret["pre_adaptor"], ret['skips'] = self.encode(x)
        ret['global_features'] = z = self.get_global_features(ret["pre_adaptor"])
        ret["multimodal_features"] = z = self.fusion(z)
        ret["out"] = self.decode(self.upsample(z.unsqueeze(2).unsqueeze(3)), ret['skips'], x.size())

        return ret

In [None]:
class _ResNetAEEncoder(nn.Module):
    def __init__(self, adaptor, downblock, upblock, num_layers, n_classes):
        super(_ResNetAE, self).__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_downlayer(downblock, 64, num_layers[0])
        self.layer2 = self._make_downlayer(downblock, 128, num_layers[1], stride=2)
        self.layer3 = self._make_downlayer(downblock, 256, num_layers[2], stride=2)
        self.layer4 = self._make_downlayer(downblock, 128, num_layers[3], stride=6)

        self.adaptor = adaptor
        
        self.uplayer0 = nn.ConvTranspose2d(768, 512, kernel_size=3, stride=1, bias=False)
        
        self.uplayer1 = self._make_up_block(upblock, 128, num_layers[3], stride=6)
        self.uplayer2 = self._make_up_block(upblock, 64, num_layers[2], stride=2)
        self.uplayer3 = self._make_up_block(upblock, 32, num_layers[1], stride=2)
        self.uplayer4 = self._make_up_block(upblock, 16, num_layers[0], stride=2)

        upsample = nn.Sequential(
            nn.ConvTranspose2d(self.in_channels, 64, kernel_size=1, stride=2, bias=False, output_padding=1),
            nn.BatchNorm2d(64),
        )
        self.uplayer_top = DeconvBottleneck(self.in_channels, 64, 1, 2, upsample)

        self.conv1_1 = nn.ConvTranspose2d(64, n_classes, kernel_size=1, stride=1, bias=False)
        
    def _make_downlayer(self, block, init_channels, num_layer, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != init_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, init_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(init_channels * block.expansion),
            )
        layers = []
        layers.append(block(self.in_channels, init_channels, stride, downsample))
        self.in_channels = init_channels * block.expansion
        for i in range(1, num_layer):
            layers.append(block(self.in_channels, init_channels))

        return nn.Sequential(*layers)

    def _make_up_block(self, block, init_channels, num_layer, stride=1):
        # upsample = None
        # # expansion = block.expansion
        # if stride != 1 or self.in_channels != init_channels * 2:
        #     upsample = nn.Sequential(
        #         nn.ConvTranspose2d(self.in_channels, init_channels * 2, kernel_size=1, stride=stride, bias=False, output_padding=1),
        #         nn.BatchNorm2d(init_channels * 2),
        #     )
        # layers = []
        # for i in range(1, num_layer):
        #     layers.append(block(self.in_channels, init_channels, 4))

        # layers.append(block(self.in_channels, init_channels, 2, stride, upsample))
        # self.in_channels = init_channels * 2
        # return nn.Sequential(*layers)
        up_block = _UpBlock(self.in_channels, block, init_channels, num_layer, stride)
        self.in_channels = up_block.in_channels
        return up_block
    
    def encode(self, x, check_resolution=True):
        if check_resolution and hasattr(self, 'weights_metadata'):
            resolution = self.weights_metadata['resolution']
            if (x.shape[2] != resolution) | (x.shape[3] != resolution):
                raise ValueError("Input size ({}x{}) is not the native resolution ({}x{}) for this model. Set check_resolution=False on the encode function to override this error.".format(x.shape[2], x.shape[3], resolution, resolution))

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = feat1 = self.layer1(x)
        x = feat2 = self.layer2(x)
        x = feat3 = self.layer3(x)
        x = feat4 = self.layer4(x)
        return x, [feat1, feat2, feat3, feat4]

    def fusion(self, x):
        x = self.adaptor(x)
        return x
    
    def upsample(self, x):
        x = self.uplayer0(x)
        return x
    
    def decode(self, x, skips, image_size=[1, 1, 224, 224]):
        x = self.uplayer1(x, skips[3])
        x = self.uplayer2(x, skips[2])
        x = self.uplayer3(x, skips[1])
        x = self.uplayer4(x, skips[0])
        x = self.uplayer_top(x, skips)

        x = self.conv1_1(x, output_size=image_size)
        return x

    def get_global_features(self, x):
        return torch.flatten(x, start_dim=2).permute((0, 2, 1)).mean(1) 
    
    def forward(self, x):
        ret = {}
        ret["pre_adaptor"], ret['skips'] = self.encode(x)
        ret['global_features'] = z = self.get_global_features(ret["pre_adaptor"])
        ret["multimodal_features"] = z = self.fusion(z)
        ret["out"] = self.decode(self.upsample(z.unsqueeze(2).unsqueeze(3)), ret['skips'], x.size())

        return ret

In [113]:
ckpt = "/vol/bitbucket/jq619/individual-project/trained_models/pretrain/resnet-ae_clinicalbert/adaptor pretrain/cuxhbf0j/checkpoints/epoch=49-step=88799.ckpt"
adaptor = Adaptor.load_from_checkpoint(ckpt)
model = _ResNetAE(adaptor, Bottleneck, DeconvBottleneck, [3, 4, 23, 2], 1)

RuntimeError: The size of tensor a (512) must match the size of tensor b (1024) at non-singleton dimension 1

In [10]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )


def up_conv(in_channels, out_channels):
    return nn.ConvTranspose2d(
        in_channels, out_channels, kernel_size=2, stride=2
    )

In [40]:
class ResNetAEUNet(nn.Module):
    def __init__(self, pretrained=False, out_channels=1):
        super().__init__()
        
        self.resnet_ae = xrv.autoencoders.ResNetAE(weights="101-elastic")
        self.encoder_layers = list(self.resnet_ae.children())[:8]
        self.block1 = nn.Sequential(*self.encoder_layers[:3])
        self.block2 = nn.Sequential(*self.encoder_layers[3:5])
        self.block3 = nn.Sequential(*self.encoder_layers[5])
        self.block4 = nn.Sequential(*self.encoder_layers[6])
        self.block5 = nn.Sequential(*self.encoder_layers[7])
        
        self.neck = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2)

        self.up_conv6 = up_conv(512, 512)
        self.conv6 = double_conv(512 + 1024, 512)
        self.up_conv7 = up_conv(512, 256)
        self.conv7 = double_conv(256 + 512, 256)
        self.up_conv8 = up_conv(256, 128)
        self.conv8 = double_conv(128 + 256, 128)
        self.up_conv9 = up_conv(128, 64)
        self.conv9 = double_conv(64 + 64, 64)
        self.up_conv10 = up_conv(64, 32)
        self.conv10 = nn.Conv2d(32, out_channels, kernel_size=1)
        
        if not pretrained:
            self._weights_init()
        
    def _weights_init(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)

        neck = self.neck(block5)
        
        x = self.up_conv6(neck)
        
        x = torch.cat([x, block4], dim=1)
        x = self.conv6(x)

        x = self.up_conv7(x)
        x = torch.cat([x, block3], dim=1)
        x = self.conv7(x)

        x = self.up_conv8(x)
        x = torch.cat([x, block2], dim=1)
        x = self.conv8(x)

        x = self.up_conv9(x)
        x = torch.cat([x, block1], dim=1)
        x = self.conv9(x)

        x = self.up_conv10(x)
        x = self.conv10(x)

        return x

In [42]:
model = ResNetAEUNet()
x = torch.ones(1, 1, 224, 224)
model(x).shape

[torch.Size([1, 64, 112, 112]), torch.Size([1, 256, 56, 56]), torch.Size([1, 512, 28, 28]), torch.Size([1, 1024, 14, 14]), torch.Size([1, 512, 3, 3]), torch.Size([1, 512, 7, 7])]


torch.Size([1, 1, 224, 224])