In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from basicsr.archs.pwcnet_arch import FlowGenerator
from einops import rearrange
from DAM import DAModule
import math

class Cross_attention(nn.Module):
    def __init__(self, dim_q, dim_k, dim_v, dim ,num_heads, bias):
        super(Cross_attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.wq = nn.Conv2d(dim_q, dim, kernel_size=1, bias=bias)
        self.wk = nn.Conv2d(dim_k, dim, kernel_size=1, bias=bias)
        self.wv = nn.Conv2d(dim_v, dim, kernel_size=1, bias=bias)

        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        


    def forward(self, q, k, v):
        b,c,h,w = q.shape

        q, k, v = self.wq(q), self.wk(k), self.wv(v) 
        # print(q.shape)
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        # print("Rearrange q shape: ", q.shape)
        # print("Rearrange k shape: ", k.shape)
        # print("Rearrange v shape: ", v.shape)
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        
        print("out",out.shape)
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

class single_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
                                  nn.LeakyReLU(negative_slope=0.2, inplace=True),)

    def forward(self, x):
        return self.conv(x)


class double_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
                                  nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                  nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
                                  nn.LeakyReLU(negative_slope=0.2, inplace=True))

    def forward(self, x):
        return self.conv(x)


class double_conv_down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
                                  nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                  nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
                                  nn.LeakyReLU(negative_slope=0.2, inplace=True))

    def forward(self, x):
        return self.conv(x)


class double_conv_up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(double_conv_up, self).__init__()
        self.conv = nn.Sequential(nn.UpsamplingNearest2d(scale_factor=2),
                                  nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
                                  nn.LeakyReLU(negative_slope=0.2, inplace=True),
                                  nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
                                  nn.LeakyReLU(negative_slope=0.2, inplace=True))

    def forward(self, x):
        return self.conv(x)

class Unet(nn.Module):
    def __init__(self, in_ch, feat_ch, out_ch):
        super().__init__()
        self.conv_in = single_conv(in_ch, feat_ch)

        self.conv1 = double_conv_down(feat_ch, feat_ch)
        self.conv2 = double_conv_down(feat_ch, feat_ch)
        self.conv3 = double_conv(feat_ch, feat_ch)
        self.conv4 = double_conv_up(feat_ch, feat_ch)
        self.conv5 = double_conv_up(feat_ch, feat_ch)
        self.conv6 = double_conv(feat_ch, out_ch)

    def forward(self, x):
        feat0 = self.conv_in(x)    # H, W
        feat1 = self.conv1(feat0)   # H/2, W/2
        feat2 = self.conv2(feat1)    # H/4, W/4
        feat3 = self.conv3(feat2)    # H/4, W/4
        feat3 = feat3 + feat2     # H/4
        feat4 = self.conv4(feat3)    # H/2, W/2
        feat4 = feat4 + feat1    # H/2, W/2
        feat5 = self.conv5(feat4)   # H
        feat5 = feat5 + feat0   # H
        feat6 = self.conv6(feat5)

        return feat0, feat1, feat2, feat3, feat4, feat6

class PosEnSine(nn.Module):
    """
    Code borrowed from DETR: models/positional_encoding.py
    output size: b*(2.num_pos_feats)*h*w
    """

    def __init__(self, num_pos_feats):
        super(PosEnSine, self).__init__()
        self.num_pos_feats = num_pos_feats
        self.normalize = True
        self.scale = 2 * math.pi
        self.temperature = 10000

    def forward(self, x):
        b, c, h, w = x.shape
        not_mask = torch.ones(1, h, w, device=x.device)
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        pos = pos.repeat(b, 1, 1, 1)
        return pos

