In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [6]:
class SEBlock(nn.Module):
    def __init__(self, channel_size, reduce_ratio):
        reduced_channel_size = channel_size // reduce_ratio
        self.reduced_channel_size = reduced_channel_size
        
        self.fc1 = nn.Linear(channel_size, reduced_channel_size)
        self.fc2 = nn.Linear(reduced_channel_size, channel_size)
        
        
    def forward(self, input_tensor):
        batch_size, channel_size, H, W = input.tensor.size()
        
        squeezed_tensor = input_tensor.view(batch_size, channel_size, -1).mean(dim=2)
        
        Y = F.relu(self.fc1(squeezed_tensor))
        Y = F.sigmoid(self.fc2(Y))
        
        return torch.mul(input_tensor, Y.view(batch_size, channel_size, 1, 1))