In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from basic_blocks import SetBlock, BasicConv2d,Non_local,Feed_Forward,Attention,SetBlock_feature,Attention_ori,Feed_Forward_cdim,Feed_Forward_ori

def gem(x, p=6.5, eps=1e-6):
    # print('x-',x.shape)
    # print('xpow-',x.clamp(min=eps).pow(p).shape)
    # print(F.avg_pool2d(x.clamp(min=eps).pow(p), (1, x.size(-1))).shape)
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (1, x.size(-1))).pow(1./p)

class GeM(nn.Module):

    def __init__(self, p=6.5, eps=1e-6):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        # print('p-',self.p)
        return gem(x, p=self.p, eps=self.eps)
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'


class BasicConv3d(nn.Module):
    def __init__(self, inplanes, planes, dilation=1, bias=False, **kwargs):
        super(BasicConv3d, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 3, 3), bias=bias, dilation=(dilation, 1, 1), padding=(dilation, 1, 1))

    def forward(self, x):
        out = self.conv1(x.permute(0,2,1,3,4).contiguous())
        out = F.leaky_relu(out, inplace=True)
        return out.permute(0,2,1,3,4).contiguous()


class mix_mlp(nn.Module):
    def __init__(self, inplanes,hidden,planes, thw, norm1=False,norm2=False,short_cut=True,**kwargs):
        super(mix_mlp, self).__init__()



        self.fc1=Feed_Forward_ori(thw,thw*2)
        self.fc2=Feed_Forward_cdim(inplanes,hidden,planes)
        
        self.need_norm1=norm1
        self.need_norm2=norm2
        self.short_cut=short_cut

        if norm1:
            self.norm1=nn.LayerNorm(thw)
        if norm2:
            self.norm2=nn.LayerNorm(inplanes)
        
        if short_cut:
            self.downsample=nn.Linear(inplanes,planes)
        
    def forward(self, x):

        #n*-1 thw c
        if self.short_cut:
            x_ori=x.clone()

        if self.need_norm1:
            x=self.norm1(x)
        x=x.permute(0,2,1).contiguous()
        x=self.fc1(x) # n*-1 c thw
        x=x.permute(0,2,1).contiguous() # n*-1 thw c
        if self.short_cut:
            x=x+x_ori
            x_ori=self.downsample(x)
        if self.need_norm2:
            x=self.norm2(x)
        x=self.fc2(x) # n*-1 thw c2
        if self.short_cut:
            x=x+x_ori
        return x






