In [1]:
import torch
import torch.nn as nn
import os,sys
sys.path.append("../src")
from modules import *

  from .autonotebook import tqdm as notebook_tqdm


In [27]:
class TCRN(nn.Module):
    def __init__(self, c_in, class_num=10, last_activation="Sigmoid"):
        super().__init__()

        self.class_num = class_num

        self.enc_1 = nn.Conv1d(257*4,256,1)
        
        self.enc_2 = TCN(
            c_in = 256,
            c_out = 256,
            TCN_activation="PReLU"
        )

        self.rnn_1 = nn.GRU(
            input_size =256,
            hidden_size = 256,
            num_layers =2,
            dropout =0.3,
            bidirectional=False,
            batch_first=True
        )
        
        self.fc_1 = nn.Sequential(
            nn.Linear(256, 256, bias=True),
            nn.PReLU()
        )
        self.fc_2 = nn.Sequential(
            nn.Linear(256, class_num, bias=True),
        )


        if last_activation == "Sigmoid" :
            self.last_activation = nn.Sigmoid()
        elif last_activation == "Softmax" : 
            self.last_activation = nn.Softmax()
        else : 
            self.last_activation = nn.Sigmoid()

    def forward(self, x):
        # [B,C,T,F]
        print(x.shape)
        
        # [B, C, F,T]
        x = torch.permute(x,(0,1,3,2))
        
        # [B,C*F,T]
        x = torch.reshape(x,(x.shape[0],x.shape[1]*x.shape[2],x.shape[3]))
        print(x.shape)
        
        x = self.enc_1(x)
        print("enc_1 : {}".format(x.shape))
        
        x = self.enc_2(x)
        print("enc_2 : {}".format(x.shape))
        
        x = torch.permute(x,(0,2,1))
        print(x.shape)
        
        x = self.rnn_1(x)[0]
        print("rnn_1 : {}".format(x.shape))
        
        x = self.fc_1(x)
        print("fc_1 : {}".format(x.shape))
        
        x = self.fc_2(x)
        print("fc_2 : {}".format(x.shape))
        
        
        azimuth_output = self.last_activation(x)

        pred = azimuth_output.mean(1)
        #prediction = scores.max(-1)[1]
        return pred
m = TCRN(4)
x = torch.rand(2,4,320,257)
y = m(x)
print("output")
print(y.shape)
print(y)

torch.Size([2, 4, 320, 257])
torch.Size([2, 1028, 320])
enc_1 : torch.Size([2, 256, 320])
enc_2 : torch.Size([2, 256, 320])
torch.Size([2, 320, 256])
rnn_1 : torch.Size([2, 320, 256])
fc_1 : torch.Size([2, 320, 256])
fc_2 : torch.Size([2, 320, 10])
output
torch.Size([2, 10])
tensor([[0.4997, 0.4757, 0.4824, 0.4975, 0.4920, 0.5091, 0.4875, 0.5034, 0.4963,
         0.4884],
        [0.4970, 0.4763, 0.4829, 0.4985, 0.4927, 0.5073, 0.4857, 0.5026, 0.4953,
         0.4871]], grad_fn=<MeanBackward1>)


In [31]:
class TCRNv2(nn.Module):
    def __init__(self, c_in, class_num=10, last_activation="Sigmoid"):
        super().__init__()

        self.class_num = class_num

        self.enc_1 = TCN(
            c_in = 1028,
            c_out = 1028*2,
            TCN_activation="PReLU"
        )
        
        self.enc_2 = TCN(
            c_in = 1028,
            c_out = 1028*2,
            TCN_activation="PReLU"
        )

        self.rnn_1 = nn.GRU(
            input_size =1028,
            hidden_size = 512,
            num_layers =2,
            dropout =0.3,
            bidirectional=False,
            batch_first=True
        )
        
        self.fc_1 = nn.Sequential(
            nn.Linear(512, 256, bias=True),
            nn.PReLU()
        )
        self.fc_2 = nn.Sequential(
            nn.Linear(256, class_num, bias=True),
        )


        if last_activation == "Sigmoid" :
            self.last_activation = nn.Sigmoid()
        elif last_activation == "Softmax" : 
            self.last_activation = nn.Softmax()
        else : 
            self.last_activation = nn.Sigmoid()

    def forward(self, x):
        # [B,C,T,F]
        print(x.shape)
        
        # [B, C, F,T]
        x = torch.permute(x,(0,1,3,2))
        
        # [B,C*F,T]
        x = torch.reshape(x,(x.shape[0],x.shape[1]*x.shape[2],x.shape[3]))
        print(x.shape)
        
        x = self.enc_1(x)
        print("enc_1 : {}".format(x.shape))
        
        x = self.enc_2(x)
        print("enc_2 : {}".format(x.shape))
        
        x = torch.permute(x,(0,2,1))
        print(x.shape)
        
        x = self.rnn_1(x)[0]
        print("rnn_1 : {}".format(x.shape))
        
        x = self.fc_1(x)
        print("fc_1 : {}".format(x.shape))
        
        x = self.fc_2(x)
        print("fc_2 : {}".format(x.shape))
        
        
        azimuth_output = self.last_activation(x)

        pred = azimuth_output.mean(1)
        #prediction = scores.max(-1)[1]
        return pred
m = TCRNv2(4)
x = torch.rand(2,4,320,257)
y = m(x)
print("output")
print(y.shape)
print(y)

torch.Size([2, 4, 320, 257])
torch.Size([2, 1028, 320])
enc_1 : torch.Size([2, 1028, 320])
enc_2 : torch.Size([2, 1028, 320])
torch.Size([2, 320, 1028])
rnn_1 : torch.Size([2, 320, 512])
fc_1 : torch.Size([2, 320, 256])
fc_2 : torch.Size([2, 320, 10])
output
torch.Size([2, 10])
tensor([[0.5142, 0.5033, 0.5093, 0.5348, 0.4615, 0.4899, 0.5168, 0.4922, 0.4773,
         0.4975],
        [0.5127, 0.5003, 0.5132, 0.5325, 0.4597, 0.4915, 0.5152, 0.4918, 0.4766,
         0.4921]], grad_fn=<MeanBackward1>)
