In [1]:
basis_dir = "./"
try:
    import google.colab
    in_colab = True
    !git clone https://github.com/kfjml/Skyrmion-U-Net-Hands-On-Session
    !pip install pip==23.3.1 tensorflow[and-cuda]==2.16.1 albumentations==1.4.3 matplotlib==3.8.4 pandas==2.2.1 chardet==5.2.0 ipympl==0.9.3 ipywidgets==7.7.1 opencv-python-headless==4.9.0.80 wget==3.2
    basis_dir = "/content/Skyrmion-U-Net-Hands-On-Session/"
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    in_colab = False
    
import tensorflow as tf
from PIL import Image
import numpy as np
import cv2
import scipy.spatial
import glob
import io
import pandas as pd
import ipywidgets
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
plt.style.use('seaborn-v0_8-dark')

is_gpu_available = lambda : len(tf.config.list_physical_devices('GPU'))
#if is_gpu_available(): print("GPU available:",tf.config.list_physical_devices('GPU'))

# Basic activation layer
class MishLayer(tf.keras.layers.Layer):
    def call(self, x):
        return tf.keras.activations.mish(x)
        
class SkyUNetModel:
    class MishLayer(tf.keras.layers.Layer):
            def call(self, x):
                return tf.keras.activations.mish(x)
                
    def __init__(self,param):
        self.param = param

    def __repr__(self):
        return f"<SkyUNet: {self.param['name']}>"

    # Basic Convolution Block
    def get_conv_block(self,x, n_channels):
        x = tf.keras.layers.Conv2D(n_channels, kernel_size=self.param["kernel_size"],kernel_initializer=self.param["kernel_initialization"],padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x) 
        x = self.MishLayer()(x)
        return x
    
    # Double Convolution Block used in "encoder" and "bottleneck"
    def get_double_conv_block(self, x, n_channels):
        x = self.get_conv_block(x,n_channels)
        x = self.get_conv_block(x,n_channels)
        return x
    
    # Downsample block for feature extraction (encoder)
    def get_downsample_block(self, x, n_channels):
        f = self.get_double_conv_block(x, n_channels)
        p = tf.keras.layers.MaxPool2D(pool_size=(2,2))(f)
        p = tf.keras.layers.Dropout(self.param["dropout"])(p)
        return f, p
    
    # Upsample block for the decoder
    def get_upsample_block(self, x, conv_features, n_channels):
        x = tf.keras.layers.Conv2DTranspose(n_channels*self.param["upsample_channel_multiplier"], self.param["kernel_size"], strides=(2,2), padding='same')(x)
        x = tf.keras.layers.concatenate([x, conv_features])
        x = tf.keras.layers.Dropout(self.param["dropout"])(x)
        x = self.get_double_conv_block(x, n_channels)
        return x

    def get_model(self):
        input = tf.keras.layers.Input(shape=self.param["input_shape"]+(1,))
        next_input = input
        
        l_residual_con = []
        for i in range(self.param["n_depth"]):
            residual_con,next_input = self.get_downsample_block(next_input, (2**i)*self.param["filter_multiplier"])
            l_residual_con.append(residual_con)
    
        next_input = self.get_double_conv_block(next_input, (2**self.param["n_depth"])*self.param["filter_multiplier"])
    
        for i in range(self.param["n_depth"]):
            next_input = self.get_upsample_block(next_input, l_residual_con[self.param["n_depth"]-1-i], (2**(self.param["n_depth"]-1-i))*self.param["filter_multiplier"])
    
        output = tf.keras.layers.Conv2D(self.param["n_class"], (1,1), padding="same", activation = "softmax",dtype='float32')(next_input)    
        
        return tf.keras.Model(input, output, name=self.param["name"])

class SkyUNet:
    def __init__(self):
        self.model = None
        self.model_ver = 2
        self.fn_model = ""

    def set_model(self,fn_model,model_ver=2):
        if self.fn_model != fn_model:
            model = tf.keras.models.load_model(fn_model,compile=False,custom_objects={'MishLayer': MishLayer})
            if not is_gpu_available():
                #create identical model, only with pure float_32 policy
                nmodel = SkyUNetModel({"name":"unet","input_shape": (512,512), "n_class":3,"filter_multiplier":16,"n_depth":4,
                "kernel_initialization":"he_normal","dropout":0.1,"kernel_size":(3,3),"upsample_channel_multiplier":8}).get_model()
                nmodel.set_weights(model.weights)
                model = nmodel
            self.model = model
            self.model_ver = model_ver
            self.fn_model = fn_model

    def predict(self,x,batch_size = 5,normalize_255=False):
        if not is_gpu_available():
            batch_size = 1
        #print(len(x))
        n = int(np.ceil(len(x)/batch_size))
        lix = [np.array(range(j*batch_size,min((j+1)*batch_size,len(x)))) for j in range(n)]
        ylabel = np.zeros(x.shape,dtype=np.uint8)
        progbar = tf.keras.utils.Progbar(n)
        for i in range(n):            
            progbar.update(i)
            input = x[lix[i]]
            if normalize_255:
                input = input/255
            ylabel[lix[i]] = self.model.predict(input,verbose=False).argmax(-1)
        progbar.update(n,finalize=True)
        if self.model_ver>1:
            #Swap class index of 1 and 2, since for model 2023 the class indeces are (skyrmion:0, background:1, defects:2) and the functions are written for class indeces (skyrmion:0, defects:1, background:2)
            ylabel[ylabel==1] = 5
            ylabel[ylabel==2] = 1
            ylabel[ylabel==5] = 2
        return ylabel

    def __call__(self,img):
        #split image in 512x512 tiles
        sizey,sizex = img.shape
        lix = [((j*512,min((j+1)*512,sizey)),(i*512,min((i+1)*512,sizex))) for j in range(int(np.ceil(sizey/512))) for i in range(int(np.ceil(sizex/512)))]
        limgarray = []
        for ele in lix:
            pimg = img[ele[0][0]:ele[0][1],ele[1][0]:ele[1][1]]
            nimg = np.ones((512,512))
            nimg[:min(512,pimg.shape[0]),:min(512,pimg.shape[1])] = pimg
            limgarray.append(nimg)
        limgarray = np.array(limgarray)
        #Predict label
        lpredict = self.predict(limgarray)
        
        #reconstruct full image from tiles
        pred_label = np.zeros((sizey,sizex),dtype=lpredict.dtype)
        for i,ele in enumerate(lix):
            pred_label[ele[0][0]:ele[0][1],ele[1][0]:ele[1][1]] = lpredict[i,:ele[0][1]-ele[0][0],:ele[1][1]-ele[1][0]]
        return pred_label

    @staticmethod
    def trafo_channel_to_rgb(I):
        basis = np.array([[255,0,0],[0,255,0],[0,0,255]],dtype=np.uint8)
        return basis[I]

    @staticmethod
    def trafo_rgb_to_channel(I):
        Q = np.zeros((I.shape[0],I.shape[1]),dtype=np.uint8)
        R,G,B = I[:,:,0],I[:,:,1],I[:,:,2]
        skyrmion_mask = (R>=128)&(G<128)&(B<128)
        defect_mask = (R<128)&(G>=128)&(B<128)
        bck_mask = ~(skyrmion_mask|defect_mask)
        Q[skyrmion_mask] = 0
        Q[defect_mask] = 1
        Q[bck_mask] = 2
        return Q