class MLP(nn.Module):
    """
    conv-based MLP layers.
    """

    def __init__(self, in_features, hidden_features=None, out_features=None):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.linear1 = nn.Conv2d(in_features, hidden_features, 1)
        self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.linear2 = nn.Conv2d(hidden_features, out_features, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.act(x)
        x = self.linear2(x)
        return x

class Transformer_enc(nn.Module):
    def __init__(self, feature_dims, n_head = 4, mlp_ratio = 2, depth_cat = False):
        super(Transformer_enc, self).__init__()
        dim_q, dim_k, dim_v, feature_dim = feature_dims
        self.pos_enc = PosEnSine(feature_dim//2)
        if depth_cat:
            self.pos_enc = PosEnSine(feature_dim)
        self.att = Cross_attention(dim_q, dim_k, dim_v, feature_dim, n_head, False)
        mlp_hidden = int(feature_dim * mlp_ratio)
        self.FFN = MLP(feature_dim, mlp_hidden)
        self.norm = nn.GroupNorm(1,  feature_dim)
        self.res = double_conv(dim_q, feature_dim)
    def forward(self, q, k, v):
        # print(q.shape)
        b,c,h,w = q.shape
        q_pos = self.pos_enc(q)
        k_pos = self.pos_enc(k)
        # print(q_pos.shape)
        if c > 1:
            att_out = self.att(q + q_pos, k + k_pos, v)
        else:
            att_out = self.att(q, k + k_pos, v)
        # print("Attention output shape: ", att_out.shape)
        first_res = self.res(q) + att_out
        # pass MLP
        mlp_out = self.FFN(first_res)
        second_res = mlp_out + first_res
        out = self.norm(second_res)
        return out

class Alignformer(nn.Module):
    '''
       The implementation of utilizing depth map to guide attention (concat/cross attention)
    '''
    def __init__(self, feature_dims, src_channel, ref_channel, 
                    out_channel, n_head = 4, mlp_ratio = 2, depth_cat = False):
        super(Alignformer, self).__init__()
        feature_dim = feature_dims[-1]
        # print(feature_dim)
        self.DAM = DAModule(in_ch = src_channel, feat_ch = feature_dim, out_ch = src_channel,
                            demodulate = True, requires_grad = True)
        # Define feature extractor
        self.unet_q = Unet(src_channel, feature_dim, feature_dim)
        self.unet_k = Unet(src_channel, feature_dim, feature_dim)
        # Define GAM
        self.trans_unit = nn.ModuleList([
            Transformer_enc(feature_dims, n_head, mlp_ratio,depth_cat = depth_cat),
            Transformer_enc(feature_dims, n_head, mlp_ratio, depth_cat = depth_cat),
            Transformer_enc(feature_dims, n_head, mlp_ratio, depth_cat = depth_cat)
        ])
        # Unet result output
        self.conv0 = double_conv(feature_dim, feature_dim)
        self.conv1 = double_conv_down(feature_dim, feature_dim)
        self.conv2 = double_conv_down(feature_dim, feature_dim)
        self.conv3 = double_conv(feature_dim, feature_dim)
        self.conv4 = double_conv_up(feature_dim, feature_dim)
        self.conv5 = double_conv_up(feature_dim, feature_dim)
        self.conv6 = nn.Sequential(single_conv(feature_dim, feature_dim), 
                        nn.Conv2d(feature_dim, out_channel, 3, 1, 1))
        
    def forward(self, x_l, x_r, dpt_map):
        l_dpt, r_dpt = dpt_map
        x_l, x_r = torch.cat([x_l, l_dpt], dim = 1), torch.cat([x_r, r_dpt], dim = 1)
        src = self.DAM(x_l, x_r)
        q_feature = self.unet_q(x_l)
        k_feature = self.unet_k(x_r)
        outputs = []
    
        for i in range(3):
            # print("Query feature shape:",q_feature[i+3].shape)
            # print("Depth feature shape:",depth_feature_l[i+3].shape)
            # print("Key feature shape:",k_feature[i+3].shape)
            outputs.append(self.trans_unit[i](
                q_feature[i+3], 
                k_feature[i+3], k_feature[i+3]
            ))
        f0 = self.conv0(outputs[2])  # H, W
        f1 = self.conv1(f0)  # H/2, W/2
        f1 = f1 + outputs[1]
        f2 = self.conv2(f1)  # H/4, W/4
        f2 = f2 + outputs[0]
        f3 = self.conv3(f2)  # H/4, W/4
        f3 = f3 + outputs[0] + f2
        f4 = self.conv4(f3)   # H/2, W/2
        f4 = f4 + outputs[1] + f1
        f5 = self.conv5(f4)   # H, W
        f5 = f5 + outputs[2] + f0
        out = self.conv6(f5)

        return out
class Alignformer_depMapG(nn.Module):
    '''
       The implementation of utilizing depth map to guide attention (concat/cross attention)
    '''
    def __init__(self, feature_dims, src_channel, ref_channel, 
                    out_channel, n_head = 4, mlp_ratio = 2, depth_cat = False, module = "cat"):
        super(Alignformer_depMapG, self).__init__()
        feature_dim = feature_dims[-1]
        # print(feature_dim)
        self.DAM = DAModule(in_ch = src_channel, feat_ch = feature_dim, out_ch = src_channel,
                            demodulate = True, requires_grad = True)
        # Define feature extractor
        self.unet_q = Unet(src_channel, feature_dim, feature_dim)
        self.unet_k = Unet(src_channel, feature_dim, feature_dim)
        self.unet_d = Unet(1, feature_dim, feature_dim)
        # Define GAM
        self.trans_unit = nn.ModuleList([
            Transformer_enc(feature_dims, n_head, mlp_ratio,depth_cat = depth_cat),
            Transformer_enc(feature_dims, n_head, mlp_ratio, depth_cat = depth_cat),
            Transformer_enc(feature_dims, n_head, mlp_ratio, depth_cat = depth_cat)
        ])
        self.guide_module = module
        # Unet result output
        self.conv0 = double_conv(feature_dim, feature_dim)
        self.conv1 = double_conv_down(feature_dim, feature_dim)
        self.conv2 = double_conv_down(feature_dim, feature_dim)
        self.conv3 = double_conv(feature_dim, feature_dim)
        self.conv4 = double_conv_up(feature_dim, feature_dim)
        self.conv5 = double_conv_up(feature_dim, feature_dim)
        self.conv6 = nn.Sequential(single_conv(feature_dim, feature_dim), 
                        nn.Conv2d(feature_dim, out_channel, 3, 1, 1))
        
    def forward(self, x_l, x_r, dpt_map):
        l_dpt, r_dpt = dpt_map
        src = self.DAM(x_l, x_r)
        q_feature = self.unet_q(x_l)
        k_feature = self.unet_k(x_r)
        depth_feature_l, depth_feature_r = self.unet_d(l_dpt), self.unet_d(r_dpt)
        outputs = []
        if self.guide_module == "cat":
            for i in range(3):
                # print("Query feature shape:",q_feature[i+3].shape)
                # print("Depth feature shape:",depth_feature_l[i+3].shape)
                # print("Key feature shape:",k_feature[i+3].shape)
                outputs.append(self.trans_unit[i](
                    torch.cat([q_feature[i+3], depth_feature_l[i+3]], dim = 1), 
                    torch.cat([k_feature[i+3], depth_feature_r[i+3]], dim = 1), k_feature[i+3]
                ))
        f0 = self.conv0(outputs[2])  # H, W
        f1 = self.conv1(f0)  # H/2, W/2
        f1 = f1 + outputs[1]
        f2 = self.conv2(f1)  # H/4, W/4
        f2 = f2 + outputs[0]
        f3 = self.conv3(f2)  # H/4, W/4
        f3 = f3 + outputs[0] + f2
        f4 = self.conv4(f3)   # H/2, W/2
        f4 = f4 + outputs[1] + f1
        f5 = self.conv5(f4)   # H, W
        f5 = f5 + outputs[2] + f0
        out = self.conv6(f5)

        return out

In [9]:
import sys
from collections import OrderedDict
sys.path.append("../..")
sys.path.append("..")
from utils.Haze4k import DPDD, LFDOF
from Segmodels.segformer_mit import SegFormer
from torch.utils.data import DataLoader
def convert_pl(path):
    '''
    This function aims to convert PT lightning parameters dictionary into torch load state dict 
    '''
    ckpt = torch.load(path,map_location='cpu')
    new_state_dict = OrderedDict()
    for k in ckpt['state_dict']:
        # print(k)
        #name = 1
        # print(k[:4])
        if k[:4] != 'net.':
        #if 'tiny_unet.' not in k:
            continue
        name = k.replace('net.','')
        # print(name)
        new_state_dict[name] = ckpt['state_dict'][k]
    return new_state_dict
depth_model = SegFormer()
depth_model_syn = SegFormer()
depth_model_real = SegFormer()
depth_path_syn = "/home/fhx/code/archive/models/DEM_Segformer/logs/Seg_Logs/synthetic/checkpoints/Seg-epoch=94-psnr=22.814829-ssim=0.843821.ckpt"
depth_path_real = "/home/fhx/code/archive/models/DEM_Segformer/logs/Seg_Logs/version_0/checkpoints/Seg-epoch=94-iou=0.884012.ckpt"
# checkpoint = torch.load(self.depth_path)
#print(checkpoint["state_dict"])
ckpt_syn = convert_pl(depth_path_syn)
ckpt_real = convert_pl(depth_path_real)
print("Load pretrained depth model: ")
depth_model_syn.load_state_dict(ckpt_syn)
depth_model_real.load_state_dict(ckpt_real)
val_set = DPDD('/home/fhx/Datasets/DPDD_dataset/train 1/train',train=False,size=320,crop=False, name = False)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False)

Load pretrained depth model: 
crop size 320


In [10]:
from torchvision.utils import save_image

# for i, j in enumerate(val_loader):
#     blur, clear, name = j
#     clear_map = depth_model_real(clear)
#     print(name)
#     clear_name = "save_img/" + name[0]
#     clear_map = (clear_map - clear_map.min()) / (clear_map.max() - clear_map.min())
#     save_image(clear_map,clear_name)


In [8]:
depth

('1P0A2488.png',)

In [11]:
blur, gt, depth = next(iter(val_loader))
from torchvision.utils import save_image

save_image(depth,"depth.png")

# depth_map = depth_model(blur)
# depth_map_syn = depth_model_syn(blur)
# depth_map_real = depth_model_real(blur)
# depth_map_gt = depth_model(gt)
# depth_map_syn_gt = depth_model_syn(gt)
# depth_map_real_gt = depth_model_real(gt)

In [35]:
import torch
import torchvision.utils as tv_utils

# Assuming 'defocus_map' is your actual defocus map tensor
# Normalize to [0, 1]
defocus_map_normalized = (depth_map_real - depth_map_real.min()) / (depth_map_real.max() - depth_map_real.min())
defocus_syn_normalized = (depth_map_syn - depth_map_syn.min()) / (depth_map_syn.max() - depth_map_syn.min())
defocus = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())


