In [72]:
import numpy as np
from PIL import Image

In [73]:
class Pixel:
    def __init__(self, r, g, b):
        self.r = r
        self.g = g
        self.b = b

    def euclid_distance(self, pixel1):
        return np.sqrt((self.r - pixel1.r)**2 + (self.g - pixel1.g)**2 + (self.b - pixel1.b)**2)
    
    def average(self,pixel1):
        return Pixel((self.r+pixel1.r)/2,(self.g+pixel1.g)/2,(self.b+pixel1.b)/2)
    
    def print(self):
        print(f"R: {self.r} G: {self.g} B: {self.b}")
        
    def value(self):
        return [self.r,self.g,self.b]

    

In [74]:
class Picture:
    def __init__(self):
        pass

    def load_pixels(self, path):
        im = Image.open(path)
        self.pixels = list(im.getdata())
        self.width, self.height = im.size
        self.mode = im.mode

    def convert_pixels(self):
        self.pixels = np.array(self.pixels)
        if self.mode == "RGB":
            self.channels = 3
        elif self.mode == "RGBA":
            self.channels = 4
        else:
            self.channels = 1
        self.pixels = self.pixels.reshape(
            (self.height, self.width, self.channels))


In [100]:
class Palette:
    def __init__(self):
        self.palette = []
        self.count = []
        self.cutoff=75

    def generate_palette(self, picture):
        pixels = picture.pixels
        print(pixels.shape)
        for i in range(pixels.shape[0]):
            for j in range(pixels.shape[1]):
                pixel = Pixel(pixels[i][j][0],pixels[i][j][1],pixels[i][j][2])
                min = [-1,self.cutoff]
                for k in range(len(self.palette)):
                    dis = self.palette[k].euclid_distance(pixel)
                    if(dis<=min[1]):
                        min[0] = k
                        min[1] = dis
                if(min[0]!=-1):
                    self.palette[min[0]].average(pixel)
                    self.count[min[0]] = self.count[min[0]]+1
                else:
                    self.palette.append(pixel)
                    self.count.append(1)
    
    def sort_palette(self):
        self.count,self.palette = (list(t) for t in zip(*sorted(zip(self.count,self.palette))))
                    
    def print_palette(self):
        print(self.count)
        self.im = Image.new(mode="RGB",size=(100*len(self.palette),200))
        for i, color in enumerate(self.palette):
            self.im.paste((color.r,color.g,color.b),(i*100, 0, (i+1)*100, 200))
        self.im.show()
            

In [103]:
image = Picture()
image.load_pixels("pexels-photo-1632790.png")
image.convert_pixels()
palette = Palette()
palette.generate_palette(image)
palette.sort_palette()
palette.print_palette()

(277, 183, 4)
[102, 320, 698, 739, 2974, 3092, 3304, 3342, 5707, 14801, 15612]
