# Notebook 2) Skyrmion U-Net prediction / inference

This notebook demonstrates how to perform prediction/inference using existing Skyrmion U-Net models. It can also be used later for research. A GUI for prediction and analysis can be found in `3_Editor.ipynb`.

## 0. Configure notebook & import packages

In [None]:
try:
    import google.colab
    in_colab = True
    ![ ! -d "AI-Magnetism-Session-Regensburg-2025" ] ||  [ ! -d "AI-Magnetism-Session-Regensburg-2025/.git" ] && git clone https://github.com/kfjml/AI-Magnetism-Session-Regensburg-2025
    ! pip install "numba>=0.61.0,<0.62" "tensorflow[and-cuda]>=2.16.2,<3" "albumentations>=2.0.4,<3"  "pandas>=2.2.2,<3" "chardet>=5.2.0,<6" "opencv-python-headless>=4.11.0.86,<5" "wget>=3.2,<4" "pyyaml>=6.0.2, <7" "pillow>=11.1.0, <12"
    ! pip install "ipympl>=0.9.6" "ipywidgets>=7.7.1" "matplotlib>=3.10.0,<4"
    basis_dir = "/content/AI-Magnetism-Session-Regensburg-2025/"
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    basis_dir = "./"    
    in_colab = False
    
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2
import scipy.spatial
import glob
import io
import pandas as pd
import ipywidgets
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable

### If you are running this notebook in **Google Colab**, after executing the first cell (cell above), go to **Runtime → Restart session**, then rerun the first cell. After that, you can execute the cells below. This is necessary because some required packages are installed in Google Colab and need a restart to take effect.

### Check if GPU is available

In [None]:
gpu_available = lambda : len(tf.config.list_physical_devices('GPU'))
if gpu_available(): print("GPU is available")

## 1. U-Net Architecture of the Trained and Available Models

The U-Net models in this repository were trained to predict on 512x512 Kerr microscopy images.

We reuse in this section code from the first notebook. For an explanation of the code, please refer to `1_Training_tutorial.ipynb`.



In [None]:
#define plot function for figures
def plotfig(fn,dpi):
    fig,ax = plt.subplots(dpi=dpi)
    ax.imshow(plt.imread(fn))
    ax.axis("off")
    
plotfig(basis_dir+"notebook_figures/u_net_architecture_1.png",420)

### 1.2 Define U-Net architecture

In [None]:
# Basic activation layer
class MishLayer(tf.keras.layers.Layer):
    def call(self, x):
        return tf.keras.activations.mish(x)

# Basic Convolution Block
def conv_block(x, n_channels, param):
    x = tf.keras.layers.Conv2D(n_channels, kernel_size=param["kernel_size"],kernel_initializer=param["kernel_initialization"],padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x) 
    x = MishLayer()(x)
    return x

# Double Convolution Block used in "encoder" and "bottleneck"
def double_conv_block(x, n_channels, param):
    x = conv_block(x,n_channels,param)
    x = conv_block(x,n_channels,param)
    return x

# Downsample block for feature extraction (encoder)
def downsample_block(x, n_channels, param):
    f = double_conv_block(x, n_channels, param)
    p = tf.keras.layers.MaxPool2D(pool_size=(2,2))(f)
    p = tf.keras.layers.Dropout(param["dropout"])(p)
    return f, p

# Upsample block for the decoder
def upsample_block(x, conv_features, n_channels, param):
    x = tf.keras.layers.Conv2DTranspose(n_channels*param["upsample_channel_multiplier"], param["kernel_size"], strides=(2,2), padding='same')(x)
    x = tf.keras.layers.concatenate([x, conv_features])
    x = tf.keras.layers.Dropout(param["dropout"])(x)
    x = double_conv_block(x, n_channels, param)
    return x

# Create the model
def get_unet(param):
    input = tf.keras.layers.Input(shape=param["input_shape"]+(1,))
    next_input = input
    
    l_residual_con = []
    for i in range(param["n_depth"]):
        residual_con,next_input = downsample_block(next_input, (2**i)*param["filter_multiplier"],param)
        l_residual_con.append(residual_con)

    next_input = double_conv_block(next_input, (2**param["n_depth"])*param["filter_multiplier"],param)

    for i in range(param["n_depth"]):
        next_input = upsample_block(next_input, l_residual_con[param["n_depth"]-1-i], (2**(param["n_depth"]-1-i))*param["filter_multiplier"],param)

    output = tf.keras.layers.Conv2D(param["n_class"], (1,1), padding="same", activation = "softmax",dtype='float32')(next_input)    
    
    return tf.keras.Model(input, output, name=param["name"])