class ImageEditor:
    def __init__(self,invert=False,intensity_clip=(0,1),xc=None,yc=None):
        self.invert = invert
        self.intensity_clip = intensity_clip
        self.xc = xc
        self.yc = yc

    def set_invert(self,invert):
        self.invert = invert

    def set_intensity_clip(self,intensity_clip):
        self.intensity_clip = intensity_clip

    def set_xc(self,xc):
        self.xc = xc

    def set_yc(self,yc):
        self.yc = yc

    def get_invert(self):
        return self.invert

    def get_intensity_clip(self):
        return self.intensity_clip

    def get_xc(self):
        return self.xc 

    def get_yc(self):
        return self.yc
    
    def __call__(self,x):
        img = x.copy()
        if len(img.shape)>2:
            img =  np.mean(img,axis=-1)

        if self.xc is not None:
            x0,x1 = self.xc
            x0,x1 = max(x0,0),min(x1,img.shape[1])
            img = img[:,int(np.floor(x0)):int(np.ceil(x1))]
        
        if self.yc is not None:
            y0,y1 = self.yc
            y0,y1 = max(y0,0),min(y1,img.shape[0])
            img = img[int(np.floor(y0)):int(np.ceil(y1)),:]
        
        img = (img-np.min(img))/(np.max(img)-np.min(img))
        if self.invert:
            img = 1-img
        v0,v1 = self.intensity_clip
        output = {}
        output["clipimg"] = np.clip(img,v0,v1)
        output["cliprescaleimg"] = np.clip(1/(v1-v0)*(img-v0),0,1)
        output["histinfo"] = img.flatten()
        return output


