In [1]:
%pip install torch
%pip install wandb
%pip install einops

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from math import log2,ceil
import torch

# -------------------- PUZZLE SETTINGS -------------------- 

PATH = '/home/wsl/Polymtl/H23/INF6201/Projet/Network'


# global 2-swap gives a size 32640 neighborhood which can be too much
# for the GPU. Capping the swapping range helps reduce the neighborhood
# without losing connectivity.
SWAP_RANGE = 2

MAX_BSIZE = 16

PADDED_SIZE = MAX_BSIZE + 2
NORTH = 0
SOUTH = 1
WEST = 2
EAST = 3

GRAY = 0
BLACK = 23
RED = 24
WHITE = 25
N_COLORS = 23


# -------------------- NETWORK SETTINGS -------------------- 

GAE_LAMBDA = 0.988
ENTROPY_WEIGHT = 0.007
VALUE_WEIGHT = .5
HIDDEN_SIZES = [32,32,64,128]
KERNEL_SIZES = [3,3,3,3]

# -------------------- TRAINING SETTINGS -------------------- 
UNIT = torch.float

ENCODING = 'binary'

if ENCODING == 'binary':
    COLOR_ENCODING_SIZE = ceil(log2(N_COLORS))
elif ENCODING == 'ordinal':
    COLOR_ENCODING_SIZE = 1
elif ENCODING == 'one_hot':
    COLOR_ENCODING_SIZE = N_COLORS
else:
    raise ValueError(f"Encoding {ENCODING} not supported")
  
EPOCHS = 10
CHECKPOINT_PERIOD = 256*200
MINIBATCH_SIZE = 256
HORIZON = 4 * 256
OPT_EPSILON = 1e-6
LR = 6e-5
GAMMA = 0.95
CLIP_EPS = 0.2

CONFIG = {
    'encoding':ENCODING,
    'unit':UNIT,
    'Batch size':MINIBATCH_SIZE,
    'Gamma':GAMMA,
}

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import matplotlib
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from matplotlib.path import Path
from matplotlib.lines import Line2D
import numpy as np

import math
import copy

GRAY = 0
BLACK = 23
RED = 24
WHITE = 25

NORTH = 0
SOUTH = 1
WEST = 2
EAST = 3


