In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import math

In [2]:
class LayerNorm(nn.Module):

    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [3]:
class Multiattention(nn.Module):
    def __init__(self, qk_channels, v_channels, multi_x8 = 1, # number of "heads"
                                             keys_queries_projection_dim_x8 = 1,
                                             values_projection_dim_x8 = 1):
        super().__init__()
        
        self.q_projector = nn.Conv1d(qk_channels,keys_queries_projection_dim_x8*8*multi_x8*8, (1,), bias=False)
        self.k_projector = nn.Conv1d(qk_channels,keys_queries_projection_dim_x8*8*multi_x8*8, (1,), bias=False)
        self.v_projector = nn.Conv1d(v_channels,values_projection_dim_x8*8*multi_x8*8, (1,), bias=False)
        
        self.v_reprojector = nn.Conv1d(values_projection_dim_x8*8*multi_x8*8, v_channels, (1,), bias=False)
        
        self.softmax = nn.Softmax(dim=3)
        
        self.multi_x8 = multi_x8
        self.keys_queries_projection_dim_x8 = keys_queries_projection_dim_x8
            
    def forward(self, Q_source, K_source, V_source):
        batch_size = V_source.size()[0]
        dict_length = V_source.size()[1]
        input_channels = V_source.size()[-1]
        
        #keys_projection_coeffs = Variable(torch.ones(1,1,K_source.size()[-1],keys_queries_projection_dim_x8*8*multi_x8*8))
        #queries_projection_coeffs = Variable(torch.ones(1,1,Q_source.size()[-1],keys_queries_projection_dim_x8*8*multi_x8*8))
        #values_projection_coeffs = Variable(torch.ones(1,1,input_channels,values_projection_dim_x8*8*multi_x8*8))
        
        def _project(tensor_to_project, projector):
            _ = tensor_to_project
            _ = _.permute(0,2,1)
            _ = projector(_)
            _ = _.permute(0,2,1)
            _ = _.contiguous()
            _ = _.view(tensor_to_project.size()[0],tensor_to_project.size()[1],self.multi_x8*8,-1)
            _ = _.permute(0,2,1,3)
            return _
            
        keys_projections = _project(K_source, self.k_projector) # [batch,heads,sequence,channels]
        queries_projections = _project(Q_source, self.q_projector) # [batch,heads,sequence,channels]
        values_projections = _project(V_source, self.v_projector) # [batch,heads,sequence,channels]
        
        _ = keys_projections.permute(0,1,3,2) # [batch,heads,channels,sequence]
        _ = queries_projections.matmul(_) # [batch,heads,sequence,sequence]
        _ = torch.div(_, math.sqrt(self.keys_queries_projection_dim_x8*8))
        attention_softmaxed = self.softmax(_) # [batch,heads,sequence,sequence]
        
        # [batch,heads,sequence,channels]
        projected_multiselfattended_input = attention_softmaxed.matmul(values_projections)
        
        _ = projected_multiselfattended_input
        _ = _.permute(0,2,1,3)
        _ = _.contiguous()
        _ = _.view(_.size()[0], _.size()[1],-1) # [batch,sequence,channels]
        _ = _.permute(0,2,1)
        _ = self.v_reprojector(_)
        _ = _.permute(0,2,1)
        reprojected_multiselfattended_input = _
        
        return reprojected_multiselfattended_input
    
class FeedForward(nn.Module):
    def __init__(self, input_channels, ff_projection_dim_x8=None):
        super().__init__()
                     
        if ff_projection_dim_x8 is None:
            ff_projection_dim_x8 = input_channels // 8 # default to same number of channels as input
            
        self.projector = nn.Conv1d(input_channels, ff_projection_dim_x8*8, (1,), bias=True)
        self.reprojector = nn.Conv1d(ff_projection_dim_x8*8, input_channels, (1,), bias=False)
            
    def forward(self, input_tensor):
        _ = input_tensor
        _ = _.permute(0,2,1)
        _ = self.projector(_)
        _ = _.permute(0,2,1)
        _ = nn.functional.relu(_)
        _ = _.permute(0,2,1)
        _ = self.reprojector(_)
        _ = _.permute(0,2,1)
        
        return _     
        
