In [None]:
import numpy as np
import torch.nn as nn
from scipy import ndimage
import cv2
import matplotlib.pyplot as plt
from datetime import datetime
from scipy.ndimage import gaussian_filter
from skimage import measure
from shapely.geometry import Polygon,Point
from skimage import graph, data, io, segmentation, color
from skimage.draw import polygon
import copy
from skimage.io import imsave, imread
from skimage.measure import regionprops
from skimage.segmentation import find_boundaries
from networkx.linalg import adjacency_matrix
from skimage.segmentation import slic,mark_boundaries
import torch
from torch.nn import functional as F
from sklearn.decomposition import PCA

In [None]:
def mask(phi):  #Making a mask
    N,M = phi.shape
    re_PHI = np.zeros([N-1,M-1])
    for i in range(N-1):
        for j in range(M-1):
            if phi[i,j] >= 0.0:
                re_PHI[i,j] = 1
            else:
                re_PHI[i,j] = 0
    return re_PHI

In [None]:
def DeleteSmall(contours,num): # Delete small contours
    count = 0
    for i in range(len(contours)):
        if(len(contours[i-count])<num):
            contours.pop(i-count)
            count += 1
    return contours

In [None]:
def changeX_Y(L): #Adjust the contour format
    l = np.copy(L)
    for i in range(len(L)):
        l[i][0]=L[i][1]
        l[i][1]=L[i][0]
    return l

In [None]:
def Complete_Contour(contours,N,M): #Complete the contours located at the edges 
    contours_new = np.copy(contours)
    for i in range(len(contours)):
        cnt = contours[i]
        cnt = np.array(cnt,dtype=np.int)
        cnt = changeX_Y(cnt)
        x, y, w, h = cv2.boundingRect(cnt)
        if (x==0) & (y==0):
            cnt = np.append(cnt, [[0,0]],axis=0)
        if (x==0) & ((y+h)==N):
            cnt = np.append(cnt, [[0,N-1]],axis=0)
        if ((x+w)==M) & (y==0):
            cnt = np.append(cnt, [[M-1,0]],axis=0)
        if ((x+w)==M) & ((y+h)==N):
            cnt = np.append(cnt, [[M-1,N-1]],axis=0)
        contours_new[i] = cnt
    return contours_new

In [None]:
def DeleteIn(frontground,contours,im): #Remove small contour units within large contours
    N,M,_= im.shape
    rock_mask = np.ones((N,M),np.uint8)
    for i in range(len(contours)):
        cv2.fillPoly(rock_mask,[contours[i]],(0))
    frontground = cv2.bitwise_and(im,im,mask=rock_mask)
    return frontground

In [None]:
def TagArea(contours,im): #Make Tag
    N,M,_ = im.shape
    tag_area = np.zeros_like(im,np.int32)
    for i in range(len(contours)):
        mask = np.zeros((N,M),np.int32)
        cv2.fillPoly(mask,[contours[i]],(1))
        tag_area[mask == 1] = i+1
    return tag_area

In [None]:
plt.ion()
fig1 = plt.figure(1)
fig2 = plt.figure(2)

In [None]:
def show_fig2(phi, img): #Draw contours
    plt.clf()
    contours = measure.find_contours(phi, 0.5)
    contours = DeleteSmall(contours,20)
    fig2, ax2 = plt.subplots(figsize=(9, 9))
    
    ax2.imshow(img)
    
    for n, contour in enumerate(contours):
        ax2.plot(contour[:, 1], contour[:, 0], linewidth=1.5, color='r')
    plt.axis('off')
    plt.show()

In [None]:
def find(a, b, f): #Replace the parts that are less than 0
    mask = f < 0
    x = b.clone()
    x[mask] = a[mask]

    return x

