# Skyrmion U-Net training example (mini U-Net)

The first part will involve training a mini U-Net model with a small dataset, making it feasible to train on a CPU within a realistic time.
## Import packages + define utility functions

In [None]:
import tensorflow as tf
import albumentations
from PIL import Image
import numpy as np
import time
import matplotlib.pyplot as plt
import pandas as pd
import glob
import os
def plotfig(l,ltitle=None,nrows=None,ncols=None,dpi=100,s0=1,suptitle=None):
    if nrows==None and ncols==None:
        nc = int(np.ceil(np.sqrt(len(l))))
        nr = int(np.ceil(len(l)/nc))
    if ncols!=None and ncols!=None:
        nc = ncols
        nr = nrows
    if nrows!=None and ncols==None:
        nr = nrows
        nc = int(np.ceil(len(l)/nr))
    if ncols!=None and nrows==None:
        nc = ncols
        nr = int(np.ceil(len(l)/nc))
    fig,ax = plt.subplots(nrows=nr,ncols=nc,dpi=dpi,figsize=(nc*s0,nr*s0))
    if suptitle != None:
        fig.suptitle(suptitle,fontsize=40,y=0.99)
    ax = ax.ravel()
    for i in range(len(ax)):
        ax[i].axis("off")
    
    for i in range(min(len(l),len(ax))):
        ax[i].imshow(l[i],cmap="gray")
        if ltitle!= None:
            ax[i].set_title(ltitle[i])
    fig.tight_layout()

In [None]:
#function that returns true if a GPU is available
gpu_available = lambda : len(tf.config.list_physical_devices('GPU'))

#in the case of GPU, switch to mixed_float16 policy
if gpu_available():
    print("GPU available")
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

## "mini" U-Net architecture

![](notebook_figures/u_net_architecture_2.png)

In [None]:
# Basic activation layer
class MishLayer(tf.keras.layers.Layer):
    def call(self, x):
        return tf.keras.activations.mish(x)
        
# Basic Convolution Block
def conv_block(x, n_channels, param):
    x = tf.keras.layers.Conv2D(n_channels, kernel_size=param["kernel_size"],kernel_initializer=param["kernel_initialization"],padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x) 
    x = MishLayer()(x)
    return x

# Double Convolution Block used in "encoder" and "bottleneck"
def double_conv_block(x, n_channels, param):
    x = conv_block(x,n_channels,param)
    x = conv_block(x,n_channels,param)
    return x

# Downsample block for feature extraction (encoder)
def downsample_block(x, n_channels, param):
    f = double_conv_block(x, n_channels, param)
    p = tf.keras.layers.MaxPool2D(pool_size=(2,2))(f)
    p = tf.keras.layers.Dropout(param["dropout"])(p)
    return f, p

# Upsample block for the decoder
def upsample_block(x, conv_features, n_channels, param):
    x = tf.keras.layers.Conv2DTranspose(n_channels*param["upsample_channel_multiplier"], param["kernel_size"], strides=(2,2), padding='same')(x)
    x = tf.keras.layers.concatenate([x, conv_features])
    x = tf.keras.layers.Dropout(param["dropout"])(x)
    x = double_conv_block(x, n_channels, param)
    return x

# Create the model
def get_unet(param):
    input = tf.keras.layers.Input(shape=param["input_shape"]+(1,))
    next_input = input
    
    l_residual_con = []
    for i in range(param["n_depth"]):
        residual_con,next_input = downsample_block(next_input, (2**i)*param["filter_multiplier"],param)
        l_residual_con.append(residual_con)

    next_input = double_conv_block(next_input, (2**param["n_depth"])*param["filter_multiplier"],param)

    for i in range(param["n_depth"]):
        next_input = upsample_block(next_input, l_residual_con[param["n_depth"]-1-i], (2**(param["n_depth"]-1-i))*param["filter_multiplier"],param)

    output = tf.keras.layers.Conv2D(param["n_class"], (1,1), padding="same", activation = "softmax",dtype='float32')(next_input)    
    
    return tf.keras.Model(input, output, name=param["name"])

