In [1]:
import numpy as np
import matplotlib.pyplot as plt

import cv2
import os
import time
import pyautogui
import math

from enum import Enum

# Custom Classes

In [10]:
class Position:
    def __init__(self, y, x):
        self.y = y
        self.x = x
        
    def __eq__(self, other):
        return self.x == other.x and self.y == other.y
    
    def __lt__(self, other):
        if(self.y < other.y):
            return True
        else:
            if(self.x < other.x):
                return True
            else:
                return False
    
    def __str__(self):
        return f"({self.x},{self.y})"
    
    def __hash__(self):
        return hash((self.x, self.y))

In [35]:
class Label(Enum):
    BG = 0
    APPL = 1
    BODY = 2
    HEAD = 3
    
n_label = len(list(Label))
label_str = [name for name, member in Label.__members__.items()]

nrows = 15
ncols = 17

# Methods

## Image

In [2]:
def open_img(img_path):
    img = cv2.imread(img_path)
    
    (left, upper, right, lower) = (607, 302, 1295, 910)
    img_cropped = img[upper:lower, left:right]
    
    return img_cropped

In [3]:
def crop_img(img):
    (upper_t, lower_t, left_t, right_t) = (9, 31, 9, 31)
    img_crp = img[upper_t:lower_t, left_t:right_t]
    
    return img_crp

