# Snake Game AI with Genetic Algorithm

In [None]:
import torch
import torch.nn as nn
import random
import numpy as np
import pickle

## Deep Neural Network

In [None]:
n_snake = 500
D_in, D_out = 27, 3
hid1, hid2 = 16, 16
hid_layer_num = 2

class Net(nn.Module):
    def __init__(self, L=1):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(D_in, hid1)
        self.fc2 = nn.Linear(hid1, hid2)
        self.fc3 = nn.Linear(hid2, D_out)
        self.relu = nn.ReLU()
        self.sig = nn.Sigmoid()
        
        nn.init.uniform_(self.fc1.weight, -L, L)
        nn.init.uniform_(self.fc1.bias, -L, L)
        nn.init.uniform_(self.fc2.weight, -L, L)
        nn.init.uniform_(self.fc2.bias, -L, L)
        nn.init.uniform_(self.fc3.weight, -L, L)
        nn.init.uniform_(self.fc3.bias, -L, L)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sig(self.fc3(x))
        return x

# make first generation
now_snake = []
now_score = []
for i in range(n_snake):
    net = Net()
    now_snake.append(net)
    now_score.append(0)
    
def get_weight(net, fcNum):
    return net.state_dict()['fc' + str(fcNum) + '.weight']
def get_bias(net, fcNum):
    return net.state_dict()['fc' + str(fcNum) + '.bias']

## Snake Game

In [None]:
mapSize = 20
start_len = 4          # snake start length (20x20:4, 10x10:3)
sLen = start_len       # snake length
startX, startY = 8, 10 # 20x20:8,10 / 10x10:4,5
headX, headY = startX, startY
snake_dir = 1          # direction of snake

BLANK, SNAKE, FOOD, WALL = 0, 1, 2, 3

mapM = [[BLANK]*(mapSize+2) for i in range(mapSize+2)]   # map matrix
sInfo = []  # snake info. (coord. of body)
score = 0
move = 0     # increase until eat food, become 0 after eat food.

def new_food(rdm=True, ipt=[]):   # set new food
    global mapSize, mapM
    if rdm:
        foodX, foodY = 0, 0
        while mapM[foodX][foodY] != BLANK:
            foodX = random.randrange(1,mapSize)
            foodY = random.randrange(1,mapSize)
        mapM[foodX][foodY] = FOOD
    else:
        foodX, foodY = ipt[0], ipt[1]
        mapM[foodX][foodY] = FOOD
    return foodX, foodY

# start new game: make new map
def new_map(food_rdm=True, food_input=[]):
    global mapM, mapSize, BLANK, WALL, SNAKE, sLen, headX, headY, sInfo
    global score, move, snake_dir, start_len, startX, startY
    
    headX, headY, sLen = startX, startY, start_len
    score, move, snake_dir = 0, 0, 1
    
    for i in range(mapSize+2):   # initialize
        for j in range(mapSize+2):
            mapM[i][j] = BLANK
    
    for i in range(mapSize+2):   # make wall
        mapM[0][i] = WALL
        mapM[mapSize+1][i] = WALL
        mapM[i][0] = WALL
        mapM[i][mapSize+1] = WALL

    sInfo = []
    for i in range(sLen):        # make snake
        mapM[headX-i][headY] = SNAKE
        sInfo.append(((headX-i), headY))

    fX, fY = new_food(food_rdm, food_input) # make food
    return fX, fY

def move_snake(dr, food_rdm=True, food_input=[]):
    # dr: up=0, right=1, down=2, left=3
    global mapSize, WALL, FOOD, SNAKE, BLANK, mapM, headX
    global headY, sLen, score, snake_dir, move, sInfo
    if mapSize == 20:
        move_max = max([100, score*15-200])
        move_max = min([move_max, mapSize*mapSize])
    elif mapSize == 10:
        move_max = 200
    
    hit = False
    fX, fY = -1, -1
    
    dx = [0, 1, 0, -1]
    dy = [-1, 0, 1, 0]
    nextX = headX + dx[dr]
    nextY = headY + dy[dr]
    nextM = mapM[nextX][nextY]
    if nextM == WALL or nextM == SNAKE:
        hit = True
    else:
        snake_dir = dr
        sInfo.insert(0, (nextX, nextY))
        mapM[nextX][nextY] = SNAKE
        if nextM == BLANK:
            tail = sInfo.pop()
            mapM[tail[0]][tail[1]] = BLANK
            move = move + 1
            if move > move_max:
                hit = True
        else:  # FOOD
            sLen = sLen + 1
            fX, fY = new_food(food_rdm, food_input)
            score = score + 1
            move = 0
        headX = nextX
        headY = nextY
    return hit, [fX, fY]