In [None]:
def guassian_blur(img, kernel_size, sigma): #Gaussian filtering 
    device = img.device
    kernel = torch.Tensor([[torch.exp(-0.5 * (i - kernel_size//2)**2 / sigma**2) *
                                torch.exp(-0.5 * (j - kernel_size//2)**2 / sigma**2)
                                for i in range(kernel_size)] for j in range(kernel_size)]).to(device)
    kernel = kernel / kernel.sum()  
    blurred = F.conv2d(img.unsqueeze(0).unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), padding=1)  
    blur = blurred.squeeze(0).squeeze(0)
    return blur

In [None]:
def torch_gradient_like_np(tensor): #Calculate the gradient of a tensor
    if tensor.dim() != 2:
        raise ValueError("Input tensor must be 2D.")
    # Calculate gradients using central differences
    dy = tensor[1:, :] - tensor[:-1, :]
    dx = tensor[:, 1:] - tensor[:, :-1]

    # Pad gradients to match the original tensor shape
    dy_padded = torch.cat((dy, torch.zeros(1, tensor.shape[1], device=tensor.device)), dim=0)
    dx_padded = torch.cat((dx, torch.zeros(tensor.shape[0], 1, device=tensor.device)), dim=1)

    return dy_padded, dx_padded

In [None]:
def GLFIF(Img, LImg, u0, sigma, lambda1, lambda2, alpha1, alpha2, g):
    u1 = u0**2
    u2 = (1 - u0)**2

    Iu1 = Img * u1
    Iu2 = Img * u2

    c1 = torch.sum(Iu1) / torch.sum(u1)
    c2 = torch.sum(Iu2) / torch.sum(u2)


    Ku1 = guassian_blur(u1, 3, sigma)
    Ku2 = guassian_blur(u2, 3, sigma)

    KI1 = guassian_blur(Iu1, 3, sigma)
    KI2 = guassian_blur(Iu2, 3, sigma)

    s1 = KI1 / Ku1
    s2 = KI2 / Ku2

    kim = (c1 * u1) + (c2 * u2)
    DcH = (LImg - kim) * LImg
    F3_old = DcH

    sim = (s1 * u1) + (s2 * u2)
    DsH = (LImg - sim) * LImg
    F4_old = DsH

    un = 1 / (1 + (lambda1 * (Img - c1)**2 + (alpha1 * s2 + alpha2 * c2)) / (
            lambda2 * (Img - c2)**2 + (alpha1 * s1 + alpha2 * c1)))

    un1 = un**2
    un2 = (1 - un)**2

    delta_u1 = un1 - u1
    delta_u2 = un2 - u2

    delta_F1 = un1 * delta_u1 * ((Img - c1)**2 / (un1 + delta_u1))
    delta_F2 = un2 * delta_u2 * ((Img - c2)**2 / (un2 + delta_u2))

    NIu1 = LImg * un1
    NIu2 = LImg * un2

    Nc1 = torch.sum(NIu1) / torch.sum(un1)
    Nc2 = torch.sum(NIu2) / torch.sum(un2)

    NK1 = NIu1 / un1
    NK2 = NIu2 / un2

    Nkim = un1 * Nc1 + un2 * Nc2
    F3_new = (LImg - Nkim) * LImg

    NKu1 = guassian_blur(un1, 3, sigma)
    NKu2 = guassian_blur(un2, 3, sigma)

    NKI1 = guassian_blur(NIu1, 3, sigma)
    NKI2 = guassian_blur(NIu2, 3, sigma)

    Ns1 = NKI1 / NKu1
    Ns2 = NKI2 / NKu2

    Nsim = un1 * Ns1 + un2 * Ns2
    F4_new = (LImg - Nsim) * LImg

    deltaF = lambda1 * delta_F1 * g + lambda2 * delta_F2 * g + alpha1 * (F3_new - F3_old) * g + alpha2 * (F4_new - F4_old) * g
    # Add the function 'find' for updating 'u0' here
    u = find(un, u0, deltaF)
    # Perform Gaussian filtering on 'u' for smoothing
    u = guassian_blur(u, 3, sigma)
    
    return u

In [None]:
def change_lsf(Img, initial_lsf, iter_num, sigma, lambda1, lambda2, alpha1, alpha2): #Perform GLFIF

    if Img.dim() != 2:
        raise Exception("Please enter a grayscale image.")

    if Img.shape != initial_lsf.shape:
        raise Exception("The input image size must match the initial_lsf.")

    if torch.max(Img) <= 1:
        raise Exception("The grayscale values must be between 0 and 1!")

    img_smooth = guassian_blur(Img, 3, sigma) 

    
    # Calculate the gradient
    dy, dx = torch_gradient_like_np(img_smooth)
    f = dy**2 + dx**2
    g = 1 / (1 + f) 
    phi = initial_lsf.clone()

    for n in range(iter_num):
        phi = GLFIF(Img, Img, phi, sigma, lambda1, lambda2, alpha1, alpha2, g)
    
    return phi

In [None]:
def Level_Set(img, iter_num, sigma, lambda1, lambda2, alpha1, alpha2): #Params of GLFIF
    device = img.device
    N, M = img.shape
    initial_lsf = torch.zeros(N, M, dtype=torch.float32).to(device)
    initial_lsf[:, :] = 0.3
    initial_lsf[0:1, 0:1] = 0.7

    params = {
        'Img': img,
        'initial_lsf': initial_lsf,
        'iter_num': iter_num,
        'sigma': sigma,
        'lambda1': lambda1,
        'lambda2': lambda2,
        'alpha1': alpha1,
        'alpha2': alpha2
    }

    return params

In [None]:
def BLS(img,output): #Calculate the segmentation result of Adaptive GLFIF
    output = output.squeeze()
    params = Level_Set(img,torch.tensor([[50]]),torch.tensor([[0.1]]),
                           output[0].float(),output[1].float(),output[2].float(),output[3].float())
    PHI = change_lsf(**params)

    return PHI

In [None]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
import torchvision.models as models
input_channels = 1
input_height = 256
input_width = 256
#Model Set
num_layers = 2
num_heads = 4
hidden_dim = 2048
dropout_rate = 0.1

class AGAModel(nn.Module):  #The model returns Adaptive GLFIF results.
    def __init__(self, input_channels, input_height, input_width, num_layers, num_heads, hidden_dim, dropout_rate):
        super(AGAModel, self).__init__()
        
        resnet50 = models.resnet50(pretrained=False)
        resnet50.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        resnet50.fc = nn.Identity()
        self.feature_extractor = resnet50
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=hidden_dim, dropout=dropout_rate),
            num_layers=num_layers
        )
        
        self.parameter_generator = nn.Sequential(
            nn.Linear(hidden_dim, 256),  
            nn.ReLU(),
            nn.Linear(256, 32), 
            nn.ReLU(),
            nn.Linear(32, 4), 
            nn.Softplus()
        )
        
    def forward(self, x):
        print(x.shape)
        features = self.feature_extractor(x)
        print(features.shape)
        features = features.view(features.size(0), -1, features.size(1))
        print(features.shape)
        transformer_output = self.transformer_encoder(features)
        print(transformer_output.shape)
        parameters = self.parameter_generator(transformer_output) 
        print(parameters.shape)
        print(parameters)
        Im = x.squeeze()
        phi = BLS(Im,parameters)
        return phi

net = AGAModel(input_channels, input_height, input_width, num_layers, num_heads, hidden_dim, dropout_rate).to(device)

In [None]:
net.load_state_dict(torch.load("./model/model_final.pth")) #Load the pretrained model

In [None]:
def show(contours, img): #show result
    fig2, ax2 = plt.subplots(figsize=(8, 6))
    
    ax2.imshow(img)
    
    for n, contour in enumerate(contours):
        ax2.plot(contour[:, 1], contour[:, 0], linewidth=1.2, color='r')
    plt.axis('off')
    plt.show()

In [None]:
from skimage import filters, draw, measure, color
from skimage.morphology import remove_small_objects
from skimage import io, measure, morphology
from scipy.ndimage import label
from sklearn.decomposition import FastICA
from skimage import filters
def img_pre_ls(src,net): #Two-step segmentation algorithm
    image = cv2.imread(src)
    img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img_lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    #Reduce the dimensionality of the image.
    reshaped_img = img_lab.reshape(-1, 3)
    pca = PCA(n_components=1)
    im = pca.fit_transform(reshaped_img).reshape(img.shape[:2])
    normalized_img = (im - np.min(im)) / (np.max(im) - np.min(im))
    N,M = im.shape
    standardized_img = (normalized_img * 255).astype(np.uint8)
    
    plt.axis('off')
    plt.imshow(standardized_img, cmap='gray')
    standardized_img = (normalized_img * 255).astype(np.uint8)
    plt.axis('off')
    plt.imshow(standardized_img, cmap='gray')
    plt.show()
    standardized_img = torch.tensor(standardized_img).unsqueeze(0).unsqueeze(0).float()
    standardized_img = standardized_img.to(device)
    #Adaptive GLFIF Segmentation.
    phi = net(standardized_img)
    phi = phi.cpu().detach().numpy()
    show_fig2(phi,img)
    contours_low = measure.find_contours(phi, 0.5)
    contours_low = DeleteSmall(contours_low,20)
    phi_padded = np.pad(phi, pad_width=1, mode='constant', constant_values=0) #padded for get contours

    contours = measure.find_contours(phi_padded, 0.5)
    contours = DeleteSmall(contours,20)

    print(len(contours))
    tag = np.zeros_like(phi_padded, dtype=int)
    for i, contour in enumerate(contours):
        rr, cc = draw.polygon(contour[:, 0], contour[:, 1])
        tag[rr, cc] = i + 1  
    
    T = np.copy(tag)
    new_image = np.zeros((258, 258, 3), dtype=image.dtype)
    new_image[1:-1, 1:-1, :] = image
    
    new_lab = np.zeros((258, 258, 3), dtype=img_lab.dtype)
    new_lab[1:-1, 1:-1, :] = img_lab
    
    new_img = np.zeros((258, 258, 3), dtype=img.dtype)
    new_img[1:-1, 1:-1, :] = img
    Gray = new_img[:,:,0]

    Gray[tag != 0] = 0
    #Otsu menthod
    threshold = filters.threshold_otsu(Gray)
    binary_image = Gray > threshold
    binary_image[tag != 0] = 0
    plt.imshow(binary_image, cmap='gray')
    plt.show()
    max_label = np.max(tag)
    binary_image_cleaned = remove_small_objects(binary_image, min_size=20)
    plt.axis('off')
    plt.imshow(binary_image_cleaned, cmap='gray')
    plt.show()
    labeled_background, num_features =  morphology.label(binary_image_cleaned, return_num=True, connectivity=2)
    labeled_background[labeled_background > 0] += max_label
    tag[labeled_background != 0] = labeled_background[labeled_background != 0] #Final tag
    print(np.unique(tag))
    print(len(np.unique(tag)))
    plt.imshow(tag, cmap='jet')
    plt.colorbar()
    plt.show()

    return tag,contours,phi_padded,T,binary_image_cleaned,contours_low

In [None]:
src = "./test_image/image_001.png" #test_image
tag,contours,phi_padded,T,binary_image_cleaned,contours_low = img_pre_ls(src,net) #If an error occurs, please execute it again.

In [None]:
    #Show segmentation result.
from skimage.segmentation import find_boundaries
image = cv2.imread(src)
image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
flatten_lab = image_lab.reshape(-1, 3)
pca = PCA(n_components=1)
im = pca.fit_transform(flatten_lab).reshape(image.shape[:2])
normalized_img = (im - np.min(im)) / (np.max(im) - np.min(im))
N,M = im.shape
standardized_img = (normalized_img * 255).astype(np.uint8)

new_image = np.zeros((258, 258, 3), dtype=image.dtype)
new_image[1:-1, 1:-1, :] = image
boundaries = find_boundaries(tag,mode='thin')

image_with_boundaries = np.copy(new_image)

image_with_boundaries[boundaries] = [255, 0, 0]  
plt.figure(figsize=(27,9))
plt.subplot(1,3,1)
plt.axis('off')
plt.imshow(image, cmap='jet')
plt.subplot(1,3,2)
plt.axis('off')
plt.imshow(new_image)
for contour in contours:
    plt.plot(contour[:, 1], contour[:, 0], linewidth=1.5, color='r')
plt.subplot(1,3,3)
plt.axis('off')
plt.imshow(new_image)
for label in np.unique(tag):
    if label == 0:
        continue  
    contours_n = measure.find_contours(tag == label, level=0.5)
    for contour in contours_n:
        plt.plot(contour[:, 1], contour[:, 0], linewidth=1.5, color='r')


