In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F


### aggregation test

In [2]:
class Aggregator(nn.Module):

    def __init__(self, reduce_dim, chs):
        super(Aggregator, self).__init__()

        self.layers = nn.ModuleList()
        for channel in chs:
            layer = nn.Conv2d(channel, reduce_dim, 1)
            self.layers.append(layer)

        self.aggregator = nn.Sequential(
            nn.Conv2d(reduce_dim*(len(chs)), reduce_dim, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(reduce_dim),
            nn.ReLU(inplace=True)
        )

    def forward(self, target, feats):
        _,_,H,W = target.size() # B,C,H,W
        
        sum_feats=[self.layers[0](target)]
        for i, feat in enumerate(feats):
            feat = self.layers[i+1](feat)
            sum_feats.append(F.interpolate(feat, (H,W), mode='bilinear', align_corners=True))

        context = self.aggregator(torch.cat(sum_feats, dim=1))

        return context


class Down(nn.Module):
    def __init__(self, reduce_dim=64):
        super(Down, self).__init__()

        self.down = nn.Sequential(
            nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, stride=2, bias=False),
            nn.BatchNorm2d(reduce_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(reduce_dim, reduce_dim, kernel_size=1, stride=1,bias=False),
            nn.BatchNorm2d(reduce_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.down(x)

class AggOnce(nn.Module):
    def __init__(self, reduce_dim=64, chs=[32, 56, 160, 448, 448, 448]):
        super(AggOnce, self).__init__()

        self.agg = Aggregator(reduce_dim, chs)

        self.down = nn.Sequential(Down(reduce_dim), 
                                Down(reduce_dim))

    def forward(self, x):
        f4 = self.agg(x[0], x[1:])
        f16 = self.down(f4)
        return f4, f16

class AggTwice(nn.Module):
    def __init__(self, reduce_dim=64):
        super(AggTwice, self).__init__()

        self.agg16 = Aggregator(reduce_dim, [160, 448, 448, 448])
        self.agg4 = Aggregator(reduce_dim, [32, 56, 64])


    def forward(self, x):
        # 4 8 16 32 64 global
        f16 = self.agg16(x[2], x[3:])
        f4 = self.agg4(x[0], [x[1], f16])
        return f4, f16


In [3]:
inputs = []
fh, fw = 900, 1600
device = torch.device("cuda")


for i, ch in enumerate([32, 56, 160, 448, 448]):
    h = int(fh/2**(i+2))
    w = int(fw/2**(i+2))
    inputs.append(torch.randn((1, ch, h, w)).to(device))

inputs.append(torch.randn((1, 448, 1, 1)).to(device))

for input in inputs:
    print(input.shape)

torch.Size([1, 32, 225, 400])
torch.Size([1, 56, 112, 200])
torch.Size([1, 160, 56, 100])
torch.Size([1, 448, 28, 50])
torch.Size([1, 448, 14, 25])
torch.Size([1, 448, 1, 1])


In [4]:
aggonce = AggOnce().to(device)
f4, f32 = aggonce(inputs)
print(f4.shape)
print(f32.shape)

torch.Size([1, 64, 225, 400])
torch.Size([1, 64, 55, 99])


In [5]:
aggtwice = AggTwice().to(device)
f4, f32 = aggtwice(inputs)
print(f4.shape)
print(f32.shape)

torch.Size([1, 64, 225, 400])
torch.Size([1, 64, 56, 100])


In [6]:
device = torch.device("cuda")
# aggtwice.to(device)
# aggonce.to(device)

# INIT LOGGERS
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 1000

timings=np.zeros((repetitions,1))
# bbs=np.zeros((repetitions,1))
# c1s, c2s = np.zeros((repetitions,1)), np.zeros((repetitions,1))
# dec = np.zeros((repetitions,1))

#GPU-WARM-UP
for _ in range(10):
    # _ = aggtwice(inputs)
    _ = aggonce(inputs)

# MEASURE PERFORMANCE
with torch.no_grad():
    for rep in range(repetitions):

        starter.record()
        # _ = aggtwice(inputs)
        _ = aggonce(inputs)
        ender.record()
        torch.cuda.synchronize()
        time = starter.elapsed_time(ender)
        timings[rep] = time

avg_time = np.sum(timings) / repetitions

print(f'inference time: {avg_time:.2f} ms  {1000/avg_time:.2f} fps')

inference time: 2.73 ms  365.93 fps


In [None]:
aggtwice: inference time: 2.16 ms  463.39 fps
aggonce : inference time: 2.76 ms  362.67 fps

### combine test

In [16]:
class ReduceAndCombine(nn.Module):
    def __init__(self):
        super(ReduceAndCombine, self).__init__()

        self.reduceM = nn.AdaptiveAvgPool2d((16, 16))# 128x128을 16x16으로 줄이기
        self.reduceV = nn.AdaptiveAvgPool2d((16, 16))# 128x128을 16x16으로 줄이기

        self.tfM2V = nn.TransformerDecoderLayer(d_model=64, nhead=8, batch_first=True)
        self.tfV2M = nn.TransformerDecoderLayer(d_model=64, nhead=8, batch_first=True)

    def forward(self, fromVoxel, fromMatching):
        
        recudedM = self.reduceM(fromMatching).reshape(1, 64, -1)
        recudedV = self.reduceV(fromVoxel).reshape(1, 64, -1)

        recudedM = recudedM + self.tfM2V(recudedV.permute((0,2,1)), recudedV.permute((0,2,1))).reshape(1, 64, 256)
        recudedV = recudedV + self.tfM2V(recudedV.permute((0,2,1)), recudedV.permute((0,2,1))).reshape(1, 64, 256)
        
        return recudedV, recudedM


from timm.models.swin_transformer import SwinTransformerBlock


class Combine(nn.Module):
    def __init__(self):
        super(Combine, self).__init__()

        # um_heads=4, head_dim=None, window_size=7, shift_size=0,
        #     mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
        #     act_layer=nn.GELU, norm_layer=nn.LayerNorm
        
        self.tfM2V = SwinTransformerBlock(dim=64, input_resolution=(128, 128), window_size=8)
        self.tfV2M = SwinTransformerBlock(dim=64, input_resolution=(128, 128), window_size=8)

    def forward(self, fromVoxel, fromMatching):
        
        recudedM = fromMatching.reshape(1, 64, -1)
        recudedV = fromVoxel.reshape(1, 64, -1)

        recudedM = recudedM + self.tfM2V(recudedM.permute((0,2,1))).permute(0, 2, 1)
        recudedV = recudedV + self.tfM2V(recudedV.permute((0,2,1))).permute(0, 2, 1)
        
        return recudedV, recudedM

In [17]:
device = torch.device("cuda")
inputs = torch.randn((1, 64, 128, 128)).to(device)

RAC = ReduceAndCombine().to(device)
C = Combine().to(device)

In [7]:
out1, out2 = RAC(inputs, inputs)

In [14]:
out1, out2 = C(inputs, inputs)

torch.Size([1, 16384, 64])


In [18]:
device = torch.device("cuda")
# aggtwice.to(device)
# aggonce.to(device)

# INIT LOGGERS
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 1000

timings=np.zeros((repetitions,1))
# bbs=np.zeros((repetitions,1))
# c1s, c2s = np.zeros((repetitions,1)), np.zeros((repetitions,1))
# dec = np.zeros((repetitions,1))

#GPU-WARM-UP
for _ in range(10):
    # RAC(inputs, inputs)
    out1, out2 = C(inputs, inputs)

# MEASURE PERFORMANCE
with torch.no_grad():
    for rep in range(repetitions):

        starter.record()
        # RAC(inputs, inputs)
        out1, out2 = C(inputs, inputs)
        ender.record()
        torch.cuda.synchronize()
        time = starter.elapsed_time(ender)
        timings[rep] = time

avg_time = np.sum(timings) / repetitions

print(f'inference time: {avg_time:.2f} ms  {1000/avg_time:.2f} fps')

inference time: 2.35 ms  424.68 fps


In [None]:
Reduce And Combine: inference time: 2.37 ms  422.65 fps
Combine: inference time: 2.35 ms  424.68 fps