class EternityPuzzle:

    def __init__(self, instance_file):

        with open(instance_file) as file:
            lines = file.readlines()

            self.board_size = int(lines[0])
            self.n_piece = self.board_size ** 2
            self.n_internal_connection = 2 * self.board_size * (self.board_size - 1)
            self.n_total_connection = self.n_internal_connection + self.board_size * 4

            flatten = lambda l: [item for sublist in l for item in sublist]

            self.piece_list = [(int(x.split()[0]), int(x.split()[1]), int(x.split()[2]), int(x.split()[3])) for line in
                               lines[1:] for x in line.strip().split('\n')]

            self.n_color = max(flatten(self.piece_list)) + 1

            assert (len(self.piece_list) == self.n_piece)

            for p in self.piece_list:
                assert (len(p) == 4)

    def generate_rotation(self, piece):

        initial_shape = piece
        rotation_90 = (piece[2], piece[3], piece[1], piece[0])
        rotation_180 = (piece[1], piece[0], piece[3], piece[2])
        rotation_270 = (piece[3], piece[2], piece[0], piece[1])

        return [initial_shape, rotation_90, rotation_180, rotation_270]

    def get_total_n_conflict(self, solution):

        n_conflict = 0

        for j in range(self.board_size):
            for i in range(self.board_size):

                k = self.board_size * j + i
                k_east = self.board_size * j + (i - 1)
                k_south = self.board_size * (j - 1) + i

                if i > 0 and solution[k][WEST] != solution[k_east][EAST]:
                    n_conflict += 1

                if i == 0 and solution[k][WEST] != GRAY:
                    n_conflict += 1

                if i == self.board_size - 1 and solution[k][EAST] != GRAY:
                    n_conflict += 1

                if j > 0 and solution[k][SOUTH] != solution[k_south][NORTH]:
                    n_conflict += 1

                if j == 0 and solution[k][SOUTH] != GRAY:
                    n_conflict += 1

                if j == self.board_size - 1 and solution[k][NORTH] != GRAY:
                    n_conflict += 1

        return n_conflict

    def display_solution(self, solution, output_file):

        if len(solution) < self.n_piece:
            solution = solution + [(WHITE, WHITE, WHITE, WHITE)] * (self.n_piece - len(solution))

        origin = 0
        size = self.board_size + 2

        color_dict = self.build_color_dict()

        fig, ax = plt.subplots()

        n_total_conflict = self.get_total_n_conflict(solution)

        n_internal_conflict = 0

        for j in range(size):  # y-axis
            for i in range(size):  # x-axis
                valid_draw = [0, size - 1]
                if i in valid_draw or j in valid_draw:
                    ax.add_patch(patches.Rectangle((i, j), i + 1, j + 1, fill=True, facecolor=color_dict[GRAY],
                                                   edgecolor=color_dict[BLACK]))
                else:
                    # ax.add_patch(patches.Rectangle((i, j), i + 1, j + 1, fill=True, facecolor='white', edgecolor='k'))

                    left_bot = (i, j)
                    right_bot = (i + 1, j)
                    right_top = (i + 1, j + 1)
                    left_top = (i, j + 1)
                    middle = (i + 0.5, j + 0.5)

                    instructions = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]

                    triangle_south_path = Path([left_bot, middle, right_bot, left_bot], instructions)
                    triangle_east_path = Path([right_top, middle, right_bot, right_top], instructions)
                    triangle_north_path = Path([right_top, middle, left_top, right_top], instructions)
                    triangle_west_path = Path([left_bot, middle, left_top, left_bot], instructions)

                    is_triangle_south_valid = True
                    is_triangle_north_valid = True
                    is_triangle_east_valid = True
                    is_triangle_west_valid = True

                    k = self.board_size * (j - 1) + (i - 1)
                    k_east = self.board_size * (j - 1) + (i - 2)
                    k_south = self.board_size * (j - 2) + (i - 1)

                    if i == 1:
                        is_triangle_west_valid = (solution[k][WEST] == GRAY)  # 1 for Gray
                    elif i == size - 2:
                        is_triangle_east_valid = (solution[k][EAST] == GRAY)
                        is_triangle_west_valid = solution[k][WEST] == solution[k_east][EAST]
                    else:
                        is_triangle_west_valid = solution[k][WEST] == solution[k_east][EAST]

                    if j == 1:
                        is_triangle_south_valid = (solution[k][SOUTH] == GRAY)
                    elif j == size - 2:
                        is_triangle_north_valid = (solution[k][NORTH] == GRAY)
                        is_triangle_south_valid = solution[k][SOUTH] == solution[k_south][NORTH]
                    else:
                        is_triangle_south_valid = solution[k][SOUTH] == solution[k_south][NORTH]

                    patch_south = patches.PathPatch(triangle_south_path, facecolor=color_dict[solution[k][SOUTH]],
                                                    edgecolor=color_dict[BLACK])

                    patch_north = patches.PathPatch(triangle_north_path, facecolor=color_dict[solution[k][NORTH]],
                                                    edgecolor=color_dict[BLACK])

                    patch_east = patches.PathPatch(triangle_east_path, facecolor=color_dict[solution[k][EAST]],
                                                   edgecolor=color_dict[BLACK])

                    patch_west = patches.PathPatch(triangle_west_path, facecolor=color_dict[solution[k][WEST]],
                                                   edgecolor=color_dict[BLACK])

                    if not is_triangle_south_valid:
                        line_zip = list(zip(left_bot, right_bot))
                        line = Line2D(line_zip[0], line_zip[1], color=color_dict[RED], lw=3)
                        ax.add_line(line)

                        if j != 1:
                            n_internal_conflict += 1

                    if not is_triangle_north_valid:
                        line_zip = list(zip(left_top, right_top))
                        line = Line2D(line_zip[0], line_zip[1], color=color_dict[RED], lw=3)
                        ax.add_line(line)

                        if j != size - 2:
                            n_internal_conflict += 1

                    if not is_triangle_west_valid:
                        line_zip = list(zip(left_bot, left_top))
                        line = Line2D(line_zip[0], line_zip[1], color=color_dict[RED], lw=3)
                        ax.add_line(line)

                        if i != 1:
                            n_internal_conflict += 1

                    if not is_triangle_east_valid:
                        line_zip = list(zip(right_bot, right_top))
                        line = Line2D(line_zip[0], line_zip[1], color=color_dict[RED], lw=3)
                        ax.add_line(line)

                        if i != size - 2:
                            n_internal_conflict += 1

                    ax.add_patch(patch_south)
                    ax.add_patch(patch_north)
                    ax.add_patch(patch_east)
                    ax.add_patch(patch_west)

                    k += 1

        plt.xlim(origin, size)
        plt.ylim(origin, size)

        title = 'Eternity of size %d X %d\n' \
                'Total connections: %d    Internal connections: %d\n' \
                'Total Valid connections: %d     Internal valid internal connections: %d\n' \
                'Total Invalid connections: %d    Internal invalid connections: %d' % \
                (self.board_size, self.board_size,
                 self.n_total_connection, self.n_internal_connection,
                 self.n_total_connection - n_total_conflict, self.n_internal_connection - n_internal_conflict,
                 n_total_conflict, n_internal_conflict,
                 )
        ax.set_title(title)

        plt.savefig(output_file)

    def print_solution(self, solution, output_file):
        with open(output_file, "w") as file:
            file.write(str(self.get_total_n_conflict(solution)) + "\n")
            file.write(str(self.board_size))
            for piece in solution:
                file.write("\n")
                for c in piece:
                    file.write(str(c) + " ")

    def build_color_dict(self):

        color_dict = {
            GRAY: 'gray',
            1: 'lightcoral',
            2: 'tab:blue',
            3: 'tab:orange',
            4: 'tab:green',
            5: 'gold',
            6: 'tab:purple',
            7: 'tab:brown',
            8: 'tab:pink',
            9: 'tab:olive',
            10: 'tab:cyan',
            11: 'deeppink',
            12: 'blue',
            13: 'slateblue',
            14: 'darkslateblue',
            15: 'darkviolet',
            16: 'teal',
            17: 'wheat',
            18: 'darkkhaki',
            19: 'indigo',
            20: 'fuchsia',
            21: 'lime',
            22: 'rosybrown',
            BLACK: 'black',
            RED: 'tab:red',
            WHITE: 'white'
        }
        return color_dict

    def hash_piece(self, piece):
        all = self.generate_rotation(piece)
        return min(all)

    def verify_solution(self,solution):
        hash_init = sorted([self.hash_piece(p) for p in self.piece_list])
        hash_sol = sorted([self.hash_piece(p) for p in solution])

        return hash_init == hash_sol



