In [None]:
from collections import deque
import math
import numpy as np
import pandas as pd
import random
import pickle
import glob

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb

from tetris.Environment import TetrisEnv

import util.decaying
import dqnmodel

In [None]:
PRO_REPLAY_DIRECTORY = 'pro-replays'

### Hyperparameters

In [None]:
MEM_SIZE = 10000
MIN_MEM_SIZE = 1000

DISCOUNT_START = 0.8
DISCOUNT_END = 0.94
DISCOUNT_DURATION = 4000

EPSILON_START = 0.5
EPSILON_END = 0.08
EPSILON_DURATION = 3000

UPDATE_TARGET_EVERY = 100
SIMULATE_EVERY = 5
USE_PRO_PLAY_CHANCE = 0.2

EPISODES = 6000
BATCH_SIZE = 164

LEARNING_RATE_START = 3e-3
LEARNING_RATE_GAMMA = 0.9
LEARNING_RATE_STEP = 300

use_pro_replays = True

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model = dqnmodel.DQNModel(UPDATE_TARGET_EVERY)
model.to(device)

In [None]:
epsilon = util.decaying.DecayingLinear(EPSILON_START, EPSILON_END, EPSILON_DURATION)
discount = util.decaying.DecayingDiscount(DISCOUNT_START, DISCOUNT_END, DISCOUNT_DURATION)

In [None]:
def get_best_state(states, use_epsilon=True):
    if not use_epsilon or random.random() > epsilon.get():
        # use the q-network (not the target network) for chosing the next state
        q_values = model.model(states)
        return torch.argmax(q_values)
    else:
        return random.choice(range(len(states)))

In [None]:
env = TetrisEnv()

### Fill the replay buffer by playing games

In [None]:
replay_buffer = []

def to_torch(state):
    return torch.from_numpy(states.reshape(states.shape[0], -1)).float()

with tqdm(total=MIN_MEM_SIZE/20) as pbar:
    while len(replay_buffer) < MIN_MEM_SIZE:
        env.reset()
        pbar.update(1)

        # play moves until game over
        while True:
            states, scores, clears, dones = env.get_next_states()

            chosen_index = get_best_state(torch.from_numpy(states.reshape(-1, 1, 20, 10)).float().to(device))

            replay_buffer.append((env.get_current_state(), states[chosen_index], scores[chosen_index], dones[chosen_index]))

            if dones[chosen_index]:
                break
            else:
                env.step(states[chosen_index], clears[chosen_index], scores[chosen_index])


### Load pro-player replays

In [None]:
class ProReplayDataset(torch.utils.data.Dataset):
    def __init__(self, path, train):
        self.path = path
        self.file_list = glob.glob(f'{path}/*.json')
        self.train = train
        
        def string_to_board(string):
            return torch.tensor([int(c) for c in list(string)])
        
        # to speed things up, load the whole dataset to memory
        self.buffer = []
        for idx, file in tqdm(enumerate(self.file_list)):
            df = pd.read_csv(file)

            for index, row in df.iterrows():
                self.buffer.append((
                    string_to_board(row.current).reshape((20, 10)).to(device),
                    string_to_board(row.next).reshape((20, 10)).to(device),
                    row.score,
                    row.done
                ))
      
    def __len__(self):
        return len(self.buffer)

    def __getitem__(self, idx):
        return self.buffer[idx]


replay_buffer_dataset = ProReplayDataset(PRO_REPLAY_DIRECTORY, True)
replay_buffer_loader = torch.utils.data.DataLoader(dataset=replay_buffer_dataset, batch_size=BATCH_SIZE, shuffle=True)
replay_buffer_iter = iter(replay_buffer_loader)

### Training loop

In [None]:
wandb.init(project='tetris-dqn', config={
    'learning-rate-start': LEARNING_RATE_START,
    'learning-rate-gamma': LEARNING_RATE_GAMMA,
    'learning-rate-step': LEARNING_RATE_STEP,
    
    'batch-size': BATCH_SIZE,
    
    'replay-max-size': MEM_SIZE,
    'replay-min-size': MIN_MEM_SIZE,
    
    'epsilon-start': EPSILON_START,
    'epsilon-end': EPSILON_END,
    'epsilon-duration': EPSILON_DURATION,
    
    'discount-start': DISCOUNT_START,
    'discount-end': DISCOUNT_END,
    'discount-duration': DISCOUNT_DURATION,
    
    'pro-play-chance': USE_PRO_PLAY_CHANCE,
    'simulate-every': SIMULATE_EVERY,
    'update-target-every': UPDATE_TARGET_EVERY,
})