model = get_unet({"name":"unet","input_shape": (512,672), "n_class":3,"filter_multiplier":5,"n_depth":1,
                  "kernel_initialization":"he_normal","dropout":0.01,"kernel_size":(8,8),"upsample_channel_multiplier":2})

In [None]:
model.summary()

## Dataset processing for training

In [None]:
dataset_table = pd.read_csv("dataset/table.csv",sep=";")
dataset_table

### Data generator

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Training'
    def __init__(self, images, labels, batch_size, n_class=3, smoothing=False,shuffle=True, aug=None):
        super().__init__()
        'Initialization'
        self.n_class = n_class
        if labels.shape[:3] != images.shape[:3]:
            raise Exception("Shape not fit")

        self.len_data = labels.shape[0]
        self.shape_data = labels.shape[1:3]
        self.labels = labels
        self.images = images

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.smoothing = smoothing
        
        self.aug = aug
        self.on_epoch_end()

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(self.len_data)
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
    
    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(self.len_data/ self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        X,y = (self.images[indexes]).copy(), (self.labels[indexes]).copy()    
        if self.aug is not None:      
            self.data_augmentation(X, y)
        y = y.astype(dtype=np.float16)
        if self.smoothing:
            y[y==0] = self.smooth_labels(0.0,factor=0.2)
            y[y==1] = self.smooth_labels(1.0,factor=0.2)
        return X/255,y
        
    def smooth_labels(self, labels, factor=0.1):
        return labels*(1 - factor)+(factor / self.n_class)
    
    def data_augmentation(self, X, y):
        for i in range(self.batch_size):
            augmented = self.aug(image=X[i],mask=y[i])
            X[i] = augmented["image"]
            y[i] = augmented["mask"]

### Augmentation

In [None]:
def get_augmentation():
    return albumentations.Compose([
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.GaussNoise(p=1,var_limit=20**2),
        albumentations.ShiftScaleRotate(shift_limit=0.5, scale_limit=(-0.15,0.4), rotate_limit=90, p=1),
        albumentations.RandomBrightnessContrast(brightness_limit=0.25,contrast_limit=0.25,p=1)], p=1)

aug = get_augmentation()

In [None]:
lexample_img = []
lexample_aug = []
for ix,subtable in dataset_table.groupby("source_id"):
    if ix not in [7, 8, 10, 11, 13, 23, 135]: continue
    img = (np.array(Image.open(subtable.iloc[0].img_fn)))  
    lexample_img.append(img)
    lexample_aug.append(aug(image=img)["image"])
plotfig(lexample_img+lexample_aug,nrows=2,dpi=500)

### Splitting of dataset

In [None]:
train_ix = np.array([ix for ele in sorted(list(set([10,11,20,23,7,135]))) for ix in dataset_table[dataset_table.source_id==ele].index.to_numpy()])
test_ix = np.array([ix for ele in sorted(list(set([6,9]))) for ix in dataset_table[dataset_table.source_id==ele].index.to_numpy()])
val_ix = np.array([ix for ele in sorted(list(set([8,13]))) for ix in dataset_table[dataset_table.source_id==ele].index.to_numpy()])

def group_after_source(lix,ncolmax=6,suptitle=None):
    table = dataset_table.iloc[lix]
    plotfig([np.array(Image.open(b.iloc[0].img_fn)) for a,b in table.groupby("source_id")],[str(a) for a,b in table.groupby("source_id")],suptitle=suptitle,ncols=6,dpi=100,s0=5)
group_after_source(train_ix,suptitle="Training set")
group_after_source(test_ix,suptitle="Test set")
group_after_source(val_ix,suptitle="Validation set")

### Loading images and labels

#### Comment

Labels for U-Net model:
- Skyrmions (RGB-label: red (255,0,0))
- Defects (RGB-label: green (0,255,0)
- FM background (RGB-label: blue (0,0,255))
- Non-FM background (RGB-label: yellow (255,255,0))
- Boundary non-FM/FM background (RGB-label: cyan (0,255,255))

The 3-class U-Net model predicts skyrmions (RGB-label: red (255,0,0)), defects (RGB-label: blue (0,255,0)), and background (RGB-label: blue (0,0,255)). The background consists of the ferromagnetic background, non-ferromagnetic background, and the boundary between ferromagnetic and non-ferromagnetic background. The RGB-label is converted to an class index for training. 

In [None]:
def trafo_rgb_to_channel(I):
    Q = np.zeros((I.shape[0],I.shape[1]),dtype=np.uint8)
    R,G,B = I[:,:,0],I[:,:,1],I[:,:,2]
    skyrmion_mask = (R>=128)&(G<128)&(B<128)
    defect_mask = (R<128)&(G>=128)&(B<128)
    bck_mask = ~(skyrmion_mask|defect_mask)
    Q[skyrmion_mask] = 0
    Q[defect_mask] = 1
    Q[bck_mask] = 2
    return Q
    
def load_img_label_data(lix):
    return np.array([Image.open(ele) for ele in dataset_table.iloc[lix].img_fn]), \
           np.array([trafo_rgb_to_channel(np.array(Image.open(ele))) for ele in dataset_table.iloc[lix].label_fn])

train_img,train_label = load_img_label_data(train_ix)
val_img,val_label = load_img_label_data(val_ix)
test_img,test_label = load_img_label_data(test_ix)
img_size = test_img.shape[1:]

## Training
### Loss & Metric

U-Net can be described as follows:

$$ \mathbf{z} = \mathbf{f}_{\vec{\theta}}(\mathbf{x}) $$

where $\vec{\theta}$ is a high-dimensional vector containing all parameters, $\mathbf{z}$ is a 3-dimensional output tensor, and $\mathbf{x}$ is the input Kerr microscopy image. In $\mathbf{f}$ are encapsulated all the operations such as convolution, max pooling, up-convolution, batch normalization, ... .

$\mathbf{z}$ is converted to probabilities using softmax:

$$ p_{(x,y),i} = \frac{\exp(\mathbf{z}_{(x,y),i})}{\sum_{i\in \mathrm{classes}} \exp(\mathbf{z}_{(x,y),i})} $$

The output mask $\mathbf{m}$ is then determined by:

$$ \mathbf{m}_{(x,y)} = \mathrm{arg\,max}_i\; p_{(x,y),i}$$

The cross-entropy loss (with averaging over pixels) is given by:

$$ L = - \frac{1}{\sum_{x,y} 1} \sum_{x,y} \sum_{i=1}^3 \;w_i\;(p_{\mathrm{ground-truth}})_{(x,y),i} \log(p_{(x,y),i}) $$

where $\sum_{(x,y)}$ in our case sums over several examples.

The training of a neural network aims to obtain the global minimum, i.e., it satisfies $\nabla_{\vec{\theta}} L = 0$. To achieve this, the network needs to be trained. The simplest approach is:

$$ \vec{\theta}_{i+1} = \vec{\theta}_{i} - \epsilon \left. \nabla_{\vec{\theta}} L \right|_{\vec{\theta}_i} $$

However, in this session, we utilized more advanced and superior training algorithms.

The gradient $\nabla_{\vec{\theta}} L$ can be calculated using backpropagation, which relies on the chain rule for derivatives.

In [None]:
def get_cross_entropy_loss(weight):
    weight = tf.convert_to_tensor(weight,dtype=np.float32)
    def loss(ytrue,ypred):
        p = tf.clip_by_value(ypred/(tf.math.reduce_sum(ypred,axis=-1,keepdims=True)),tf.keras.backend.epsilon(),1-tf.keras.backend.epsilon())
        return -tf.math.reduce_mean(tf.math.reduce_sum(weight*ytrue*tf.math.log(p),axis=-1))
    return loss

In [None]:
def get_TF_PN(y_true,y_pred,ix0):
    m1,m2 = y_true==ix0,y_pred==ix0
    im1,im2 = tf.math.logical_not(m1),tf.math.logical_not(m2)
    TP = tf.math.reduce_mean(tf.cast(tf.math.logical_and(m1,m2),dtype=np.float64))
    TN = tf.math.reduce_mean(tf.cast(tf.math.logical_and(im1,im2),dtype=np.float64))
    FP = tf.math.reduce_mean(tf.cast(tf.math.logical_and(im1,m2),dtype=np.float64))
    FN = tf.math.reduce_mean(tf.cast(tf.math.logical_and(m1,im2),dtype=np.float64))
    return TP,TN,FP,FN

def get_mcc_from_TF_PN(TP,TN,FP,FN):
    denom = tf.keras.ops.sqrt((TP + FN) * (FP + TN) * (FP + TP) * (FN + TN))
    val = (TP * TN - FP * FN) / denom
    return  tf.where(tf.equal(denom, 0), tf.constant(0, dtype=tf.float64), val)

get_mcc = lambda x,y,ix0: get_mcc_from_TF_PN(*get_TF_PN(x,y,ix0)).numpy()

#Matthews correlation coefficient
def get_mcc_cpu(a,b,ix0):
    #a=true label,b=predicted label
    m1,m2 = (a==ix0).flatten(),(b==ix0).flatten()
    im1,im2 = np.invert(m1),np.invert(m2)
    TP = np.sum(np.logical_and(m1,m2))/(len(m1))
    TN = np.sum(np.logical_and(im1,im2))/(len(m1))
    FP = np.sum(np.logical_and(im1,m2))/(len(m1))
    FN = np.sum(np.logical_and(m1,im2))/(len(m1))
    if (TP+FN)*(FP+TN)*(FP+TP)*(FN+TN)==0:
        return np.nan
    return (TP*TN-FP*FN)/(np.sqrt((TP+FN)*(FP+TN)*(FP+TP)*(FN+TN)))

class MCC(tf.keras.metrics.Metric):
    def __init__(self, ix0, name="MCC", **kwargs):
        super().__init__(name=name,dtype=tf.float64,**kwargs)
        self.ix0 = ix0
        self.tp = self.add_variable(shape=(),name="TP", initializer="zeros",dtype=tf.float64)
        self.tn = self.add_variable(shape=(),name="TN", initializer="zeros",dtype=tf.float64)
        self.fp = self.add_variable(shape=(),name="FP", initializer="zeros",dtype=tf.float64)
        self.fn = self.add_variable(shape=(),name="FN", initializer="zeros",dtype=tf.float64)
        self.tot = self.add_variable(shape=(),name="TOT", initializer="zeros",dtype=tf.int64)
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        TP,TN,FP,FN = get_TF_PN(tf.argmax(y_true,axis=-1),tf.argmax(y_pred,axis=-1),self.ix0)
        totn1 = tf.cast(self.tot+1,tf.float64)
        self.tp.assign_add((TP-self.tp)/totn1)
        self.tn.assign_add((TN-self.tn)/totn1)
        self.fp.assign_add((FP-self.fp)/totn1)
        self.fn.assign_add((FN-self.fn)/totn1)
        self.tot.assign_add(1)
        
    def reset_state(self):
        self.tp.assign(0)
        self.tn.assign(0)
        self.fp.assign(0)
        self.fn.assign(0)
        self.tot.assign(0)    
    
    def result(self):
        return get_mcc_from_TF_PN(self.tp,self.tn,self.fp,self.fn)

### Initialization and performing the training

In [None]:
modeln = "model_1"
folder = "./training/"

if not os.path.exists(folder):
    os.makedirs(folder)

modelfn = folder+modeln+".keras"

model = get_unet({"name":"unet","input_shape": (512,672), "n_class":3,"filter_multiplier":5,"n_depth":1,
                  "kernel_initialization":"he_normal","dropout":0.01,"kernel_size":(8,8),"upsample_channel_multiplier":2})

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.005),loss=get_cross_entropy_loss([6,1,1]),metrics=["accuracy",MCC(0)])

