# Imports

In [1]:
import torchvision
import torch
import numpy as np
from collections import defaultdict
import torch.nn as nn

import pandas as pd
from sklearn.model_selection import train_test_split
from torchsummary import summary
import tqdm
import random
import math

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

Device:  cuda


In [3]:
!unzip MNIST_CSV.zip

# Dataloading

In [4]:
train_set = pd.read_csv("mnist_train.csv")
test_images = pd.read_csv("mnist_test.csv")

train_images, val_images, train_labels, val_labels = train_test_split(train_set.iloc[:, 1:],
                                                                     train_set.iloc[:, 0],
                                                                     test_size=0.3)

train_images.reset_index(drop=True, inplace=True)
val_images.reset_index(drop=True, inplace=True)
train_labels.reset_index(drop=True, inplace=True)
val_labels.reset_index(drop=True, inplace=True)

train_images = train_images.to_numpy(dtype='float32')
train_labels = train_labels.to_numpy(dtype='float32')

val_images = val_images.to_numpy()
val_labels = val_labels.to_numpy()

test_images = test_images.to_numpy()

In [5]:
class JigsawDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, permutations=10, img_transformer=None):
        self.images = images
        self.labels = labels
        self.permutations = permutations

        self.N = len(self.images)
        self.grid_size = 3

    def __retrieve_permutations(self):
        nums = range(self.grid_size * self.grid_size)

        return np.random.permutation(nums)

    def __get_image(self, index):
        return self.images[index]

    def __get_tiles(self, index):
        img = self.__get_image(index).reshape(28, 28)
        tiles = np.zeros((9, 9, 9), dtype='float32')

        for i in range(3):
            for j in range(3):
                tiles[i * 3 + j] = img[i*9:i*9+9, j*9:j*9+9]

        return tiles

    def __getitem__(self, index):
        n_grids = self.grid_size ** 2
        tiles = self.__get_tiles(index)

        order = self.__retrieve_permutations()

        data = [torch.from_numpy(tiles[order[t]]) for t in range(n_grids)]

        item = torch.stack(data, 0)
        return item, order, int(self.labels[index])

    def __len__(self):
        return self.N

In [6]:
class JigsawDatasetTest(torch.utils.data.Dataset):
    def __init__(self, images, permutations=10, img_transformer=None):
        self.images = images
        self.permutations = permutations

        self.N = len(self.images)
        self.grid_size = 3

    def __retrieve_permutations(self):
        nums = range(self.grid_size * self.grid_size)

        return np.random.permutation(nums)

    def __get_image(self, index):
        return self.images[index]

    def __get_tiles(self, index):
        img = self.__get_image(index).reshape(28, 28)
        tiles = np.zeros((9, 9, 9), dtype='float32')

        for i in range(3):
            for j in range(3):
                tiles[i * 3 + j] = img[i*9:i*9+9, j*9:j*9+9]

        return tiles

    def __getitem__(self, index):
        n_grids = self.grid_size ** 2
        tiles = self.__get_tiles(index)

        order = self.__retrieve_permutations()
        
        data = [torch.from_numpy(tiles[order[t]]) for t in range(n_grids)]

        item = torch.stack(data, 0)
        return item, order

    def __len__(self):
        return self.N

In [7]:
train_data = JigsawDataset(train_images, train_labels)
val_data = JigsawDataset(val_images, val_labels)
test_data = JigsawDatasetTest(test_images)

train_loader = torch.utils.data.DataLoader(
        dataset     = train_data,
        num_workers = 2,
        batch_size  = 1,
        pin_memory  = True,
        shuffle     = True,
)

val_loader = torch.utils.data.DataLoader(
        dataset     = val_data,
        num_workers = 1,
        batch_size  = 1,
        pin_memory  = True,
        shuffle     = False,
    )

test_loader = torch.utils.data.DataLoader(
        dataset     = test_data,
        num_workers = 1,
        batch_size  = 1,
        pin_memory  = True,
        shuffle     = False,
    )