class MaskAnalysis:
    def __init__(self,nbins=40):
        self.pred_label = None
        self.img = None
        self.min_max_radius_range = None
        self.min_max_radius = None
        self.nbins = nbins
        self.radiusl = np.random.rand(30)
        self.binsl = np.linspace(0,1,self.nbins)
        with plt.ioff():
            self.fig = plt.figure(dpi=100,figsize=(10,6))
        self.fig.canvas.header_visible = False
        self.fig.canvas.toolbar_visible = False
        gs = self.fig.add_gridspec(2,2,height_ratios=[1,0.3])
        self.ax = [self.fig.add_subplot(gs[0,0]),self.fig.add_subplot(gs[0,1]),self.fig.add_subplot(gs[1,:])]
        self.im1 = self.ax[0].imshow(np.random.rand(20,20,3))
        self.im2 = self.ax[1].imshow(np.random.rand(20,20),cmap="gray",vmin=0,vmax=1,zorder=-900)
        self.lcontobj =  []
        self.ax[1].set_title("Skyrmion contours")
        self.ax[0].set_title("Single skyrmion mask")
        self.ax[0].grid(False)
        self.ax[1].grid(False)
        _,_,histobj = self.ax[2].hist(self.radiusl,bins=self.binsl)
        self.lcontobj += histobj
        self.meanrhist =  self.ax[2].axvline(0,color="red",lw=3,label="Mean value")
        self.lcontobj += [self.meanrhist]
        self.ax[2].legend()
        self.ax[2].set_xlabel(r"Radius $r=\sqrt{A/\pi}$ [pixel]")
        self.ax[2].set_ylabel(r"Probability")
        self.ax[2].set_title("Skyrmion radius statistic")        
        self.fig.tight_layout()
    
    def get_min_max_radius(self):
        return self.min_max_radius
        
    def set_min_max_radius_select_range(self,range):
        self.min_max_radius_range = range
        if self.pred_label is not None: self.analysis_2()

    def get_min_max_radius_select_range(self):
        return self.min_max_radius_range
        
    def __call__(self,pred_label,img):
        self.pred_label = pred_label
        self.img = img
        self.analysis_1()
        self.analysis_2()
        
    def analysis_1(self):
        _,self.init_labels,self.init_stats,self.init_posl = cv2.connectedComponentsWithStats((self.pred_label==0).astype(np.uint8),cv2.CV_32S)

    def analysis_2(self):
        #filter area
        labels,stats,posl = self.init_labels,self.init_stats,self.init_posl
        ixl,areal,posl = np.arange(1,len(stats)),stats[1:,4],posl[1:]
        radiusl = np.sqrt(areal/np.pi)
        min_max_radius = self.min_max_radius = (np.min(radiusl),np.max(radiusl))
        
        if self.min_max_radius_range != None:
            ixfilter = np.logical_and(radiusl>=self.min_max_radius_range[0],radiusl<=self.min_max_radius_range[1])
            ixl,areal,posl = ixl[ixfilter],areal[ixfilter],posl[ixfilter]
            
        ixsort = np.argsort(areal)[::-1]
        ixl,self.areal,self.posl = ixl[ixsort],areal[ixsort],posl[ixsort]
        ixlabel = np.zeros(np.max(labels)+1,dtype=np.int64)
        ixlabel[ixl] = np.arange(1,1+len(ixl))
        labels = self.labels = ixlabel[labels]
    
        colmap = np.vstack((np.ones(3),0.87*np.random.rand(np.max(labels),3)))
        cont = self.cont,_ = cv2.findContours((labels!=0).astype(np.uint8),cv2.RETR_LIST,cv2.CHAIN_APPROX_NONE)
        radiusl = self.radiusl = np.sqrt(areal/np.pi)

        while len(self.lcontobj)>0:
            obj = self.lcontobj.pop()
            obj.remove()

        for i in range(len(self.cont)):
            lx,ly = list(self.cont[i][:,0,0]),list(self.cont[i][:,0,1])
            self.lcontobj.append(self.ax[1].plot(lx+[lx[0]],ly+[ly[0]],color="r",lw=0.5)[0])
        self.im1.set_data(colmap[labels])
        self.im1.set_extent((0,self.pred_label.shape[1],self.pred_label.shape[0],0))
        self.im2.set_data(self.img)
        self.im2.set_extent((0,self.pred_label.shape[1],self.pred_label.shape[0],0))

        if len(self.radiusl) > 0:
            self.binsl = np.linspace(np.min(self.radiusl),np.max(self.radiusl)*(1+1e-12),self.nbins)
            hb,xb,bobj = self.ax[2].hist(self.radiusl,bins=self.binsl,color="#1f77b4",density=True)
            self.lcontobj += bobj
            self.ax[2].set_ylim(0,np.max(hb)*1.1)
            self.ax[2].set_xlim(np.min(self.radiusl),np.max(self.radiusl))
            mrl = np.mean(self.radiusl)
            msl = np.std(self.radiusl)
            meanrhist =  self.ax[2].axvline(mrl,color="red",lw=3,label=fr"Mean value: (${mrl:.2f}\pm {msl:.2f}$) pixel")
            self.lcontobj += [meanrhist]
            #set_xdata([mrl,mrl])
            #self.meanrhist.set_label()
            t = np.linspace(np.min(self.radiusl),np.max(self.radiusl),100)
            pl = self.ax[2].plot(t,1/np.sqrt(2*np.pi*msl**2)*np.exp(-(t-mrl)**2/(2*msl**2)),color="red")
            self.lcontobj += pl
            
        
        self.ax[2].legend()
        self.fig.tight_layout()

    def get_datatable(self):
        return pd.DataFrame({"pos_x [pixel]":self.posl[:,0],"pos_y [pixel]":self.posl[:,1],"area [pixel*pixel]":self.areal})

    def get_posl(self):
        return self.posl