In [4]:
from datetime import datetime
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from hashlib import sha256
from torch.utils.data import TensorDataset,DataLoader
import torch.nn.init as init


class EpisodeBuffer:
    def __init__(self,capacity:int, n_tiles:int,bsize:int,device) -> None:
        self.capacity = capacity
        self.state_buf = torch.empty((capacity,bsize,bsize,4*COLOR_ENCODING_SIZE),device=device)
        self.act_buf = torch.empty((capacity),dtype=int,device=device)
        self.policy_buf = torch.empty((capacity,n_tiles),device=device)
        self.value_buf = torch.empty((capacity),device=device)
        self.next_value_buf = torch.empty((capacity),device=device)
        self.rew_buf = torch.empty((capacity),device=device)
        self.final_buf = torch.empty((capacity),dtype=int,device=device)
        self.ptr = 0

    def push(
            self,
            state,
            action,
            policy,
            value,
            next_value,
            reward,
            final
            ):


        self.state_buf[self.ptr] = state
        self.act_buf[self.ptr] = action
        self.policy_buf[self.ptr] = policy
        self.value_buf[self.ptr] = value
        self.next_value_buf[self.ptr] = next_value
        self.rew_buf[self.ptr] = reward
        self.final_buf[self.ptr] = final

        self.ptr += 1
        if self.ptr == self.capacity:
            self.ptr = 0
        
    def reset(self):
        if self.ptr != 0:
            raise OSError(self.ptr)
        
        self.ptr = 0



class BatchMemory:
    def __init__(self,n_tiles:int,bsize:int,capacity:int=MINIBATCH_SIZE, ep_length:int=256,device='cpu') -> None:
        self.capacity = capacity
        self.ep_length = ep_length
        self.n_tiles = n_tiles
        self.device = device
        self.bsize = bsize
        self.ptr = 0
        self.reset()


    def load(self,buff:EpisodeBuffer):
        self.state_buf[self.ptr] = buff.state_buf
        self.act_buf[self.ptr] = buff.act_buf
        self.policy_buf[self.ptr] = buff.policy_buf
        self.value_buf[self.ptr] = buff.value_buf
        self.next_value_buf[self.ptr] = buff.next_value_buf
        self.final_buf[self.ptr] = buff.final_buf
        self.rew_buf[self.ptr] = buff.rew_buf

        self.ptr += 1
    
        buff.reset()
    
    def reset(self):

        if self.ptr != self.capacity:
            print(self.ptr)
            print(Warning(f'Memory not full : {self.ptr}/{self.capacity}'))
        self.state_buf = torch.empty((self.capacity,self.ep_length,self.bsize,self.bsize,4*COLOR_ENCODING_SIZE),device=self.device)
        self.act_buf = torch.empty((self.capacity,self.ep_length),dtype=int,device=self.device)
        self.policy_buf = torch.empty((self.capacity,self.ep_length,self.n_tiles),device=self.device)
        self.value_buf = torch.empty((self.capacity,self.ep_length),device=self.device)
        self.next_value_buf = torch.empty((self.capacity,self.ep_length),device=self.device)
        self.rew_buf = torch.empty((self.capacity,self.ep_length),device=self.device)
        self.final_buf = torch.empty((self.capacity,self.ep_length),dtype=int,device=self.device)
        self.ptr = 0

    def __getitem__(self,key):
        return getattr(self,key)[:self.ptr+1]



class AdvantageBuffer():

    def __init__(self) -> None:

        self.state = None
        self.action = None
        self.policy = None
        self.value = None
        self.reward = None



class TabuList():

    def __init__(self,size) -> None:
        self.size = size
        self.tabu = {}

    def push(self,state:torch.Tensor, step:int):
        key = sha256(state.cpu().numpy()).hexdigest()
        self.tabu[key] = step + TABU_LENGTH
    
    def in_tabu(self,state):
        key = sha256(state.cpu().numpy()).hexdigest()
        return key in self.tabu.keys()
    
    def filter(self,step:int):
        self.tabu = {k:v for k,v in self.tabu.items() if v > step}

    def get_update(self,batch:torch.Tensor,step:int):

        np_batch = batch.cpu().numpy()
        for i in range(np_batch.shape[0]):

            key = sha256(np_batch[i]).hexdigest()

            if key not in self.tabu.keys():

                self.push(batch[i],step)
                return torch.from_numpy(np_batch[i]).to(device=batch.device).to(dtype=batch.dtype),i
            
        return None,None

    def fast_foward(self):
        vals = self.tabu.values()
        m = min(vals)
        print(m)
        for k in self.tabu.keys():
            self.tabu[k] -= m
        

class StoppingCriterion():

    def __init__(self,threshold) -> None:
        """
        Stopping critirion for a trajectory.
        Internal counter is updated each step :
         * \+ 1 if degrading move
         * \- 0.5 if the new score is better than the previous one
         * Reset to 0 if new best score
        """
        self.counter = 0
        self.prev_score = 0
        self.eos = False
        self.threshold = threshold


    def update(self,score,best_score):

        if score > best_score:
            self.counter = 0

        elif score > self.prev_score:
            self.counter -= 0.5

        else:
            self.counter += 1

        self.prev_score = score

        if self.counter > self.threshold:
            self.eos = True

    def is_stale(self):
        return self.eos
        
    
    def reset(self):
        self.counter = 0
        self.prev_score = 0
        self.eos = False




