In [9]:
import numpy as np
from PIL import Image
from math import floor
import matplotlib.pyplot as plt

In [11]:
class Pixel:
    def __init__(self, r, g, b):
        self.r = r
        self.g = g
        self.b = b
        r, g, b = r/255, g/255, b/255
        minimum = np.min([r,g,b])
        maximum = np.max([r,g,b])
        if(maximum==r):
            hue = (g-b)/(maximum-minimum)
        elif(maximum==g):
            hue = 2+ (b-r)/(maximum-minimum)
        else:
            hue = 4+(r-g)/(maximum-minimum)
        self.hue = hue*60
        self.lum = 0.299*r + 0.587*g + 0.114*b
        if(self.hue>0):
            self.hue = floor(hue)
        else:
            self.hue = floor(360+hue)
    
    def hue_distance(self,pixel1):
        ans = min(abs(pixel1.hue-self.hue), 360-abs(pixel1.hue-self.hue))
        return ans

    def euclid_distance(self, pixel1):
        return np.sqrt((self.r - pixel1.r)**2 + (self.g - pixel1.g)**2 + (self.b - pixel1.b)**2)
    
    def average(self,pixel1):
        return Pixel((self.r+pixel1.r)/2,(self.g+pixel1.g)/2,(self.b+pixel1.b)/2)
    
    def sum(self, pixel1):
        self.c+=1
        return Pixel(self.r+pixel1.r,self.g+pixel1.g,self.b+pixel1.b)
    
    def div(self):
        self.r,self.g,self.b = self.r/self.c,self.g/self.c,self.b/self.c
    
    def print(self):
        print(f"R: {self.r} G: {self.g} B: {self.b}")
        
    def value(self):
        return [self.r,self.g,self.b]
    def __lt__(self, pixel1):
        if self.r>pixel1.r:
            return True
        else:
            return False

    

In [12]:
class Picture:
    def __init__(self):
        pass

    def load_pixels(self, path):
        im = Image.open(path)
        self.pixels = list(im.getdata())
        self.width, self.height = im.size
        self.mode = im.mode

    def convert_pixels(self):
        self.pixels = np.array(self.pixels)
        if self.mode == "RGB":
            self.channels = 3
        elif self.mode == "RGBA":
            self.channels = 4
        else:
            self.channels = 1
        self.pixels = self.pixels.reshape(
            (self.height, self.width, self.channels))


In [13]:
class Palette:
    def __init__(self):
        self.palette = []
        self.count = []
        self.cutoff=1

    def generate_palette(self, picture):
        pixels = picture.pixels
        print(pixels.shape)
        for i in range(pixels.shape[0]):
            for j in range(pixels.shape[1]):
                pixel = Pixel(pixels[i][j][0],pixels[i][j][1],pixels[i][j][2])
                min = [-1,self.cutoff]
                for k in range(len(self.palette)):
                    dis = self.palette[k].hue_distance(pixel)
                    if(dis<=min[1]):
                        min[0] = k
                        min[1] = dis
                if(min[0]!=-1):
                    self.palette[min[0]] = self.palette[min[0]].average(pixel)
                    self.count[min[0]] = self.count[min[0]]+1
                else:
                    self.palette.append(pixel)
                    self.count.append(1)
    '''
    def sort_count(self):
        self.count,self.palette = (list(t) for t in zip(*sorted(zip(self.count,self.palette))))
        
    def sort_hue(self):
        self.hue,self.palette = (list(t) for t in zip(*sorted(zip(self.hue,self.palette))))
    
    def sort_lum(self):
        self.lum,self.palette = (list(t) for t in zip(*sorted(zip(self.lum,self.palette))))
    '''
                    
    def print_palette(self):
        self.im = Image.new(mode="RGB",size=(100*len(self.palette),200))
        for i, color in enumerate(self.palette):
            self.im.paste((floor(color.r),floor(color.g),floor(color.b)),(i*100, 0, (i+1)*100, 200))
        self.im.show()
        
            

In [14]:
image = Picture()
image.load_pixels("lenna.png")
image.convert_pixels()
palette = Palette()
palette.generate_palette(image)
palette.print_palette()

(512, 512, 3)


In [None]:
class Histogram:
    def __init__(self,image):
        self.w, self.h = image.size
        self.colors = image.getcolors(self.w*self.h)
    def hexencode(self,rgb):
        r=rgb[0]
        g=rgb[1]
        b=rgb[2]
        return '#%02x%02x%02x' % (r,g,b)
    def generate(self):
        print(len(self.colors))
        for idx,c in enumerate(self.colors):
            plt.bar(idx,c[0],color=self.hexencode(c[1]))
            print(idx, end=' ')
        plt.show()