defocus_map_normalized_gt = (depth_map_real_gt - depth_map_real_gt.min()) / (depth_map_real_gt.max() - depth_map_real_gt.min())
defocus_syn_normalized_gt = (depth_map_syn_gt - depth_map_syn_gt.min()) / (depth_map_syn_gt.max() - depth_map_syn_gt.min())
defocus_gt = (depth_map_gt - depth_map_gt.min()) / (depth_map_gt.max() - depth_map_gt.min())
# Convert to 8-bit integer (0-255)
# defocus_map_uint8 = (defocus_map_normalized * 255).clamp(0, 255).to(torch.uint8)

# Save as PNG
# save_path = "defocus_map_normalized.png"
# tv_utils.save_image(defocus_map_normalized, save_path)

# print(f"Normalized defocus map saved as {save_path}")


In [36]:
from torchvision.utils import save_image
save_image(blur,"blur.png")
save_image(defocus,"hg_depth_blur.png")
save_image(defocus_syn_normalized,"syn_depth_blur.png")
save_image(defocus_map_normalized,"real_depth_blur.png")
save_image(gt,"gt.png")
save_image(defocus_gt,"hg_depth.png")
save_image(defocus_syn_normalized_gt,"syn_depth.png")
save_image(defocus_map_normalized_gt,"real_depth.png")