class Conv3to2d(nn.Module):

    def __init__(self,kernel_size,input_channels,layer_size,device) -> None:
        super().__init__()

        self.conv = nn.Conv2d(
            in_channels=input_channels,
            out_channels=layer_size,
            kernel_size= (kernel_size,kernel_size),
            dtype=UNIT,
            device=device,
            )
        
    def  forward(self,x):
        x = self.conv(rearrange(x,'b c h w d -> b c h (w d)'))
        return x

class View(nn.Module):

    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, input):
        return input.view(-1,*self.shape)

class SoftmaxStable(nn.Module):
    def forward(self, x):
        x = x - x.max(dim=-1, keepdim=True).values
        return F.softmax(x, dim=-1)
    
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_sizes, kernel_sizes, device):

        self.state_dim =state_dim
        super(ActorCritic, self).__init__()

        lin_h = self.lin_size(kernel_sizes,self.state_dim)
        twodto3d = self.lin_size(kernel_sizes[:1],self.state_dim) * self.lin_size(kernel_sizes[:1],4 * COLOR_ENCODING_SIZE)
        
        lin_w = self.lin_size(kernel_sizes[1:],twodto3d)

        print(state_dim)
        lin_size = lin_h * lin_w * hidden_sizes[-1]

        print(lin_h,lin_w)

        print(lin_size)

        self.actor = nn.Sequential(
            nn.Conv3d(1, hidden_sizes[0],kernel_sizes[0],device=device),
            nn.BatchNorm3d(hidden_sizes[0],device=device),
            nn.ReLU(),
            Conv3to2d(kernel_sizes[1],hidden_sizes[0],hidden_sizes[1],device),
            nn.BatchNorm2d(hidden_sizes[1],device=device),
            nn.ReLU(),
            nn.Conv2d(hidden_sizes[1], hidden_sizes[2],kernel_sizes[2],device=device),
            nn.BatchNorm2d(hidden_sizes[2],device=device),
            nn.ReLU(),
            nn.Conv2d(hidden_sizes[2], hidden_sizes[3],kernel_sizes[3],device=device),
            nn.BatchNorm2d(hidden_sizes[3],device=device),
            View((lin_size,)),
            nn.ReLU(),
            nn.Linear(lin_size, 128,device=device),
            nn.ReLU(),
            nn.Linear(128, action_dim,device=device),
            SoftmaxStable()
        )

        self.critic = nn.Sequential(
            nn.Conv3d(1, hidden_sizes[0],kernel_sizes[0],device=device),
            nn.BatchNorm3d(hidden_sizes[0],device=device),
            nn.ReLU(),
            Conv3to2d(kernel_sizes[1],hidden_sizes[0],hidden_sizes[1],device),
            nn.BatchNorm2d(hidden_sizes[1],device=device),
            nn.ReLU(),
            nn.Conv2d(hidden_sizes[1], hidden_sizes[2],kernel_sizes[2],device=device),
            nn.BatchNorm2d(hidden_sizes[2],device=device),
            nn.ReLU(),
            nn.Conv2d(hidden_sizes[2], hidden_sizes[3],kernel_sizes[3],device=device),
            nn.BatchNorm2d(hidden_sizes[3],device=device),
            View((lin_size,)),
            nn.ReLU(),
            nn.Linear(lin_size, 128,device=device),
            nn.ReLU(),
            nn.Linear(128, 1,device=device),
        )

        def init_weights(m):
            if type(m) == nn.Module:
                init.xavier_normal_(m.weight)

        self.critic.apply(init_weights)
        self.actor.apply(init_weights)

    def lin_size(self, kernel_sizes, dim, strides=None):

        size = dim

        if strides is None:
            strides = [1] * len(kernel_sizes)

        for ks,st in zip(kernel_sizes,strides):
            size = (size - ks) // st + 1

        return size


    def forward(self, state):
        policy = self.actor(state)
        value = self.critic(state)
        return policy, value