In [None]:
epsilon = util.decaying.DecayingLinear(EPSILON_START, EPSILON_END, EPSILON_DURATION)
discount = util.decaying.DecayingDiscount(DISCOUNT_START, DISCOUNT_END, DISCOUNT_DURATION)

In [None]:
criterion = nn.MSELoss() # HuberLoss()
optimizer = torch.optim.AdamW(model.model.parameters(), lr=LEARNING_RATE_START)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=LEARNING_RATE_STEP, gamma=LEARNING_RATE_GAMMA, verbose=False)

training_loss = []
training_scores = []

model.model.train()

for episode in tqdm(range(EPISODES)):
    use_pro_replays = random.random() < USE_PRO_PLAY_CHANCE
    
    # play another game
    if episode % SIMULATE_EVERY == 1:
        env.reset()
        while True:
            states, scores, clears, dones = env.get_next_states()

            chosen_index = get_best_state(torch.from_numpy(states.reshape(-1, 1, 20, 10)).float().to(device))

            replay_buffer.append((env.get_current_state(), states[chosen_index], scores[chosen_index], dones[chosen_index]))

            if dones[chosen_index]:
                training_scores.append({'epoch': episode, 'score': env.score})
                break
            else:
                env.step(states[chosen_index], clears[chosen_index], scores[chosen_index])

        if len(replay_buffer) > MEM_SIZE:
            replay_buffer = replay_buffer[int(MEM_SIZE/10):]
            
        wandb.log({'game/score': env.score,
                   'game/singles': env.clears[0], 'game/doubles': env.clears[1], 'game/triples': env.clears[2], 'game/quads': env.clears[3],
                   'game/tspins': env.tspins, 'game/all_clears': env.all_clears, 'game/moves': env.moves })


    # get the batch, consisting of (current_state, next_state, score, done), and extract current and next states
    if use_pro_replays:
        batch = next(replay_buffer_iter, None)
        # if we took all entries of the pro-player replays, start over again
        if batch is None:
            replay_buffer_iter = iter(replay_buffer_loader)
            batch = next(replay_buffer_iter, None)
            
        current_states = batch[0].reshape(-1, 1, 20, 10).float()
        next_states = batch[0].reshape(-1, 1, 20, 10).float()
    else:
        # take sample from replay memory
        batch = random.sample(replay_buffer, BATCH_SIZE)

        current_states = torch.from_numpy(np.array([s[0] for s in batch])).reshape(-1, 1, 20, 10).float().to(device)
        next_states = torch.from_numpy(np.array([s[1] for s in batch])).reshape(-1, 1, 20, 10).float().to(device)

    # get the q-values of the current state
    y_hat = model.model(current_states)

    # calculate expected q-values of the next state using the target-network
    next_q_values = model.target_model(next_states)
    y = []
    if use_pro_replays:
        for i in range(batch[3].shape[0]):
            done = batch[3][i]
            score = batch[2][i].float()
            
            if not done:
                new_q = score + discount.get() * next_q_values[i]
            else:
                new_q = score
            
            y.append(new_q)
    else:
        for i, (_, _, score, done) in enumerate(batch):
            if not done:
                new_q = score + discount.get() * next_q_values[i]
            else:
                new_q = score

            y.append(new_q)

    
            
    # fit the model to the expected q value
    loss = criterion(y_hat, torch.tensor(y).reshape(BATCH_SIZE if not use_pro_replays else batch[3].shape[0], 1).to(device))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    epsilon.step()
    discount.step()
    scheduler.step()
    
    model.step()

    wandb.log({'training/loss': loss.item()})
    training_loss.append({'epoch': episode, 'loss': loss.item()})
    
    if episode % 100 == 0:
        torch.save(model.model.state_dict(), f'models/run-cnn-{math.floor(episode / 100) + 105}.pt')

In [None]:
torch.save(model.model.state_dict(), 'models/run-cnn-after-18000.pt')