In [None]:
def detect(dr, obj, mode = 0):
    # obj: 0(North) ~ 7(Northwest) Clockwise
    # mode: distance = 0, binary = 1
    global headX, headY, mapM, WALL    
    dx = [0, 1, 1, 1, 0, -1, -1, -1]
    dy = [-1, -1, 0, 1, 1, 1, 0, -1]
    cnt = 1
    val = mapM[headX+cnt*dx[dr]][headY+cnt*dy[dr]]
    while val != WALL:
        if val == obj:
            break
        else:
            cnt = cnt + 1
            val = mapM[headX+cnt*dx[dr]][headY+cnt*dy[dr]]
            if val == WALL and obj != WALL: # no food/body
                return 0
    if mode == 0:
        return 1.0/cnt
    elif mode == 1:
        return 1

# input: Net -> output: score, dir_list, food_list
def snake_play_game(net, mode = 0):
    # mode: distance = 0, binary = 1
    global score, headX, headY, SNAKE, FOOD, WALL, mapM, snake_dir
    
    dir_list = []
    food_list = []
    
    fX, fY = new_map()
    food_list.append([fX, fY])

    hit = False
    detect_order = [[0,1,2,3,4,5,6,7],
                    [2,3,4,5,6,7,0,1],
                    [4,5,6,7,0,1,2,3],
                    [6,7,0,1,2,3,4,5]]
    
    prev_out = 0
    while not hit:
        prev_dir = [0]*3
        prev_dir[prev_out] = 1
        detect_info = []
        for i in detect_order[snake_dir]:
            detect_info.append(detect(i, FOOD, mode))
            detect_info.append(detect(i, WALL, mode))
            detect_info.append(detect(i, SNAKE, mode))
        detect_info = detect_info + prev_dir
        in_data = torch.tensor(detect_info, dtype=torch.float)
        out_data = net(in_data).tolist() # [left, straight, right]
        prev_out = out_data.index(max(out_data))
        next_dir = (snake_dir + prev_out + 3) % 4
        dir_list.append(next_dir)
        hit, [fX, fY] = move_snake(next_dir)
        if fX > 0 and fY > 0:
            food_list.append([fX, fY])
    
    return score, dir_list, food_list

## Crossover and Mutation

In [None]:
# Uniform binary crossover (UBX)
def ubx(p1, p2, p=0.5):
    c1 = p1.copy()
    c2 = p2.copy()
    
    mask = np.random.uniform(0, 1, size=c1.shape)
    c1[mask > p] = c2[mask > p]
    c2[mask > p] = c1[mask > p]

    return c1, c2

# Single point binary crossover (SPBX)
def spbx(p1, p2):
    c1 = p1.copy()
    c2 = p2.copy()
    
    if len(p2.shape) == 1:
        rows = p2.shape
        row = np.random.randint(0, rows)[0]
        c1[:row] = p2[:row]
        c2[:row] = p1[:row]
    else:
        rows, cols = p2.shape
        row = np.random.randint(0, rows)
        col = np.random.randint(0, cols)
        
        c1[:row, :] = p2[:row, :]
        c2[:row, :] = p1[:row, :]
        c1[row, :col+1] = p2[row, :col+1]
        c2[row, :col+1] = p1[row, :col+1]

    return c1, c2

# Simulated binary crossover (SBX)
def sbx(p1, p2, eta=100):
    rand = np.random.random(p1.shape)
    gamma = np.empty(p1.shape)
    gamma[rand <= 0.5] = (2 * rand[rand <= 0.5]) ** (1.0 / (eta + 1))
    gamma[rand > 0.5] = (1.0 / (2.0 * (1.0 - rand[rand > 0.5]))) ** (1.0 / (eta + 1))

    c1 = 0.5 * ((1 + gamma)*p1 + (1 - gamma)*p2)
    c2 = 0.5 * ((1 - gamma)*p1 + (1 + gamma)*p2)

    return c1, c2