class EncoderBlock(nn.Module):
    def __init__(self,
                 input_channels,
                 ff_projection_dim_x8=None,
                 multi_x8=1,
                 keys_queries_projection_dim_x8=1,
                 values_projection_dim_x8=1):
        super().__init__()
        self.ff = FeedForward(input_channels = input_channels,
                              ff_projection_dim_x8 = ff_projection_dim_x8)
        
                        
        self.multiattention_head_with_projections = Multiattention(qk_channels=input_channels,
                                                                   v_channels=input_channels,
                                                                   multi_x8=multi_x8,
                                                                   keys_queries_projection_dim_x8=keys_queries_projection_dim_x8,
                                                                   values_projection_dim_x8=values_projection_dim_x8)
        
        
        self.attention_layer_norm = LayerNorm(features=input_channels)
        self.ff_layer_norm = LayerNorm(features=input_channels)
        
    def forward(self, input_tensor):

        
        enc_attention = self.multiattention_head_with_projections(Q_source=input_tensor,
                                                                  V_source=input_tensor,
                                                                  K_source=input_tensor) 

        attention_with_res = input_tensor.add(enc_attention)
        
        normed_attention_with_res = self.attention_layer_norm(attention_with_res)

        _ = self.ff(normed_attention_with_res)
        ff_with_res = _.add(normed_attention_with_res)
        
        normed_ff_with_res = self.ff_layer_norm(ff_with_res)
        
        return normed_ff_with_res
    
class DecoderBlock(nn.Module):
    def __init__(self,
                 input_channels,
                 ff_projection_dim_x8=None,
                 multi_x8=1,
                 keys_queries_projection_dim_x8=1,
                 values_projection_dim_x8=1):
        super().__init__()
        self.ff = FeedForward(input_channels = input_channels,
                              ff_projection_dim_x8 = ff_projection_dim_x8)
        
        self.dec_multiattention_head_with_projections = Multiattention(qk_channels=input_channels,
                                                                       v_channels=input_channels,
                                                                       multi_x8=multi_x8,
                                                                       keys_queries_projection_dim_x8=keys_queries_projection_dim_x8,
                                                                       values_projection_dim_x8=values_projection_dim_x8)
        
        self.enc_multiattention_head_with_projections = Multiattention(qk_channels=input_channels,
                                                                       v_channels=input_channels,
                                                                       multi_x8=multi_x8,
                                                                       keys_queries_projection_dim_x8=keys_queries_projection_dim_x8,
                                                                       values_projection_dim_x8=values_projection_dim_x8)

        self.dec_attention_layer_norm = LayerNorm(features=input_channels)
        self.enc_attention_layer_norm = LayerNorm(features=input_channels)
        self.ff_layer_norm = LayerNorm(features=input_channels)
        
        
    def forward(self, input_tensor, encoder_output):
        dec_attention = self.dec_multiattention_head_with_projections(Q_source=input_tensor,
                                                                      K_source=input_tensor,
                                                                      V_source=input_tensor) 

        dec_attention_with_res = input_tensor.add(dec_attention)
        
        normed_dec_attention_with_res = self.dec_attention_layer_norm(dec_attention_with_res)
        
        enc_attention = self.enc_multiattention_head_with_projections(Q_source=normed_dec_attention_with_res,
                                                                      K_source=encoder_output,
                                                                      V_source=encoder_output) 
        
        enc_attention_with_res = normed_dec_attention_with_res.add(enc_attention)
        
        normed_enc_attention_with_res = self.enc_attention_layer_norm(enc_attention_with_res)

        _ = self.ff(normed_enc_attention_with_res)
        ff_with_res = _.add(normed_enc_attention_with_res)
        
        normed_ff_with_res = self.ff_layer_norm(ff_with_res)
        
        return normed_ff_with_res
    
class Net(nn.Module):
    def __init__(self, input_channels, blocks=6):
        super().__init__()
        
        self.encoder_blocs = nn.ModuleList([EncoderBlock(input_channels=input_channels,
                                                         ff_projection_dim_x8=256,
                                                         multi_x8=1,
                                                         keys_queries_projection_dim_x8=8,
                                                         values_projection_dim_x8=8) for i in range(blocks)])
        self.decoder_blocs = nn.ModuleList([DecoderBlock(input_channels=input_channels,
                                                         ff_projection_dim_x8=256,
                                                         multi_x8=1,
                                                         keys_queries_projection_dim_x8=8,
                                                         values_projection_dim_x8=8) for i in range(blocks)])
        
    def forward(self, input_tensor):
        _,__ = input_tensor,input_tensor
        
        for i in range(len(self.encoder_blocs)):
            _ = self.encoder_blocs[i](_)
            __ = self.decoder_blocs[i](__,_)
            
        return _.add(__)

In [4]:
torch.backends.cudnn.benchmark = True

In [5]:
%%time
rand = np.random.randn(5, 1024, 512)

net = Net(512, blocks=7)

with torch.cuda.device(0):
    rand = Variable(torch.Tensor(rand)).cuda().half()
    net.cuda().half()
    
    # these are automatically cuda
    param_master_copy = [param.clone().type(torch.cuda.FloatTensor).detach() for param in net.parameters()]
    for param in param_master_copy:
        param.requires_grad = True
    optimizer = torch.optim.SGD(param_master_copy, lr=.01,momentum=.0, 	weight_decay=.0)


CPU times: user 1.58 s, sys: 496 ms, total: 2.08 s
Wall time: 2.1 s