batch_size = 5
basis = np.eye(3,dtype=np.uint8)
dg_train = DataGenerator(train_img,basis[train_label],batch_size,smoothing=False,aug=get_augmentation())
dg_val = DataGenerator(val_img,basis[val_label],batch_size,smoothing=False,aug=get_augmentation())

history = model.fit(x=dg_train,validation_data=dg_val,epochs=15,verbose=1,batch_size=batch_size,callbacks=[
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',mode='min',factor=0.5, patience=2, min_lr=1e-6, verbose=1),
tf.keras.callbacks.ModelCheckpoint(modelfn, verbose=1,monitor='val_loss',mode='min',save_best_only=True,save_weights_only=False)])

hist_df = pd.DataFrame(history.history)  
hist_df.to_csv(folder+modeln+"_history.csv")

In [None]:
fig,ax = plt.subplots(ncols=3,figsize=(10,3.5),dpi=300)
fig.suptitle("Training history - Validation Set",fontsize=14)
ax[0].plot(hist_df["val_loss"])
ax[1].plot(hist_df["val_accuracy"])
ax[2].plot(hist_df["val_MCC"])
for i in range(3):
    ax[i].set_xlabel("Epoch")
ax[0].set_ylabel("Cross-entropy loss")
ax[0].set_title("Cross-entropy loss")
ax[1].set_ylabel("Accuracy")
ax[1].set_title("Accuracy")
ax[2].set_ylabel("MCC")
ax[2].set_title("MCC")
fig.tight_layout()