print("Train dataset samples = {}, batches = {}".format(train_data.__len__(), len(train_loader)))
print("Val dataset samples = {}, batches = {}".format(val_data.__len__(), len(val_loader)))
print("Test dataset samples = {}, batches = {}".format(test_data.__len__(), len(test_loader)))


Train dataset samples = 41999, batches = 41999
Val dataset samples = 18000, batches = 18000
Test dataset samples = 9999, batches = 9999


# State and environment

In [8]:
import sys
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt

# Constants
GAMMA = 0.9

class State:
    def __init__(self, tiles, current_shuffle_list, last_action=None):
        self.tiles = tiles
        self.current_shuffle_list = current_shuffle_list
        self.last_action = last_action

        self.legal_actions = []

        for i in range(8, 0, -1):
          for j in range(i):
            self.legal_actions.append((i,j))

        num_incorrect = 9 - self.num_correct()
        print(f"Initial incorrect: {num_incorrect}")

    def get_legal_actions(self):
        return self.legal_actions

    def num_correct(self):
      count = 0
      for i in range(9):
        if i == self.current_shuffle_list[0][i]:
          count += 1

      return count

    def is_game_over(self):
        for i in range(9):
            if i != self.current_shuffle_list[0][i]:
                return False

        return True

    def move(self, action):
        i, j = action
        self.tiles[0][i], self.tiles[0][j] = self.tiles[0][j].clone(), self.tiles[0][i].clone()

        self.current_shuffle_list[0][i], self.current_shuffle_list[0][j] = self.current_shuffle_list[0][j].item(), self.current_shuffle_list[0][i].item()

        self.last_action = action
        return self

In [9]:
class JigsawEnv():

    def __init__(self, tiles, current_shuffle_list):
        self.state = State(tiles, current_shuffle_list)
        self.num_incorrect = 9 - self.state.num_correct()

        self.actions = []

        for i in range(8, 0, -1): # bc 0-8 is 9 numbers
          for j in range(i):
            self.actions.append((i,j))

    """
    Compute the reward for the current state of the environment.

    Parameters:
    - None

    Returns:
    - reward (float)
    """

    def computeReward(self):
        reward = 0

        if self.state.is_game_over():
            reward += 100
        else:
            for i in range(9):
                if self.state.current_shuffle_list[0][i] == i:
                    reward += 1

        return reward

    """
    Step the environment by one timestep.

    Parameters:
    - action

    Returns:
    - observation (object): agent's observation of the current environment
    - reward (float) : amount of reward returned after previous action
    - done (boolean): whether the episode has ended, in which case further step() calls will return undefined results
    - info (dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning)

    """
    def step(self, action):
        observation = self.state.move(action)
        reward =  self.computeReward()

        return observation, reward, self.state.is_game_over()

    """
    Renders the environment.
    """
    def render(self):
        fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(6,6))
        fig.subplots_adjust(hspace=.1)
        for i in range(3):
            for j in range(3):
                ax[i][j].axis('off')
                ax[i][j].imshow(self.state.tiles[(self.state.order==i * 3 + j).nonzero(as_tuple=True)[0].item()], cmap='gray')

In [10]:
"""
Global variable used to map policy network tensor to action tuple (i,j)
"""

actions = []

for i in range(8, 0, -1): # bc 0-8 is 9 numbers
  for j in range(i):
    actions.append((i,j))

# Policy Network

