In [25]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
data = np.load('board_grade.npy', allow_pickle=True)
x, y = data[:, 1], data[:, 0]

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
all_boards = np.stack(x).reshape(-1, 18, 11, 2)
all_classes = np.stack(y).reshape(-1, 7)

# Convert data to PyTorch tensors
x_tensor = torch.tensor(all_boards, dtype=torch.float32)
y_tensor = torch.tensor(all_classes, dtype=torch.float32)
# create the dataset
dataset = data.TensorDataset(x_tensor, y_tensor)
val_size = int(len(dataset)*0.2)
train_size = len(dataset)- int(len(dataset)*0.2)
train_dataset, val_dataset = data.random_split(dataset, [train_size, val_size])

In [65]:
INPUT_DIM = 18 * 11 * 2
LATENT_DIM = 40

class MLP_Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()

        self.input_fc = nn.Linear(input_dim, 160)
        self.dropout = nn.Dropout(0.5)
        self.hidden_fc = nn.Linear(160, 80)
        self.mean_fc = nn.Linear(80, latent_dim)
        self.var_fc = nn.Linear(80, latent_dim)

        self.training = True

    def forward(self, x):
        h_1 = F.mish(self.input_fc(x))
        h_2 = F.mish(self.hidden_fc(h_1))
        y_mean = self.mean_fc(h_2)
        y_var = self.var_fc(h_2)
        return y_mean, y_var

class MLP_Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()

        self.input_fc = nn.Linear(latent_dim, 80)
        self.dropout = nn.Dropout(0.5)
        self.hidden_fc = nn.Linear(80, 160)
        self.output_fc = nn.Linear(160, output_dim)

    def forward(self, x):
        h_1 = F.mish(self.input_fc(x))
        h_2 = F.mish(self.hidden_fc(h_1))
        x_new = torch.sigmoid(self.output_fc(h_2))
        return x_new

In [66]:
class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super().__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(device)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat = self.Decoder(z)
        
        return x_hat, mean, log_var

In [68]:
from torch.optim import Adam

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD


optimizer = Adam(model.parameters(), lr=1e-3)

In [71]:
import time
from tqdm import tqdm

print("Start training VAE...")

encoder = MLP_Encoder(input_dim=INPUT_DIM, latent_dim=LATENT_DIM)
decoder = MLP_Decoder(latent_dim=LATENT_DIM, output_dim=INPUT_DIM)

model = VAE(Encoder=encoder, Decoder=decoder).to(device)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

EPOCHS = 20

BATCH_SIZE = 4

train_iterator = data.DataLoader(train_dataset,
                                 shuffle=True,
                                 batch_size=BATCH_SIZE)

valid_iterator = data.DataLoader(val_dataset,
                                 shuffle=True,
                                 batch_size=BATCH_SIZE)

def train(model, iterator, optimizer, device):

    epoch_loss = 0

    model.train()

    for (x, _) in tqdm(iterator, desc="Training", leave=False):
        x = x.view(x.size(0), -1)
        x = x.to(device)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        loss = loss_function(x, x_hat, mean, log_var)
        
        epoch_loss += loss.item()
        
        loss.backward()
        optimizer.step()

    return epoch_loss / (len(iterator) * BATCH_SIZE)

def evaluate(model, iterator, device):

    epoch_loss = 0

    model.eval()

    with torch.no_grad():

        for (x, _) in tqdm(iterator, desc="Training", leave=False):
            x = x.view(x.size(0), -1)
            x = x.to(device)
    
            optimizer.zero_grad()
    
            x_hat, mean, log_var = model(x)
            loss = loss_function(x, x_hat, mean, log_var)
            
            epoch_loss += loss.item()

    return epoch_loss / (len(iterator) * BATCH_SIZE)
    
for epoch in range(EPOCHS):

    start_time = time.monotonic()

    train_loss = train(model, train_iterator, optimizer, device)
    valid_loss = evaluate(model, valid_iterator, device)

    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')

Start training VAE...


                                                                                                                                                                                                                         

Epoch: 01 | Epoch Time: 0m 7s
	Train Loss: 23.029 | Val. Loss: 23.224


                                                                                                                                                                                                                         

