# Paper to pytorch examples from Max

In [None]:
import torch
import copy
from collections import OrderedDict
from torchsummary import summary
import torch.nn as nn

## Highwaynet

https://arxiv.org/abs/1505.00387

The concept of a Highway block is that each layer adds some nonlinear transform (eg. sigmoid) transform T 
which determines the proportion of the layer that is passed through the affine transformation H, 
and conversely C (cary) = 1-T which determines a straight passthrough of input to output. 
Note that for this to be valid the input and output of the highway block must have the same dimensionality

In [None]:
import torch
import torch.nn as nn


class HighwayBlock(nn.Module):
    def __init__(self, input_dim):
        super(HighwayBlock, self).__init__()
        self.H = nn.Linear(input_dim, input_dim)
        self.T = nn.Linear(input_dim, input_dim)
        self.sigmoid = nn.Sigmoid()

        # Negative intialization of bias (to -1)
        # So that the layer is initially biassed towards carry behaviour
        self.T.bias.data *= 0.
        self.T.bias.data -= 1.

    def forward(self, x):
        # usually an affine transformation
        h = self.H(x)
        # transform gate
        t = self.sigmoid(self.T(x))
        # carry gate
        c = 1-t
        return h * t + c * x

In [None]:
highway = HighwayBlock(1024)
summary(highway,(1,1024))
highway

## ULMFiT

https://arxiv.org/abs/1801.06146

In [None]:
class ULMFiT(nn.Module):
    def __init__(self, hidden_dim=1150, embedding_dim=400, vocab_size=100000, rnn_layers=3, weight_drop=0.5):
        super(ULMFiT, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_dim, num_layers=rnn_layers, batch_first=True, bidirectional=True)
        self.lm_head = nn.Linear(hidden_dim, vocab_size)

        self.rnn_layers = rnn_layers
        self.weight_drop = weight_drop
        self.weight_cache = {}

    def forward(self, x):
        self.drop_connect()
        h, _ = self.rnn(x)
        self.restore_weights()
        return h

    def drop_connect(self):
        """Randomly sets H-H weights to zero and caches weights."""
        self.weight_cache = {}
        for i in range(self.rnn_layers):
            for name, param in self.rnn.named_parameters():
                if f'weight_hh_l{i}' in name:
                    shape = param.data.shape
                    mask = torch.rand(param.data.view(-1).shape[0]) > self.weight_drop
                    mask = mask.view(shape)
                    inv_mask = 1 - mask
                    self.weight_cache[name] = param.data * inv_mask.float()
                    param.data = param.data * mask.float()

    def restore_weights(self):
        for i in range(self.rnn_layers):
            for name, param in self.rnn.named_parameters():
                if f'weight_hh_l{i}' in name:
                    param.data += self.weight_cache[name]

In [None]:
ulmfit=ULMFiT()
ulmfit

## VGGNet

https://arxiv.org/abs/1409.1556