class PosAnalysis:
    def __init__(self):
        self.posl = None
        self.min_angle_range = 0
        with plt.ioff():
            self.fig = plt.figure(dpi=100,figsize=(10,10))
        self.fig.canvas.header_visible = False
        self.fig.canvas.toolbar_visible = False
        gs = self.fig.add_gridspec(2,1,height_ratios=[2,0.5])
        self.ax = [self.fig.add_subplot(gs[0]),self.fig.add_subplot(gs[1])]
        self.lobj = []

        self.ax[0].set_title("Delaunay triangulation result")
        self.ax[0].grid(False)
        self.ax[1].set_xlabel(r"Skyrmion-Skyrmion distance [pixel]")
        #self.ax[1].legend()
        self.ax[1].set_ylabel(r"Probability")
        self.ax[1].set_title("Skyrmion-Skyrmion statistic")
        self.im1 = self.ax[0].imshow(np.zeros((20,20)),cmap="gray",vmin=0,vmax=1)
        self.fig.tight_layout()
        
    def set_min_angle_select_range(self,v):
        self.min_angle_range = v
        if self.posl is not None: self.analysis_2()

    def get_min_angle_select_range(self):
        return self.min_angle_range
        
    def __call__(self,posl,img):
        self.posl = posl
        self.img = img
        self.analysis_1()
        self.analysis_2()

    def analysis_1(self):
        if len(self.posl)>0:
            self.voronoi = scipy.spatial.Voronoi(self.posl)
            self.delaunay = scipy.spatial.Delaunay(self.posl)
        else:
            self.voronoi = None
            self.delunay = None
        
    def analysis_2(self):
        while len(self.lobj)>0:
            obj = self.lobj.pop()
            obj.remove()
        
        lcon,lcon1 = [],[]
        self.distance = []
        if (self.delaunay is not None) and (self.voronoi is not None):
            #selector out triangles with small angles (only occurs at the boundary of the image; therefore, they do not represent real skyrmion distances)
            for ele in self.delaunay.simplices:
                ok = True
                for i in range(len(ele)):
                    v1,v2 = self.delaunay.points[ele[(i+1)%len(ele)]]-self.delaunay.points[ele[i]],self.delaunay.points[ele[(i-1)%len(ele)]]-self.delaunay.points[ele[i]]
                    if not (np.pi/180*self.min_angle_range<=np.arccos(np.clip(np.dot(v1,v2)/np.sqrt(np.dot(v1,v1)*np.dot(v2,v2)),-1,1))):#<=np.pi/180*self.min_max_angle_range[1]):
                        ok = False
                        break
                if ok:
                    for i in range(len(ele)):
                        a,b = ele[i],ele[(i+1)%len(ele)]
                        lcon.append((min(a,b),max(a,b)))
                else:
                    for i in range(len(ele)):
                        a,b = ele[i],ele[(i+1)%len(ele)]
                        lcon1.append((min(a,b),max(a,b)))
                    
            self.lcon = list(set(lcon))
            self.lcon1 = list(set(lcon1)-set(lcon))
            self.distance = [np.linalg.norm(self.posl[a]-self.posl[b]) for a,b in lcon]
            
            for ele in self.voronoi.regions:
                if len(ele)==0 or len(list(filter(lambda x:x==-1,ele)))>0: continue
                self.lobj.append(self.ax[0].plot([self.voronoi.vertices[i,0] for i in list(ele)+[ele[0]]],[self.voronoi.vertices[i,1] for i in list(ele)+[ele[0]]],lw=1.5,color="b")[0])
            
            for a,b in self.lcon:
                self.lobj.append(self.ax[0].plot([self.posl[a,0],self.posl[b,0]],[self.posl[a,1],self.posl[b,1]],color="r",lw=1)[0])
                
            for a,b in self.lcon1:
                self.lobj.append(self.ax[0].plot([self.posl[a,0],self.posl[b,0]],[self.posl[a,1],self.posl[b,1]],color="g",lw=1)[0])
            
        #self.ax[0].set_xlim(0,self.img.shape[1])
        #self.ax[0].set_ylim(0,self.img.shape[0])
        self.im1.set_data(self.img)#,cmap="gray",origin="lower")
        self.im1.set_extent((0,self.img.shape[1],self.img.shape[0],0))
        

        if len(self.distance)>0:
            yh,xh,histobj = self.ax[1].hist(self.distance,color='#1f77b4',density=True,bins=40)
            mean,sigma = np.mean(self.distance),np.std(self.distance)
            t = np.linspace(np.min(self.distance),np.max(self.distance),100)
            pl = self.ax[1].plot(t,1/np.sqrt(2*np.pi*sigma**2)*np.exp(-(t-mean)**2/(2*sigma**2)),color="red")[0]
            histmean = self.ax[1].axvline(np.mean(self.distance),color="red",lw=3,label=fr"Mean value: (${mean:.2f}\pm {sigma:.2f}$) pixel")
            self.lobj += [histobj,histmean,pl]
            self.ax[1].set_ylim(0,np.max(yh)*1.2)
            self.ax[1].set_xlim(np.min(xh),np.max(xh))
        self.ax[1].legend()
        self.fig.tight_layout()




