In [206]:
import numpy as np
import cv2
from scipy.spatial.distance import cdist
from scipy import signal
import matplotlib.pyplot as plt
from skimage import data, filters, color, io, feature
from skimage.color import rgb2gray
from skimage.morphology import disk
from skimage.draw import disk
from skimage.filters import gaussian
from skimage import morphology
import warnings
from skimage.segmentation import chan_vese
from skimage.segmentation import active_contour
from scipy import fft, ifft
from skimage.transform import hough_line, probabilistic_hough_line
from sklearn.cluster import KMeans
from skimage.color import label2rgb
from skimage.filters import sobel
from skimage.measure import label
from skimage.segmentation import expand_labels
from skimage.segmentation import felzenszwalb, slic, quickshift, watershed
from skimage.util import img_as_float
from skimage.segmentation import mark_boundaries
from matplotlib.path import Path
from shapely.geometry import Polygon
from scipy.spatial import ConvexHull
import os
from skimage import img_as_ubyte

In [187]:
def calculate_threshold(length, segment_counts):
    min_count = np.min(segment_counts)
    max_count = np.max(segment_counts)
    mean_count = np.mean(segment_counts)
    threshold = (mean_count + (mean_count / max_count)) / length**(mean_count/max_count)
    return threshold

In [188]:
import numpy as np
from scipy.spatial.distance import cdist
from skimage.draw import disk