In [None]:
class VGG16(nn.Module):
    def __init__(self, n_conv=16, im_size=224):
        super(VGG16, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1)
        
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1)
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1)
        self.conv8 = nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1)

        self.conv9 = nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=1)
        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1, stride=1)
        self.conv11 = nn.Conv2d(512, 512, kernel_size=3, padding=1, stride=1)
        self.conv12 = nn.Conv2d(512, 512, kernel_size=3, padding=1, stride=1)

        self.conv13 = nn.Conv2d(512, 512, kernel_size=3, padding=1, stride=1)
        self.conv14 = nn.Conv2d(512, 512, kernel_size=3, padding=1, stride=1)
        self.conv15 = nn.Conv2d(512, 512, kernel_size=3, padding=1, stride=1)
        self.conv16 = nn.Conv2d(512, 512, kernel_size=3, padding=1, stride=1)

        self.im_flatten = (im_size // 32) ** 2 * 512
        self.fc1 = nn.Linear(self.im_flatten, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.out = nn.Linear(4096, 1000)

        self.pool = nn.MaxPool2d(2, stride=2)
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
    # Conv stack
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv3(x)
        x = self.relu(x)
        x = self.conv4(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv5(x)
        x = self.relu(x)
        x = self.conv6(x)
        x = self.relu(x)
        x = self.conv7(x)
        x = self.relu(x)
        x = self.conv8(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv9(x)
        x = self.relu(x)
        x = self.conv10(x)
        x = self.relu(x)
        x = self.conv11(x)
        x = self.relu(x)
        x = self.conv12(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv13(x)
        x = self.relu(x)
        x = self.conv14(x)
        x = self.relu(x)
        x = self.conv15(x)
        x = self.relu(x)
        x = self.conv16(x)
        x = self.relu(x)
        x = self.pool(x)

        x = x.view(-1, self.im_flatten)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.out(x)
        y = self.softmax(x)
        return y

### My rewrite using the `nn.Sequential` module

Referenced [this tutorial](https://github.com/FrancescoSaverioZuppichini/Pytorch-how-and-when-to-use-Module-Sequential-ModuleList-and-ModuleDict) to rewrite the network using `nn.Sequential` to make more modular and readable

In [None]:
def conv_block(in_f,out_f,**kwargs):
    return nn.Sequential(
        nn.Conv2d(in_f, out_f, **kwargs),
        nn.ReLU(),
    )
def conv_stack_block(in_f,out_f,stack_size,**kwargs):
    upstack = conv_block(in_f, out_f,**kwargs)
    stack_body = [conv_block(out_f,out_f,**kwargs) for _ in range(stack_size-1)]
    stack_block = nn.Sequential(
        upstack,
        *stack_body,
        nn.MaxPool2d(2, stride=2)
    )
    return stack_block
    

class SequentialVGG16(nn.Module):
    def __init__(self, n_conv=16, im_size=224):
        super(SequentialVGG16, self).__init__()
        
        self.conv_layers = nn.Sequential(OrderedDict([
            ('convblock1',conv_stack_block(3, 64, stack_size = 2, kernel_size=3, padding=1, stride=1 )),
            ('convblock2',conv_stack_block(64, 128, stack_size = 2, kernel_size=3, padding=1, stride=1 )),
            ('convblock3',conv_stack_block(128, 256, stack_size = 4, kernel_size=3, padding=1, stride=1 )),
            ('convblock4',conv_stack_block(256, 512, stack_size = 4, kernel_size=3, padding=1, stride=1 )),
            ('convblock5',conv_stack_block(512, 512, stack_size = 4, kernel_size=3, padding=1, stride=1 )),
        ]))
        
        self.im_flatten = (im_size // 32) ** 2 * 512
        
        self.fc_output = nn.Sequential(
            nn.Linear(self.im_flatten, 4096),
            nn.ReLU(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Linear(4096, 1000),
            nn.LogSoftmax(dim=1),
        )
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, self.im_flatten)
        y = self.fc_output(x)
        return y
        
    

In [None]:
vgg_max=VGG16()
vgg_seq=SequentialVGG16()

In [None]:
summary(vgg_seq, input_size=(3,224,224))

## Wavenet

https://arxiv.org/abs/1609.03499

In [None]:
class CausalConvBlock(nn.Module):
    def __init__(self, channels=32, kernel=3):
        super(CausalConvBlock, self).__init__()
        self.kernel_size = kernel
        self.gate_cnn = nn.Conv1d(channels, channels, kernel_size=kernel)
        self.filter_cnn = nn.Conv1d(channels, channels, kernel_size=kernel)
        self.final_cnn = nn.Conv1d(channels, channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x):
        gate = self.gate_cnn(x)
        filt = self.filter_cnn(x)
        # Shift the conv outputs rightward (pad left side)
        gate = torch.cat([torch.zeros_like(gate)[:, :, :(self.kernel_size - 1)], gate], dim=-1)
        filt = torch.cat([torch.zeros_like(filt)[:, :, :(self.kernel_size - 1)], filt], dim=-1)

        z = self.tanh(filt) * self.sigmoid(gate)
        z = self.final_cnn(z)
        # Add residual connection
        return z + x

## Transformer (encoder)

https://arxiv.org/abs/1706.03762

In [None]:

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias


class ScaledDotAttention(nn.Module):
    def __init__(self, scale, drop_p=0.1):
        super(ScaledDotAttention, self).__init__()
        self.scale = scale
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(drop_p)

    def forward(self, q, k , v):
        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale # Should be (BATCH, T, T)
        attn = self.softmax(attn)
        attn = self.dropout(attn)
        return torch.matmul(attn, v)  # Should result in (BATCH, T, d_attn)


class MultiHeadAttention(nn.Module):
    def __init__(self, scale, dmodel, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        self.v_proj = [nn.Linear(dmodel, int(dmodel / num_heads), bias=False).to(DEVICE) for _ in range(num_heads)]
        self.k_proj = [nn.Linear(dmodel, int(dmodel / num_heads), bias=False).to(DEVICE) for _ in range(num_heads)]
        self.q_proj = [nn.Linear(dmodel, int(dmodel / num_heads), bias=False).to(DEVICE) for _ in range(num_heads)]
        self.scaled_attention = ScaledDotAttention(scale)

        self.out = nn.Linear(dmodel, dmodel, bias=False)

    def forward(self, x):
        attns = []
        for v_proj, k_proj, q_proj in zip(self.v_proj, self.k_proj, self.q_proj):
            temp_v = v_proj(x)
            temp_k = k_proj(x)
            temp_q = q_proj(x)
            attns.append(self.scaled_attention(temp_v, temp_k, temp_q))

        attention = torch.cat(attns, dim=-1)
        return self.out(attention)


class FFBlock(nn.Module):
    def __init__(self, dmodel, dff=2048):
        super(FFBlock, self).__init__()
        self.fc1 = nn.Linear(dmodel, dff)
        self.fc2 = nn.Linear(dff, dmodel)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)


class EncoderLayer(nn.Module):
    def __init__(self, scale, dmodel=512, num_heads=8):
        super(EncoderLayer, self).__init__()
        self.mh_attention = MultiHeadAttention(scale=scale, dmodel=dmodel, num_heads=num_heads)
        self.ffblock = FFBlock(dmodel)
        self.layer_norm = LayerNorm(dmodel)

    def forward(self, x):
        attn = self.mh_attention(x)
        x = self.layer_norm(x + attn) # "add and norm"
        ff = self.ffblock(x)
        x = self.layer_norm(x + ff)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, num_encoders=6, dmodel=512, num_heads=8):
        super(TransformerEncoder, self).__init__()
        self.encoders = [EncoderLayer(dmodel, num_heads=num_heads).to(DEVICE) for _ in range(num_encoders)]

    def forward(self, embeddings):
        # Run embeddings through the encoders
        enc_outs = []
        for encoder in self.encoders:
            enc_outs.append(encoder(embeddings))
        # Sum all of the stacked transformer encodings
        return torch.sum(torch.stack(enc_outs), dim=0)