Epoch: 02 | Epoch Time: 0m 7s
	Train Loss: 23.043 | Val. Loss: 23.144


                                                                                                                                                                                                                         

Epoch: 03 | Epoch Time: 0m 7s
	Train Loss: 23.029 | Val. Loss: 23.079


                                                                                                                                                                                                                         

Epoch: 04 | Epoch Time: 0m 7s
	Train Loss: 22.970 | Val. Loss: 23.112


                                                                                                                                                                                                                         

Epoch: 05 | Epoch Time: 0m 8s
	Train Loss: 23.004 | Val. Loss: 23.066


                                                                                                                                                                                                                         

Epoch: 06 | Epoch Time: 0m 8s
	Train Loss: 23.008 | Val. Loss: 23.144


                                                                                                                                                                                                                         

Epoch: 07 | Epoch Time: 0m 8s
	Train Loss: 22.951 | Val. Loss: 23.087


                                                                                                                                                                                                                         

Epoch: 08 | Epoch Time: 0m 7s
	Train Loss: 22.961 | Val. Loss: 23.111


                                                                                                                                                                                                                         

KeyboardInterrupt: 

In [72]:
def print_board_bar():
    for i in range(26):
        print("-", end="")
    print()

def print_board(board):
    print(" "*3, end="| ")
    for i in range(11):
        print(f"{chr(i + 65)}", end=" ")
    print()
    print_board_bar()
    for i in range(18):
        print(f"{18 - i:2} |", end=" ")
        for j in range(11):
            num_printed = False
            for c in range(2):
                if board[17 - i, j, c] > 0:
                    char = "M"
                    if c > 0:
                        char = "S" if i >= 12 else "E"
                    print(char, end=" ")  
                    num_printed = True
            if not num_printed:
                print("-", end=" ")
        print()
    print_board_bar()

In [108]:
# this gives us a map of holds for a given problem, which we can tweak to produce a board
with torch.no_grad():
    noise = torch.normal(0, 1, size=(1, LATENT_DIM)).to(device)
    generated_board_map = decoder(noise)

In [109]:
# this uses the problem map to generate a board problem
def build_problem_from_map(generated_board_map, num_middle_holds, num_start_holds, num_end_holds):
    generated_board_map = generated_board_map.view(18, 11, 2)
    board = np.zeros((18, 11, 2))
    # first get middle holds
    v, i = torch.topk(generated_board_map[:, :, 0].flatten(), num_middle_holds)
    idx_list = np.array(np.unravel_index(i.cpu().numpy(), (18, 11))).T
    for[x, y] in idx_list:
        board[x, y, 0] = 1
    # now get start hold(s)
    v, i = torch.topk(generated_board_map[:6, :, 1].flatten(), num_start_holds)
    idx_list = np.array(np.unravel_index(i.cpu().numpy(), (6, 11))).T
    for[x, y] in idx_list:
        board[x, y, 1] = 1
    # then end hold(s)
    v, i = torch.topk(generated_board_map[17:, :, 1].flatten(), num_end_holds)
    idx_list = np.array(np.unravel_index(i.cpu().numpy(), (1, 11))).T
    for[x, y] in idx_list:
        board[x + 17, y, 1] = 1
    return board

In [128]:
# the full problem generator
with torch.no_grad():
        noise = torch.normal(0, 1, size=(1, LATENT_DIM)).to(device)
        generated_board_map = decoder(noise)
gen_board = build_problem_from_map(generated_board_map, 4, 2, 1)
print_board(gen_board)

   | A B C D E F G H I J K 
--------------------------
18 | - - - - - - E - - - - 
17 | - - - - - - - - - - - 
16 | - - - - - - - - - - - 
15 | - - - - M - - - - - - 
14 | - - - - - - - - - - - 
13 | - - - - - - M - - - - 
12 | - - - - - - - - - - - 
11 | - - - - - - - - - - - 
10 | - - - - M - - M - - - 
 9 | - - - - - - - - - - - 
 8 | - - - - - - - - - - - 
 7 | - - - - - - - - - - - 
 6 | - - - - - - - - - - - 
 5 | - - - - - S - - - - - 
 4 | - - - - - - S - - - - 
 3 | - - - - - - - - - - - 
 2 | - - - - - - - - - - - 
 1 | - - - - - - - - - - - 
--------------------------