In [6]:
def test():
    for i in range(20):
        net.zero_grad()
        optimizer.zero_grad()

        forward_pass = net(rand)
        print(forward_pass.size())
        loss = forward_pass.abs().mean(2).mean(1).mean(0)
        loss.backward()

        for param_master, param_w_grad in zip(param_master_copy, net.parameters()):
            if param_master.grad is None:
                param_master.grad = torch.nn.Parameter(param_master.data.new().resize_(*param_master.data.size()))
            param_master.grad.data.copy_(param_w_grad.grad.data)

        optimizer.step()

        params = list(net.parameters())
        for i in range(len(params)):
            params[i].data.copy_(param_master_copy[i].data)

        ##learning_rate = 0.01
        ##params = list(net.parameters())
        ##for i in range(len(params)):
        ##    param_master_copy[i].data.sub_(params[i].grad.data * learning_rate)
        ##    params[i].data.copy_(param_master_copy[i].data)

In [7]:
%%time
test()

torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
torch.Size([5, 1024, 512])
CPU times: user 4.3 s, sys: 2.36 s, total: 6.66 s
Wall time: 6.8 s


In [8]:
import seaborn as sns

In [None]:
sns.distplot(param_master_copy,kde=False)
#sns.distplot([param_master.grad.data for param_master in param_master_copy])

In [8]:
%%time
test()

# device=titan5,runs=20,framework=pytorch,passes=both:
#   input=(2, 1024, 512),net(512, blocks=7),FP16: 4.5s
#   input=(2, 1024, 512),net(512, blocks=7),FP32: 6s
#   input=(4, 1023, 511),net(511, blocks=7),FP16: 8s
#   input=(4, 1023, 511),net(511, blocks=7),FP32: 8.5s
#   input=(4, 1024, 512),net(512, blocks=7),FP16: 6s
#   input=(4, 1024, 512),net(512, blocks=7),FP32: 8.4s - max batch_size
#   input=(5, 1023, 511),net(511, blocks=7),FP16: 9s
#   input=(5, 1023, 511),net(511, blocks=7),FP32: OOM
#   input=(5, 1024, 512),net(512, blocks=7),FP16: 7s
#   input=(5, 1024, 512),net(512, blocks=7),FP32: OOM
#   input=(6, 1023, 511),net(511, blocks=7),FP16: 12s
#   input=(6, 1024, 512),net(512, blocks=7),FP16: 9s
#   input=(9, 1023, 511),net(511, blocks=7),FP16: 15.7s
#   input=(9, 1024, 512),net(512, blocks=7),FP16: 11.5s - max batch_size
# device=titan5,runs=20,framework=pytorch,passes=both,benchmark=True:
#   input=(9, 1024, 512),net(512, blocks=7),FP16: 10s
# device=980ti,runs=20,framework=pytorch,passes=both:
#   input=(2, 1024, 512),net(512, blocks=7),FP32: 10s - max batch_size
# device=980ti,runs=20,framework=pytorch,passes=both,benchmark=True:
#   input=(2, 1024, 512),net(512, blocks=7),FP32: 9s
# device=titan5,runs=20,framework=tf,passes=both:
#   input=(2, 1024, 512),net(512, blocks=7),FP16: 4.5s
#   input=(2, 1024, 512),net(512, blocks=7),FP32: 6s
#   input=(4, 1023, 511),net(511, blocks=7),FP16: 10s
#   input=(4, 1023, 511),net(511, blocks=7),FP32: 10s
#   input=(4, 1024, 512),net(512, blocks=7),FP16: 6.5s
#   input=(4, 1024, 512),net(512, blocks=7),FP32: 10s
#   input=(5, 1023, 511),net(511, blocks=7),FP16: 11s
#   input=(5, 1023, 511),net(511, blocks=7),FP32: 11.5s
#   input=(5, 1024, 512),net(512, blocks=7),FP16: 7.5s
#   input=(5, 1024, 512),net(512, blocks=7),FP32: 11.5s - max batch_size
#   input=(6, 1023, 511),net(511, blocks=7),FP16: 13s
#   input=(6, 1024, 512),net(512, blocks=7),FP16: 9s - max batch_size
#   input=(9, 1024, 512),net(512, blocks=7),FP16: OOM
#   input=(9, 1023, 511),net(511, blocks=7),FP16: OOM
# device=980ti,runs=20,framework=tf,passes=both:
#   input=(2, 1024, 512),net(512, blocks=7),FP32: 10s - max batch_size

torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
torch.Size([9, 1024, 512])
CPU times: user 6.48 s, sys: 3.86 s, total: 10.3 s
Wall time: 10.3 s


In [9]:
### net = Net(input_channels=512, blocks=7)
model_parameters = list(filter(lambda p: p.requires_grad, net.parameters()))
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)
#for p in model_parameters:
#    print(p.size())

51444736


In [10]:
del(rand)
del(net)
del(param_master_copy)
del(optimizer)
del(forward_pass)
del(loss)

NameError: name 'forward_pass' is not defined