In [4]:
def show_img(img):
    cv2.imshow('windows',img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

## Histogram

In [5]:
def calc_hist(img):
    chans = cv2.split(img)
    colors = ("b", "g", "r")
    
    hist_arr = []
    
    # loop over the image channels
    for (chan, color) in zip(chans, colors):        
        hist = cv2.calcHist([chan], [0], None, [128], [0, 256])
        norm_hist = cv2.normalize(hist, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        hist_arr.append(norm_hist)
    
    return hist_arr

In [6]:
def calc_all_hist():
    hist_all_label = []
    img_all_label = []
    
    labeled_folder_path = "../img_gallery"
    
    for _label in list(Label):
        label = _label.value
        labeled_path = f"{labeled_folder_path}/{label}/"

        hist_label = []        
        img_label = []
        for filename in os.listdir(labeled_path):
            img = cv2.imread(labeled_path + filename)
            img_crp = crop_img(img)
            
            hist_img = calc_hist(img_crp)
            hist_label.append(hist_img)
            img_label.append(img_crp)
        hist_all_label.append(hist_label)
        img_all_label.append(img_label)
        
    return hist_all_label, img_all_label

## Image processing

In [37]:
def preprocessing(img):
    bg_path = "../img_raw/bg.png"
    bg = open_img(bg_path)
    
    # background subtraction
    img_sub = cv2.subtract(img, bg)
    
    # masking the image with the background
    img_gray = cv2.cvtColor(img_sub, cv2.COLOR_BGR2GRAY)
    ret,thresh1 = cv2.threshold(img_gray,5,255,cv2.THRESH_BINARY)

    kernel = np.ones((3,3),np.uint8)
    erosion = cv2.erode(thresh1,kernel,iterations = 2)
    dilation = cv2.dilate(erosion,kernel,iterations = 2)
    
    img_masked = cv2.bitwise_and(img, img, mask=dilation)
    
    return img_masked

In [8]:
def normalize_img_size(img):
    nrows = 15
    ncols = 17
    img_tiles = []

    size = 40
    w_start = 0
    h_start = 0

    y_start = h_start
    
    # normalize image size to 40*40
    for i in range(nrows):
        if(i%2==0):
            y_end = y_start + size
        elif(i%2==1):
            y_end = y_start + (size+1)

        img_row = []
        x_start = w_start
        for j in range(ncols):
            if(j%2==0):
                x_end = x_start + (size+1)
            elif(j%2==1):
                x_end = x_start + size
            
            img_i = img[y_start:y_start+size, x_start:x_start+size]
            
            # crop image for prediction (use only pixels in the middle)
            img_i_crp = crop_img(img_i)
            img_row.append(img_i_crp)
            
            x_start = x_end
            
        img_tiles.append(img_row)
        y_start = y_end
    
    return img_tiles

In [9]:
def predict_img_label(img, hist_all_label, threshold):
    hist_img = calc_hist(img)
    
    threshold_bg = 0.9
    threshold_0 = threshold
    
    matched_label_no = -1
    matched_hist_no = -1
    max_coeff = -1
    
    for label_no in range(n_label):
        for hist_no in range(len(hist_all_label[label_no])):
            sum_diff = 0
            diffs = [-1, -1, -1]
            for chan in range(3):
                diffs[chan] = cv2.compareHist(hist_img[chan], hist_all_label[label_no][hist_no][chan], cv2.HISTCMP_CORREL)
                sum_diff += diffs[chan]

            avg_diff = sum_diff/3
            reach_threshold = False

            threshold_0 = threshold_bg if(label_no == Label.BG.value) else threshold
                
            if(avg_diff > threshold_0):
                reach_threshold = True
                for k in range(3):
                    if(diffs[k]<threshold_0):
                        reach_threshold = False

            if(reach_threshold):
                if(avg_diff>max_coeff):
                    max_coeff = avg_diff
                    matched_label_no = label_no
                    matched_hist_no = hist_no
    
    if(max_coeff!=-1):
        return matched_label_no, matched_hist_no, max_coeff
    
    return 0, -1, 0

## A-star

In [11]:
def food_position(cells, nrows, ncols):
    for i in range(nrows):
        for j in range(ncols):
            if(cells[i,j]==Label.APPL.value):
                return Position(i,j)
            
    return None

In [12]:
def center_of_mass(cells, nrows, ncols):
    cm_x = 0
    cm_y = 0
    N = 0
    for i in range(nrows):
        for j in range(ncols):
            if(cells[i,j]==Label.BODY.value):
                cm_x += j
                cm_y += i
                N += 1
    
    cm_x /= N
    cm_y /= N
    
    return Position(cm_y, cm_x)

In [13]:
def center_of_mass_new(cells, nrows, ncols):
    cm_x = 0
    cm_y = 0
    
    # get all snake body coordinates and get a sample every K tile
    count = 0
    snake_body_samples = []
    for i in range(nrows):
        for j in range(ncols):
            if(cells[i,j] == Label.BODY.value):
                count += 1
                if(count-1 == 4):
                    snake_body_samples.append(Position(i,j))
                    count = 0
    for pos in snake_body_samples:
        cm_x += pos.x
        cm_y += pos.y
        
    N = len(snake_body_samples)
    if(N==0):
        return None
    
    cm_x /= N
    cm_y /= N
    
    return Position(cm_y, cm_x)

In [14]:
def find_min_max_body_pos(cells, nrows, ncols, get_min):
    max_coord = Position(-1, -1)
    
    pos_x = -1
    pos_y = -1
    
    first_body_cell = True
    # get all snake body coordinates and find min/max body coordinates
    for i in range(nrows):
        for j in range(ncols):
            if(cells[i,j]==Label.BODY.value):
                pos_x = j
                pos_y = i
                if(first_body_cell):
                    min_coord = Position(i,j)
                    first_body_cell = False
                else:
                    if(pos_y > max_coord.y):
                        max_coord = Position(pos_y, pos_x)
                    elif(pos_y == max_coord.y):
                        if(pos_x > max_coord.x):
                            max_coord = Position(pos_y, pos_x)
                
    if(get_min):
        return min_coord
    else:
        return max_coord

## Predict labels

In [1]:
def get_cells(img, Label, threshold):    
    img_p = preprocessing(img)
    img_tiles = normalize_img_size(img_p)

    hist_all_label, _ = calc_all_hist()

    cells = np.zeros((nrows, ncols), dtype=np.int8)
    head_coeffs = []
    head_positions = []
    apple_pos = None

    for i in range(nrows):
        for j in range(ncols):
            cells[i,j], img_no, cell_coeff = predict_img_label(img_tiles[i][j], hist_all_label, threshold)
            if(cells[i,j]==Label.HEAD.value):
                head_coeffs.append(cell_coeff)
                head_positions.append(Position(i,j))
            if(cells[i,j]==Label.APPL.value):
                apple_pos = Position(i,j)
    
    head_pos = None
    max_head_coeff = -1
    for i,coeff in enumerate(head_coeffs):
        if(coeff>max_head_coeff):
            max_head_coeff = coeff
            head_pos = head_positions[i]

    for pos in head_positions:
        if(pos!=head_pos):
            cells[pos.y, pos.x] = Label.BODY.value

    return cells, head_pos, apple_pos

## Pathfinding

In [1]:
def find_optimal_path(cells, head_pos, apple_pos, cm_pos_list, weights):
    start_pos = Position(head_pos.y, head_pos.x)
    
    if(apple_pos is None):
        return None, None
    
    new_cells = np.zeros((nrows, ncols), dtype=np.int8)
    for i in range(nrows):
        for j in range(ncols):
            new_cells[i,j] = cells[i,j]

    path = astar(new_cells, start_pos, apple_pos, cm_pos_list, weights)

    for pos in path:
        if(new_cells[pos.y, pos.x]==Label.BG.value):
            new_cells[pos.y, pos.x] = -1

    return new_cells, path

# Executions

In [21]:
# print(label_str)

['BG', 'APPL', 'BODY', 'HEAD']