In [11]:
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
    def __init__(self, num_actions=36, lr=1e-3):
        super(PolicyNetwork, self).__init__()

        self.num_actions = num_actions
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(32, 16, kernel_size=9, stride=9)
        self.bn3 = nn.BatchNorm2d(16)

        self.linear1 = nn.Linear(144, 64)
        self.linear2 = nn.Linear(64, 36)

        self.activation = nn.SiLU()

        self.optimizer = optim.Adam(self.parameters(), lr=lr)

    def forward(self, state):
        out = self.assemble(state).to(DEVICE)
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.activation(out)


        out = self.conv2(out)
        out = self.bn2(out)
        out = self.activation(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.activation(out)

        out = out.view(out.size(0), -1)
        out = self.linear1(out)
        out = self.activation(out)

        out = self.linear2(out)
        out = F.softmax(out, dim=1)
        return out

    """
    Assembles 9x9x9 tensor into 1x27x27 tensor representing full image

    Parameters:
    - state (tensor): state of the environment

    Returns:
    - result_image (tensor): tensor representing full image
    """

    def assemble(self, state):
      num_batch = state.size(dim=0)
      result_image = torch.zeros(num_batch, 1, 27, 27)

      for i in range(9):
          row = i // 3  # Calculate the row index in the 3x3 grid
          col = i % 3   # Calculate the column index in the 3x3 grid

          # Calculate the starting position to place the image in the result tensor
          start_row = row * 9
          start_col = col * 9

          # Assign the 9x9 image to the correct position in the result tensor
          result_image[:, :, start_row:start_row + 9, start_col:start_col + 9] = state[:, 3*row+col, :, :]

      return result_image

    """
    Returns tensor of probabilities for each action

    Parameters:
    - state (tensor): state of the environment

    Returns:
    - probs (tensor): tensor of probabilities for each action
    """

    def get_action(self, state):
        probs = self.forward(Variable(state))
        probs = probs[0]
        probs = probs.cpu()
        return probs

In [12]:
policy_network = PolicyNetwork().to(DEVICE)
dict1 = torch.load('FINAL.pth')
policy_network.load_state_dict(dict1['model_state_dict'])

<All keys matched successfully>

In [13]:
total_params = sum(
	param.numel() for param in policy_network.parameters()
)
print(f'PARAMS: {total_params:,}')


PARAMS: 66,228


# Value Network

In [14]:
class ValueNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        self.resnet = torchvision.models.resnet18(pretrained=True)
        self.linear = nn.Linear(1000, 16)
        self.linear2 = nn.Linear(16, 1)
        self.softmax = nn.Softmax()
        self.flat = nn.Flatten()

    def forward(self, x):

        x = self.resnet(x)
        x = self.linear(x)
        x = self.linear2(x)
        # x = self.softmax(x)
        x = self.flat(x)
        x = torch.squeeze(x)

        return x

In [15]:
from google.colab import drive

value_network = ValueNetwork().to(DEVICE)
checkpoint = torch.load('/content/model_policy.pth')
value_network.load_state_dict(checkpoint['model_state_dict'])
value_network.eval()



ValueNetwork(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [16]:
total_params = sum(
	param.numel() for param in value_network.parameters()
)
print(f'PARAMS: {total_params:,}')


PARAMS: 11,705,545


# MCTS

In [17]:
policy_network.eval()
value_network.eval()

from IPython.display import clear_output
clear_output()

In [76]:
"""
TREE SEARCH VISUALIZATION
"""

import networkx as nx
from networkx.drawing.nx_agraph import write_dot, graphviz_layout
import copy


def display_graph(tree):
    G = nx.DiGraph()

    def add_edges(node):
        for child in node.children.values():
            if child.visits != 0:
              G.add_edge(node, child, label=(round(float(child.prior), 2), tuple(child.state.last_action)))
              add_edges(child)

    add_edges(tree)
    write_dot(G,'test.dot')

    # same layout using matplotlib with no labels
    pos =graphviz_layout(G, prog='dot')
    edge_labels = nx.get_edge_attributes(G, 'label')
    nx.draw(G, pos, with_labels=False, arrows=True, node_size=8)
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red', font_size=8)

    plt.show()

In [112]:
def render(state):
    fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(6,6))
    fig.subplots_adjust(hspace=.1)
    for i in range(3):
        for j in range(3):
            ax[i][j].axis('off')
            ax[i][j].imshow(state.tiles[0][i * 3 + j], cmap='gray')

def simulate(node, value_network):
    """
    Use value network to approximate the value of the state.

    Alternatively, we can also using the policy net to perform a greedy search for a terminal state and backprop this estimated value.
    AlphaZero uses a mixing of random rollouts and estimation from their value network.
    """

    assembled = policy_network.assemble(node.state.tiles).to(DEVICE)

    value = value_network(assembled.repeat(1,3,1,1))

    transformed = math.exp(5 * (value - 0.65))
    return transformed *

def custom_sort_key(children_with_values):
    return (children_with_values[1], children_with_values[0].prior)



In [122]:
class TreeNode:
    def __init__(self, state, prior, policy_network=None, value_network=None, parent=None):
        self.state = state
        self.parent = parent
        self.children = {}
        self.visits = 0
        self.value_total = 0.0
        self.prior = prior

        self.policy_network = policy_network
        self.value_network = value_network

    def PUCT(self, child, exploration_weight):
        """
        Inspired by AlphaGo, upper confidence bound for trees.


        PUCT = Q(s, a) + c * P(s, a) * sqrt(parent.vists) / (child.vists + 1)

        Q(s, a) is zero when there have been no visits to that state-action pair.
        Otherwise Q(s, a) is the average reward from vists.

        c is the exploration weight constant.

        P(s, a) is the prior probablity given by the policy network.
        """

        prior_score = exploration_weight * child.prior * math.sqrt(self.visits) / (child.visits + 1)

        if child.visits == 0:
          Q = 0
        else:
          Q = self.value_total / child.visits
        print(f"Q:{Q}")
        print(f"prior_score{prior_score}")
        return Q + prior_score


    def best_child(self, exploration_weight=1.0):
        """
        Using PUCT with exploration_weight parameter.

        We determine the action values from node statistics.
        We return the child with the highest action value.
        """
        if not self.children:
            return None

        children_with_values = [
            (child, self.PUCT(child, exploration_weight))
            for child in self.children.values()
        ]

        # print(sorted(children_with_values, key=custom_sort_key))
        # print(f"action values: {[float(value) for child, value in children_with_values]}")
        return sorted(children_with_values, key=custom_sort_key)[0][0]


    def add_children(self):
        """
        Add all children/edges/actions to the current node
        """
        if len(self.children) == 0:
          self.policy = policy_network.get_action(self.state.tiles)

          for i, prob in enumerate(self.policy):
            action = actions[i]
            next_state = copy.deepcopy(self.state).move(action)

            child = TreeNode(next_state, prob, self.policy_network, self.value_network, parent=self)
            self.children[action] = child


    def expand_and_evaluate(self):
        """
        Given an action, perform that action and expand the graph
        """
        self.add_children()
        reward = simulate(self, self.value_network)

        self.backpropagate(reward)

    def print_node_statistics(self):
        for node in self.children.values():
          print(f"Action: {node.state.last_action}, value: {node.value_total}, visits: {node.visits}, prior: {node.prior}")

    def most_visited_child(self):
        max = 0
        max_visited = None
        for child in self.children.values():
          if child.visits > max:
            max = child.visits
            max_visited = child

        return max_visited

    def depth(self):
        count = 0
        node = self
        while node is not None:
          count += 1
          node = node.parent

        return count

    def backpropagate(self, result):
        """
        Backpropogate up the tree using the reward from the value network.
        Update node statistics.
        """

        node = self
        while node is not None:
            node.visits += 1
            node.value_total += result
            node = node.parent

def mcts(root_node, policy_network=None, value_network=None, iterations=1000, exploration_weight=1.0):
    """
    We perform the following for steps for some number of iterations.
    This function estimates the best action from the given root state.

    1. Selection: choose node with highest action value
    2. Expansion: leaf node is expanded and policies are computed using the policy network
    3. Evaluation: using the value network to estimate the value of the state
    4. Backprop: we estimate the reward using the value network and update node statistics of all parents
    """

    root_node.add_children()
    root_node.visits = 1

    for i in range(iterations):
        node = root_node

        # DEBUG prints
        # print(f"iteration {i}")
        # node.print_node_statistics()

        # Selection
        while node.children:
            node = node.best_child(exploration_weight)

        # Expansion and Backpropgation (evaluation)
        node.expand_and_evaluate()

        exploration_weight = max(0, exploration_weight * 0.99)

    # display_graph(root_node)
    # Return most visited child
    best_child = root_node.most_visited_child()
    return best_child.state.last_action, best_child


In [123]:
tiles, order = None, None
for episode, (x, y, num) in enumerate(train_loader):
  tiles, order = x, y

  print(f"Digit: {num}")
  break

Digit: tensor([2])


In [None]:
env = JigsawEnv(tiles, order)
simulate(env.state, value_network)

In [None]:
print(f"Initial state: {env.state.current_shuffle_list}")
root_node = TreeNode(env.state, 0.0, policy_network, value_network)

print(f"Policy greedy recommendation: {actions[np.argmax(policy_network.get_action(env.state.tiles).detach().numpy())]}")

best_action, _ = mcts(root_node, policy_network, value_network, iterations=100, exploration_weight=100)
print(f"MCTS recommends: {best_action}")

root_node.print_node_statistics()

print(f"Initial state: {env.state.current_shuffle_list}")


In [None]:
root_node.children[(5,4)].print_node_statistics()

In [None]:
def solve(env, policy_network, value_network, iterations=50, exploration_weight=1.0):
    """
    Solves the environment using MCTS.

    Parameters:
    - env: environment
    - policy_network: policy network
    - value_network: value network
    - iterations: number of iterations to run MCTS
    - exploration_weight: exploration weight

    Returns:
    - actions: list of actions to solve the environment
    - state: final state
    """
    root_node = TreeNode(env.state, 0, policy_network, value_network)

    confidence = 0

    iter = 0
    max_swaps = 25

    while confidence < 0.85 and iter < max_swaps:
        best_action, root_node = mcts(root_node, policy_network, value_network, iterations=iterations, exploration_weight=exploration_weight)

        confidence = value_network(policy_network.assemble(root_node.state.tiles).to(DEVICE).repeat(1, 3, 1, 1))

        iter += 1

    return root_node.state, iter



In [None]:
def is_success(final_state):
    """
    Checks if the final state is the goal state.

    ***
    If tiles are identitical, we disregard if they in the "wrong" position
    ***

    Parameters:
    - env: environment
    - final_state: final state

    Returns:
    - success: True if the final state is the goal state, False otherwise
    """

    for i in range(9):
        if i != final_state.current_shuffle_list[0][i]:
            if torch.all(final_state.tiles[0][i] == 0):
                continue
            else:
                return False

    return True


In [None]:
def solve2(env, policy_network, value_network, iterations=50, exploration_weight=1.0):
    """
    Solves the environment using MCTS.

    Parameters:
    - env: environment
    - policy_network: policy network
    - value_network: value network
    - iterations: number of iterations to run MCTS
    - exploration_weight: exploration weight

    Returns:
    - actions: list of actions to solve the environment
    - state: final state
    """
    state = env.state
    root_node = TreeNode(state, 0, policy_network, value_network)

    confidence = 0

    iter = 0
    max_swaps = 25

    while not is_success(state) and iter < 50:
        best_action, root_node = mcts(root_node, policy_network, value_network, iterations=iterations, exploration_weight=exploration_weight)
        state = root_node.state

        # confidence = value_network(policy_network.assemble(root_node.state.tiles).to(DEVICE).repeat(1, 3, 1, 1))
        iter += 1

    return root_node.state, iter



In [None]:
"""
TEST!

Solving boards to completion using MCTS
"""
success = 0
steps = 0

for episode, (x, y, num) in enumerate(val_loader):
  tiles, order = x, y

  env = JigsawEnv(tiles, order)
  state, swaps = solve2(env, policy_network, value_network, iterations=100, exploration_weight=1)

  if is_success(state):
    success += 1
    steps += swaps
    print(f"Success in {swaps} steps!")
  else:
    print("Failed :(")

  render(state)

print(f"Percent success: {success / len(val_loader) * 100}%")
print(f"Average swaps on success: {steps / len(val_loader) * 100}%")

In [None]:
print(success/episode)