# Machine Learning for SEM Image Segmentation in Materials Science

## *Using a U-Net model to segment microscopy images*

In this tutorial will learn how to use a pre-trained U-Net model to segment a scanning electron microscopy image of graphene on a substrate.

**Outline:**
1. Import image and model
2. Pre-process image
3. Run the model

**Get started:** Click "Shift-Enter" to run the code in each cell.

## <ins>Let's begin</ins>

We will first import the relevant Python libraries.

In [None]:
# import relevant libraries

from keras.models import load_model
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2

import warnings
warnings.filterwarnings('ignore') # Remove warnings from output

## <ins>Import the model and image</ins>

Now, we load the model and import the image as a numpy n-dimensional array and display the image.

In [None]:
modelp = '/data/tools/imagesegment/models/model_E99_0.974.hdf5'
imagep = '../data/test_kmeans2.tif'

In [None]:
model = load_model(modelp)
img_in = Image.open(imagep)

plt.imshow(img_in, cmap='gray')

## <ins>Pre-process the image</ins>

We make changes to the image to be able to run the method effectively. 

In [None]:
orig_size = img_in.size # record the size of the original image
img = img_in.resize((256,256)) # resize the image to match the input required for the model
img = np.array(img, dtype='uint8') # change the encoding of the image
img = img[np.newaxis,...,np.newaxis] # increase the dimensions of the image

## <ins>Run the model</ins>

We run the model and predict an output

In [None]:
pred = model.predict(img) # run the prediction. The output is an array where "0" means "graphene" and "1" means "not graphene".
coverage = 1-np.mean(pred) # calculate the mean coverage
pred = pred.astype('uint8') # change the encoding of the image

print('coverage: ', coverage) # print the coverage

In [None]:
new_pred = Image.fromarray(pred[0,...,0]).resize(orig_size) # create an Image object from the output array and reize to match the original image
#plt.imshow(new_pred, cmap='gray') # display the output image

fig, ax = plt.subplots(nrows=1,ncols=2, figsize=(12,6)) # create a figure with sub-plots
ax[0].imshow(new_pred, cmap='gray'); # display the output image
ax[1].imshow(img_in, cmap='gray') # display the original image

# label the images
ax[0].set_title('Output Image')
ax[1].set_title('Input Image')