# Crossover
def crossover(netA, netB, prob, p=0.5):
    global hid_layer_num
    netC = Net()
    netD = Net()
    C_dict = netC.state_dict()
    D_dict = netD.state_dict()
 
    for i in range(1, (hid_layer_num+2)):
        tw = np.random.uniform(0, 1)
        tb = np.random.uniform(0, 1)
        w1 = get_weight(netA, i).numpy()
        w2 = get_weight(netB, i).numpy()
        b1 = get_bias(netA, i).numpy()
        b2 = get_bias(netB, i).numpy()
        
        if tw < prob[0]: # ubx
            w3, w4 = ubx(w1, w2, p)
        elif tw < prob[1]+prob[0]: # spbx
            w3, w4 = spbx(w1, w2)
        else: # sbx
            w3, w4 = sbx(w1, w2)
        
        if tb < prob[0]: # ubx
            b3, b4 = ubx(b1, b2, p)
        elif tb < prob[1]+prob[0]: # spbx
            b3, b4 = spbx(b1, b2)
        else: # sbx
            b3, b4 = sbx(b1, b2)
        
        C_dict['fc'+str(i)+'.weight'] = torch.tensor(w3)
        C_dict['fc'+str(i)+'.bias'] = torch.tensor(b3)
        D_dict['fc'+str(i)+'.weight'] = torch.tensor(w4)
        D_dict['fc'+str(i)+'.bias'] = torch.tensor(b4)
    
    netC.load_state_dict(C_dict)
    netD.load_state_dict(D_dict)

    return netC, netD

# Gaussian mutation
def gaussian_mut(parent, prob, scale=0.2):
    mut_idx = np.random.random(parent.shape) < prob
    gauss = np.random.normal(size=parent.shape)
    parent[mut_idx] = parent[mut_idx] + gauss[mut_idx] * scale
    return parent

# Mutation
def mutation(netInput, mut_prob=0.05, mut_type=0, score=0, decay=False):
    # mut_type: uniform = 0, gaussian = 1, mixed = 2
    global hid_layer_num
    if decay:
        ln2 = 0.69314718056
        mut_prob = mut_prob * exp(-float(score)*ln2/60.0)
    if mut_type == 2:
        tmp = np.random.uniform(0, 1)
        if tmp > 0.5:
            mut_type = 1
        else:
            mut_type = 0
    if mut_type == 0:
        netRandom = Net()
        netInput, nTemp = crossover(netInput, netRandom, prob = [1,0,0], p = 1-mut_prob)
    elif mut_type == 1: # gaussian
        input_dict = netInput.state_dict()
        for i in range(1,(hid_layer_num+2)):
            w1 = get_weight(netInput, i).numpy()
            b1 = get_bias(netInput, i).numpy()
            w2 = gaussian_mut(w1, mut_prob)
            b2 = gaussian_mut(b1, mut_prob)
            input_dict['fc'+str(i)+'.weight'] = torch.tensor(w2)
            input_dict['fc'+str(i)+'.bias'] = torch.tensor(b2)
        netInput.load_state_dict(input_dict)
    
    return netInput

# Roulette Wheel
def roulette(fit_list, choose=2):
    if 0 in fit_list:
        for i in range(len(fit_list)):
            fit_list[i] = fit_list[i] * 100 + 1
    
    select_idx = []
    tot_fit = sum(fit_list)
    fit_n = len(fit_list)
    sep = [0]
    for i in range(fit_n):
        sep.append(sum(fit_list[0:i+1]))
    
    cnt = 0
    while cnt < choose:
        x = random.uniform(0, tot_fit)
        idx = 0
        for i in range(fit_n):
            if sep[i] <= x and sep[i+1] > x:
                idx = i
                break
        if idx not in select_idx:
            select_idx.append(idx)
            cnt = cnt + 1
            
    return select_idx

## Prepare training

In [None]:
# Refresh lists
fitness_max = []
best_dir_list = []
best_food_list = []
best_net_list = []
best_score_list = []
best_score_acc_list = []
prev_snake = []
prev_score = []
avg_score_list = []

In [None]:
# Save best Net
ver = 8.5
gen = 1787
PATH = './snake_v'+str(ver)+'_'+str(best_score_list[gen])+'_'+str(gen+1)+'.pth'
model = best_net_list[gen - gen_end + len(best_net_list)]
torch.save(model.state_dict(), PATH)