%matplotlib widget
class ImageEditorGUI(ipywidgets.HBox):
    def __init__(self,update_event=None,reset_event=None):
        super().__init__()
        with plt.ioff():
            self.fig1,self.ax1 = plt.subplots(dpi=100)
            self.fig2,self.ax2 = plt.subplots(dpi=100,figsize=(2,2))
            self.fig1.canvas.header_visible = False
            self.fig2.canvas.header_visible = False
            self.fig1.canvas.toolbar_visible = False
            self.fig2.canvas.toolbar_visible = False

        self.update_event = update_event
        self.reset_event = reset_event
        self.ax1.set_xlabel("x")
        self.ax1.set_ylabel("y")
        self.ax1.grid(False)
        cmg = matplotlib.colormaps.get_cmap("gray")
        clis = cmg(np.arange(cmg.N))
        clis[0] = [1,0,0,1]
        clis[-1] = [0,0,1,1]
        self.ncmap = matplotlib.colors.ListedColormap(clis)
        self.imgedit = ImageEditor()
        self.img = np.random.rand(20,20)
        self.im1 = self.ax1.imshow(self.img,cmap=self.ncmap,vmin=0,vmax=1)
        self.cax = make_axes_locatable(self.ax1).append_axes("right",size = "2%",pad=0.03)
        self.colorbar = plt.colorbar(self.im1,cax=self.cax)
        self.bins = np.linspace(0,1,40)
        _,_,self.hist = self.ax2.hist(self.img.flatten(),bins=self.bins,density=True)

        self.ax2.set_xlim(0,1)
        self.ax2.set_xlabel("Intensity")
        self.lowerl = self.ax2.axvline(0,color="red")
        self.higherl = self.ax2.axvline(1,color="blue")
        self.fig2.tight_layout()

        self.int_slider = ipywidgets.FloatRangeSlider(description="Intensity clip",step=0.01)
        self.yc_slider = ipywidgets.IntRangeSlider(description="y-crop")
        self.xc_slider = ipywidgets.IntRangeSlider(description="x-crop")
        self.inv_check = ipywidgets.Checkbox(description="Inversion")
        self.zoom_slider1 = ipywidgets.FloatSlider(min=0.2,max=5,value=1.5,description="Plot zoom")
        self.zoom_slider2 = ipywidgets.FloatSlider(min=0.2,max=5,value=1.5,description="Hist. zoom")
        self.rzoom_button = ipywidgets.Button(description="Reset zoom")
        self.rzoom_button.on_click(self.event_resetzoom)
        
        self.children = [ipywidgets.VBox([self.xc_slider,self.yc_slider,self.inv_check,self.int_slider,self.zoom_slider1,self.zoom_slider2,self.rzoom_button,self.fig2.canvas],layout=ipywidgets.Layout(width="30%",align_items="center")),self.fig1.canvas]  

        self.widget_observe()
    
    def event_resetzoom(self,x):
        self.zoom_slider1.value = 1.5
        self.zoom_slider2.value = 1.5
    
    def event_intslider(self,x):
        self.imgedit.set_intensity_clip(x.new)
        self.update()

    def event_xcslider(self,x):
        self.imgedit.set_xc(x.new)
        self.update()

    def event_ycslider(self,x):
        self.imgedit.set_yc(x.new)
        self.update()

    def event_invcheck(self,x):
        self.imgedit.set_invert(x.new)
        self.update()

    def event_zoomslider1(self,x):
        self.fig1.set_dpi(100*x.new)
        self.update()

    def event_zoomslider2(self,x):
        self.fig2.set_dpi(100*x.new)
        self.update()
    
    def widget_observe(self):
        self.int_slider.observe(self.event_intslider,"value")
        self.yc_slider.observe(self.event_ycslider,"value")
        self.xc_slider.observe(self.event_xcslider,"value")
        self.inv_check.observe(self.event_invcheck,"value")
        self.zoom_slider1.observe(self.event_zoomslider1,"value")
        self.zoom_slider2.observe(self.event_zoomslider2,"value")   
    
    def widget_unobserve(self):
        self.int_slider.unobserve(self.event_intslider,"value")
        self.yc_slider.unobserve(self.event_ycslider,"value")
        self.xc_slider.unobserve(self.event_xcslider,"value")
        self.inv_check.unobserve(self.event_invcheck,"value")
        self.zoom_slider1.unobserve(self.event_zoomslider1,"value")
        self.zoom_slider2.unobserve(self.event_zoomslider2,"value")  

    def __call__(self,img,canvasdraw=True):
        self.img = img
        self.init_gui(canvasdraw)

    def init_gui(self,canvasdraw=True):
        self.imgedit.set_xc((0,self.img.shape[1]))
        self.imgedit.set_yc((0,self.img.shape[0]))
        self.imgedit.set_intensity_clip((0,1))
        self.imgedit.set_invert(False)
        
        self.widget_unobserve()
        self.int_slider.min,self.int_slider.max = self.imgedit.get_intensity_clip()
        self.xc_slider.min,self.xc_slider.max = self.imgedit.get_xc()
        self.yc_slider.min,self.yc_slider.max = self.imgedit.get_yc()
        self.int_slider.value = self.imgedit.get_intensity_clip()
        self.xc_slider.value = self.imgedit.get_xc()
        self.yc_slider.value = self.imgedit.get_yc()
        self.inv_check.value = self.imgedit.get_invert()      
        self.zoom_slider1.value = 1.5
        self.zoom_slider2.value = 1.5
        self.widget_observe()

        self.fig1.set_dpi(150)
        self.fig2.set_dpi(150)
        self.update(canvasdraw)

    def update(self,canvasdraw=True):
        output = self.imgedit(self.img)
        editimg,histinfo = output["clipimg"],output["histinfo"]
        
        self.im1.set_data(editimg)
        xc,yc = self.imgedit.get_xc(),self.imgedit.get_yc()
        self.im1.set_extent((xc[0],xc[1],yc[1],yc[0]))
        
        v0,v1 = self.imgedit.get_intensity_clip()
        self.lowerl.set_xdata([v0,v0])
        self.higherl.set_xdata([v1,v1])
        self.ax1.set_aspect(1)
        
        self.im1.set_clim(v0,v1)
        self.colorbar.update_normal(self.im1)
        
        nhist = np.histogram(histinfo,self.bins,density=True)[0]
        self.nhist=nhist
        for i,ele in enumerate(self.hist):
            ele.set_height(nhist[i])
        self.ax2.set_ylim(0,np.max(nhist)*1.1)
        
        #if canvasdraw:
        #    self.fig1.canvas.draw()
        #    self.fig2.canvas.draw()    

        if self.update_event is not None:
            self.update_event()

    def get_editimg(self):
        return self.imgedit(self.img)["cliprescaleimg"]

    def get_config(self):
        return {"intensity_clip":self.imgedit.get_intensity_clip(),
                "xc":self.imgedit.get_xc(),
                "yc":self.imgedit.get_yc(),
                "invert":self.imgedit.get_invert() }