def split_patch(x,patch_size_h, patch_size_w, patch_size_t):
    n,t,h,w,c=x.shape
    x=x.view(n,t//patch_size_t,patch_size_t,h//patch_size_h,patch_size_h,w//patch_size_w,patch_size_w,c)
    # n numt pt numh ph numw pw c
    x=x.permute(0,1,3,5,2,4,6,7).contiguous().view(n,-1,patch_size_t,patch_size_h,patch_size_w,c)
    n,numbin,pt,ph,pw,c=x.shape
    return x.view(n,numbin,-1,c)

def reserve_patch(x,patch_size_h, patch_size_w, patch_size_t, h, w, t):
    #n numt*numh*numw pt*ph*pw c
    n,numbin,ppp,c=x.shape
    x=x.view(n, t//patch_size_t, h//patch_size_h, w//patch_size_w, patch_size_t, patch_size_h, patch_size_w,c)
    x=x.permute(0,1,4,2,5,3,6,7).contiguous()
    x=x.view(n,t,h,w,c)
    return x
    


class LocaltemporalAG(nn.Module):
    def __init__(self, inplanes, planes, dilation=1, bias=False, **kwargs):
        super(LocaltemporalAG, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1), stride=(3,1,1), bias=bias,padding=(0, 0, 0))

    def forward(self, x):
        x=x.permute(0,4,1,2,3).contiguous() # b,c,t,h,w
        out1 = self.conv1(x)
        out = F.leaky_relu(out1, inplace=True)
        return out.permute(0,2,3,4,1).contiguous()


class BasicConv3d_p(nn.Module):
    def __init__(self, inplanes, planes, kernel=5, bias=False, p=2, FM=False, **kwargs):
        super(BasicConv3d_p, self).__init__()
        self.p = p
        self.fm = FM
        self.convdl = nn.Conv3d(inplanes, planes, kernel_size=(kernel, kernel, kernel), bias=bias, padding=((kernel-1)//2, (kernel-1)//2, (kernel-1)//2))
        self.convdg = nn.Conv3d(inplanes, planes, kernel_size=(kernel, kernel, kernel), bias=bias, padding=((kernel-1)//2, (kernel-1)//2, (kernel-1)//2))
    def forward(self, x):
        x=x.permute(0,2,1,3,4).contiguous()
        n, c, t, h, w = x.size()
        scale = h//self.p
        # print('p-',x.shape,n, c, t, h, w,'scale-',scale)
        feature = list()
        for i in range(self.p):
            temp = self.convdl(x[:,:,:,i*scale:(i+1)*scale,:])
            # print(temp.shape,i*scale,(i+1)*scale)
            feature.append(temp)

        outl = torch.cat(feature, 3)
        # print('outl-',outl.shape)
        outl = F.leaky_relu(outl, inplace=True)

        outg = self.convdg(x)
        outg = F.leaky_relu(outg, inplace=True)
        # print('outg-',outg.shape)
        if not self.fm:
            # print('1-1')
            out = outg + outl
        else:
            # print('1-2')
            out = torch.cat((outg, outl), dim=3)
        out=out.permute(0,2,1,3,4).contiguous()
        return out


class transview_pure(nn.Module):
    def __init__(self):
        super(transview_pure, self).__init__()
        self.batch_frame = None
        self.Gem=GeM()

        self.patch_size1=4
        self.patch_size1_t=3

        
        self.patch_size2_h=8
        self.patch_size2_w=4
        self.patch_size2_t=1

        self.patch_size3_h_1=4
        self.patch_size3_w_1=2
        self.patch_size3_t_1=5

        self.patch_size3_h_2=2
        self.patch_size3_w_2=11
        self.patch_size3_t_2=1



        self.bin_num=16
        self.bin_num2=32
        self.bin_num2_w=2
        self.bin_num_t=2
        _set_in_channels = 1
        _set_channels = [32, 64, 128,256]
        #self.set_layer1 = SetBlock(BasicConv2d(_set_in_channels, _set_channels[0], 5, padding=2))
        #self.set_layer2 = SetBlock(BasicConv2d(_set_channels[0], _set_channels[0], 3, padding=1), True)
        self.mlp_layer1=SetBlock_feature(mix_mlp(_set_in_channels,_set_channels[0],_set_channels[0],4*4*3,False,False,False))
        self.mlp_layer2=SetBlock_feature(mix_mlp(_set_channels[0],_set_channels[1],_set_channels[1],8*4*1,True,True,True))
        


        self.tp1=LocaltemporalAG(_set_channels[0],_set_channels[0])


        self.mlp_layer3_1=SetBlock_feature(mix_mlp(_set_channels[1],_set_channels[2],_set_channels[1],4*2*5,True,True,True))
        self.mlp_layer3_2=SetBlock_feature(mix_mlp(_set_channels[1],_set_channels[2],_set_channels[1],2*11*1,True,True,True))

        #self.c3d3=BasicConv3d_p(_set_channels[1], _set_channels[1],FM=True)
        #self.non_layer1 = SetBlock(Non_local( _set_channels[1],8))
        #self.non_layer2 = SetBlock(Non_local( _set_channels[1],8))

        self.non_layer1 = SetBlock(Attention( _set_channels[1],heads=2,dim_head=_set_channels[1]//4,dropout=0.1))
        self.non_layer2 = SetBlock(Attention( _set_channels[1],heads=2,dim_head=_set_channels[1]//4,dropout=0.1))


        self.fead_forward_layer1=SetBlock(Feed_Forward(_set_channels[1],_set_channels[2],_set_channels[2]))
        #self.feedforward_layer1=nn.Sequential(nn.Conv2d(32,32,kernel_size=(1,1),stride=1,padding=0,bias=True))
        #self.non_layer3 = SetBlock(Non_local(_set_channels[2],8))
        #self.non_layer4 = SetBlock(Non_local(_set_channels[2],8))


        self.non_layer3 = SetBlock(Attention( _set_channels[2],heads=2,dim_head=_set_channels[2]//8,dropout=0.1))
        self.non_layer4 = SetBlock(Attention( _set_channels[2],heads=2,dim_head=_set_channels[2]//8,dropout=0.1))


        self.fead_forward_layer2=SetBlock(Feed_Forward(_set_channels[2],_set_channels[3],_set_channels[3]))
        self.pool2d = nn.MaxPool2d(2)
        
        self.non_layer5 = Attention_ori( _set_channels[3],heads=2,dim_head=_set_channels[3]//4,dropout=0.1)

        self.pos_embedding_patch = nn.Parameter(torch.randn(1,1,64,32*2,22))
        self.pos_embedding_32 = nn.Parameter(torch.randn(1,1,64,32*2,22))
        self.pos_embedding_64 = nn.Parameter(torch.randn(1,1,128,32*2,22))


        self.layer_bin=8
        self.bin_numgl = [32,64]
        
        self.fc_bin = nn.Parameter(
                nn.init.xavier_uniform_(
                    torch.zeros(sum(self.bin_numgl), _set_channels[3], _set_channels[3])))
                
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Conv1d)):
                nn.init.xavier_uniform_(m.weight.data)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    nn.init.constant(m.bias.data, 0.0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.normal(m.weight.data, 1.0, 0.02)
                nn.init.constant(m.bias.data, 0.0)

    def frame_max(self, x):
        if self.batch_frame is None:
            return torch.max(x, 1)
        else:
            _tmp = [
                torch.max(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1)
                for i in range(len(self.batch_frame) - 1)
                ]
            max_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0)
            arg_max_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0)
            return max_list, arg_max_list

    def frame_median(self, x):
        if self.batch_frame is None:
            return torch.median(x, 1)
        else:
            _tmp = [
                torch.median(x[:, self.batch_frame[i]:self.batch_frame[i + 1], :, :, :], 1)
                for i in range(len(self.batch_frame) - 1)
                ]
            median_list = torch.cat([_tmp[i][0] for i in range(len(_tmp))], 0)
            arg_median_list = torch.cat([_tmp[i][1] for i in range(len(_tmp))], 0)
            return median_list, arg_median_list

    


    def forward(self, silho, batch_frame=None):

        n = silho.size(0)
        t = silho.size(1)
        x=silho
        n,t,c,h,w=x.shape
        #x = silho.unsqueeze(1)
        del silho
        if t == 1:
            x = x.repeat(1, 6, 1, 1, 1)
        elif t == 2 or t==3:
            x = x.repeat(1, 3, 1, 1, 1)
        #elif t == 3:
        #    x = torch.cat((x,x[:,0:3,:,:,:]),dim=1)
        elif t<6:
            x=x.unsqueeze(1)
            x=x.expand(-1,2,-1,-1,-1,-1).contiguous()
            #print(x.shape)
            x=x.view(n,-1,c,h,w)
        x=x[:,:x.shape[1]-x.shape[1]%6]
        
        
        #print(x.shape)
        #x = self.set_layer1(x)
        # n t c h w
        x=x.permute(0,1,3,4,2).contiguous() #n t h w c
        b,t,h,w,c=x.shape
        x=split_patch(x,self.patch_size1,self.patch_size1,self.patch_size1_t) #n -1 pt*ph*pw c
        x=self.mlp_layer1(x)
        x=reserve_patch(x,self.patch_size1,self.patch_size1,self.patch_size1_t,h,w,t) #n t h w c

        x = self.tp1(x) # n,t/3,32,64,44
        
        b,t,h,w,c=x.shape
        print(x.shape)
        x=split_patch(x,self.patch_size2_h,self.patch_size2_w,self.patch_size2_t) # n -1 pt*ph*pw c
        x=self.mlp_layer2(x)
        x=reserve_patch(x,self.patch_size2_h,self.patch_size2_w,self.patch_size2_t,h,w,t) #n t/3 64 64 44



        b,t,h,w,c=x.shape

        x=x.permute(0,1,4,2,3).contiguous()
        x = x.view(-1,c,h,w)
        x = self.pool2d(x)

        x = x.view(b,t,c,h//2,w//2)
        x=x.permute(0,1,3,4,2).contiguous()

        b,t,h,w,c=x.shape

        x1=x.clone()
        x1=split_patch(x1,self.patch_size3_h_1,self.patch_size3_w_1,self.patch_size3_t_1)
        x1=self.mlp_layer3_1(x1)
        x1=reserve_patch(x1,self.patch_size3_h_1,self.patch_size3_w_1,self.patch_size3_t_1,h,w,t)

        x2=x.clone()
        x2=split_patch(x2,self.patch_size3_h_2,self.patch_size3_w_2,self.patch_size3_t_2)
        x2=self.mlp_layer3_2(x2)
        x2=reserve_patch(X2,self.patch_size3_h_2,self.patch_size3_w_2,self.patch_size3_t_2,h,w,t)





        x=torch.cat([x1,x2],2)
        x=x.permute(0,1,4,2,3).contiguous()
        
        #print(x.shape)
        x=x+self.pos_embedding_32
        n,t,c,h,w=x.shape


        win_size=h//self.bin_num
        x=x.view(n,t,c,self.bin_num,win_size,w).permute(0,1,3,2,4,5)\
            .contiguous().view(n,t*self.bin_num,c,win_size,w) # n tp c h/p w   

        x = self.non_layer1(x) # 
        #x = self.non_layer2(x) #
        x = self.fead_forward_layer1(x)
        x = x.view(n,t,self.bin_num,c*2,win_size,w).permute(0,1,3,2,4,5).contiguous().view(n,t,c*2,h,w)


        x=x+self.pos_embedding_64
        n,t,c,h,w=x.shape
        win_size_h=h//self.bin_num2
        win_size_w=w//self.bin_num2_w
        t_size=t//self.bin_num_t

        x_2=x.view(n,self.bin_num_t,t_size,c,self.bin_num2,win_size_h,self.bin_num2_w,win_size_w).permute(0,1,4,6,3,2,5,7)\
            .contiguous().view(n,self.bin_num_t*self.bin_num2*self.bin_num2_w,c,t_size*win_size_h,win_size_w)

        #x_2=x.view(n,t,c,self.bin_num,win_size,w).permute(0,1,3,2,4,5)\
        #    .contiguous().view(n,t*self.bin_num,c,win_size,w) # n tp c h/p w   

        x_2=self.non_layer3(x_2)
        x_2=self.non_layer4(x_2)

        x_2 = self.fead_forward_layer2(x_2) 
        x_2 = x_2.view(n,self.bin_num_t,self.bin_num2,self.bin_num2_w,c*2,t_size,win_size_h,win_size_w).permute(0,1,5,4,2,6,3,7).contiguous()
        x_2 = x_2.view(n,t,c*2,h,w)
        #x_2 = x_2.view(n,t,self.bin_num,c*2,win_size,w).permute(0,1,3,2,4,5).contiguous().view(n,t,c*2,h,w)
        # n t 128 64 22

        #x_2=torch.max(x_2,1)[0] # n 128 64 22
        #x_2=torch.cat(x_2.mean(1),x.max(1)[0])
        x_2=x_2.mean(1)

        _, c2d, _, _ = x_2.size()

        feature = list()
        for num_bin in self.bin_numgl:
            z = x_2.view(n, c2d, num_bin, -1).contiguous()
            # z1 = z.mean(3) + z.max(3)[0]
            # print('z1-',z1.shape)
            z2 = self.Gem(z).squeeze(-1)
            # print('z2-',z2.shape)
            feature.append(z2)
        feature = torch.cat(feature, 2).permute(2, 0, 1).contiguous()
        #print('feature',feature.shape)
        feature = feature.matmul(self.fc_bin) # 96 n 256
        #feature = feature.permute(1, 2, 0).contiguous()
        #print('feature',feature.shape)

        #feature = self.non_layer5(feature)
        feature = feature.permute(1, 2, 0).contiguous()
        return feature

In [2]:
net=transview_pure()



In [3]:
x=torch.randn(1,30,1,64,44)
out=net(x)

torch.Size([1, 10, 64, 44, 32])


RuntimeError: Given normalized_shape=[64], expected input with shape [*, 64], but got input of size[880, 32, 32]