fig,ax = plt.subplots(ncols=3,figsize=(10,3.5),dpi=300)
fig.suptitle("Training history - Trainings Set",fontsize=14)
ax[0].plot(hist_df["loss"])
ax[1].plot(hist_df["accuracy"])
ax[2].plot(hist_df["MCC"])
for i in range(3):
    ax[i].set_xlabel("Epoch")
ax[0].set_ylabel("Cross-entropy loss")
ax[0].set_title("Cross-entropy loss")
ax[1].set_ylabel("Accuracy")
ax[1].set_title("Accuracy")
ax[2].set_ylabel("MCC")
ax[2].set_title("MCC")
fig.tight_layout()

hist_df

### Testing

In [None]:
#Predict the label based on Kerr images and the U-Net model.
def predict(x,modelfn,batch_size=5,normalize_255=True):
    model = tf.keras.models.load_model(modelfn,compile=False,custom_objects={'MishLayer': MishLayer})
    n = int(np.ceil(len(x)/batch_size))
    lix = [np.array(range(j*batch_size,min((j+1)*batch_size,len(x)))) for j in range(n)]
    ylabel = np.zeros(x.shape,dtype=np.uint8)
    progbar = tf.keras.utils.Progbar(n)
    for i in range(n):            
        progbar.update(i)
        input = x[lix[i]]
        if normalize_255:
            input = input/255
        ylabel[lix[i]] = model.predict(input,verbose=False).argmax(-1)
    progbar.update(n,finalize=True)
    return ylabel