In [None]:
a = torch.randn((2, 7, 3,3))
a[:, :3, :, :].shape, a[:, 3:4, :, :].shape, a[:,4:,:,:].shape

(torch.Size([2, 3, 3, 3]), torch.Size([2, 1, 3, 3]), torch.Size([2, 3, 3, 3]))

In [None]:
a[:, :3, :, :]

tensor([[[[-1.0548, -0.0567,  0.9441],
          [-0.6711, -1.3959, -0.2184],
          [ 2.3578,  1.6532, -0.8417]],

         [[ 0.6245, -0.4252,  1.0264],
          [ 0.1983,  2.0613,  1.7486],
          [ 0.2571,  0.9472,  0.7934]],

         [[ 0.9532,  1.5996,  0.5311],
          [ 0.6595,  2.5955, -0.4356],
          [-0.3221,  0.3520,  0.2600]]],


        [[[-0.9736,  0.0792, -1.1424],
          [ 0.9019, -1.0491, -0.3803],
          [ 1.7257, -0.6090,  1.9963]],

         [[ 0.5129,  1.2238, -0.3489],
          [-1.1229, -0.5799,  0.0894],
          [ 0.2790,  0.4487,  0.8890]],

         [[-0.5826,  0.1090, -1.0912],
          [-2.0799,  2.0098,  1.7442],
          [ 0.8457,  2.0026, -0.5016]]]])

