In [1]:
class Fundus():
    # Constructor
    def __init__(self, initarg=False, **kwargs):
        
        if isinstance(initarg, str):
            self.image = self.image_from_file(initarg)
        
        if isinstance(initarg, np.ndarray):
            self.image = self.image_from_pixels(initarg, **kwargs)
        
        # Getting number of pixels
        self.npixels = np.prod(self.image.size)
        self.size = self.image.size
        self.c = self.image.getbands()
        self.c_size = len(self.c)
        self.palette, self.counts = self.get_palette()
        
    # Constructor from file
    def image_from_file(self, path):
        return Image.open(path)
    
    # Constructor from pixels
    def image_from_pixels(self, pixels, **kwargs):
        arr = np.resize(pixels, (kwargs["w"], kwargs["h"], 3)).astype(np.uint8)
        return Image.fromarray(arr)
    
    # Get numpy array for the image
    def get_array(self):
        return np.asanyarray(self.image)
    
    # Get channel numpy array 
    def get_channels_asarray(self):
        r, g, b = self.image.split()
        return np.asanyarray(r), np.asanyarray(g), np.asanyarray(b)
    
    # Get an scpecific channel as an array
    def get_channel(self, channel):
        return np.asanyarray(self.image.getchannel(channel))
    
    # Transforme the image to a list of pixels
    def get_pixels(self):
        return np.resize(self.get_array(), (self.npixels, self.c_size))
    
    # Ignore, black pixels from list of pixels
    def ignore_black_in_channel(self):
        R,G,B = self.get_pixels().T
        R = np.delete(R, np.where(R == 0))
        G = np.delete(G, np.where(G == 0))
        B = np.delete(B, np.where(B == 0))
        return R, G, B
        
    # Get a unique list of pixels amd their counts (palette)
    def get_palette(self):
        ps = np.sort(self.get_pixels(), axis=0)
        return np.unique(ps, return_counts=True, axis=0)
                         
    # Plot color counts
    def plot_counts(self):
        # Configure subplots
        f, ax = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
        sns.despine(left=True)
                
        # Plot
        # TODO:  change xlabels to colors
        sns.lineplot(data=self.counts, ax=ax[0])
        sns.lineplot(data=np.sort(self.counts), ax=ax[1])
        
        # Title
        ax[0].set_title("Color sorted")
        ax[1].set_title("Count sorted")
    
    def color_bar(self, colors):
        cmap = mpl.colors.ListedColormap(colors)
        
        # Get hex values for 
        hexcol = np.array([mpl.colors.rgb2hex(x) for x in colors])
        
        # Get figure
        fig, ax = plt.subplots(1, 1, figsize=(10, 2))
        fig.subplots_adjust(bottom=0.25)
        
        # Calculate bounds
        bounds = range(cmap.N + 1)
        norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
        
        # Plot color bar
        bar = mpl.colorbar.ColorbarBase(ax=ax,
                                       cmap=cmap,
                                       norm=norm,
                                       boundaries=bounds,
                                       extend="neither",
                                       ticks=None,
                                       ticklocation="top",
                                       drawedges=False,
                                       spacing="uniform",
                                       filled=True,
                                       orientation="horizontal")
        bar.set_ticklabels(hexcol)
        
    
    def plot_palette(self):        
        # Transform 0-255 RGB to 0-1 RGB
        colors = self.palette/255
        
        # plot color bar
        self.color_bar(colors)
    
    
    def plot_cbar(self, nc=5):        
        # Transform 0-255 RGB to 0-1 RGB
        colors = self.palette[np.argsort(self.counts)[::-1][0:nc]]/255
        
        # Plot color bar
        self.color_bar(colors)
    

    # Plots histogram of all color chanels separately and alltogether
    def plot_histogram(self, ignore=True):
        # Set figure
        sns.set(style="white", palette="muted", color_codes=True)
        fig, ax = plt.subplots(2, 2, figsize=(12, 12), sharex=True)
        sns.despine(left=True)
        
        # Remove black color
        if ignore:
            R, G, B = self.ignore_black_in_channel()
        else:
            R, G, B = self.get_pixels().T
        
        # Plot  00,01 and 10 (separate channels)
        sns.distplot(R, color="r", ax=ax[0, 0])
        sns.distplot(G, color="g", ax=ax[0, 1])
        sns.distplot(B, color="b", ax=ax[1, 0])
               
        # Plot the 3 distributions together
        sns.distplot(R, color="r", hist=False, kde_kws={"shade": True}, ax=ax[1, 1])
        sns.distplot(G, color="g", hist=False, kde_kws={"shade": True}, ax=ax[1, 1])
        sns.distplot(B, color="b", hist=False, kde_kws={"shade": True}, ax=ax[1, 1])

    def plot_lines(self):
        # Plot color distribution of oroginal image
        fig, ax = plt.subplots(1,2, figsize=(12, 5), sharey=True)
        sns.despine(left=True)
        
        # Get list of all pixels
        pixels_sorted = np.sort(self.get_pixels(), axis=0)
        ax[0].plot(pixels_sorted)
        ax[0].set_title("All pixels")

        ax[1].plot(self.get_palette()[0])
        ax[1].set_title("Unique pixels")
        
    def get_summary(self):
        df_palette = pd.DataFrame(self.palette, columns=self.c)
        df_summary = pd.DataFrame()
        df_summary["mean"]  = df_palette.mean()
        df_summary["std"]  = df_palette.std()
        df_summary["var"]  = df_palette.var()
        df_summary["max-min"]  = df_palette.max()-df_palette.min()
        
        return df_summary
    
    def replace_pixels(self, colors2replace, replacement = [0, 0, 0]):
        """
        Tales a code of a lost of color codes in RGB 0-255 format.
        And replaces all of those pixels with a given one (default black).
        """
        pixels = self.get_pixels()
        for color in colors2replace:
            pixels[(self.get_pixels() == color).all(axis = 1)] = replacement
        
        return pixels
        