class PPOAgent:
    def __init__(self,config):
        self.gamma = config['gamma']
        self.clip_eps = config['clip_eps']
        self.lr = config['lr']
        self.epochs = config['epochs']
        self.minibatch_size = config['minibatch_size']
        self.horizon = config['horizon']
        self.state_dim = config['state_dim']
        self.entropy_weight = config['entropy_weight']
        self.value_weight = config['value_weight']
        self.gae_lambda = config['gae_lambda']
        device = config['device']
        n_tiles = config['n_tiles']
        hidden_sizes = config['hidden_sizes']
        kernel_sizes = config['kernel_sizes']

        self.action_dim = n_tiles
        self.model = ActorCritic(self.state_dim, self.action_dim, hidden_sizes, kernel_sizes, device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr,weight_decay=1e-4)
        
    def get_action(self, policy:torch.Tensor,mask:torch.BoolTensor):

        if mask.count_nonzero() == 0:
            raise Exception("No playable tile")
        sm = torch.softmax(policy,-1) * mask
        sm /= sm.sum(dim=-1, keepdim=True)
        return torch.multinomial(sm,1)
        
    def compute_gae(self, rewards, values, next_values, finals):

        td_errors = rewards + self.gamma * next_values * (1 - finals) - values
        gae = 0
        advantages = torch.zeros_like(td_errors)
        for t in reversed(range(len(td_errors))):
            gae = td_errors[t] + self.gamma * self.gae_lambda * (1 - finals[t]) * gae
            advantages[t] = gae
        return advantages
    
    # def update(self, states, actions, old_policies, values, advantages, returns):
    def update(self,mem:BatchMemory):

        advantages = self.compute_gae(
            rewards=mem['rew_buf'],
            values=mem['value_buf'],
            next_values=mem['next_value_buf'],
            finals=mem['final_buf']
            )

        returns = advantages + mem['value_buf']

        
        device = 'cuda'
        t0 = datetime.now()
        dataset = TensorDataset(
            rearrange(mem['state_buf'],'b ep h w d -> (b ep) h w d').unsqueeze(1).to(device),
            rearrange(mem['act_buf'],'b ep -> (b ep)').to(device),
            rearrange(mem['policy_buf'],'b ep p -> (b ep) p').to(device),
            rearrange(advantages,'b ep -> (b ep)').to(device),
            rearrange(returns,'b ep -> (b ep)').to(device)
        )

        loader = DataLoader(dataset, batch_size=self.minibatch_size, shuffle=True)
        
        self.model = self.model.to(device)
        # Perform multiple update epochs
        for k in range(self.epochs):
            for batch in loader:
                batch_states, batch_actions, batch_old_policies, batch_advantages, batch_returns = batch

                batch_advantages = (batch_advantages - batch_advantages.mean()) / (batch_advantages.std()+1e-7)
                batch_returns = (batch_returns - batch_returns.mean()) / (batch_returns.std()+1e-7)
                # Calculate new policy and value estimates
                batch_policy, batch_value = self.model(batch_states)
                # Calculate ratios and surrogates for PPO loss
                action_probs = batch_policy.gather(1, batch_actions.unsqueeze(1))
                old_action_probs = batch_old_policies.gather(1, batch_actions.unsqueeze(1))
                ratio = action_probs / (old_action_probs + 1e-6)
                clipped_ratio = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps)
                surrogate1 = ratio * batch_advantages.unsqueeze(1)
                surrogate2 = clipped_ratio * batch_advantages.unsqueeze(1)
                policy_loss = -torch.min(surrogate1, surrogate2).mean()
                # Calculate value function loss
                value_loss = F.mse_loss(batch_value.squeeze(), batch_returns) * self.value_weight

                # Calculate entropy bonus
                entropy = -(batch_policy[batch_policy != 0] * torch.log(batch_policy[batch_policy != 0])).sum(dim=-1).mean()
                entropy_loss = -self.entropy_weight * entropy
                # Compute total loss and update parameters


                loss = policy_loss + value_loss + entropy_loss
                if False:
                    print("--------")
                    print(batch_actions.max())
                    print(batch_actions.min())
                    print(batch_advantages.max())
                    print(batch_advantages.min())
                    print(batch_old_policies.max())
                    print(batch_old_policies.min())
                    print(batch_returns.max())
                    print(batch_returns.min())
                    print(torch.topk(batch_policy,4).values)
                    print(batch_policy.min())
                    print(entropy_loss.item(),value_loss.item(),policy_loss.item(),loss)
                    print(batch_states.max())
                    print(batch_states.min())

                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),0.5)
                self.optimizer.step()
                
                g=0
                for name,param in self.model.named_parameters():
                    # print(f"{name:>28} - {torch.norm(param.grad)}")
                    g+=torch.norm(param.grad)
                    param.grad.clamp_(-1,1)

            if k == self.epochs -1:
                    

                wandb.log({
                    "Total loss":loss,
                    "Cumul grad norm":g,
                    "Value loss":value_loss,
                    "Entropy loss":entropy_loss,
                    "Policy loss":policy_loss,
                    })

        if  not torch.cuda.is_available():
            self.model = self.model.cpu()
        print(datetime.now()-t0)
        mem.reset()

In [5]:
import argparse
import random
import sys
import torch
from torch import Tensor
from einops import rearrange, repeat


# -------------------- UTILS --------------------

def parse_arguments():
    parser = argparse.ArgumentParser()

    # Instances parameters
    parser.add_argument('--instance', type=str, default='input')
    parser.add_argument('--hotstart', type=str, default=False)

    return parser.parse_args()


def initialize_sol(instance_file:str, device):

    pz = EternityPuzzle(instance_file)
    n_tiles = len(pz.piece_list)
    tiles = rearrange(to_tensor(pz.piece_list),'h w d -> (h w) d').to(device)
    return torch.zeros((pz.board_size+2,pz.board_size+2,4*COLOR_ENCODING_SIZE),device=device), tiles, n_tiles




def pprint(state,bsize):
    offset = (PADDED_SIZE - bsize) // 2

    if state.size()[0] != PADDED_SIZE:
        for s in state:
            print(s[offset:offset+bsize,offset:offset+bsize])

    else:
        print(state[offset:offset+bsize,offset:offset+bsize])
        



def ucb(q,count,step):

    return q + 0 * torch.sqrt(-torch.log((count + 0.1)/(step + 0.1)))

def binary(x: Tensor, bits):
    mask = 2**torch.arange(bits)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).to(UNIT)