def trafo_channel_to_rgb(I):
    basis = np.array([[255,0,0],[0,255,0],[0,0,255],[255,255,0],[0,255,255]],dtype=np.uint8)
    return basis[I]

In [None]:
print("Pixelwise Matthews correlation coefficient on test set (true=skyrmion,false=defect,background)",get_mcc(val_label,predict(val_img,modelfn,5),0))
total_ix = np.hstack((test_ix,val_ix,train_ix))
total_img = np.vstack((test_img,val_img,train_img))
total_label = np.vstack((test_label,val_label,train_label))
total_pred = predict(total_img,modelfn,5)
print("Pixelwise Matthews correlation coefficient on complete set (true=skyrmion,false=defect,background)",get_mcc(total_label,total_pred,0))

### Prediction examples
Calculate a prediction for a frame from each video of which frames occur in the dataset.


In [None]:
S = set(np.hstack((test_ix,val_ix,train_ix)))&set([b.iloc[0].name for a,b in dataset_table.groupby("source_id")])
for ix0 in range(len(total_img)):
    if total_ix[ix0] not in S: continue
        
    fig,ax = plt.subplots(ncols=3,figsize=(10,3),dpi=200)
    for i in range(3):
        ax[i].axis("off")
    ax[0].imshow(total_img[ix0],cmap="gray")
    ax[1].imshow(trafo_channel_to_rgb(total_label[ix0]))
    ax[2].imshow(trafo_channel_to_rgb(total_pred[ix0]))
    ax[0].set_title("Kerr image")
    ax[1].set_title("Ground truth")
    ax[2].set_title("Predicted label")
    if total_ix[ix0] in train_ix:
        fig.suptitle("Image used in training",y=1.02,fontsize=16)
    elif total_ix[ix0] in test_ix:
        fig.suptitle("Image used in testing",y=1.02,fontsize=16)
    elif total_ix[ix0] in val_ix:
        fig.suptitle("Image used in validation",y=1.02,fontsize=16)
    fig.tight_layout()

