In [1]:
import torch
import torch.nn as nn
import os,sys
sys.path.append("../src")
from modules import *

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
class CRFN(nn.Module):
    def __init__(self, c_in, class_num=10, pool_type='avg', pool_size=(2,2), pretrained_path=None,last_activation="Sigmoid"):
        super().__init__()

        self.class_num = class_num
        self.pool_type = pool_type
        self.pool_size = pool_size
        
        self.conv_block0 = ConvBlock(in_channels=c_in, out_channels=64)    # 1: 7, 128     2: 7, 64
        self.gate_block0 = ConvBlock(in_channels=4, out_channels=64)

        
        self.gru = nn.GRU(input_size=64, hidden_size=64, 
            num_layers=3, dropout=0.3, batch_first=True, bidirectional=True)

        self.fc_1 = nn.Sequential(
            nn.Linear(128, 512, bias=True),
            nn.PReLU()
        )
        self.fc_2 = nn.Sequential(
            nn.Linear(512, 128, bias=True),
            nn.PReLU()
        )
        self.fc_3 = nn.Linear(128, class_num, bias=True)

        self.init_weights()

        if last_activation == "Sigmoid" :
            self.last_activation = nn.Sigmoid()
        elif last_activation == "Softmax" : 
            self.last_activation = nn.Softmax()
        else : 
            self.last_activation = nn.Sigmoid()

    def init_weights(self):

        init_gru(self.gru)
        #init_layer(self.azimuth_fc)
        init_layer(self.fc_1)
        init_layer(self.fc_2)
        init_layer(self.fc_3)

    def forward(self, x):
        #pdb.set_trace() 
        '''input: (batch_size, mic_channels, time_steps, mel_bins)'''
        gate = self.gate_block0(x, self.pool_type, pool_size=self.pool_size)
        x = self.conv_block0(x, self.pool_type, pool_size=self.pool_size)
        x = x * torch.sigmoid(gate)
        
        if self.pool_type == 'avg':
            x = torch.mean(x, dim=2)
        elif self.pool_type == 'max':
            (x, _) = torch.max(x, dim=2)
        '''(batch_size, feature_maps, time_steps)'''
        
        x = x.transpose(1,2)
        ''' (batch_size, time_steps, feature_maps):'''

        self.gru.flatten_parameters()
        (x, _) = self.gru(x)
        print(x.shape)
        # 
        x = self.fc_1(x)
        print("fc_0 : {}".format(x.shape))
        x = self.fc_2(x)
        x = self.fc_3(x)
        #azimuth_output = self.azimuth_fc(x)
        '''(batch_size, time_steps, class_num)'''

        # Interpolate
        azimuth_output = self.last_activation(x)
        
        

        pred = azimuth_output.mean(1)
        #prediction = scores.max(-1)[1]
        return pred
m = CRFN(4)
x = torch.rand(2,4,320,257)
y = m(x)
print("output")
print(y.shape)
print(y)

torch.Size([2, 128, 128])
fc_0 : torch.Size([2, 128, 512])
output
torch.Size([2, 10])
tensor([[0.5127, 0.5034, 0.5012, 0.5201, 0.4837, 0.4695, 0.4831, 0.4738, 0.4753,
         0.5272],
        [0.5141, 0.5033, 0.5023, 0.5192, 0.4836, 0.4701, 0.4858, 0.4762, 0.4754,
         0.5257]], grad_fn=<MeanBackward1>)