def to_tensor(list_sol:list, encoding = ENCODING,gray_borders:bool=False) -> Tensor:
    """
    Converts solutions from list format to a torch Tensor.
    Tensor format:
    [MAX_BSIZE, MAX_BSIZE, N_COLORS * 4]
    Each tile is represented as a vector, consisting of concatenated one hot encoding of the colors
    in the order  N - S - E - W . 
    If there were 4 colors a grey tile would be :
        N       S       E       W
    [1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0]
    TODO: convert the list to tensor once 
    """

    # To be able to rotate tiles easily, it is better to have either NESW or NWSE
    orientation = [0,2,1,3]

    list_sol = list_sol.copy()
    b_size = int(len(list_sol)**0.5)


    if gray_borders:
        sol = []
        l = len(list_sol)

        for k in range(l):

            # if on the border
            if k < b_size or k % b_size == 0 or k % (b_size) == b_size-1 or k > b_size * (b_size - 1):
                border = True
            
            else:
                border = False

            for i in range(len(list_sol)):
                
                tile = list_sol[i]

                if (border and GRAY in tile) or (not border and not (GRAY in tile)):
                    sol.append(tile)
                    list_sol.pop(i)
                    break
    else:
        sol = list_sol
    b_size = int(len(sol)**0.5)

    if encoding == 'binary':
        color_enc_size = ceil(torch.log2(torch.tensor(N_COLORS)))
        tens = torch.zeros((b_size,b_size,4*color_enc_size), device='cuda' if torch.cuda.is_available() else 'cpu',dtype=UNIT)

        # Tiles around the board
        # To make sure the policy learns that the gray tiles are always one the border,
        # the reward for connecting to those tiles is bigger.
        tens[0,:,color_enc_size:2*color_enc_size] = binary(torch.tensor(GRAY),color_enc_size)
        tens[:,0,2*color_enc_size:color_enc_size*3] = binary(torch.tensor(GRAY),color_enc_size)
        tens[b_size-1,:,:color_enc_size] = binary(torch.tensor(GRAY),color_enc_size)
        tens[:,b_size-1,3*color_enc_size:] = binary(torch.tensor(GRAY),color_enc_size)



        # center the playable board as much as possible
        #one hot encode the colors
        for i in range(b_size):
            for j in range(b_size):

                tens[i,j,:] = 0

                for d in range(4):
                    dir = orientation[d]
                    tens[i,j, d * color_enc_size:(d+1) * color_enc_size] = binary(torch.tensor(sol[i * b_size + j][dir]),color_enc_size)

    elif encoding == 'ordinal':
        tens = torch.zeros((b_size,b_size,4), device='cuda' if torch.cuda.is_available() else 'cpu',dtype=UNIT)

        # Tiles around the board
        # To make sure the policy learns that the gray tiles are always one the border,
        # the reward for connecting to those tiles is bigger.
        tens[0,:,1] = 0
        tens[:,0,2] = 0
        tens[b_size-1,:,0] = 0
        tens[:,b_size-1,3] = 0


        # center the playable board as much as possible
        #one hot encode the colors
        for i in range(b_size):
            for j in range(b_size):

                tens[i,j,:] = 0

                for d in range(4):
                    dir = orientation[d]
                    tens[i,j,d] = torch.tensor(sol[i * b_size + j][dir])
        
        tens.unsqueeze(-1)


    else:

        tens = torch.zeros((b_size,b_size,4*N_COLORS), device='cuda' if torch.cuda.is_available() else 'cpu',dtype=UNIT)

        tens[0,:,N_COLORS + GRAY] = 1
        tens[:,0,N_COLORS * 2 + GRAY] = 1
        tens[b_size-1,:,GRAY] = 1
        tens[:,b_size-1,N_COLORS * 3 + GRAY] = 1


        # center the playable board as much as possible
        #one hot encode the colors
        for i in range(b_size):
            for j in range(b_size):

                if i >= 0 and i < b_size and j >= 0 and j < b_size:
                    tens[i,j,:] = 0

                    for d in range(4):
                        dir = orientation[d]
                        tens[i,j, d * N_COLORS + sol[i * b_size + j][dir]] = 1
                    
                else:
                    for dir in range(4):
                        tens[i,j, orientation[dir] * N_COLORS] = 1
        

    return tens

def base10(x:Tensor):
    s = 0
    for i in range(x.size()[0]):
        s += x[i] * 2**i
    
    return int(s)

def pprint(state,bsize):
    offset = (MAX_BSIZE + 2 - bsize) // 2

    if state.size()[0] != MAX_BSIZE + 2:
        for s in state:
            print(s[offset:offset+bsize,offset:offset+bsize])

    else:
        print(state[offset:offset+bsize,offset:offset+bsize])
        


def to_list(sol:torch.Tensor,bsize:int) -> list:

    orientation = [0,2,1,3]

    list_sol = []

    sol.int()

    offset = 1

    if ENCODING == 'binary':

        for i in range(offset, offset + bsize):
            for j in range(offset, offset + bsize):

                temp = [0]*4
                for d in range(4):
                    dir = orientation[d]
                    temp[d] = base10(sol[i,j,dir*COLOR_ENCODING_SIZE:(dir+1)*COLOR_ENCODING_SIZE])
                
                list_sol.append(tuple(temp))
    
    elif ENCODING == 'ordinal':

        for i in range(offset, offset + bsize):
            for j in range(offset, offset + bsize):

                temp = [0] * 4
                for d in range(4):
                    dir = orientation[d]
                    temp[d] = sol[i,j,dir].item()

                list_sol.append(tuple(temp))

    if ENCODING == 'one_hot':

        for i in range(offset, offset + bsize):
            for j in range(offset, offset + bsize):

                temp = [0] * 4

                for d in range(4):
                    dir = orientation[d]
                    temp[d] = torch.where(sol[i,j,dir*COLOR_ENCODING_SIZE:(dir+1)*COLOR_ENCODING_SIZE] == 1)[0].item()
                list_sol.append(tuple(temp))

    return list_sol

