# Design attention head

In [1]:
import torch
from torch import nn
from model import HRSeg, SelfAttention
import torch.nn.functional as F
from utils import count_parameters



In [2]:
x1 = torch.rand(1, 64, 72, 72)
x2 = torch.rand(1, 128, 36, 36)
x3 = torch.rand(1, 320, 18, 18)
x4 = torch.rand(1, 512, 9, 9)

In [10]:
class AttentionHead(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.upConv4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=64, kernel_size=4, stride=4),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),
            nn.ReLU(),
            SelfAttention(64)
        ) # [1, 64, 72, 72]
        
        self.upConv3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=320, out_channels=64, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),
            nn.ReLU(),
            SelfAttention(64)
        ) # [1, 64, 72, 72]

        self.upConv2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2),
            nn.ReLU(),
            SelfAttention(64)
        ) # [1, 64, 72, 72]

        self.upConv_out = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=4, stride=4)

    def forward(self, x1, x2, x3, x4):
        x = self.upConv4(x4)
        return x

In [16]:
model = HRSeg()

Loaded state dict for encoder: pretrained_pth/mit_b2.pth


In [23]:
encoded = model.encoder(torch.rand(1, 3, 288, 288))
att_map = model.att_head(*encoded)
att_map.shape

torch.Size([1, 1, 288, 288])

# Check dataloader

In [1]:
from dataloader import get_train_loader
trainloader = get_train_loader(train_roots=['./dataset/TrainDataset'], batchsize=1, inner_size=352, outer_size=528)

In [2]:
res = next(iter(trainloader))

In [3]:
res.keys()

dict_keys(['image', 'inner_image', 'mask', 'slice'])

In [4]:
res['image'].shape

torch.Size([1, 3, 352, 352])

In [5]:
res['inner_image'].shape

torch.Size([1, 3, 352, 352])

In [6]:
res['mask'].shape

torch.Size([1, 1, 528, 528])