In [32]:
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
import numpy as np

(X_train, y_train), (X_test, y_test) = cifar10.load_data()

backend = tf.keras.backend
models = tf.keras.models
layers = tf.keras.layers

In [26]:
X_train = X_train.astype('float32') / 255.0

# Make Dilated Block

In [11]:
class Dilated_block(tf.keras.Model):
    def __init__(self, filters, kernel_size = 3, rate = 1):
        super(Dilated_block, self).__init__()
        
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')
        self.conv1 = layers.Conv2D(filters, kernel_size, padding = 'same', dilation_rate = rate)
        
    def call(self, inputs):
        
        x = self.bn1(inputs)
        x = self.relu(x)
        x = self.conv1(x)
        
        return x

# Make DenseASPP

In [20]:
class DenseASPP(tf.keras.Model):
    def __init__(self, num_classes):
        super(DenseASPP, self).__init__()
        
        self.dilation1 = Dilated_block(256, 1)
        self.dilation2 = Dilated_block(64, 3, rate = 3)
        self.dilation3 = Dilated_block(64, 3, rate = 6)
        self.dilation4 = Dilated_block(64, 3, rate = 12)
        self.dilation5 = Dilated_block(64, 3, rate = 18)
        self.dilation6 = Dilated_block(64, 3, rate = 24)
        
        self.conv1 = layers.Conv2D(num_classes, 1, strides = 1)
        self.upsam1 = layers.UpSampling2D(size=(8, 8), interpolation='bilinear')
        
    def call(self, inputs):
        
        # inputs
        inpts = inputs
        
        # rate = 3
        a1 = self.dilation1(inpts)
        a1 = self.dilation2(a1)
        
        # rate = 6
        a2 = backend.concatenate([inpts, a1])
        a2 = self.dilation1(a2)
        a2 = self.dilation3(a2)
        
        # rate = 12
        a3 = backend.concatenate([inpts, a1, a2])
        a3 = self.dilation1(a3)
        a3 = self.dilation4(a3)
         
        # rate = 18
        a4 = backend.concatenate([inpts, a1, a2, a3])
        a4 = self.dilation1(a4)
        a4 = self.dilation5(a4)   
        
        # rate = 24
        a5 = backend.concatenate([inpts, a1, a2, a3, a4])
        a5 = self.dilation1(a5)
        a5 = self.dilation6(a5)
        
        # Conv
        x = backend.concatenate([inpts, a1, a2, a3, a4, a5])
        x = self.conv1(x)
        x = self.upsam1(x)
        
        return x