In [7]:
import os
import torch
from torch import Tensor
from math import  exp
import wandb

LOG_EVERY = 1

def train_model(instance:str,hotstart:str = None):

    """
    Your solver for the problem
    :param eternity_puzzle: object describing the input
    :return: a tuple (solution, cost) where solution is a list of the pieces (rotations applied) and
        cost is the cost of the solution
    """
    pz = EternityPuzzle(instance)
    n_tiles = len(pz.piece_list)
    bsize = pz.board_size
    # torch.cuda.is_available = lambda : False
    
    # -------------------- GAME INIT --------------------

    # -------------------- NETWORK INIT -------------------- 

    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'


    cfg = {
    'n_tiles' : n_tiles ,
    'gamma' : GAMMA ,
    'clip_eps' : CLIP_EPS ,
    'lr' : LR ,
    'epochs' : EPOCHS ,
    'minibatch_size' : MINIBATCH_SIZE ,
    'horizon' : HORIZON,
    'state_dim' : bsize+2,
    'hidden_sizes': HIDDEN_SIZES,
    'kernel_sizes': KERNEL_SIZES,
    'entropy_weight':ENTROPY_WEIGHT,
    'value_weight':VALUE_WEIGHT,
    'gae_lambda':GAE_LAMBDA,
    'device' : device 
    }


    agent = PPOAgent(cfg)

    move_buffer = AdvantageBuffer()
    
    memory = BatchMemory(
        n_tiles=n_tiles,
        bsize=bsize+2,
        ep_length=bsize**2,
        capacity=HORIZON//n_tiles,
        device=device
    )

    ep_buf = EpisodeBuffer(
        capacity=n_tiles,
        n_tiles=n_tiles,
        bsize=bsize+2,
        device=device
    )

    max_entropy = torch.log2(torch.tensor(MAX_BSIZE**2))
    policy_entropy = max_entropy

    # -------------------- TRAINING LOOP --------------------

    torch.cuda.empty_cache()

    step = 0

    print("start")

    best_score = 0
    episode = 0

    try:
        
        while 1:

            state, remaining_tiles, n_tiles = initialize_sol(instance,device)
            state = state.to(dtype=UNIT)
            mask = (torch.zeros_like(remaining_tiles[:,0]) == 0)
            episode_end = False
            prev_conflicts = get_conflicts(state,bsize)

            # print(f"NEW EPISODE : - {episode:>5}")
            conflicts = 0
            ep_reward = 0
            consec_good_moves = 0
            episode_best_score = 0
            ep_start = step
            ep_step = 0

            if episode != 0:
                memory.load(ep_buf)

            while not episode_end:


                with torch.no_grad():
                    policy, value = agent.model(state.unsqueeze(0).unsqueeze(0))
                    # print(policy.max())
                    # print(policy.min())

                selected_tile_idx = agent.get_action(policy,mask)
                
                selected_tile = remaining_tiles[selected_tile_idx]

                mask[selected_tile_idx] = False

                new_state, new_conf = place_tile(state,selected_tile,ep_step)

                conflicts += new_conf #get_conflicts(new_state,bsize)

                if new_conf == 0:
                    consec_good_moves += 1
                    reward = streak(consec_good_moves,n_tiles)
                else:
                    consec_good_moves = 0
                    reward = -0.3 * new_conf

                ep_reward += reward

                if ep_step != 0:

                    ep_buf.push(
                        state = move_buffer.state,
                        action = move_buffer.action,
                        policy = move_buffer.policy,
                        value = move_buffer.value,
                        next_value = value,
                        reward = move_buffer.reward,
                        final = 0
                    )


                if ep_step == n_tiles-1:

                    # pz.display_solution(to_list(new_state,bsize),f"{step}")
                    ep_buf.push(
                        state = state,
                        action = selected_tile_idx,
                        policy = policy,
                        value = value,
                        next_value = 0,
                        reward = move_buffer.reward,
                        final = 1
                    )
                    if episode % 10 == 0:
                        print(f"END EPISODE {episode} - Conflicts {conflicts}/{bsize * 2 *(bsize+1)}",end='\r')
                    episode_end = True
                    episode += 1

                    if episode % LOG_EVERY == 0:
                        wandb.log({"Mean episode reward":ep_reward/(step - ep_start + 1e-5),'Final conflicts':conflicts})
                    

                if (step) % HORIZON == 0 and step != 0:

                    agent.update(
                        mem=memory
                    )


                
                prev_conflicts = conflicts
                move_buffer.state = state 
                move_buffer.action = selected_tile_idx 
                move_buffer.policy = policy 
                move_buffer.value = value 
                move_buffer.reward = reward 


                state = new_state

                with torch.no_grad():
                    policy_prob = torch.softmax(policy,dim=-1).squeeze(-1)
                    policy_prob = policy_prob[policy_prob != 0]
                    policy_entropy = -(policy_prob * torch.log2(policy_prob)).sum()
                    
                

                if step % LOG_EVERY==0:

                    wandb.log(
                        {   
                            'Score':conflicts,
                            'Relative policy entropy': policy_entropy/max_entropy,
                            'Value': value,
                            'Reward': reward,
                        }
                    )
                # -------------------- MODEL OPTIMIZATION --------------------


                # torch.cuda.empty_cache() 

                step += 1
                ep_step += 1
        
                # checkpoint the policy net
                if step % CHECKPOINT_PERIOD == 0:
                    inst = instance.replace("instances/eternity_","")
                    inst = inst.replace(".txt","")
                    try:
                        torch.save(agent.model.state_dict(), f'models/checkpoint/{inst}/{step // CHECKPOINT_PERIOD}.pt')
                    except:
                        os.mkdir(f"models/checkpoint/{inst}/")
                        torch.save(agent.model.state_dict(), f'models/checkpoint/{inst}/{step // CHECKPOINT_PERIOD}.pt')


            if episode_best_score > best_score:
                best_score = episode_best_score

    except KeyboardInterrupt:
        pass

    print("STILL VALID :",pz.verify_solution(to_list(state,bsize)))
    print(best_score)
    return 


