<a href="https://colab.research.google.com/github/maxmatical/pytorch-projects/blob/master/self_attention_inception_time.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import fastai
from fastai.vision import *
# from models import *


In [0]:

act_fn = nn.ReLU(inplace=True)
def conv(ni, nf, ks=3, stride=1, bias=False):
    return nn.Conv1d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)



class AdaptiveConcatPool1d(nn.Module):
    "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`."
    def __init__(self, sz:Optional[int]=None):
        "Output will be 2*sz or 2 if sz is None"
        super().__init__()
        self.output_size = sz or 1
        self.ap = nn.AdaptiveAvgPool1d(self.output_size)
        self.mp = nn.AdaptiveMaxPool1d(self.output_size)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)


class Shortcut(Module):
    "Merge a shortcut with the result of the module by adding them. Adds Conv, BN and ReLU"
    def __init__(self, ni, nf, act_fn=act_fn): 
        self.act_fn=act_fn
        self.conv=conv(ni, nf, 1)
        self.bn=nn.BatchNorm1d(nf)
    def forward(self, x): return act_fn(x + self.bn(self.conv(x.orig)))

class InceptionModule(Module):
    "An inception module for TimeSeries, based on https://arxiv.org/pdf/1611.06455.pdf"
    def __init__(self, ni, nb_filters=32, kss=[39, 19, 9], bottleneck_size=32, stride=1):
        if (bottleneck_size>0 and ni>1): self.bottleneck = conv(ni, bottleneck_size, 1, stride)
        else: self.bottleneck = noop
        self.convs = nn.ModuleList([conv(bottleneck_size if (bottleneck_size>1 and ni>1) else ni, nb_filters, ks) for ks in listify(kss)])
        self.conv_bottle = nn.Sequential(nn.MaxPool1d(3, stride, padding=1), conv(ni, nb_filters, 1))
        self.bn_relu = nn.Sequential(nn.BatchNorm1d((len(kss)+1)*nb_filters), nn.ReLU())
    def forward(self, x):
        bottled = self.bottleneck(x)
        return self.bn_relu(torch.cat([c(bottled) for c in self.convs]+[self.conv_bottle(x)], dim=1))

def create_inception(ni, nout, kss=[39, 19, 9], depth=6, bottleneck_size=32, nb_filters=32, head=True):
    "Inception time architecture"
    layers = []
    n_ks = len(kss) + 1
    for d in range(depth):
        im = SequentialEx(InceptionModule(ni if d==0 else n_ks*nb_filters, kss=kss, bottleneck_size=bottleneck_size))
        if d%3==2: im.append(Shortcut(n_ks*nb_filters, n_ks*nb_filters))      
        layers.append(im)
    head = [AdaptiveConcatPool1d(), Flatten(), nn.Linear(2*n_ks*nb_filters, nout)] if head else []
    return  nn.Sequential(*layers, *head)

tests

In [25]:
tmp = torch.randn((32, 2, 1024))

# conv
a = conv(2, 32, 3, 1) 
print(a(tmp).shape)


b = InceptionModule(2, 32, kss=[3, 7, 9], bottleneck_size = 0)
print(b(tmp).shape)

c = create_inception(2, 10)
print(c(tmp).shape)


torch.Size([32, 32, 1024])
torch.Size([32, 128, 1024])
torch.Size([32, 10])


adding self attention

In [0]:
# conv1d layer (used for self attention layer):
def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False):
    "Create and initialize a `nn.Conv1d` layer with spectral normalization."
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias: conv.bias.data.zero_()
    return nn.utils.spectral_norm(conv)


class SelfAttention(nn.Module):
    "Self attention layer for after convolutions."
    def __init__(self, n_channels):
        super(SelfAttention, self).__init__()
        self.query = conv1d(n_channels, n_channels//8)
        self.key   = conv1d(n_channels, n_channels//8)
        self.value = conv1d(n_channels, n_channels)
        self.gamma = nn.Parameter(torch.tensor([0.]))

    def forward(self, x):
        #Notation from https://arxiv.org/pdf/1805.08318.pdf
        size = x.size()
        x = x.view(*size[:2],-1)
        f,g,h = self.query(x),self.key(x),self.value(x)
        beta = F.softmax(torch.bmm(f.permute(0,2,1).contiguous(), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        return o.view(*size).contiguous()



class InceptionModule(Module):
    "An inception module for TimeSeries, based on https://arxiv.org/pdf/1611.06455.pdf"
    def __init__(self, ni, nb_filters=32, kss=[39, 19, 9], bottleneck_size=32, stride=1, self_attention = False):
        self.nb_filters = nb_filters
        self.self_attention = self_attention
        if (bottleneck_size>0 and ni>1): self.bottleneck = conv(ni, bottleneck_size, 1, stride)
        else: self.bottleneck = noop
        self.convs = nn.ModuleList([conv(bottleneck_size if (bottleneck_size>1 and ni>1) else ni, nb_filters, ks) for ks in listify(kss)])
        self.conv_bottle = nn.Sequential(nn.MaxPool1d(3, stride, padding=1), conv(ni, nb_filters, 1))
        self.bn_relu = nn.Sequential(nn.BatchNorm1d((len(kss)+1)*nb_filters), nn.ReLU())
        self.attn = SelfAttention(self.nb_filters*4)
    def forward(self, x):
        bottled = self.bottleneck(x)
        out =  self.bn_relu(torch.cat([c(bottled) for c in self.convs]+[self.conv_bottle(x)], dim=1))
        if self.self_attention:
            out = self.attn(out)
        return out


def create_inception(ni, nout, kss=[39, 19, 9], depth=6, bottleneck_size=32, nb_filters=32, head=True, self_attention = False):
    "Inception time architecture"
    layers = []
    n_ks = len(kss) + 1
    for d in range(depth):
        im = SequentialEx(InceptionModule(ni if d==0 else n_ks*nb_filters, kss=kss, bottleneck_size=bottleneck_size, self_attention=self_attention))
        if d%3==2: im.append(Shortcut(n_ks*nb_filters, n_ks*nb_filters))      
        layers.append(im)
    head = [AdaptiveConcatPool1d(), Flatten(), nn.Linear(2*n_ks*nb_filters, nout)] if head else []
    return  nn.Sequential(*layers, *head)


In [31]:
tmp = torch.randn((32, 2, 1024))

# conv
a = conv(2, 32, 3, 1) 
print(a(tmp).shape)


b = InceptionModule(2, self_attention = True)
print(b(tmp).shape)

c = create_inception(2, 10, self_attention = True)
print(c(tmp).shape)


torch.Size([32, 32, 1024])
torch.Size([32, 128, 1024])
torch.Size([32, 10])
