In [11]:
from tensorflow.keras.layers import Layer, Conv2D, Input
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
import tensorflow as tf

In [149]:
# SAGAN https://arxiv.org/pdf/1805.08318.pdf
# following https://github.com/taki0112/Self-Attention-GAN-Tensorflow Code2
# shape of kernel line
# https://github.com/tensorflow/tensorflow/blob/d5163e15c21874fddb03fedaf2cc6316a590f490/tensorflow/python/keras/layers/convolutional.py#L194
# WIP this is attention version 1, I would like to make googles attention
class SelfAttention(Model):
    def __init__(self, channels, **kwargs):
        super(SelfAttention, self).__init__()

        self.channels = channels
        self.filters_f = channels // 8
        self.filters_g = channels // 8
        self.filters_h = channels 
        
        self.f = Conv2D(self.filters_f, (1,1), strides=(1,1), padding='same')
        self.g = Conv2D(self.filters_g, (1,1), strides=(1,1), padding='same')
        self.h = Conv2D(self.filters_h, (1,1), strides=(1,1), padding='same')
#         self.v = Conv2D(self.channels, (1,1), strides=(1,1), padding='same')

        self.gamma = tf.Variable(0.0, trainable=True, name='gamma')
    
    
    def call(self, x):
        def hw_flatten(x):
            return K.reshape(x, shape=[K.shape(x)[0], -1, K.shape(x)[-1]])
        
#         batch_size, height, width, num_channels = K.shape(x)
        f = self.f(x)
#         f = K.pool2d(x,(2,2),(2,2), padding='same')
        
        g = self.g(x)
        
        h = self.h(x)
#         h = K.pool2d(x,(2,2),(2,2), padding='same')
        
        s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True)

        beta = K.softmax(s)
        
        
        o = tf.matmul(beta, hw_flatten(h))
        
#         o = self.v(o)
        o = K.reshape(o, shape=K.shape(x))
        x = self.gamma * o + x
        
        return x
        

In [150]:
inputs = Input(shape=(32,32,64))
x = SelfAttention(64)(inputs)
model = Model(inputs, x)

In [151]:
model.summary()

Model: "model_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_48 (InputLayer)        [(None, 32, 32, 64)]      0         
_________________________________________________________________
self_attention_45 (SelfAtten (None, 32, 32, 64)        5201      
Total params: 5,201
Trainable params: 5,201
Non-trainable params: 0
_________________________________________________________________


In [152]:
import numpy as np

In [153]:
test = np.zeros((1, 32,32,64))

In [154]:
model.predict(test)

array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
    