# Skyrmion U-Net training example (large U-Net)

This part shows how to train a large U-Net model, with a large dataset.

## Download of Zenodo skyrmion U-Net repository
Comment this out only if you are going to train the large U-Net. For this purpose, approximately 1GB of data will be downloaded.

In [None]:
"""
import wget
import zipfile
import glob

zenodo_folder = "./zenodo_dataset/"
if not os.path.exists(zenodo_folder):
    os.makedirs(zenodo_folder)

wget.download("https://zenodo.org/records/10997175/files/public_unet_skyrmion_dataset.zip",zenodo_folder+"zenodo_unet_skyrmion_dataset.zip")
with zipfile.ZipFile(zenodo_folder+"zenodo_unet_skyrmion_dataset.zip","r") as zip:
    zip.extractall(zenodo_folder)
"""

## Load dataset = images & labels

In [None]:
zenodo_dataset_folder = "./zenodo_dataset/public_unet_skyrmion_dataset/"
zenodo_dataset = pd.read_csv(zenodo_dataset_folder+"table.csv",sep=";")

In [None]:
lfnimg = sorted(list(glob.iglob(zenodo_dataset_folder+"images/*.png")))
lfnlabel = sorted(list(glob.iglob(zenodo_dataset_folder+"labels/*.png")))
limg = np.array([np.array(Image.open(ele)) for ele in lfnimg])
llabel = np.array([trafo_rgb_to_channel(np.array(Image.open(ele))) for ele in lfnlabel])
img_size = limg.shape[1:]

## Splitting of dataset

In [None]:
with open(zenodo_dataset_folder+"partition.txt","r") as f:
    S = f.read()
traintest_batch = []
trainonly_batch = []
for ele in S.split("\n")[:-1]:
    lt = ele.split(";")
    if lt[0]=="train_test_val":
        traintest_batch.append([int(ele) for ele in lt[1:]])
    elif lt[0]=="only_training":
        trainonly_batch.append([int(ele) for ele in lt[1:]])
    else:
        print("error")