In [2]:
# Get 5 more representative colors in the image
original.image.getpixel((892, 426))

NameError: name 'original' is not defined

In [None]:
print(original.get_channels_asarray()[0][426][892])
print(original.get_channels_asarray()[1][426][892])
print(original.get_channels_asarray()[2][426][892])

original.get_array()[426][892]

In [None]:
colors = original.palette[np.argsort(original.counts)[::-1]][0:5]/255
original.color_bar(colors)

In [None]:
# Get palette from the image, and sort it by color frequency
#palette = original.palette
#counts = original.counts
palette_sorted = original.palette[np.argsort(original.counts)[::-1]]

In [None]:
# Remove pixels for the most common pixels (usually gray)
print(len(palette_sorted))
arr_new = original.replace_pixels(palette_sorted[:5], [0,0,0])

In [None]:
# Generate new figure blackening colors 
new = Fundus(arr_new, w=original.size[1], h=original.size[0])
new.image

# Here we start modifying the image

In [None]:
# Get color palette for the new image
new_palette, new_counts = new.get_palette()
new_palette

In [None]:
# Get dendogram at pixel level
plt.figure(figsize=(10, 10))  
plt.title("Dendrograms")  
dend = shc.dendrogram(shc.linkage(new_palette, method='ward'))

In [None]:
# Merging 2 closest colors
cluster = AgglomerativeClustering(n_clusters=len(new_palette)//2, affinity='euclidean', linkage='ward')  
clustered = cluster.fit_predict(new_palette)
print("Total number of colors before",len(clustered))
print("Total number of colors after ",len(np.unique(clustered)))

In [None]:
# Get RGB values of the colors  to merge
clustered

In [None]:
np.unique(clustered, return_counts=True)

In [None]:
pixels = new.get_pixels()

for i in range(len(np.unique(clustered))):
    colors2replace = new_palette[clustered == i]
    replacement = np.mean(new_palette[clustered==i], axis=0, dtype=int)

    for color in colors2replace:
        pixels[(pixels == color).all(axis = 1)] = replacement

In [None]:
# Generate new figure blackening colors 
new2 = Fundus(pixels, w=original.size[1], h=original.size[0])
new2.image

In [None]:
new2_palette, new_counts = new2.get_palette()

In [None]:
len(new2_palette)