In [None]:
# Load Net
def load_net(PATH):
    load_model = Net()
    load_model.load_state_dict(torch.load(PATH))
    return load_model

In [None]:
# Started with loaded Net
now_snake = []
for i in range(n_snake):
    now_snake.append(mutation(load_model, 0.05, mut_type=1))
prev_snake = []
prev_snake.append(load_model)

In [None]:
# Put loaded Net in current generation
idx = sorted(range(len(now_score)),key= lambda i: now_score[i])[0]
prev_snake.pop(idx)
prev_snake.append(load_model)
prev_score.pop(idx)
prev_score.append(58)

## Training

In [None]:
#===== TRAINING =====
import time

gen_start = 0
gen_end = 500
selection_type = 1               # elitism=0, roulette=1
crossover_prob = [0.0, 0.7, 0.3] # prob. of ubx, spbx, sbx
mutation_type = 1                # uniform=0, gaussian=1, mixed=2
mutation_decay = False
input_mode = 0                   # distance=0, binary=1

def get_top_idx(a, n=1):
    idx = sorted(range(len(a)),key= lambda i: a[i])[-n:]
    return idx

save_permission = False
t_init = time.time()
for gen in range(gen_start, gen_end):
    # all snakes play game
    t_start = time.time()
    print("Generation %4d" % (gen+1), end=": ")
    best_dir, best_food, best_score = [], [], 0     # initialize
    best_net = now_snake[0]
    fitness_max.append(1)
    score_sum = 0
    for i in range(n_snake):
        score, dir_list, food_list = snake_play_game(now_snake[i], mode=input_mode)
        score_sum = score_sum + score
        step = len(dir_list)
        fitness = step + ((2**score) + (score**2.1)*500) - (((.25 * step)**1.3) * (score**1.2))
        now_score[i] = fitness
        if fitness_max[gen] < fitness:
            fitness_max[gen] = fitness
            best_dir = dir_list
            best_food = food_list
            best_net = now_snake[i]
            best_score = score
    best_dir_list.append(best_dir)
    best_food_list.append(best_food)
    best_net_list.append(best_net)
    best_score_list.append(best_score)
    best_score_acc = max(best_score_list)
    best_score_acc_list.append(best_score_acc)
    avg_score_list.append(float(score_sum)/float(n_snake))
    
    print("%6.3f" % (float(score_sum)/float(n_snake)), end = " / ")
    print("%3d" % best_score, end=" / ")
    print("%3d" % best_score_acc, end=" / ")
    
    t_game_end = time.time()
    print("%6.3fs / " % (t_game_end - t_start), end="")
    
    if prev_snake:  # prev_snake is not empty
        now_snake = prev_snake + now_snake
        now_score = prev_score + now_score
    
    # selection
    if selection_type == 0: # top 4
        surv_idx = get_top_idx(now_score, 4)
        surv_snake = [now_snake[i] for i in surv_idx]
        surv_score = [now_score[i] for i in surv_idx]
    
    top_idx = get_top_idx(now_score, n_snake)
    now_snake = [now_snake[i] for i in top_idx]
    now_score = [now_score[i] for i in top_idx]
    
    if selection_type == 1: # top (n_snake)
        surv_score = now_score
        surv_snake = now_snake
    
    for i in range(1,5):
        print("%.3e" % now_score[-i], end=" ")
    
    # selection, crossover, and mutation
    child_snake = []
    if selection_type == 0:  # top 4 combination (elitism)
        parentA_id = [0, 0, 0, 1, 1, 2]
        parentB_id = [1, 2, 3, 2, 3, 3]
        for j in range(6):
            parentA, parentB = surv_snake[parentA_id[j]], surv_snake[parentB_id[j]]
            for i in range(n_snake // 12):
                childA, childB = crossover(parentA, parentB, prob = crossover_prob)
                childA = mutation(childA, mut_type=mutation_type, decay=mutation_decay)
                childB = mutation(childB, mut_type=mutation_type, decay=mutation_decay)
                child_snake.append(childA)
                child_snake.append(childB)
    elif selection_type == 1:  # roulette wheel
        sum_time = 0
        for i in range(n_snake // 2):
            roulette_start_time = time.time()
            ids = roulette(surv_score)  # get 2 indices by using roulette
            sum_time = sum_time + (time.time() - roulette_start_time)
            parentA, parentB = surv_snake[ids[0]], surv_snake[ids[1]]
            childA, childB = crossover(parentA, parentB, prob = crossover_prob)
            childA = mutation(childA, mut_type=mutation_type, decay=mutation_decay)
            childB = mutation(childB, mut_type=mutation_type, decay=mutation_decay)
            child_snake.append(childA)
            child_snake.append(childB)
        print("/ %.3fs" % sum_time, end=" ")
    
    prev_snake = now_snake
    prev_score = now_score
    now_snake = child_snake
    now_score = [0]*n_snake
    print("/ time: %6.3fs" % (time.time() - t_start))
    
print("Total "+str(gen_end - gen_start)+" Generation ("+str(gen_start+1)+"~"+str(gen_end), end=")")
time_total = time.time() - t_init
min_total = int(time_total) // 60
hour_total = min_total // 60
min_total = min_total - 60*hour_total
sec_total = time_total - min_total * 60.0
print(" / total time: %dh %dm %.3fs" % (hour_total, min_total, sec_total))

save_permission = True

### Save all models and lists (temporary for next learning)

In [None]:
endstr = ""
if save_permission:
    for i in range(len(prev_snake)):
        PATH = './snake_save/prev'+str(i)+endstr+'.pth'
        model = prev_snake[i]
        torch.save(model.state_dict(), PATH)
    for i in range(len(now_snake)):
        PATH = './snake_save/now'+str(i)+endstr+'.pth'
        model = now_snake[i]
        torch.save(model.state_dict(), PATH)
    for i in range(len(best_net_list)):
        PATH = './snake_save/best'+str(i)+endstr+'.pth'
        model = best_net_list[i]
        torch.save(model.state_dict(), PATH)

    with open("./snake_save/now_score"+endstr+".txt", "wb") as fp:
        pickle.dump(now_score, fp)
    with open("./snake_save/prev_score"+endstr+".txt", "wb") as fp:
        pickle.dump(prev_score, fp)
    with open("./snake_save/fitness_max"+endstr+".txt", "wb") as fp:
        pickle.dump(fitness_max, fp)
    with open("./snake_save/best_dir_list"+endstr+".txt", "wb") as fp:
        pickle.dump(best_dir_list, fp)
    with open("./snake_save/best_food_list"+endstr+".txt", "wb") as fp:
        pickle.dump(best_food_list, fp)
    with open("./snake_save/best_score_list"+endstr+".txt", "wb") as fp:
        pickle.dump(best_score_list, fp)
    with open("./snake_save/best_score_acc_list"+endstr+".txt", "wb") as fp:
        pickle.dump(best_score_acc_list, fp)
save_permission = False

### Save best models

In [None]:
ver = 9
idx = sorted(range(len(best_score_list)),key= lambda i: best_score_list[i])[-10:]
idx.reverse()
for i in range(10):
    PATH = './best_snake_save/snake_v'+str(ver)+'_'+str(best_score_list[idx[i]])+'_'+str(idx[i]+1)+'.pth'
    model = prev_snake[-i-1]
    torch.save(model.state_dict(), PATH)

### Load all models

In [None]:
endstr = ""
prev_snake = []
best_net_list = []
for i in range(n_snake):
    PATH = './snake_save/prev'+str(i)+endstr+'.pth'
    model = Net()
    model.load_state_dict(torch.load(PATH))
    prev_snake.append(model)
for i in range(n_snake):
    PATH = './snake_save/now'+str(i)+endstr+'.pth'
    now_snake[i] = Net()
    now_snake[i].load_state_dict(torch.load(PATH)) 

# Load lists
import pickle

with open("./snake_save/now_score"+endstr+".txt", "rb") as fp:
    now_score = pickle.load(fp)
with open("./snake_save/prev_score"+endstr+".txt", "rb") as fp:
    prev_score = pickle.load(fp)
with open("./snake_save/fitness_max"+endstr+".txt", "rb") as fp:
    fitness_max = pickle.load(fp)
with open("./snake_save/best_dir_list"+endstr+".txt", "rb") as fp:
    best_dir_list = pickle.load(fp)
with open("./snake_save/best_food_list"+endstr+".txt", "rb") as fp:
    best_food_list = pickle.load(fp)
with open("./snake_save/best_score_list"+endstr+".txt", "rb") as fp:
    best_score_list = pickle.load(fp)
with open("./snake_save/best_score_acc_list"+endstr+".txt", "rb") as fp:
    best_score_acc_list = pickle.load(fp)
    
for i in range(len(best_score_list)):
    PATH = './snake_save/best'+str(i)+endstr+'.pth'
    model = Net()
    model.load_state_dict(torch.load(PATH))
    best_net_list.append(model)

## Plot

In [None]:
import matplotlib.pyplot as plt
from csaps import csaps

# Plot Best Score vs. Generation
xlen = len(best_score_list)
x = np.linspace(0., float(xlen), xlen)
xs = np.linspace(0., float(xlen), xlen*5)
y = np.array(best_score_list)
ys = csaps(x, y, xs, smooth=0.02)

plt.rcParams["figure.figsize"] = (12, 6)
plt.plot(range(xlen), best_score_list, range(xlen), best_score_acc_list, xs, ys, 'k')
plt.xlabel('Generation')
plt.ylabel('Best Score')
plt.legend(["Best score of current generation", "Best score up to current generation"], loc='lower right')
plt.title("population = 500 / roulette wheel / spbx 70% + sbx 30% / gaussian mutation 5% / distance input")

ntop = 20
idx = get_top_idx(best_score_list, ntop)
idx.reverse()
for i in range(ntop):
    print(idx[i], end=" ")
print()
for i in range(ntop):
    print(best_score_list[idx[i]], end=" ")

## Animation

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML

# Print iterations progress
def printProgressBar(iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'):
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r')
    if iteration == total: 
        print()

# make animation with dir_list and food_list
food_cnt = 1
def make_ani(dir_list, food_list, itv = 20):
    from matplotlib.animation import FuncAnimation
    import matplotlib
    global food_cnt
    matplotlib.rcParams['animation.embed_limit'] = 2**128
    # dir_list: up=0, right=1, down=2, left=3

    new_map(food_rdm=False, food_input=food_list[0])
    food_cnt = 1

    f, ax = plt.subplots(figsize=(6, 6))
    ax.set_xlim(-0.5, mapSize+1.5), ax.set_ylim(-0.5, mapSize+1.5)
    plt.gca().invert_yaxis()
    im = plt.imshow(np.array(mapM).T)

    def init():
        im.set_data(np.array(mapM).T)
        printProgressBar(0, len(dir_list), prefix = 'Animation:', suffix = 'Complete', length = 50)
        return [im]

    def update(frame):
        global mapM, food_cnt
        if frame > 0:
            if food_cnt >= len(food_list):
                food_cnt = food_cnt - 1
            hit, fXY = move_snake(dir_list[frame-1], food_rdm=False, food_input=food_list[food_cnt])
            if fXY[0] > 0:
                food_cnt = food_cnt + 1
        im.set_data(np.array(mapM).T)
        printProgressBar(frame, len(dir_list), prefix = 'Animation:', suffix = 'Complete', length = 50)
        return [im]

    ani = FuncAnimation(fig=f, func=update, frames=range(len(dir_list)+1), init_func=init, blit=True, interval=itv)
    plt.close(f)
    
    return ani.to_jshtml()
    
gen = idx[0]
print("Generation "+str(gen+1)+": "+str(len(best_dir_list[gen])+1)+" frames")
HTML(make_ani(best_dir_list[gen], best_food_list[gen], 5))
Html_file = open("animation_test_"+str(test_score_list[gen])+".html","w")
Html_file.write(ani_html)

## Test the model

In [None]:
# Test 1 model for several times
test_num = 500
show_num = 10
test_model = best_net_list[idx[0]]
input_mode = 0

test_score = []
test_dir_list = []
test_food_list = []
for i in range(test_num):
    score, dir_list, food_list = snake_play_game(test_model, mode=input_mode)
    test_dir_list.append(dir_list)
    test_food_list.append(food_list)
    test_score.append(score)

idx = get_top_idx(test_score, show_num)
idx.reverse()
print("Index: ", end=": ")
for i in range(show_num):
    print(idx[i], end=" ")
print("\nScore: ", end=": ")
for i in range(show_num):
    print(test_score[idx[i]], end=" ")
print("\nAvg. Score: %.3f" % (float(sum(test_score))/float(test_num)))
    
ani_idx = idx[0]
print(str(ani_idx)+": "+str(len(test_dir_list[ani_idx])+1)+" frames")
Html_file = open("animation_test_"+str(test_score[ani_idx])+".html","w")
Html_file.write(ani_html)