## Define the model architecture

In [277]:
import torch
import torch.nn as nn
from src.env import *
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_

class ValuePolicyNetwork(nn.Module):
    def __init__(self):
        super(ValuePolicyNetwork, self).__init__()

        # Define shared layers
        self.shared_layers = nn.Sequential(
            nn.Linear(480, 4096),
            nn.LayerNorm(4096),
            nn.ELU(),
            nn.Linear(4096, 2048),
            nn.LayerNorm(2048),
            nn.ELU()
        )

        # Define value head
        self.value_head = nn.Sequential(
            nn.Linear(2048, 512),  # 2048 -> 512
            nn.LayerNorm(512),
            nn.ELU(),
            nn.Linear(512, 1)  # 512 -> 1 (scalar value)
        )

        # Define policy head
        self.policy_head = nn.Sequential(
            nn.Linear(2048, 512),  # 2048 -> 512
            nn.LayerNorm(512),
            nn.ELU(),
            nn.Linear(512, 12)  # 512 -> 12 (policy logits)
        )
        
        # Apply Glorot initialization
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            xavier_uniform_(m.weight)  # Glorot initialization for weights
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x):
        # Transform input
        b = x.shape[0]
        x = x.view(b, -1)
        shared_out = self.shared_layers(x)
        value_out = self.value_head(shared_out)
        policy_out = self.policy_head(shared_out)
        return value_out, policy_out


## Generate training samples

In [278]:
from copy import deepcopy
import random

def generate_samples(k: int, l: int):
    samples = []
    for _ in range(l):
        cube = Cube()
        actions = []
        # Make random moves up to depth k
        for _ in range(k):
            appended = False
            move_index = 0
            while not appended:
                move_index = random.randint(0, 11)
                if len(actions) == 0:
                    appended = True
                    continue
                # Ensure that adjacent moves are not repeated
                if move_index%2 == 0 and actions[-1] == move_index + 1:
                    continue
                if move_index%2 == 1 and actions[-1] == move_index - 1:
                    continue
                appended = True
            cube.move(move_index)
            actions.append(move_index)
            state = (deepcopy(cube.get_state()), deepcopy(actions))
            samples.append(state)

    # Transform samples to dictionary format
    samples_dict = []
    for state, actions in samples:
        sample_dict = {
            "state": [state[0], state[1]],
            "actions": actions
        }
        samples_dict.append(sample_dict)

    return samples_dict


In [279]:
# Define reward
def reward(cube: Cube, action):
    new_cube = deepcopy(cube)
    new_cube.move(action)
    # +1 if solved
    if new_cube.is_solved():
        return 1
    # -1 otherwise
    else:
        return -1

In [280]:
def custom_loss(y_vi_pred, y_pi_pred, y_vi, y_pi, weight, alpha=1):
    # Compute per-sample losses for value and policy
    loss_v = nn.MSELoss(reduction='none')(y_vi_pred, y_vi)  # Shape: (batch_size,)
    loss_p = nn.CrossEntropyLoss(reduction='none')(y_pi_pred, y_pi)  # Shape: (batch_size,)

    # Apply weights to the per-sample losses
    weighted_loss_v = (loss_v * weight).mean()
    weighted_loss_p = (loss_p * weight).mean()

    # Combine losses
    return weighted_loss_v + alpha * weighted_loss_p

In [281]:
from tqdm import tqdm

def train(samples, model, epochs, optimizer, loss_fn, device):
    for epoch in range(epochs):
        bar = tqdm(range(0, len(samples), batch_size), desc="Training")
        running_loss = 0
        for step, i in enumerate(bar):
            batch = samples[i:i + batch_size]
            
            values = torch.zeros(len(batch)).to(device)
            policies = torch.zeros((len(batch), 12)).to(device)
            data = torch.zeros((len(batch), 20, 24)).to(device)
            weights = torch.zeros(len(batch)).to(device)
            
            # Compute the target value and target policy
            for j, sample in enumerate(batch):
                children_vi = torch.zeros(12).to(device)
                children_pi = []

                cube = Cube(sample['state'][0], sample['state'][1])
                
                # Go through all actions
                for action in range(12):
                    child_cube = deepcopy(cube)
                    r = reward(child_cube, action)

                    child_cube.move(action)
                    corners, edges = child_cube.get_state()      
                    cube_representation = torch.concat([corners, edges], dim=0)
                    cube_representation_encoded = F.one_hot(cube_representation, num_classes=24).float().to(device)
                    with torch.no_grad():
                        v, p = model(cube_representation_encoded.unsqueeze(0))

                    children_vi[action] = v + r
                    children_pi.append(p)
                
                # Choose the best action
                target_vi = torch.max(children_vi)
                target_pi = F.one_hot(torch.argmax(children_vi), num_classes=12).float()
                
                values[j] = target_vi
                policies[j] = target_pi
                corners, edges = cube.get_state()
                cube_representation = torch.concat([corners, edges], dim=0)
                cube_representation_encoded = F.one_hot(cube_representation, num_classes=24).float().to(device)
                data[j] = cube_representation_encoded
                
                # Save the weight for the loss computation. Longer paths have lower weight
                weights[j] = 1 / len(sample['actions'])
            
            # Compute loss and backpropagate
            vi_pred, pi_pred = model(data)
            loss = loss_fn(vi_pred, pi_pred, values, policies, weights)
            optimizer.zero_grad()
            loss.backward()
            
            running_loss += loss.item()
            bar.set_description(f'loss: {running_loss / (step + 1):4f}, {loss.item()}')
            optimizer.step()


In [282]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

model = ValuePolicyNetwork().to(device)

print(device)

cuda:0


In [337]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-4, weight_decay=2e-4)
loss_fn = custom_loss


for global_epoch in range(100):
    samples = generate_samples(20, 1000)
    # Shake the samples
    random.shuffle(samples)
    train(samples, model, 1, optimizer, loss_fn, device)
    torch.save(model.state_dict(), 'model.pth')

loss: 0.503194, 0.5405396223068237: 100%|██████████| 15000/15000 [03:31<00:00, 70.98it/s]   


## Test greedy algorithm

In [349]:
steps = [1, 9, 7]
cube = Cube()
for step in steps:
    cube.move(step)

i = 0

solution = []
while not cube.is_solved():
    corners, edges = cube.get_state()
    cube_representation = torch.concat([corners, edges], dim=0)
    cube_representation_encoded = F.one_hot(cube_representation, num_classes=24).float().to(device)
    with torch.no_grad():
        v, p = model(cube_representation_encoded.unsqueeze(0))
    print(v.item())
    # print(p.detach().cpu().numpy())
    action = torch.argmax(p).item()
    print(action)
    solution.append(action)
    cube.move(action)
    i += 1
    
    if i >= 10:
        break
    
if cube.is_solved():
    print("Solved!")
    print(solution)
else:
    print("Not Solved :(")
    # print(solution)

481.9053649902344
0
482.7429504394531
6
484.2241516113281
8
Solved!
[6, 8, 0]