class UNetPredictionGUI(ipywidgets.VBox):
    def __init__(self,modeldic,pred_event=None):
        super().__init__()
        self.modeldic = modeldic
        self.cmodel = modeldic[0]
        self.skyunet = SkyUNet()

        self.pred_event = pred_event
        self.button_predict = ipywidgets.Button(description="Predict")
        self.model_dropdown = ipywidgets.Dropdown(options={b[1]:a for a,b in modeldic.items()})
        self.zoomslider = ipywidgets.FloatSlider(min=0.2,max=5,value=1.5,description="Plot zoom")
        self.resetzoom = ipywidgets.Button(description="Reset zoom")
        self.out = ipywidgets.Output()

        with plt.ioff():
            self.fig,self.ax = plt.subplots(ncols=2,dpi=100,figsize=(10,5))
            self.fig.canvas.header_visible = False
            self.fig.canvas.toolbar_visible = False
        
        self.fig.canvas.header_visible = False
        self.img = np.random.rand(20,20)
        self.pred = None
        self.im1 = self.ax[0].imshow(self.img,cmap="gray",vmin=0,vmax=1)
        self.im2 = self.ax[1].imshow(np.ones((20,20,3)))
        self.ax[0].set_title("Kerr image")
        self.ax[1].set_title("Predicted label")
        self.ax[0].grid(False)
        self.ax[1].grid(False)
        self.textobj = self.ax[1].text(0,self.img.shape[0]/2,"Please click the \"Predict\" button to make a prediction with the U-Net :-)",fontsize=20,wrap=True)        
        self.children = [ipywidgets.HBox([self.zoomslider,self.resetzoom]),ipywidgets.HBox([self.model_dropdown,self.button_predict]),self.out,self.fig.canvas]
    
        self.resetzoom.on_click(self.event_resetzoom)
        self.button_predict.on_click(self.event_predict)
        self.widget_observe()
        self.start_out()
        
    def get_config(self):
        fnmodel = self.cmodel[0]
        return {"unet_model":[ele for ele in self.modeldic.values() if ele[0]==fnmodel][0]}
    
    def start_out(self):
        with self.out:
            self.out.clear_output()
            print("",end="\r")
    
    def event_resetzoom(self,x):
        self.zoomslider.value = 1.5

    def event_predict(self,x):
        self.start_out()
        with self.out:
            modelt = self.cmodel
            self.skyunet.set_model(modelt[0],modelt[2])
            self.predix = self.skyunet(self.img)
            self.pred = self.skyunet.trafo_channel_to_rgb(self.predix)
            if self.pred_event is not None:
                self.pred_event()
                
        self.update()

    def __call__(self,img,canvasdraw=True):
        self.img = img
        self.textobjremove()
        self.pred = None
        self.init_gui(canvasdraw)

    def set_model(self,x):
        self.cmodel = self.modeldic[x.new]

    def widget_observe(self):
        self.zoomslider.observe(self.event_zoomslider,"value")
        self.model_dropdown.observe(self.set_model,"value")
        
    def widget_unobserve(self):
        self.zoomslider.unobserve(self.event_zoomslider,"value")
        self.model_dropdown.unobserve(self.set_model,"value")

    def event_zoomslider(self,x):
        self.fig.set_dpi(100*x.new)

    def textobjremove(self):
        try:
            self.textobj.remove()
        except:
            pass
    
    def update(self,canvasdraw=True):
        self.im1.set_data(self.img)
        self.im1.set_extent((0,self.img.shape[1],0,self.img.shape[0]))
        self.textobjremove()
        if self.pred is None:
            pred = np.ones((self.img.shape[0],self.img.shape[1],3))
        else:
            pred = self.pred
            
        self.im2.set_data(pred)
        self.im2.set_extent((0,pred.shape[1],0,pred.shape[0]))
        self.fig.tight_layout()
        #if canvasdraw:
        #    self.fig.canvas.draw()
        if self.pred is None:
            self.textobj = self.ax[1].text(0,self.img.shape[0]/2,"Please click the \"Predict\" button to make a prediction with the U-Net :-)",fontsize=20,wrap=True)        

    def get_prediction(self):
        return self.predix

    def init_gui(self,canvasdraw=True):
        self.update(canvasdraw)

    def reset_gui(self):
        self.re = 43
        self.widget_unobserve()
        self.model_dropdown.value = 0
        self.zoomslider.value = 1.5
        self.fig.set_dpi(150)
        self.textobjremove()
        self.pred = None
        self.widget_observe()
        self.update()

