<a href="https://colab.research.google.com/github/ayyucedemirbas/sobel_conv2d/blob/main/sobel_convolution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
import numpy as np

In [2]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = tf.keras.utils.to_categorical(y_train), tf.keras.utils.to_categorical(y_test)

In [3]:
sobel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=np.float32).reshape((3, 3, 1, 1))
sobel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32).reshape((3, 3, 1, 1))

In [4]:
class SobelLayer(layers.Layer):
    def __init__(self):
        super(SobelLayer, self).__init__()
        # Sobel filters for all input channels
        self.sobel_x = tf.constant(sobel_x, dtype=tf.float32)
        self.sobel_y = tf.constant(sobel_y, dtype=tf.float32)

    def call(self, inputs):
        # Expand Sobel filters to match the input's channel dimension
        input_channels = inputs.shape[-1]
        sobel_x_filter = tf.tile(self.sobel_x, [1, 1, input_channels, 1])
        sobel_y_filter = tf.tile(self.sobel_y, [1, 1, input_channels, 1])

        # Apply Sobel filters independently to each channel
        sobel_x_output = tf.nn.depthwise_conv2d(inputs, sobel_x_filter, strides=[1, 1, 1, 1], padding='SAME')
        sobel_y_output = tf.nn.depthwise_conv2d(inputs, sobel_y_filter, strides=[1, 1, 1, 1], padding='SAME')

        # Compute the magnitude of the gradients
        sobel_output = tf.sqrt(tf.square(sobel_x_output) + tf.square(sobel_y_output))

        # Concatenate Sobel output with the original inputs
        return tf.concat([sobel_output, inputs], axis=-1)

def create_model():
    inputs = layers.Input(shape=(32, 32, 3))
    x = SobelLayer()(inputs)
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(10, activation='softmax')(x)

    model = models.Model(inputs, outputs)
    return model

In [5]:
model = create_model()
model.trainable = True

In [6]:
i=0
for layer in model.layers:
    if i<2:
      #print("hello")
      layer.trainable = False
      i+=1

In [7]:
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [8]:
model.fit(x_train, y_train, batch_size=64, epochs=100, validation_data=(x_test, y_test))

Epoch 1/100
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 10ms/step - accuracy: 0.2950 - loss: 1.8897 - val_accuracy: 0.4524 - val_loss: 1.4960
Epoch 2/100
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.4862 - loss: 1.4176 - val_accuracy: 0.5352 - val_loss: 1.2861
Epoch 3/100
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.5406 - loss: 1.2703 - val_accuracy: 0.5714 - val_loss: 1.1912
Epoch 4/100
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.5809 - loss: 1.1716 - val_accuracy: 0.6005 - val_loss: 1.1158
Epoch 5/100
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.6146 - loss: 1.0808 - val_accuracy: 0.6287 - val_loss: 1.0520
Epoch 6/100
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 5ms/step - accuracy: 0.6428 - loss: 1.0097 - val_accuracy: 0.6392 - val_loss: 1.0188
Epoch 7/100
[1m782/

<keras.src.callbacks.history.History at 0x7ef9b0511120>

In [9]:
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc:.4f}")

313/313 - 1s - 4ms/step - accuracy: 0.7430 - loss: 1.7973
Test accuracy: 0.7430
