In [114]:
from collections import deque
import numpy as np
import pandas as pd
import seaborn as sns
import random
import pickle

import cv2
import PIL
from PIL import Image
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb

In [115]:
MEM_SIZE = 8000
MIN_MEM_SIZE = 1000
DISCOUNT = 0.94
EPSILON_START = 1
EPSILON_END = 0.05
EPSILON_STOP = 300

SIMULATE_EVERY = 4

EPISODES = 50
BATCH_SIZE = 164
LEARNING_RATE = 1e-3

In [116]:
%run tetris-environment.ipynb

In [117]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device
device = 'cpu'

In [118]:
class Decaying():
    def __init__(self, start, end, duration):
        self.per_step = (end - start) / duration
        self.duration = duration
        self.current_step = 0
        self.current = start
        

    def step(self):
        if self.current_step < self.duration:
            self.current += self.per_step
            self.current_step += 1
            
        return self.current
    
    def get(self):
        return self.current

In [119]:
# layer initialisation
def init_linear_layer(m, method):
    torch.nn.init.xavier_normal_(m.weight, nn.init.calculate_gain(method))
    torch.nn.init.constant_(m.bias, 0)
    return m

In [120]:
model = nn.Sequential(
    init_linear_layer(nn.Linear(20*10, 128), 'relu'),
    nn.ReLU(),
    init_linear_layer(nn.Linear(128, 64), 'relu'),
    nn.ReLU(),
    init_linear_layer(nn.Linear(64, 1), 'linear')
)

model.load_state_dict(torch.load('models/run-1.pt'))
model = model.to(device)

# model = nn.Sequential(
#     init_linear_layer(nn.Linear(20*10, 128), 'relu'),
#     nn.ReLU(),
#     init_linear_layer(nn.Linear(128, 64), 'relu'),
#     nn.ReLU(),
#     init_linear_layer(nn.Linear(64, 32), 'relu'),
#     nn.ReLU(),
#     init_linear_layer(nn.Linear(32, 1), 'linear')
# )

# model = model.to(device)

In [121]:
def get_best_state(states, use_epsilon=True):
    if not use_epsilon or random.random() > epsilon.current:
        q_values = model(states)
        return torch.argmax(q_values)
    else:
        return random.choice(range(len(states)))

In [122]:
env = TetrisEnv()

### Fill the replay buffer by playing games

In [123]:
replay_buffer = []

def to_torch(state):
    return torch.from_numpy(states.reshape(states.shape[0], -1)).float()

with tqdm(total=MIN_MEM_SIZE/20) as pbar:
    while len(replay_buffer) < MIN_MEM_SIZE:
        env.reset()
        pbar.update(1)

        # play moves until game over
        while True:
            states, scores, clears, dones = env.get_next_states()

            chosen_index = get_best_state(torch.from_numpy(states.reshape(-1, 200)).float().to(device))

            replay_buffer.append((env.get_current_state(), states[chosen_index], scores[chosen_index], dones[chosen_index]))

            if dones[chosen_index]:
                break
            else:
                env.step(states[chosen_index], clears[chosen_index], scores[chosen_index])

# states, _, _, _, = env.get_next_states()
# states = torch.from_numpy(states.reshape(states.shape[0], -1)).float()

# model(states)

  0%|          | 0/50.0 [00:00<?, ?it/s]

In [124]:
# with gzip.GzipFile('buffer.bin', 'wb', compresslevel=1) as gzipFile:
#     pickle.dump(replay_buffer, gzipFile)

### Fill the replay buffer by loading from file

### Training loop

In [125]:
wandb.init(project='tetris-dqn', config={
    'learning-rate': LEARNING_RATE,
    'batch-size': BATCH_SIZE,
    'replay-max-size': MEM_SIZE,
    'replay-min-size': MIN_MEM_SIZE,
    'discound-factor': DISCOUNT
})

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
game/all_clears,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
game/doubles,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁
game/moves,▄▃▄▃▄▃▅▅▄▅▆▄▆▄▄▆▅▆▅▅▆▆▅▆▅▃▇▄▁▄▂▃▁█▃▃▃▃▄▄
game/quads,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
game/score,▁▁▁▁▁▁▂▁▂▁▂▁▂▁▂▄▁▂▂▂▂▂▁▄▂▁▅▃▁▁▁▁▁█▁▁▁▁▁▁
game/singles,▁▁▁▁▁▁▂▁▂▁▃▁▃▁▂▃▁▃▂▂▃▃▁▃▃▁▃▃▁▁▁▁▁█▁▁▁▁▁▁
game/triples,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
game/tspins,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
training/loss,▃▂▁▂▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▃▁▂▂▂▂▂█▂▂▂▁▂▁▁

0,1
game/all_clears,0.0
game/doubles,0.0
game/moves,27.0
game/quads,0.0
game/score,100.0
game/singles,1.0
game/triples,0.0
game/tspins,0.0
training/loss,107.84496


In [126]:
# %load_ext line_profiler
# %prun train(replay_buffer)

epsilon = Decaying(1, 0.03, 4000)
discount = Decaying(0.7, 0.94, 4000)