class MaskAnalysisGUI(ipywidgets.VBox):
    def __init__(self,event_maskanalysis):
        super().__init__()
        self.label = None
        self.event_maskanalysis = event_maskanalysis
        self.zoomslider = ipywidgets.FloatSlider(min=0.2,max=5,value=1.5,description="Plot zoom")
        self.resetzoom = ipywidgets.Button(description="Reset zoom")

        self.analysis = MaskAnalysis()
        self.analysis.fig.set_visible(False)
        
        #self.radius_slider = ipywidgets.FloatRangeSlider(min=0,max=100,value=(0,100),readout_format=".1f",description="Radius range",step=0.01)
        #self.radius_slider.layout.width="40%"

            
        self.radius_slider_min = ipywidgets.FloatLogSlider(min=np.log10(0.5),max=np.log10(100),value=0.5,readout_format=".1f",description="min. radius",step=0.01)
        self.radius_slider_min.layout.width="40%"

        self.radius_slider_max = ipywidgets.FloatLogSlider(min=np.log10(0.5),max=np.log10(100),value=100,readout_format=".1f",description="max. radius",step=0.01)
        self.radius_slider_max.layout.width="40%"
        
        self.children = [ipywidgets.HBox([self.zoomslider,self.resetzoom]),self.radius_slider_min,self.radius_slider_max,self.analysis.fig.canvas]
        self.resetzoom.on_click(self.event_resetzoom)
        self.widget_observe()    

    def event_resetzoom(self,x):
        self.zoomslider.value = 1.5

    def __call__(self,label,img,canvasdraw=True):
        self.label = label
        self.img = img
        self.init_gui(canvasdraw)

    def event_radius_max_sel(self,x):
        #self.radius_slider_min.max = np.log10(x.new)
        self.event_radius_sel_2()

    def event_radius_min_sel(self,x):
        #self.radius_slider_max.min = np.log10(x.new)
        self.event_radius_sel_2()

    def event_radius_sel_2(self):
        self.analysis.set_min_max_radius_select_range((self.radius_slider_min.value,self.radius_slider_max.value))
        if self.event_maskanalysis is not None:
            self.event_maskanalysis()
    
    def widget_observe(self):
        self.zoomslider.observe(self.event_zoomslider,"value")
        #self.radius_slider.observe(self.event_radius_sel,"value")
        self.radius_slider_max.observe(self.event_radius_max_sel,"value")
        self.radius_slider_min.observe(self.event_radius_min_sel,"value")
        
    def widget_unobserve(self):
        try:
            self.zoomslider.unobserve(self.event_zoomslider,"value")
        except:
            pass
        try:
            self.radius_slider_max.unobserve(self.event_radius_max_sel,"value")
        except:
            pass
        try:
            self.radius_slider_min.unobserve(self.event_radius_min_sel,"value")
        except:
            pass
        

    def event_radius_sel(self,x):
        pass
    
    def event_zoomslider(self,x):
        self.analysis.fig.set_dpi(100*x.new)
         
    def init_gui(self,canvasdraw=True):
        self.analysis(self.label,self.img)
        self.analysis.fig.set_visible(True)
        v0,v1 = 0.5,np.sqrt(self.label.shape[0]*self.label.shape[1]/np.pi)#
        #v0,v1 = np.floor(v0*10)/10,np.ceil(v1*10)/10
        v1 = np.ceil(v1*10)/10
        self.analysis.set_min_max_radius_select_range((v0,v1))
        
        self.widget_unobserve()
        self.radius_slider_min.min = np.log10(0.5)
        self.radius_slider_min.max = np.log10(v1)
        self.radius_slider_max.min = np.log10(0.5)
        self.radius_slider_max.max = np.log10(v1)
        #v0,v1 = self.analysis.get_min_max_radius()
        self.radius_slider_max.value = v1
        self.radius_slider_min.value = 2
        
        #self.radius_slider.min = v0
        #self.radius_slider.max = v1
        #self.radius_slider.value = (v0,v1)
        self.widget_observe()
        if self.event_maskanalysis is not None:
            self.event_maskanalysis()

    def reset_gui(self):
        self.widget_unobserve()
        self.zoomslider.value = 1.5
        self.analysis.fig.set_visible(False)
        self.analysis.fig.set_dpi(150)
        self.label = None
        self.img = None
        self.widget_observe()
        #self.update()

    def get_posl(self):
        return self.analysis.get_posl()

    def get_config(self):
        if self.label is None:
            return None
        return {"radius_range":self.analysis.get_min_max_radius_select_range()}


class PosAnalysisGUI(ipywidgets.VBox):
    def __init__(self):
        super().__init__()
        self.posl = None
        self.zoomslider = ipywidgets.FloatSlider(min=0.2,max=5,value=1.5,description="Plot zoom")
        self.resetzoom = ipywidgets.Button(description="Reset zoom")

        self.analysis = PosAnalysis()
        self.analysis.fig.set_visible(False)
        self.angle_standard_range = 20
        
        self.angle_slider = ipywidgets.FloatSlider(value=self.angle_standard_range,min=0,max=60,readout_format=".1f",description="min. angle [°]",step=0.01)
        self.angle_slider.layout.width="40%"
        
        self.children = [ipywidgets.HBox([self.zoomslider,self.resetzoom]),self.angle_slider,self.analysis.fig.canvas]
        self.resetzoom.on_click(self.event_resetzoom)
        self.widget_observe()    

    def event_resetzoom(self,x):
        self.zoomslider.value = 1.5

    def __call__(self,posl,img,canvasdraw=True):
        self.posl = posl
        self.img = img
        self.init_gui(canvasdraw)

    def event_angle_sel(self,x):
        self.analysis.set_min_angle_select_range(x.new)

    def widget_observe(self):
        self.zoomslider.observe(self.event_zoomslider,"value")
        self.angle_slider.observe(self.event_angle_sel,"value")
    
    def widget_unobserve(self):
        try:
            self.zoomslider.unobserve(self.event_zoomslider,"value")
        except:
            pass
        try:
            self.angle_slider.unobserve(self.event_angle_sel,"value")
        except:
            pass
    
    def event_zoomslider(self,x):
        self.analysis.fig.set_dpi(100*x.new)
         
    def init_gui(self,canvasdraw=True):
        self.widget_unobserve()
        self.analysis(self.posl,self.img)
        self.analysis.set_min_angle_select_range(self.angle_standard_range)
        self.analysis.fig.set_visible(True)
        self.angle_slider.min = 0
        self.angle_slider.max = 60
        self.angle_slider.value = self.angle_standard_range
        self.widget_observe()
    
    def reset_gui(self):
        self.widget_unobserve()
        self.zoomslider.value = 1.5
        self.analysis.fig.set_visible(False)
        self.analysis.fig.set_dpi(150)
        self.posl = None
        self.img = None
        self.widget_observe
        
    def get_config(self):
        if self.posl is None:
            return None
        return {"angle_range":self.analysis.get_min_angle_select_range()}