### 1.3 Define Segmentation Mask Index ↔ RGB Conversion Function

The Kerr micrographs are labeled with a segmentation mask. The segmentation mask in the dataset consists of five distinct classes:

- **Skyrmions** — RGB label: red [1, 0, 0]  
- **Defects** — RGB label: green [0, 1, 0]  
- **Ferromagnetic (FM) background** — RGB label: blue [0, 0, 1]  
- **Non-Ferromagnetic (FM) background** — RGB label: yellow [1, 1, 0]  
- **Boundary non-Ferromagnetic/Ferromagnetic background** — RGB label: cyan [0, 1, 1]  


The **3-class U-Net** model predicts:

- **Skyrmions** - RGB label: red [255, 0, 0]  
- **Defects** - RGB label: green [0, 255, 0]
- **Background** - RGB label: blue [0, 0, 255]  
   - The background class includes:
     - The ferromagnetic (FM) background  
     - The non-ferromagnetic (non-FM) background  
     - The boundary between ferromagnetic and non-ferromagnetic backgrounds

For the **2023 model**, the class indices are:  

- **Skyrmions:** 0  
- **Defects:** 2
- **Background:** 1  

For the **2022 model**, the class indices are:  

- **Skyrmions:** 0  
- **Background:** 1  
- **Defects:** 2

For the **2022 inversion model**, the class indices are:  

- **Skyrmions:** 0  
- **Defects:** 1  
- **Background:** 2  

In [None]:
def trafo_channel_to_rgb(I):
    basis = np.array([[255,0,0],[0,255,0],[0,0,255]],dtype=np.uint8)
    return basis[I]

def trafo_rgb_to_channel(I):
    Q = np.zeros((I.shape[0],I.shape[1]),dtype=np.uint8)
    R,G,B = I[:,:,0],I[:,:,1],I[:,:,2]
    skyrmion_mask = (R>=128)&(G<128)&(B<128)
    defect_mask = (R<128)&(G>=128)&(B<128)
    bck_mask = ~(skyrmion_mask|defect_mask)
    Q[skyrmion_mask] = 0
    Q[defect_mask] = 1
    Q[bck_mask] = 2
    return Q

### 1.4 Define Matthews correlation coefficient (MCC) Code

For an explanation, please refer to the notebook `1_Training_and_prediction_tutorial.ipynb`.

In [None]:
def get_TF_PN(y_true,y_pred,ix0):
    m1,m2 = y_true==ix0,y_pred==ix0
    im1,im2 = tf.math.logical_not(m1),tf.math.logical_not(m2)
    TP = tf.math.reduce_mean(tf.cast(tf.math.logical_and(m1,m2),dtype=np.float64))
    TN = tf.math.reduce_mean(tf.cast(tf.math.logical_and(im1,im2),dtype=np.float64))
    FP = tf.math.reduce_mean(tf.cast(tf.math.logical_and(im1,m2),dtype=np.float64))
    FN = tf.math.reduce_mean(tf.cast(tf.math.logical_and(m1,im2),dtype=np.float64))
    return TP,TN,FP,FN

def get_mcc_from_TF_PN(TP,TN,FP,FN):
    denom = tf.keras.ops.sqrt((TP + FN) * (FP + TN) * (FP + TP) * (FN + TN))
    val = (TP * TN - FP * FN) / denom
    return  tf.where(tf.equal(denom, 0), tf.constant(0, dtype=tf.float64), val)

def get_mcc(y_true,y_false,n_class):
    return get_mcc_from_TF_PN(*get_TF_PN(y_true,y_false,n_class)).numpy()

## 2. Function for Prediction with U-Net model

Once again, as in the other notebook, here is a function for prediction with the U-Net model:

