In [1]:
from typing import Dict
import torch
import torch.nn as nn
from collections import OrderedDict

import math

class HardAttentionLayer(nn.Module):
    def __init__(self, 
            hidden_size : int,
            attention_size : int, 
            N : int = 1, # number of elements to select
            temperature : float = 1.0,
            encoding : nn.Module = None
        ):
        super(HardAttentionLayer, self).__init__()
        
        self.temperature = temperature
        
        self.pe = PositionalEncoding(hidden_size, 100)
        
        self.N = N

        self.Q = nn.Linear(hidden_size, attention_size * N, bias = False)
        self.K = nn.Linear(hidden_size, attention_size * N, bias = False)

    def forward(self, x):
        batch_size, sequence_length, hidden_size = x.size()

        # encode the sequence with positional encoding
        pos_emb = self.pe(x)

        # calculate the query and key
        Q = self.Q(pos_emb)
        K = self.K(pos_emb)
        
        Q = Q.reshape( batch_size, sequence_length, self.N, -1 ).transpose(1, 2)
        K = K.reshape( batch_size, sequence_length, self.N, -1 ).transpose(1, 2)
        
        attention = torch.einsum( "bnsh,bnth -> bnst", Q, K ) / math.sqrt( hidden_size )
        attention = torch.sum(attention, dim=-1) / sequence_length

        # attention shape : (batch_size * N, sequence_length)
        logits = attention.reshape( batch_size * self.N, sequence_length )                
        # apply the Gumbel-Softmax trick to select the N most important elements
        alphas = torch.nn.functional.gumbel_softmax(logits, tau=self.temperature, hard=True)
        alphas = alphas.reshape( batch_size, self.N, sequence_length )
        
        # select N elements from the sequence x using alphas
        x = torch.einsum( "bns, bsh -> bnh", alphas, x )
        
        return x


class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size, max_len=5000):
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, hidden_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer("pe", pe, persistent=False)

    def forward(self, x):
        # shape x : (batch_size, seq_len, hidden_size)
        x = x + self.pe[:, : x.size(1)]
        return x


In [53]:
class EpochEncoder( nn.Module ):
    def __init__( self, 
        hidden_size : int,
        attention_size : int,
        N : int = 1, # number of elements to select
        temperature : float = 1.0
        ):
        super(EpochEncoder, self).__init__()
        self.hidden_size = hidden_size
        
        # we need to extract frequency-related features from the signal
        self.conv1 = nn.Sequential( OrderedDict([
            ("conv1", nn.Conv1d(N, 32, 5)),
            ("relu1", nn.ReLU()),
            ("maxpool1", nn.MaxPool1d(5)),
            ("conv2", nn.Conv1d(32, 64, 5)),
            ("relu2", nn.ReLU()),
            ("maxpool2", nn.MaxPool1d(5)),
            ("conv3", nn.Conv1d(64, 128, 5)),
            ("relu3", nn.ReLU()),
            ("maxpool3", nn.MaxPool1d(5)),
            ("conv4", nn.Conv1d(128, 256, 3)),
            ("relu4", nn.ReLU()),
            ("flatten", nn.Flatten())
        ]))
        
        self.out_size = self.conv1(torch.randn(1, N,  hidden_size)).shape[1]
        
        self.sampler = HardAttentionLayer(hidden_size, attention_size, N, temperature)
        
    def forward( self, x ):
        # x shape : (batch_size, seq_len, n_chan, n_samp)
        batch_size, seq_len, n_chan, n_samp = x.size()

        assert n_samp % self.hidden_size == 0, "Hidden size must be a divisor of the number of samples"

        x = x.reshape( batch_size * seq_len, n_chan*(n_samp//self.hidden_size), -1 )
        print( x.shape )
        x = self.sampler( x ) # shape : (batch_size * seq_len, N, hidden_size)
        
        print( x.shape )
        x = self.conv1( x ) # shape : (batch_size * seq_len, out_size)
        print( x.shape )
        return x

In [54]:
layer = EpochEncoder( 
    hidden_size = 3000 // 5,
    attention_size = 128,
    )

print( layer )

x = torch.randn( 32, 21, 3, 3000 )

y = layer(x)



EpochEncoder(
  (conv1): Sequential(
    (conv1): Conv1d(1, 32, kernel_size=(5,), stride=(1,))
    (relu1): ReLU()
    (maxpool1): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv1d(32, 64, kernel_size=(5,), stride=(1,))
    (relu2): ReLU()
    (maxpool2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv1d(64, 128, kernel_size=(5,), stride=(1,))
    (relu3): ReLU()
    (maxpool3): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
    (conv4): Conv1d(128, 256, kernel_size=(3,), stride=(1,))
    (relu4): ReLU()
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (sampler): HardAttentionLayer(
    (pe): PositionalEncoding()
    (Q): Linear(in_features=600, out_features=128, bias=False)
    (K): Linear(in_features=600, out_features=128, bias=False)
  )
)
torch.Size([672, 15, 600])
torch.Size([672, 1, 600])
torch.Size([672, 256])