def replace_segments_circle(segments, small_radius):
    shape = (512, 512)
    mas_small = np.zeros(shape, dtype=np.uint8)

    for i in range(shape[0] - small_radius, 0, -small_radius//2):
        for j in range(shape[1] - small_radius, 0, -small_radius//2):
            mas_small = np.zeros(shape, dtype=np.uint8)
            center_small = (i + small_radius // 2, j + small_radius // 2)
            [rr_small, cc_small] = disk(center_small, small_radius // 2, shape=shape)
            mas_small[rr_small, cc_small] = 1

            part_of_segments = segments[mas_small == 1]

            if len(part_of_segments) > 0:
                unique_segments, segment_counts = np.unique(part_of_segments, return_counts=True)
                threshold_value = calculate_threshold(len(unique_segments), segment_counts)
                segments_to_change = unique_segments[segment_counts < threshold_value]
                segments_to_replace = unique_segments[segment_counts == np.max(segment_counts)]

                for idx, segment in enumerate(unique_segments):
                    if segment in segments_to_change:
                        segment = segment.reshape(1, -1)
                        distances = cdist(segment, segments_to_replace.reshape(-1, len(segment[0])), 'euclidean')[0]
                        closest_segment_index = np.argmin(distances)
                        replacement_segment = segments_to_replace[closest_segment_index]
                        part_of_segments[part_of_segments == segment[0]] = replacement_segment

                segments[mas_small == 1] = part_of_segments

    return segments


In [189]:
def compute_segment_centers(original_segments):
    unique_segments, segment_counts = np.unique(original_segments, return_counts=True)
    segment_centers = []
    segment_dict = {}

    for seg in unique_segments:
        if seg >= 0:
            segment_indices = np.where(original_segments == seg)
            segment_center = tuple(np.round(np.mean(np.array(segment_indices), axis=1)).astype(int))
            segment_centers.append(segment_center)

            if segment_center not in segment_dict:
                segment_dict[segment_center] = seg

    segment_centers = np.array(segment_centers)
    return segment_centers, segment_dict

def split_centers(original_segments):
    centers, segment_dict = compute_segment_centers(original_segments)
    unique_segments, segment_counts = np.unique(original_segments, return_counts=True)
    order = np.argsort(segment_counts)
    sorted_segments = unique_segments[order]
    threshold = calculate_threshold(unique_segments, segment_counts)
    segments_to_replace = sorted_segments[segment_counts[order] < threshold]
    segments_replacement = sorted_segments[segment_counts[order] > threshold]

    centers_to_replace = [center for seg, center in zip(unique_segments, centers) if seg in segments_to_replace]
    centers_replacement = [center for seg, center in zip(unique_segments, centers) if seg in segments_replacement]

    return np.array(centers_to_replace), np.array(centers_replacement), segment_dict


In [190]:
import numpy as np
import matplotlib.pyplot as plt

def plot_segment_centers(original_segments):
    centers_to_replace, centers_replacement, dicti = split_centers(original_segments)
    centers, _ = compute_segment_centers(original_segments)
    unique_segments = np.unique(original_segments)

    plt.figure(figsize=(10, 10))

    colored_segments = np.zeros_like(original_segments, dtype=np.float)

    for seg, center in zip(unique_segments, centers):
        if seg >= 0:
            color = np.mean(original_segments[np.where(original_segments == seg)])
            colored_segments[original_segments == seg] = plt.cm.viridis(color / 255)[1]

            if center.tolist() in centers_to_replace.tolist():
                plt.scatter(center[1], center[0], color='red', marker='x', s=100)
            elif center.tolist() in centers_replacement.tolist():
                plt.scatter(center[1], center[0], color='blue', marker='o', s=100)
            else:
                plt.scatter(center[1], center[0], color=plt.cm.viridis(color / 255), s=100)

    plt.imshow(colored_segments)
    plt.legend()
    plt.title('Segment Centers')
    plt.axis('off')
    plt.show()


In [191]:
def distance(point1, point2):
    return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)

In [192]:
def check_inside(img, seg_large, seg_small, seed, no_seed):
    segment_large_locations = np.where(img == seg_large)
    segment_small_locations = np.where(img == seg_small)
    segment_large_locations = np.array(list(set(zip(segment_large_locations[0], segment_large_locations[1]))))
    segment_small_locations = np.array(list(set(zip(segment_small_locations[0], segment_small_locations[1]))))

    contours_large = []
    contours_small = []
    hull_large = ConvexHull(segment_large_locations)
    hull_small = ConvexHull(segment_small_locations)
    # fig, ax = plt.subplots(figsize=(5.12, 5.12))
    for simplex in hull_large.simplices:
        # plt.plot(segment_large_locations[simplex, 0], segment_large_locations[simplex, 1], 'k-')
        mas = segment_large_locations[simplex, 0:2]
        contours_large.append(mas[0])
        contours_large.append(mas[1])
    for simplex in hull_small.simplices:
        # plt.plot(segment_small_locations[simplex, 0], segment_small_locations[simplex, 1], 'k-')
        mas = segment_small_locations[simplex, 0:2]
        contours_small.append(mas[0])
        contours_small.append(mas[1])
    radius = []
    for i in contours_large:
        radius.append(distance(seed, i))
    return (seed[0]-no_seed[0])**2+(seed[1]-no_seed[1])**2 <= np.min(radius)**2

In [193]:
def segment_radius(img, seg, center):
  segment_locations = np.where(img == seg)
  segment_locations = np.array(list(set(zip(segment_locations[0], segment_locations[1]))))
  first_index = segment_locations[0][0]
  last_index = segment_locations[0][1]
  all_same_first = all(item[0] == first_index for item in segment_locations)
  all_same_last = all(item[1] == last_index for item in segment_locations)
  if len(segment_locations) >= 3 and all_same_first == False and all_same_last == False:
    contours = []
    hull = ConvexHull(segment_locations)
    for simplex in hull.simplices:
          mas = segment_locations[simplex, 0:2]
          contours.append(mas[0])
          contours.append(mas[1])
    radius = []
    for i in contours:
      radius.append(distance(center, i))
    return np.mean(radius)
  else:
    return 1

In [194]:
def finding_the_strongest_pleasures(segments, centers_to_replace, centers_replacement, segment_dict):
  unique_segments, segment_counts = np.unique(segments, return_counts=True)
  segment_weight = {i:j for i, j in zip(unique_segments, segment_counts)}
  center_weight = {}
  center_weight.update({key: segment_weight[value] for key, value in segment_dict.items()})
  strongest_centers = {tuple(center): [] for center in centers_replacement}
  for center_tr in centers_to_replace:
        max_F = 0
        R = segment_radius(segments, segment_dict[center_tr], center_tr)
        F_seg = (center_weight[center_tr] * center_weight[center_tr])/(R**2)
        strongest_center = None
        for center_r in centers_replacement:
            dist = distance(center_r, center_tr)
            F = (center_weight[center_tr]*center_weight[center_r])/(dist**2)
            if  F > F_seg and F > max_F:
                max_F = F
                strongest_center = center_r
        if strongest_center != None:
            strongest_centers[strongest_center].append(center_tr)
  return strongest_centers

In [195]:
def region_growing_inside(img, seed, no_seeds, segment_dict):
    new_no_seeds = []
    for no_seed in no_seeds:
        if check_inside(img, segment_dict[seed], segment_dict[no_seed], seed, no_seed):
            img[img == segment_dict[no_seed]] = segment_dict[seed]
        else:
            new_no_seeds.append(no_seed)
    return img, new_no_seeds

In [196]:
def region_growing(img, center_seed, seed, strongest_centers, segment_dict):
    unique_segments, segment_counts = np.unique(img, return_counts=True)
    # Initialize the region with the seed point
    seeds = np.where(img == segment_dict[center_seed])
    seeds = list(set(zip(seeds[0], seeds[1])))
    region = np.zeros(img.shape, dtype=bool)
    region[seed] = True
    no_centers = strongest_centers[center_seed]
    # Initialize list of points to be examined, start with the seed point
    for i_seed in seeds:
        points_to_check = [i_seed]
    # 8-connectivity (connects diagonals as well)
    connectivity = [(-1, 0), (1, 0), (0, -1), (0, 1),(-1, 1), (1, 1), (1, -1), (-1, -1)]

    while points_to_check:
        point = points_to_check.pop()
        for conn in connectivity:
            # Get a new point by applying connectivity
            new_point = (point[0] + conn[0], point[1] + conn[1])

            # Check if the new point is within image boundaries
            if (new_point[0] >= 0 and new_point[0] < img.shape[0] and
                    new_point[1] >= 0 and new_point[1] < img.shape[1]):
                # Check if the new point is similar to the seed and if it's not already in the region
                if ((img[new_point] == img[point] or img[new_point] in [segment_dict[no_center] for no_center in no_centers] or img[new_point] == 0) and not region[new_point]):
                    # Add the point to the region
                    region[new_point] = True
                    segment_dict[new_point] = segment_dict[center_seed]
                    points_to_check.append(new_point)
    return region


In [197]:
def centers_n_seeds(segments):
    centers_to_replace, centers_replacement, segment_dict = split_centers(segments)
    strongest_centers = finding_the_strongest_pleasures(segments, [tuple(row) for row in centers_to_replace.tolist()], [tuple(row) for row in centers_replacement.tolist()], segment_dict)
    seg_seeds = [segment_dict[pos] for pos in [tuple(row) for row in centers_replacement.tolist()]]
    positions = [np.where(segments == seg_seed) for seg_seed in seg_seeds]
    random_positions = [np.random.choice(range(len(pos[0]))) for pos in positions]
    seeds = [(pos[0][rand], pos[1][rand]) for pos, rand in zip(positions, random_positions)]
    center_seeds = [tuple(row) for row in centers_replacement.tolist()]
    no_seeds = [tuple(row) for row in centers_to_replace.tolist()]
    return segment_dict, strongest_centers, seeds, center_seeds, no_seeds

In [201]:
def iterate_region_growing(img, scale=100, sigma=0.4, min_size=100, iterations=30):
    segments = felzenszwalb(img, scale=scale, sigma=sigma, min_size=min_size)
    # plot_segment_centers(segments)

    result = np.zeros_like(segments)
    segment_dict, strongest_centers, seeds, center_seeds, no_seeds = centers_n_seeds(segments)
    print("Iteration 0: number of segments =", len(np.unique(segments)))
    u_pred = 0
    for _ in range(iterations):
        for j in range(len(seeds)):
            result_growing = region_growing(segments, center_seeds[j], seeds[j], strongest_centers, segment_dict)
            new_segments = result_growing & (result == 0)
            result[new_segments] = segment_dict[center_seeds[j]]
        # plot_segment_centers(result)
        segment_dict, strongest_centers, seeds, center_seeds, no_seeds = centers_n_seeds(result)
        num_seeds = len(seeds)
        segments = result
        result = np.zeros_like(segments)
        segment_dict, strongest_centers, seeds, center_seeds, no_seeds = centers_n_seeds(segments)
        print(f"Iteration {_ + 1}: number of segments =", len(np.unique(segments)))
        if len(no_seeds) == 0 or u_pred == len(np.unique(segments)) or len(np.unique(segments)) < 10:
            break
        u_pred = len(np.unique(segments))
    return segments

In [199]:
images = []
for i in range(1, 21):
    filename = f"/content/drive/MyDrive/Grayscale_data/tm{i}_1_1.png"
    image = io.imread(filename)
    new_img = image.copy()
    h, w = new_img.shape
    for i in range(h):
      for j in range(w):
        new_img[i,j] = np.round((255/(1 + np.exp(-image[i,j]/255))))
    images.append(new_img)

In [207]:
final_results = []
folder_path = '/content/sample_data/output/'
if not os.path.exists(folder_path):
    os.makedirs(folder_path)
for i, image in enumerate(images):
    file_name = f"seg{i+1}_1_1.png"
    print(f"Image_{i+1}")
    segmented_img = iterate_region_growing(image)
    segmented_img = img_as_ubyte(segmented_img)
    io.imsave(folder_path + file_name, segmented_img)
    final_results.append(segmented_img)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    axes[0].imshow(image, cmap='gray')
    axes[0].set_title('Original Image')

    axes[1].imshow(segmented_img, cmap='viridis')
    axes[1].set_title('Segmented Image')

    plt.show()

Output hidden; open in https://colab.research.google.com to view.