In [None]:
#Predict the label based on Kerr images and the U-Net model.
def predict(x,fn_model,batch_size=5,normalize_255=True):
    #load U-Net model
    model = tf.keras.models.load_model(fn_model,compile=False,custom_objects={'MishLayer': MishLayer})
    
    if not gpu_available():
        #create identical model, only with pure float_32 policy
        batch_size = 1
        nmodel = get_unet({"name":"unet","input_shape": (512,512), "n_class":3,"filter_multiplier":16,"n_depth":4,
                  "kernel_initialization":"he_normal","dropout":0.1,"kernel_size":(3,3),"upsample_channel_multiplier":8})
        nmodel.set_weights(model.weights)
        model = nmodel
    
    n = int(np.ceil(len(x)/batch_size))
    lix = [np.array(range(j*batch_size,min((j+1)*batch_size,len(x)))) for j in range(n)]
    ylabel = np.zeros(x.shape,dtype=np.uint8)
    progbar = tf.keras.utils.Progbar(n)
    for i in range(n):            
        progbar.update(i)
        input = x[lix[i]]
        if normalize_255:
            input = input/255
        ylabel[lix[i]] = model.predict(input,verbose=False).argmax(-1)
    progbar.update(n,finalize=True)
    return ylabel

## 3. Make Predictions Using Data Available in This Repository

Load filenames of dataset

In [None]:
dataset = pd.read_csv(basis_dir+"dataset/table.csv",sep=";")
dataset["img_fn"] = dataset.apply(lambda x:basis_dir+x["img_fn"],axis=1)
dataset["label_fn"] = dataset.apply(lambda x:basis_dir+x["label_fn"],axis=1)

fnimg,fnlabel = list(dataset.img_fn.to_numpy()),list(dataset.label_fn.to_numpy())
dataset

In the following, we will explored the pretrained models in this repository using example predictions. The models are trained on an input size of **512x512 pixels**, so we will crop the images accordingly in the following steps. 

For certain (also much larger) input sizes, the U-Net model can simply be redefined, the trained weights imported, and the prediction will still work on larger images.  If a specific pixel size is not supported, the image can be divided into tiles. This is not entirely straightforward, as there are better and worse ways to define the tiles for prediction. After performing predictions on the tiles, they can be stitched back together.  This method is also implemented in the next notebook: `3_Editor.ipynb`.


### 3.1 Model 2023

Now we make a prediction with the model 2023. The variable **ix** can be changed to predict different images from this dataset repository.

In [None]:
#load imag and label
ix = 129
img = np.array(Image.open(fnimg[ix]))
label = np.array(Image.open(fnlabel[ix]))
#cut to 512x512
img,label = img[:512,:512],label[:512,:512]
#Predict label with U-Net
predicted_label = predict(np.array([img]),basis_dir+"models/2023_model.keras",batch_size=1)[0]
#Swap class index of 1 and 2, since for model 2023 the class indeces are (skyrmion:0, background:1, defects:2) and the functions are written for class indeces (skyrmion:0, defects:1, background:2)
predicted_label[predicted_label==1] = 9
predicted_label[predicted_label==2] = 1
predicted_label[predicted_label==9] = 2
#Evaluate predicitionn with Matthews correlation coefficient
print("Pixelwise Matthews correlation coefficient (true=skyrmion,false=defect,background)",get_mcc(trafo_rgb_to_channel(label),predicted_label,0))

In [None]:
plt.close("all")
%matplotlib inline

fig,ax = plt.subplots(ncols=3,dpi=300,figsize=(15,3.5))
ax[0].imshow(img,cmap="gray")
ax[1].imshow(label)
ax[2].imshow(trafo_channel_to_rgb(predicted_label))
cbar = fig.colorbar(plt.cm.ScalarMappable(norm=matplotlib.colors.BoundaryNorm([0,1,2,3],3), 
                        cmap=matplotlib.colors.ListedColormap([(0,0,1),(0,1,0),(1,0,0)])),
                        ax=ax[2], ticks=[0.5, 1.5, 2.5])
cbar.set_ticklabels(['Background', 'Defects', 'Skyrmions'])
cbar = fig.colorbar(plt.cm.ScalarMappable(norm=matplotlib.colors.BoundaryNorm([0,1,2,3,4,5],6), 
                    cmap=matplotlib.colors.ListedColormap([(0, 1, 1), (1, 1, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)])),
                    ax=ax[1], ticks=[0.5, 1.5, 2.5,3.5,4.5])
cbar.set_ticklabels(['FM-non FM Boundary', 'non FM Bakckground', 'FM Bakckground', 'Defects', 'Skyrmions'])

ax[0].set_title("Kerr image")
ax[1].set_title("Ground truth")
ax[2].set_title("Predicted label")
fig.tight_layout()

### 3.2 Model 2022
Now we will try another model:

In [None]:
#load imag and label
ix = 78
img = np.array(Image.open(fnimg[ix]))
label = np.array(Image.open(fnlabel[ix]))
#cut to 512x512
img,label = img[:512,:512],label[:512,:512]
#Predict label with U-Net
predicted_label = predict(np.array([img]),basis_dir+"models/2022_model.keras",batch_size=1)[0]
#Evaluate predicitionn with Matthews correlation coefficient
print("Pixelwise Matthews correlation coefficient (true=skyrmion,false=defect,background)",get_mcc(trafo_rgb_to_channel(label),predicted_label,0))

In [None]:
plt.close("all")
%matplotlib inline

fig,ax = plt.subplots(ncols=3,dpi=300,figsize=(15,3.5))
ax[0].imshow(img,cmap="gray")
ax[1].imshow(label)
ax[2].imshow(trafo_channel_to_rgb(predicted_label))
ax[0].set_title("Kerr image")
ax[1].set_title("Ground truth")
ax[2].set_title("Predicted label")
cbar = fig.colorbar(plt.cm.ScalarMappable(norm=matplotlib.colors.BoundaryNorm([0,1,2,3],3), 
                        cmap=matplotlib.colors.ListedColormap([(0,0,1),(0,1,0),(1,0,0)])),
                        ax=ax[2], ticks=[0.5, 1.5, 2.5])
cbar.set_ticklabels(['Background', 'Defects', 'Skyrmions'])
cbar = fig.colorbar(plt.cm.ScalarMappable(norm=matplotlib.colors.BoundaryNorm([0,1,2,3,4,5],6), 
                    cmap=matplotlib.colors.ListedColormap([(0, 1, 1), (1, 1, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)])),
                    ax=ax[1], ticks=[0.5, 1.5, 2.5,3.5,4.5])
cbar.set_ticklabels(['FM-non FM Boundary', 'non FM Bakckground', 'FM Bakckground', 'Defects', 'Skyrmions'])
fig.tight_layout()

### 3.3 Model Inversion 2022

And now, Model Inversion 2022, which also works with Kerr micrographs featuring both normal and inverted intensity (bright skyrmions on a dark background):

In [None]:
#load imag and label
ix = 1
img = 255-np.array(Image.open(fnimg[ix]))
label = np.array(Image.open(fnlabel[ix]))
#cut to 512x512
img,label = img[:512,:512],label[:512,:512]
#Predict label with U-Net
predicted_label = predict(np.array([img]),basis_dir+"models/2022_model_inv.keras",batch_size=1)[0]
#Evaluate predicitionn with Matthews correlation coefficient
print("Pixelwise Matthews correlation coefficient (true=skyrmion,false=defect,background)",get_mcc(trafo_rgb_to_channel(label),predicted_label,0))

In [None]:
plt.close("all")
%matplotlib inline

fig,ax = plt.subplots(ncols=3,dpi=300,figsize=(15,3.5))
ax[0].imshow(img,cmap="gray")
ax[1].imshow(label)
ax[2].imshow(trafo_channel_to_rgb(predicted_label))
ax[0].set_title("Kerr image")
ax[1].set_title("Ground truth")
ax[2].set_title("Predicted label")

cbar = fig.colorbar(plt.cm.ScalarMappable(norm=matplotlib.colors.BoundaryNorm([0,1,2,3],3), 
                        cmap=matplotlib.colors.ListedColormap([(0,0,1),(0,1,0),(1,0,0)])),
                        ax=ax[2], ticks=[0.5, 1.5, 2.5])
cbar.set_ticklabels(['Background', 'Defects', 'Skyrmions'])
cbar = fig.colorbar(plt.cm.ScalarMappable(norm=matplotlib.colors.BoundaryNorm([0,1,2,3,4,5],6), 
                    cmap=matplotlib.colors.ListedColormap([(0, 1, 1), (1, 1, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)])),
                    ax=ax[1], ticks=[0.5, 1.5, 2.5,3.5,4.5])
cbar.set_ticklabels(['FM-non FM Boundary', 'non FM Bakckground', 'FM Bakckground', 'Defects', 'Skyrmions'])

fig.tight_layout()

## 4. Additional information

Further information on the Skyrmion U-Net can be found in the paper: Labrie-Boulay et al., *Phys. Rev. Applied* **21**, 014014 (2023). The complete training data and models (the models are also included here in this repository) can be found in the Zenodo repository by Winkler et al. [https://zenodo.org/records/10997175](https://zenodo.org/records/10997175) (2024).