In [1]:
"""
    clustering.ipynb
    Functions for implementing k-means clustering for quantizing the colors in an image.
    Mustafa Siddiqui
"""
pass

In [1]:
from inc import IP
%matplotlib inline

In [20]:
def gray2rgb(img,nchannels):
    ''' Convert grayscale image to RGB image. '''
    out = IP.np.stack((img,)*3, axis=-1)
    
    return out

In [4]:
class Color:
    ''' A class to store RGB values of a color and provide
        some simple methods.
    '''
    
    def __init__(self, red, green, blue):
        ''' Constructor for class. '''
        self.red = red
        self.green = green
        self.blue = blue
    
    def getDistance(self, anotherColor):
        ''' Returns the Euclidean distance on the RGB color cube model. '''
        redDiff = (self.red - anotherColor.red)**2
        greenDiff = (self.green - anotherColor.green)**2
        blueDiff = (self.blue - anotherColor.blue)**2
        
        return int(IP.np.sqrt(redDiff + blueDiff + greenDiff))
    
    def __str__(self):
        return "red: " + str(self.red) + ", green: " + str(self.green) + ", blue: " + str(self.blue)

In [5]:
def getColorFromPixel(img, coor):
    ''' Given a tuple of x,y coordinates (x, y), return the RGB values
        of the pixel at those coordinates in an image.
    '''
    
    red = int(img[coor[0], coor[1], 0])
    green = int(img[coor[0], coor[1], 1])
    blue = int(img[coor[0], coor[1], 2])
    
    return Color(red, green, blue)

In [6]:
def colorDistinctOrNot(color1, color2):
    ''' Helper function to check if two colors are distinct
        enough based on their euclidean distance on the RGB cube.
    '''
    
    d = color1.getDistance(color2)
    if (d > 50):
        return True
    
    return False

In [14]:
def chooseRandomColors(img, k):
    ''' Choose k random but distinct colors from the image.
        
        @return a set of k color objects
    '''
    l, w, numColors = img.shape
    
    # choose k distinct somewhat contrasting colors at random
    # first get random pixel coordinates and then obtain color
    colors = set()
    for i in range(k):
        distinct = False;
        while (distinct == False):
            x = IP.np.random.randint(1, l)
            y = IP.np.random.randint(1, w)
            coor = (x, y)
            color = getColorFromPixel(I, coor)
            
            # if set is not empty, check if new color is
            # somewhat distinct from colors already in set
            _distinct = True
            if (len(colors) != 0):
                for c in colors:
                    if (colorDistinctOrNot(c, color) == False):
                        # if not distinct with one, no need to check with others
                        _distinct = False
                        break

                # distinct from all colors currently stored
                if (_distinct):
                    distinct = True
            else:
                break
            
            # if color not distinct, pick another color
            if (distinct == False):
                continue
        
        # add distinct color
        colors.add(color)
        
    return colors

In [8]:
def getMinKey(distances):
    ''' Takes in a hashmap with Color object as key and the distance 
        on the RGB cube with a pixel (pre-calculated) as the value and
        returns the key (Color object) with the minimum distance.
        
        @return Color object
    '''
    
    # get centroid which has the lowest distance (aka color difference)
    minDiff = 9999
    minKey = Color(0,0,0)
    for key in distances:
        if distances[key] < minDiff:
            minDiff = distances[key]
            minKey = key
            
    return minKey

In [15]:
def getIdealCentroids(I, centroidDict):
    ''' Implement k-means algorithm on image initially with centroid of
        random but distinct colors.
        
        @param  I: original colored image
                centroidDict: intial chosen centroids (basically colors)
        
        @return hashmap of centroids (centroids = Color objects)
    '''
    l, w, numColors = I.shape
    
    changeInCentroids = True
    while (changeInCentroids):
        # iterate over the image and populate centroid dictionary
        for i in range(l):
            for j in range(w):
                # get color from pixel
                pixelColor = getColorFromPixel(I, (i, j))

                # populate hashmap with distances between colors based on RGB cube
                distances = {}
                for centroid in centroidDict:
                    if (centroid not in distances):
                        distances[centroid] = pixelColor.getDistance(centroid)

                # get centroid which has the lowest distance (aka color difference)
                minKey = getMinKey(distances)

                # add RGB color to centroid list
                centroidDict[minKey].append(pixelColor)

        # get mean of red, green, blue values separately and create new
        # color object with those values and update centroid values
        redSum = 0
        greenSum = 0
        blueSum = 0
        newDict = {}
        changeInCentroids = False
        for key in centroidDict:
            for color in centroidDict[key]:
                redSum += color.red
                greenSum += color.green
                blueSum += color.blue

            redSum /= len(centroidDict[key])
            greenSum /= len(centroidDict[key])
            blueSum /= len(centroidDict[key])

            # add new centroid to new dict
            newCentroid = Color(int(redSum), int(greenSum), int(blueSum))
            newDict[newCentroid] = []

            # update flag if new centroid is different from prev one
            if (key.getDistance(newCentroid) > 50):
                changeInCentroids = True

        # replace centroid dict with new one
        centroidDict = newDict
        
    return centroidDict

In [17]:
def quantizeImage(img, k):
    ''' Quantize a colored image by reducing the number of
        colors in the image to k colors.
        Essentially does so by implementing the k-means clustering
        algorithm on the image.
        
        @return color-quantized image
    '''
    
    I = IP.np.copy(img)
    l, w, numColors = I.shape
    
    # get k random colors from image
    colorsCentroid = chooseRandomColors(I, k)
    
    # create hashmap with centroids as keys and empty lists
    # list for a centroid will store the colors which have the minimum difference
    # with that centroid
    centroidDict = {}
    for c in colorsCentroid:
        centroidDict[c] = []
    
    centroidDict = getIdealCentroids(I, centroidDict)
    
    # update the colors in the image based on ideal centroids
    for i in range(l):
        for j in range(w):
            # get color from pixel
            pixelColor = getColorFromPixel(I, (i, j))

            # populate hashmap with distances between colors based RGB cube
            distances = {}
            for centroid in centroidDict:
                if (centroid not in distances):
                    distances[centroid] = pixelColor.getDistance(centroid)

            # get centroid which has the lowest distance (aka color difference)
            minKey = getMinKey(distances)

            # update colors 
            I[i, j, 0] = minKey.red
            I[i, j, 1] = minKey.green
            I[i, j, 2] = minKey.blue
            
    return I