In [None]:
import numpy as np
import random
import time
import pygame
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import dqnmodel

from tetris.Piece import Piece
from tetris.EnvironmentRendered import TetrisEnvRendered

import gui.gui

In [None]:
device = 'cuda'

model = dqnmodel.DQNModel(100, 'models/test-run.pt')
model.to(device)

In [None]:
def get_best_state(states):
    # use the q-network (not the target network) for chosing the next state
    q_values = model.model(states)
    return torch.argmax(q_values)

In [None]:
env = TetrisEnvRendered()

## Without hold

In [None]:
# runs = []
# def render_run():
#     model.model.eval()
#     with torch.no_grad():
#         env.reset()
#         states_to_render = []

#         while True:
#             states, states_pretty, scores, clears, dones, moves = env.get_next_states()

#             chosen_index = get_best_state(torch.from_numpy(states.reshape(-1, 1, 20, 10)).float().to(device))

#             states_to_render.append((states_pretty[chosen_index], moves[chosen_index], env.current_piece, env.get_next_queue(), clears[chosen_index]))

#             if dones[chosen_index]:
#                 print(f'Score: {env.score}')
#                 print(f'Clears: {env.clears}, t-spins: {env.tspins}, alll_clears: {env.all_clears}')
#                 break
#             else:
#                 env.step(states[chosen_index], states_pretty[chosen_index], clears[chosen_index], scores[chosen_index])

#         runs.append(states_to_render)
    
# for i in range(1):
#     render_run()

## With hold

In [None]:
runs = []
def render_run():
    model.model.eval()
    with torch.no_grad():
        env.reset()
        states_to_render = []

        while True:
            states_curr, states_pretty_curr, scores_curr, clears_curr, dones_curr, moves_curr = env.get_next_states()
            states_hold, states_pretty_hold, scores_hold, clears_hold, dones_hold, moves_hold = env.get_next_states(use_hold=True)

            states = np.concatenate([states_curr, states_hold]) if states_hold is not None else states_curr
            states_pretty = np.concatenate([states_pretty_curr, states_pretty_hold]) if states_pretty_hold is not None else states_pretty_curr
            scores = scores_curr + scores_hold if scores_hold is not None else scores_curr
            clears = np.concatenate([clears_curr, clears_hold]) if clears_hold is not None else clears_curr
            dones = dones_curr + dones_hold if dones_hold is not None else dones_curr
            moves = np.concatenate([moves_curr, moves_hold]) if moves_hold is not None else moves_curr
                        
            chosen_index = get_best_state(torch.from_numpy(states.reshape(-1, 1, 20, 10)).float().to(device))

            # it holded
            if chosen_index >= len(states_curr):
                states_to_render.append((states_pretty[chosen_index], moves[chosen_index], env.bag.peek_piece() if env.hold_piece is None else env.hold_piece, env.get_next_queue(), env.current_piece, clears[chosen_index]))
                env.hold()
            else:
                states_to_render.append((states_pretty[chosen_index], moves[chosen_index], env.current_piece, env.get_next_queue(), env.hold_piece, clears[chosen_index]))

            if dones[chosen_index]:
                print(f'Score: {env.score}')
                print(f'Clears: {env.clears}, t-spins: {env.tspins}, alll_clears: {env.all_clears}')
                break
            else:
                env.step(states[chosen_index], states_pretty[chosen_index], clears[chosen_index], scores[chosen_index])

        runs.append(states_to_render)
    
render_run()

In [None]:
# # store run
# with open('runs/530k.pkl', 'wb') as file:
#     pickle.dump(runs[0], file, pickle.HIGHEST_PROTOCOL)

# # load run
# with open('runs/396k.pkl', 'rb') as file:
#     run = pickle.load(file)

In [None]:
g = gui.gui.Gui(sleep=33)

last_state = np.ones((20,10)) * -1

for state, moves, piece, queue, hold, clears in runs[0]:
    p = Piece(piece)
    
    for move in moves:
        for m in move.split(','):
            if m == 'ml':
                p.pos[1] -= 1
            if m == 'mr':
                p.pos[1] += 1
            if m == 'sd':
                p.pos[0] += 1
            if m == 'mu':
                p.pos[0] -= 1

            if m == 'rr':
                p.rot = (p.rot + 1) % p.pdata.num_rot
            if m == 'rl':
                p.rot = (p.rot - 1) % p.pdata.num_rot

        g.draw(last_state, p, queue, hold)

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
    
    if clears[0] == 4:
        g.play_sound_quad()
    elif clears[1] and clears[0] > 0:
        g.play_sound_tspin()
    elif clears[2]:
        pass # all-clear
    elif clears[0] > 0:
        g.play_sound_clear()
    
    last_state = state

In [None]:
pygame.quit()