plt.close("all")
class UNetGUI(ipywidgets.VBox):
    def __init__(self):
        super().__init__()
        self.modeldic = {0:(basis_dir+'models/2023_model.keras',"Model 2023",2),1:(basis_dir+'models/2022_model.keras',"Model 2022",1),2:(basis_dir+'models/2022_model_inv.keras',"Model 2022 inverse",1)}
        self.posanalysisgui = PosAnalysisGUI()
        self.maskanalysisgui = MaskAnalysisGUI(self.mask_analysis_update)
        
        self.predgui = UNetPredictionGUI(self.modeldic,self.pred_update)
        self.editorgui = ImageEditorGUI(self.editor_update,self.reset_gui)
        input_img = np.array(Image.open(basis_dir+"example_kerr_microscopy_image.png"))
        self.editorgui(input_img,False)

        self.res_button = ipywidgets.Button(description="Reset")
        self.res_button.on_click(self.reset_gui_button)
        self.fileup = ipywidgets.FileUpload()
        self.fileup.observe(self.event_fileup,"value")

        self.fileupbatch = ipywidgets.FileUpload(multiple=True,description="Upload & Batch Analysis & Download (after parameter setup via GUI)")
        self.fileupbatch.layout.width="60%"
        self.fileupbatch.observe(self.event_filebatch,"value")

        self.tab = ipywidgets.Tab([self.editorgui,self.predgui,self.maskanalysisgui,self.posanalysisgui])
        self.tab.set_title(0,"1) Image Editor")
        self.tab.set_title(1,"2) Prediction")
        self.tab.set_title(2,"3) Mask analysis")
        self.tab.set_title(3,"4) Position analysis")
        
        self.out = ipywidgets.Output()
        self.clear_output()
        
        self.children = [ipywidgets.HBox([self.fileup,self.res_button]),self.tab,ipywidgets.HBox([self.fileupbatch,self.out])]

    def show_output(self,x):
        from IPython.display import FileLink,display
        with self.out:
            self.out.clear_output()
            display(FileLink(x))

    def clear_output(self):
        with self.out:
            self.out.clear_output()
            
    
    def event_fileup(self,x):
        self.reset_gui()
        self.editorgui(np.array(Image.open(io.BytesIO(x.new[0].content.tobytes()))))
        self.fileup.value = ()
    

    def event_filebatch(self,x):
        self.process_batch(x.new)
    
    def process_batch(self,fnl):
        config = self.get_config()
        self.clear_output()
        
        if (config["mask_analysis_editor"] is None) or (config["pos_analysis_editor"] is None):
            return

        imgeditor = ImageEditor(invert=config["img_editor"]["invert"],
                                intensity_clip=config["img_editor"]["intensity_clip"],
                                xc=config["img_editor"]["xc"],
                                yc=config["img_editor"]["yc"])
        
        sky_unet = SkyUNet()
        sky_unet.set_model(config["prediction_editor"]["unet_model"][0],config["prediction_editor"]["unet_model"][2])
        
        analysis1 = MaskAnalysis()
        analysis1.set_min_max_radius_select_range(config["mask_analysis_editor"]["radius_range"])
        analysis1.fig.set_dpi(300)
        
        analysis2 = PosAnalysis()
        analysis2.set_min_angle_select_range(config["pos_analysis_editor"]["angle_range"])
        analysis2.fig.set_dpi(300)
        self.qqq = 4
        
        import shutil
        from IPython.display import FileLink,display
        import os
        tmp_folder = "./tmp/"
        for ele in glob.iglob(tmp_folder+"*"):
            os.remove(ele)
        if not os.path.isdir(tmp_folder):
            os.mkdir(tmp_folder)
        if os.path.isfile("result.zip"):
            os.remove("result.zip")
            
        for i,ele in enumerate(fnl):
            prefix_file = tmp_folder+os.path.splitext(ele.name)[0]+"_ix_"+str(i)
            input = np.array(Image.open(io.BytesIO(ele.content.tobytes())))
            img = imgeditor(input)["cliprescaleimg"]
            pred = sky_unet(img)
            predimg = sky_unet.trafo_channel_to_rgb(pred)
            analysis1(pred,img)
            analysis2(analysis1.posl,img)
            
            plt.imsave(prefix_file+"_input.png",input,cmap="gray")
            plt.imsave(prefix_file+"_edited.png",img,cmap="gray")
            plt.imsave(prefix_file+"_skyunet_prediction.png",predimg)
            analysis1.get_datatable().to_csv(prefix_file+"_datatable.csv")
            analysis1.fig.savefig(prefix_file+"_mask_analysis.png",bbox_inches="tight")
            analysis2.fig.savefig(prefix_file+"_pos_analysis.png",bbox_inches="tight")
            #break
        zipfile = "result"
        shutil.make_archive(zipfile,"zip",tmp_folder)
        self.clear_output()
        self.show_output(zipfile+".zip")
        self.fileupbatch.unobserve(self.event_filebatch,"value")
        self.fileupbatch.value = ()
        self.fileupbatch.observe(self.event_filebatch,"value")
    
    def editor_update(self):
        self.predgui(self.editorgui.get_editimg())

    def pred_update(self):
        self.maskanalysisgui(self.predgui.get_prediction(),self.editorgui.get_editimg())

    def reset_gui_button(self,x):
        self.reset_gui()

    def mask_analysis_update(self):
        self.posanalysisgui(self.maskanalysisgui.get_posl(),self.editorgui.get_editimg())
    
    def reset_gui(self):
        self.predgui.reset_gui()
        self.editorgui.init_gui()
        self.maskanalysisgui.reset_gui()
        self.posanalysisgui.reset_gui()

    def get_config(self):
        return {"img_editor":self.editorgui.get_config(),
                "prediction_editor":self.predgui.get_config(),
                "mask_analysis_editor":self.maskanalysisgui.get_config(),
                "pos_analysis_editor":self.posanalysisgui.get_config()}
         

gui = UNetGUI()
gui

2024-05-21 16:17:54.022662: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


UNetGUI(children=(HBox(children=(FileUpload(value=(), description='Upload'), Button(description='Reset', style…