def place_tile(state:Tensor,tile:Tensor,step:int):

    state = state.clone()
    bsize = state.size()[0] - 2
    best_conf = 540
    for dir in range(4):
        tile = tile.roll(COLOR_ENCODING_SIZE,-1)
        state[step // bsize + 1, step % bsize + 1,:] = tile
        conf = filling_conflicts(state,bsize,step)
        if conf < best_conf:
            best_state=state.clone()
            best_conf=conf

    return best_state, best_conf

def streak(streak_length:int, n_tiles):
    return (2 - exp(-streak_length * 3/(0.8 * n_tiles)))



def filling_conflicts(state:Tensor, bsize:int, step):
    i = step // bsize + 1
    j = step % bsize + 1
    west_tile_color = state[i,j-1,3*COLOR_ENCODING_SIZE:4*COLOR_ENCODING_SIZE]
    south_tile_color = state[i-1,j,:COLOR_ENCODING_SIZE]

    west_border_color = state[i,j,1*COLOR_ENCODING_SIZE:2*COLOR_ENCODING_SIZE]
    south_border_color = state[i,j,2*COLOR_ENCODING_SIZE:3*COLOR_ENCODING_SIZE]

    conf = 0

    if j == 1:
        if not torch.all(west_border_color == 0):
            conf += 1
    
    elif not torch.all(west_border_color == west_tile_color):
        conf += 1

    if i == 1:
        if not torch.all(south_border_color == 0):
            conf += 1
    
    elif not torch.all(south_border_color == south_tile_color):
        conf += 1
   
   
    if j == bsize:

        east_border_color = state[i,j,3*COLOR_ENCODING_SIZE:4*COLOR_ENCODING_SIZE]

        if not torch.all(east_border_color == 0):
            conf += 1
    

    if i == bsize:

        north_border_color = state[i,j,:COLOR_ENCODING_SIZE]
        if not torch.all(north_border_color == 0):
            conf += 1
    

    return conf
        
        



def get_conflicts(state:Tensor, bsize:int, step:int = 0) -> int:

    offset = 1
    mask = torch.ones(bsize**2)
    board = state[offset:offset+bsize,offset:offset+bsize].clone()
    
    extended_board = state[offset-1:offset+bsize+1,offset-1:offset+bsize+1]

    n_offset = extended_board[2:,1:-1,2*COLOR_ENCODING_SIZE:3*COLOR_ENCODING_SIZE]
    s_offset = extended_board[:-2,1:-1,:COLOR_ENCODING_SIZE]
    w_offset = extended_board[1:-1,:-2,3*COLOR_ENCODING_SIZE:4*COLOR_ENCODING_SIZE]
    e_offset = extended_board[1:-1,2:,COLOR_ENCODING_SIZE:2*COLOR_ENCODING_SIZE]

    n_connections = board[:,:,:COLOR_ENCODING_SIZE] == n_offset
    s_connections = board[:,:,2*COLOR_ENCODING_SIZE:3*COLOR_ENCODING_SIZE] == s_offset
    w_connections = board[:,:,COLOR_ENCODING_SIZE: 2*COLOR_ENCODING_SIZE] == w_offset
    e_connections = board[:,:,3*COLOR_ENCODING_SIZE: 4*COLOR_ENCODING_SIZE] == e_offset



    redundant_ns = torch.logical_and(n_connections[:-1,:],s_connections[1:,:])
    redundant_we = torch.logical_and(w_connections[:,1:],e_connections[:,:-1])

    redundant_connections = torch.all(redundant_we,dim=-1).sum() + torch.all(redundant_ns,dim=-1).sum()

    all = (torch.all(n_connections,dim=-1).sum() + torch.all(s_connections,dim=-1).sum() + torch.all(e_connections,dim=-1).sum() + torch.all(w_connections,dim=-1).sum())
    
    total_connections = all - redundant_connections

    max_connections = (bsize + 1) * bsize * 2

    return max_connections - total_connections




In [8]:
!wandb login cdd836c352ffd933807c80225c7b616d7ba369d7
!wandb online

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/azureuser/.netrc
W&B online. Running your script from this directory will now sync to the cloud.


In [11]:
# ----------- MAIN CALL -----------

instance = 'instances/eternity_complet.txt'


CONFIG['Instance'] = instance

wandb.init(
    project='Eternity II',
    group='Distributional approach',
    config=CONFIG
)

train_model(instance)

wandb.finish()

FileNotFoundError: [Errno 2] No such file or directory: 'instances/eternity_complet.txt'