# Deep learning-based automated rock classification via high-resolution drone-captured core sample imagery
***
### Domenico M. Crisafulli, Misael M. Morales, and Carlos Torres-Verdin
#### The University of Texas at Austin, 2024
***

## Build and Train NN-classifier
| Class             | OLD   | New   |
| ---               | ---   | ---   |
| Background        | 0     | 0     |
| Sandstone type 1  | 1     | 2     |
| Shaly Rock        | 2     | 3     |
| Sandstone type 2  | 3     | 4     |
| Carbonate         | 4     | 5     |
| Shale             | 5     | 6     |
| Sandstone type 3  | 6     | 7     |
| Box               | 10    | 1     |

In [None]:
import os
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import tensorflow as tf
import keras
import keras.backend as K
from keras import Model
from keras import layers
from keras import optimizers
from keras.applications.resnet import ResNet50, preprocess_input

def check_tf_gpu():
    sys_info = tf.sysconfig.get_build_info()
    version, cuda, cudnn = tf.__version__, sys_info["cuda_version"], sys_info["cudnn_version"]
    count = len(tf.config.experimental.list_physical_devices())
    name  = [device.name for device in tf.config.experimental.list_physical_devices('GPU')]
    print('-'*60)
    print('----------------------- VERSION INFO -----------------------')
    print('TF version: {} | # Device(s) available: {}'.format(version, count))
    print('TF Built with CUDA? {} | CUDA: {} | cuDNN: {}'.format(tf.test.is_built_with_cuda(), cuda, cudnn))
    print(tf.config.list_physical_devices()[0],'\n', tf.config.list_physical_devices()[1])
    print('-'*60+'\n')
    return None

check_tf_gpu()

In [None]:
X_data = np.load('data/x_images.npy')
y_data = np.load('data/y_images.npy')
print('X: {} | y: {}'.format(X_data.shape, y_data.shape))

In [None]:
def DeeplabV3Plus(image_size, num_classes):
    def convolution_block(block_input, num_filters=256, kernel_size=3, dilation_rate=1, use_bias=False):
        x = layers.Conv2D(num_filters, kernel_size=kernel_size, dilation_rate=dilation_rate, padding="same", 
                          use_bias=use_bias, kernel_initializer=keras.initializers.HeNormal())(block_input)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        return x

    def DilatedSpatialPyramidPooling(dspp_input):
        dims = dspp_input.shape
        x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
        x = convolution_block(x, kernel_size=1, use_bias=True)
        out_pool = layers.UpSampling2D(
            size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear")(x)
        out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
        out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
        out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
        out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
        x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
        output = convolution_block(x, kernel_size=1)
        return output

    model_input = keras.Input(shape=(image_size, image_size, 3))
    
    preprocessed = preprocess_input(model_input)
    resnet50 = ResNet50(weights="imagenet", include_top=False, input_tensor=preprocessed)

    x = resnet50.get_layer("conv4_block6_2_relu").output
    x = DilatedSpatialPyramidPooling(x)

    input_a = layers.UpSampling2D(
        size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
        interpolation="bilinear",
    )(x)
    input_b = resnet50.get_layer("conv2_block3_2_relu").output
    input_b = convolution_block(input_b, num_filters=48, kernel_size=1)

    x = layers.Concatenate(axis=-1)([input_a, input_b])
    x = convolution_block(x)
    x = convolution_block(x)
    x = layers.UpSampling2D(
        size=(image_size // x.shape[1], image_size // x.shape[2]),
        interpolation="bilinear",
    )(x)
    model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    return keras.Model(inputs=model_input, outputs=model_output)

In [None]:
idx = np.random.choice(range(len(X_data)), len(X_data), replace=False)
n_train = int(len(idx) * 0.77)
X_data = np.repeat(X_data, 3, axis=-1)

X_train, y_train = X_data[idx[:n_train]], y_data[idx[:n_train]]
X_test,  y_test  = X_data[idx[n_train:]], y_data[idx[n_train:]]
print('X - train: {} | test: {}'.format(X_train.shape, X_test.shape))
print('y - train: {} | test: {}'.format(y_train.shape, y_test.shape))

In [None]:
model = DeeplabV3Plus(image_size=512, num_classes=1)
print('# params: {:,}'.format(model.count_params()))
model.compile(optimizer=optimizers.AdamW(1e-3, 4e-6), loss="binary_crossentropy", metrics=["accuracy"])

fit = model.fit(X_train, y_train,
                batch_size       = 8,
                epochs           = 10,
                validation_split = 0.2,
                shuffle          = True,
                verbose          = 1)

model.save_weights('rockClassification.weights.h5')
pd.DataFrame(fit.history).to_csv('fit_history.csv', index=False)

In [None]:
losses = pd.read_csv('fit_history.csv')

plt.figure(figsize=(7,5))
plt.plot(losses.index, losses.accuracy, ls='-', marker='o', label='Accuracy')
plt.plot(losses.index, losses.val_accuracy, ls='--', marker='.', label='Val Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(facecolor='lightgrey', edgecolor='k')
plt.grid(True, which='both')
plt.tight_layout()
plt.show()

In [None]:
model = DeeplabV3Plus(image_size=512, num_classes=1)
model.load_weights('rockClassification.weights.h5')
print('# params: {:,}'.format(model.count_params()))

In [None]:
y_train_pred = model.predict(X_train, verbose=0)
y_test_pred  = model.predict(X_test, verbose=0).round()
print('Pred - train: {} | test: {}'.format(y_train_pred.shape, y_test_pred.shape))

In [None]:
fig, axs = plt.subplots(3, 10, figsize=(15,5), sharex=True, sharey=True)
for j in range(10):
    ax1, ax2, ax3 = axs[0,j], axs[1,j], axs[2,j]
    im1 = ax1.imshow(X_train[j])
    im2 = ax2.imshow(y_train[j])
    im3 = ax3.imshow(y_train_pred[j])
    [a.set(xticks=[], yticks=[]) for a in [ax1, ax2, ax3]]
    [plt.colorbar(i, pad=0.04, fraction=0.046) for i in [im1,im2,im3]]
plt.tight_layout()
plt.show()

***
# END