In [None]:
# test result
dpt_map = [torch.randn(2,1,224,224), torch.randn(2,1,224,224)]
l,r = torch.randn(2, 3, 224, 224), torch.randn(2, 3, 224, 224)
model = Alignformer_depMapG([64, 64, 32, 32], 3,3,3, depth_cat = True)
model(l,r,dpt_map).shape

out torch.Size([2, 4, 8, 3136])
out torch.Size([2, 4, 8, 12544])
out torch.Size([2, 4, 8, 50176])


torch.Size([2, 3, 224, 224])

In [None]:
class Cross_attention(nn.Module):
    def __init__(self, dim_q, dim_k, dim_v, dim ,num_heads, bias):
        super(Cross_attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.wq = nn.Conv2d(dim_q, dim, kernel_size=1, bias=bias)
        self.wk = nn.Conv2d(dim_k, dim, kernel_size=1, bias=bias)
        self.wv = nn.Conv2d(dim_v, dim, kernel_size=1, bias=bias)
        self.wshare = nn.Conv2d(dim, dim, kernel_size=3,  stride=1, padding=1,groups=dim,bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        


    def forward(self, q, k, v):
        b,c,h,w = q.shape

        q, k, v = self.wshare(self.wq(q)), self.wshare(self.wk(k)), self.wshare(self.wv(v)) 
        print(q.shape)
        # print(q.shape)
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        # print("Rearrange q shape: ", q.shape)
        # print("Rearrange k shape: ", k.shape)
        # print("Rearrange v shape: ", v.shape)
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        
        print("out",out.shape)
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out
    
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        


    def forward(self, x):
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        print(qkv.shape)
        q,k,v = qkv.chunk(3, dim=1)   
        # print(q.shape)
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        print(out.shape)
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        
        out = self.project_out(out)
        return out

In [None]:
att = Cross_attention(64,64,64,64,8,False)
tst = torch.randn((2,64,128,128))
att(tst,tst,tst).shape

torch.Size([2, 64, 128, 128])
out torch.Size([2, 8, 8, 16384])


torch.Size([2, 64, 128, 128])

In [None]:
att2 = Attention(64,8,True)
att2(tst).shape

torch.Size([2, 192, 128, 128])
torch.Size([2, 8, 8, 16384])


torch.Size([2, 64, 128, 128])

In [2]:
class Transformer_enc_guide(nn.Module):
    '''
        The implementation of guidance utilizing depth map (cross attention)
    '''
    def __init__(self, feature_dims, n_head = 4, mlp_ratio = 2, depth_cat = False):
        super(Transformer_enc_guide, self).__init__()
        self.tf_block_l = Transformer_enc(feature_dims, n_head, mlp_ratio, depth_cat)
        self.tf_block_r = Transformer_enc(feature_dims, n_head, mlp_ratio, depth_cat)
        self.tf_block_all = Transformer_enc(feature_dims, n_head, mlp_ratio, depth_cat)
    
    def forward(self, x_l, x_r, dpt_map):
        # block_1 
        dpt_l, dpt_r = dpt_map
        att_l_out = self.tf_block_l(dpt_l, x_l, x_l)
        att_r_out = self.tf_block_r(dpt_r, x_r, x_r)
        att_out = self.tf_block_all(att_l_out, att_r_out, att_r_out)
        return att_out


class Alignformer_depMapCross(nn.Module):
    '''
       The implementation of utilizing depth map to guide attention (concat/cross attention)
    '''
    def __init__(self, feature_dims, src_channel, ref_channel, 
                    out_channel, n_head = 4, mlp_ratio = 2, depth_cat = False, module = "cat"):
        super(Alignformer_depMapCross, self).__init__()
        feature_dim = feature_dims[-1]
        # print(feature_dim)
        self.DAM = DAModule(in_ch = src_channel, feat_ch = feature_dim, out_ch = src_channel,
                            demodulate = True, requires_grad = True)
        # Define feature extractor
        self.unet_q = Unet(src_channel, feature_dim, feature_dim)
        self.unet_k = Unet(src_channel, feature_dim, feature_dim)
        self.unet_d = Unet(1, feature_dim, feature_dim)
        # Define GAM
        self.trans_unit = nn.ModuleList([
            Transformer_enc_guide(feature_dims, n_head, mlp_ratio,depth_cat = depth_cat),
            Transformer_enc_guide(feature_dims, n_head, mlp_ratio, depth_cat = depth_cat),
            Transformer_enc_guide(feature_dims, n_head, mlp_ratio, depth_cat = depth_cat)
        ])
        self.guide_module = module
        # Unet result output
        self.conv0 = double_conv(feature_dim, feature_dim)
        self.conv1 = double_conv_down(feature_dim, feature_dim)
        self.conv2 = double_conv_down(feature_dim, feature_dim)
        self.conv3 = double_conv(feature_dim, feature_dim)
        self.conv4 = double_conv_up(feature_dim, feature_dim)
        self.conv5 = double_conv_up(feature_dim, feature_dim)
        self.conv6 = nn.Sequential(single_conv(feature_dim, feature_dim), 
                        nn.Conv2d(feature_dim, out_channel, 3, 1, 1))
        
    def forward(self, x_l, x_r, dpt_map):
        l_dpt, r_dpt = dpt_map
        src = self.DAM(x_l, x_r)
        q_feature = self.unet_q(x_l)
        k_feature = self.unet_k(x_r)
        depth_feature_l, depth_feature_r = self.unet_d(l_dpt), self.unet_d(r_dpt)
        outputs = []
        for i in range(3):
            # print("Query feature shape:",q_feature[i+3].shape)
            # print("Depth feature shape:",depth_feature_l[i+3].shape)
            # print("Key feature shape:",k_feature[i+3].shape)
            outputs.append(self.trans_unit[i](
                q_feature[i+3], 
                k_feature[i+3],
                [depth_feature_l[i+3], depth_feature_r[i+3]]
            ))
        f0 = self.conv0(outputs[2])  # H, W
        f1 = self.conv1(f0)  # H/2, W/2
        f1 = f1 + outputs[1]
        f2 = self.conv2(f1)  # H/4, W/4
        f2 = f2 + outputs[0]
        f3 = self.conv3(f2)  # H/4, W/4
        f3 = f3 + outputs[0] + f2
        f4 = self.conv4(f3)   # H/2, W/2
        f4 = f4 + outputs[1] + f1
        f5 = self.conv5(f4)   # H, W
        f5 = f5 + outputs[2] + f0
        out = self.conv6(f5)

        return out

NameError: name 'nn' is not defined

In [None]:
# test result
dpt_map = [torch.randn(2,1,224,224), torch.randn(2,1,224,224)]
l,r = torch.randn(2, 3, 224, 224), torch.randn(2, 3, 224, 224)
model = Alignformer_depMapCross([32, 32, 32, 32], 3,3,3, depth_cat = False)
model(l,r,dpt_map).shape

torch.Size([2, 3, 224, 224])

In [None]:
from DCNv4 import modules as opsm
deform_conv=getattr(opsm, 'DCNv4')
class DCNConv(nn.Module):
    """
    DCNv4
    """
    def __init__(self,channels,kernel_size,group,offset_scale=1.0, output_bias=False):
        super(DCNConv, self).__init__()

        self.dcn_conv = deform_conv(
                channels=channels,
                group=group,
                offset_scale=offset_scale,
                dw_kernel_size=kernel_size,
                output_bias=output_bias
            )
    
    def forward(self, inp):
        _,_,H,W = inp.size()
        inp = rearrange(inp, 'b c h w -> b (h w) c')
        out = self.dcn_conv(inp,shape=[H,W])
        out = rearrange(out, 'b (h w) c -> b c h w',h=H,w=W)
        return out

In [None]:
device = torch.device("cuda")
# print(torch.devices)
dcn_test = DCNConv(64, 3, 4).to(device)
img_in = torch.randn((2, 64, 192, 192)).to(device)
dcn_test(img_in).shape

torch.Size([2, 64, 192, 192])

In [None]:
conv = nn.Conv2d(64, 64, kernel_size=3).to(device)
conv(img_in).shape

torch.Size([2, 64, 190, 190])