In [118]:
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import datetime

In [57]:
a = tf.keras.backend.random_bernoulli(shape = (1,), p = 1.)

In [146]:
train = pd.read_csv('train.csv')
X, y = np.reshape(np.array(train.iloc[:,1:]), (-1, 28,28 ,1)), train.iloc[:,0]

In [147]:
X_train, X_val, y_train, y_val = train_test_split(X, y, stratify=y, random_state=42, test_size = .2)

In [148]:
X_train = tf.convert_to_tensor(X_train, dtype = 'float32')
X_val = tf.convert_to_tensor(X_val, dtype = 'float32')
y_train = tf.convert_to_tensor(y_train, dtype = 'float32')
y_val = tf.convert_to_tensor(y_val, dtype = 'float32')

In [152]:
class Spatial_Gating_Unit(tf.keras.layers.Layer):
    def __init__(self, n_patches:int, initial_stddev:float):
        super(Spatial_Gating_Unit, self).__init__()
        self.n_patches = n_patches
        self.initial_stddev = initial_stddev
        
        self.ln = tf.keras.layers.LayerNormalization()
        self.t = tf.keras.layers.Permute((2,1))
        self.Wb = tf.keras.layers.Dense(int(self.n_patches), kernel_initializer = tf.keras.initializers.RandomNormal(stddev=self.initial_stddev), bias_initializer = 'ones')
        self.M = tf.keras.layers.Multiply()
        
    def call(self, X):
        z1, z2 = tf.split(X, 2, axis=-1)
        z2 = self.ln(z2)
        z2 = self.t(z2)
        z2 = self.Wb(z2)
        z2 = self.t(z2)
        X = self.M([z1,z2])
        return X
        
        
class gMLPs_Block(tf.keras.layers.Layer):
    def __init__(self, d_model:int, patch_size:int, initial_stddev:float, d_ffn:int, survival_prob:float):
        super(gMLPs_Block, self).__init__()
        self.d_model = d_model
        self.patch_size = patch_size
        self.initial_stddev = initial_stddev
        self.d_ffn = d_ffn
        self.survival_prob = survival_prob
        
        self.ln = tf.keras.layers.LayerNormalization()
        self.U = tf.keras.layers.Dense(self.d_ffn, activation = 'gelu', kernel_initializer = tf.keras.initializers.lecun_normal())
        self.SGU = Spatial_Gating_Unit(self.patch_size, self.initial_stddev)
        self.V = tf.keras.layers.Dense(self.d_model, kernel_initializer = tf.keras.initializers.glorot_normal())
        
    def call(self, X):
        y = self.ln(X)
        y = self.U(y)
        y = self.SGU(y)
        y = self.V(y)
        y = y * tf.keras.backend.random_bernoulli(shape = (1,), p = self.survival_prob)
        y = X + y
        # U : [batchs, patch_size^2, d_ffn] -> SGU : [batch, patch_size^2, d_ffn/2] -> V : [batch, patch_size^2, d_models] 
        return y
    
class gMLPs(tf.keras.models.Model):
    def __init__(self, d_model:int, d_ffn:int, image_size:int, patch_size:int, n_res_layers:int, n_labels:int,
                 survival_prob = 1., initial_stddev:float = 0.00001, mode:str = 'softmax'):
        super(gMLPs, self).__init__()
        self.d_model = d_model
        self.d_ffn = d_ffn
        self.image_size = image_size
        self.patch_size = patch_size
        if (self.image_size % self.patch_size) != 0:
            raise ValueError('size error')
        self.n_patches = int((tf.square(self.image_size) / tf.square(self.patch_size)).numpy())
        self.n_res_layers = n_res_layers
        self.n_labels = n_labels
        self.initial_stddev = initial_stddev
        if mode not in ['sigmoid','softmax']:
            raise ValueError('mode must be sigmoid or softmax')
        else:
            self.mode = mode
        self.survival_prob = survival_prob
        
        self.patchConv = tf.keras.layers.Conv2D(self.d_model, (self.patch_size, self.patch_size), strides = (self.patch_size, self.patch_size))
        self.reshapeL = tf.keras.layers.Reshape((self.n_patches, self.d_model,))
        self.gMLPBlocks = [gMLPs_Block(self.d_model, self.n_patches, self.initial_stddev, self.d_ffn, self.survival_prob) for x in range(self.n_res_layers)]
        self.gap = tf.keras.layers.GlobalAveragePooling1D()
        self.classifier = tf.keras.layers.Dense(self.n_labels if self.n_labels > 2 else 1, activation = mode, kernel_initializer = tf.keras.initializers.glorot_normal(seed = 42))
        
    def call(self, X):
        X = self.patchConv(X)
        X = self.reshapeL(X)
        for gMLPB in self.gMLPBlocks:
            X = gMLPB(X)
        X = self.gap(X)
        X = self.classifier(X)
        return X

In [153]:
gmlp = gMLPs(256, 1536, 28, 4, 30, 10,
            survival_prob = .95)

In [155]:
gmlp.compile(tf.keras.optimizers.Adam(0.0005),
             loss = tf.keras.losses.SparseCategoricalCrossentropy(),
             metrics = 'accuracy'
            )
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
es = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', restore_best_weights=True, patience=3)
gmlp.fit(X_train, y_train, validation_data=(X_val, y_val), epochs = 10000,
        callbacks = [es, tensorboard_callback], batch_size = 64)

AlreadyExistsError: Another profiler is running.