#Association of batches to training, validation and testing
train_batch,val_batch,test_batch = [0,1,4,5],[2],[3]
#Assignment of indices for each training object (image and label) to validation, test, and training
train_sources = [ele1 for ele in trainonly_batch for ele1 in ele]+[ele for ix in train_batch for ele in traintest_batch[ix]]
train_ix = np.array(sorted([ele1 for ele in train_sources for ele1 in zenodo_dataset[zenodo_dataset["source_id"]==ele].index]))
val_sources = [ele for ix in val_batch for ele in traintest_batch[ix]]
val_ix = np.array(sorted([ele1 for ele in val_sources for ele1 in zenodo_dataset[zenodo_dataset["source_id"]==ele].index]))
test_sources = [ele for ix in test_batch for ele in traintest_batch[ix]]
test_ix = np.array(sorted([ele1 for ele in test_sources for ele1 in zenodo_dataset[zenodo_dataset["source_id"]==ele].index]))

## Initialisation of Augmentation & Data Generator

In [None]:
#Batch size refers to how many images will be passed through the U-Net at the same time during training
#One must adjust the batch_size based on available RAM/VRAM.
batch_size = 5
    
#stronger augmentation, compared to the mini U-Net
def get_augmentation():
    return albumentations.Compose([
        albumentations.HorizontalFlip(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.GaussNoise(p=1,var_limit=20**2),
        albumentations.ShiftScaleRotate(shift_limit=0.5, scale_limit=(-0.6,0.6), rotate_limit=90, p=1),
        albumentations.RandomBrightnessContrast(brightness_limit=0.3,contrast_limit=0.5,p=1)], p=1)

basis = np.eye(3,dtype=np.uint8)
#creation of data generators for train and test
dg_train = DataGenerator(limg[train_ix],basis[llabel[train_ix]],batch_size,smoothing=False,aug=get_augmentation())
dg_val = DataGenerator(limg[val_ix],basis[llabel[val_ix]],batch_size,smoothing=False,aug=get_augmentation())

## U-Net architecture & generation

![](notebook_figures/u_net_architecture_1.png)

In [None]:
modeln = "model_2"
folder = "./training/"

if not os.path.exists(folder):
    os.makedirs(folder)

modelfn = folder+modeln+".keras"

model = get_unet({"name":"unet","input_shape": img_size, "n_class":3,"filter_multiplier":16,"n_depth":4,
                  "kernel_initialization":"he_normal","dropout":0.1,"kernel_size":(3,3),"upsample_channel_multiplier":8})

In [None]:
model.summary()

## Training

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss=get_cross_entropy_loss([6.2,10,1.2]),metrics=["accuracy",MCC(0)])
history = model.fit(x=dg_train,validation_data=dg_val,epochs=15,verbose=1,batch_size=batch_size,callbacks=[
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',mode='min',factor=0.5, patience=3, min_lr=1e-6, verbose=1),
tf.keras.callbacks.ModelCheckpoint(modelfn, verbose=1,monitor='val_loss',mode='min',save_best_only=True,save_weights_only=False)])

hist_df = pd.DataFrame(history.history)  
hist_df.to_csv(folder+modeln+"_history.csv")

## Testing

In [None]:
t_img = limg[test_ix]
t_pred = predict(t_img,modelfn,5)
t_true = llabel[test_ix]
#alternative: get_mcc on gpu, if gpu VRAM is large enough
print("Pixelwise Matthews correlation coefficient on test set (true=skyrmion,false=defect,background)",get_mcc_cpu(t_true,t_pred,0))
lpred = predict(limg,modelfn,5)
print("Pixelwise Matthews correlation coefficient on complete set (true=skyrmion,false=defect,background)",get_mcc_cpu(llabel,lpred,0))

In [None]:
#Example plot of a prediction from the test set
fig,ax = plt.subplots(ncols=3,dpi=300)
ix0 = 2890
ax[0].imshow(limg[ix0],cmap="gray")
ax[1].imshow(trafo_channel_to_rgb(llabel[ix0]))
ax[2].imshow(trafo_channel_to_rgb(lpred[ix0]))
ax[0].set_title("Kerr image")
ax[1].set_title("Ground truth")
ax[2].set_title("Predicted label")
fig.tight_layout()