In [180]:

# def train():
criterion = nn.HuberLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

training_loss = []
training_scores = []

for episode in tqdm(range(2000)): # tqdm(range(EPISODES)):
    # play another game
    if episode % SIMULATE_EVERY == 0:
        env.reset()
        while True:
            states, scores, clears, dones = env.get_next_states()

            chosen_index = get_best_state(torch.from_numpy(states.reshape(-1, 200)).float().to(device))

            replay_buffer.append((env.get_current_state(), states[chosen_index], scores[chosen_index], dones[chosen_index]))

            if dones[chosen_index]:
                training_scores.append({'epoch': episode, 'score': env.score})
                break
            else:
                env.step(states[chosen_index], clears[chosen_index], scores[chosen_index])

        if len(replay_buffer) > MEM_SIZE:
            replay_buffer = replay_buffer[int(MEM_SIZE/10):]
            
        wandb.log({'game/score': env.score,
                   'game/singles': env.clears[0], 'game/doubles': env.clears[1], 'game/triples': env.clears[2], 'game/quads': env.clears[3],
                   'game/tspins': env.tspins, 'game/all_clears': env.all_clears, 'game/moves': env.moves })


    # take sample from replay memory
    batch = random.sample(replay_buffer, BATCH_SIZE)

    current_states = torch.from_numpy(np.array([s[0].reshape(200) for s in batch])).float().to(device)
    next_states = torch.from_numpy(np.array([s[1].reshape(200) for s in batch])).float().to(device)

    next_q_values = model(next_states)
    y_hat = model(current_states)

    # calculate expected q value
    y = []
    for i, (state, _, score, done) in enumerate(batch):
        if not done:
            new_q = score + discount.current * next_q_values[i]
        else:
            new_q = score

        y.append(new_q)

    # fit the model to the expected q value
    loss = criterion(y_hat, torch.tensor(y).reshape(BATCH_SIZE, 1).to(device))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    epsilon.step()
    discount.step()

    wandb.log({'training/loss': loss.item()})
    training_loss.append({'epoch': episode, 'loss': loss.item()})
        # print(loss)
    
# sns.lineplot(data=pd.DataFrame(training_loss), x='epoch', y='loss')
# sns.lineplot(data=pd.DataFrame(training_scores), x='epoch', y='score')

  0%|          | 0/2000 [00:00<?, ?it/s]

In [178]:
# sns.lineplot(data=pd.DataFrame(training_scores), x='epoch', y='score')
print(discount.current)
print(epsilon.current)

0.9399999999997959
0.029999999999945952


In [129]:

# env.reset()
# print(env.current_piece)

# while True:
#     states, scores, clears, dones = env.get_next_states()

#     chosen_index = get_best_state(torch.from_numpy(states.reshape(-1, 200)).float())

#     env._print_state(states[chosen_index])
#     print()
    
#     # replay_buffer.append((env.get_current_state(), states[chosen_index], scores[chosen_index], dones[chosen_index]))

#     if dones[chosen_index]:
#         print(f'Score: {env.score}')
#         print(f'Clears: {env.clears}, t-spins: {env.tspins}, alll_clears: {env.all_clears}')
#         break
#     else:
#         env.step(states[chosen_index], clears[chosen_index], scores[chosen_index])

In [242]:
%matplotlib inline

COLORS = {
    0: (0,0,0),
    1: (255,255,255)
    # 0: (128, 0, 128),
    # 1: (255, 127, 0),
    # 2: (0, 0, 255),
    # 3: (255, 255, 0),
    # 4: (0, 255, 255),
    # 5: (0, 255, 0),
    # 6: (255, 0, 0)
}
    
def render_run():
    def gen_image(state):
        img = [COLORS[cell] for row in state for cell in row]
        img = np.array(img).reshape(20, 10, 3).astype(np.uint8)
        img = img[..., ::-1] # Convert RRG to BGR (used by cv2)
        img = Image.fromarray(img, 'RGB')

        img = img.resize((10 * 25, 20 * 25), Image.Resampling.NEAREST)
        
        return img
    
    env.reset()

    frames = []
    
    while True:
        states, scores, clears, dones = env.get_next_states()

        chosen_index = get_best_state(torch.from_numpy(states.reshape(-1, 200)).float(), False)

        frames.append(gen_image(states[chosen_index]))

        # replay_buffer.append((env.get_current_state(), states[chosen_index], scores[chosen_index], dones[chosen_index]))

        if dones[chosen_index]:
            print(f'Score: {env.score}')
            print(f'Clears: {env.clears}, t-spins: {env.tspins}, alll_clears: {env.all_clears}')
            break
        else:
            env.step(states[chosen_index], clears[chosen_index], scores[chosen_index])
    
    frames[0].save('game15.gif', format='GIF', append_images=frames, save_all=True, duration=300, loop=0)
    
render_run()

Score: 1000
Clears: [10, 0, 0, 0], t-spins: 0, alll_clears: 0


In [198]:
torch.save(model.state_dict(), 'models/run-4.pt')