In [142]:
import numpy as np
import torch
import torch.nn as nn
from RubikCube.src.env import *
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_

class ValuePolicyNetwork(nn.Module):
    def __init__(self):
        super(ValuePolicyNetwork, self).__init__()

        # Define shared layers
        self.shared_layers = nn.Sequential(
            nn.Linear(480, 4096),
            nn.LayerNorm(4096),
            nn.ELU(),
            nn.Linear(4096, 2048),
            nn.LayerNorm(2048),
            nn.ELU()
        )

        # Define value head
        self.value_head = nn.Sequential(
            nn.Linear(2048, 512),  # 2048 -> 512
            nn.LayerNorm(512),
            nn.ELU(),
            nn.Linear(512, 1)  # 512 -> 1 (scalar value)
        )

        # Define policy head
        self.policy_head = nn.Sequential(
            nn.Linear(2048, 512),  # 2048 -> 512
            nn.LayerNorm(512),
            nn.ELU(),
            nn.Linear(512, 12)  # 512 -> 12 (policy logits)
        )

        # Apply Glorot initialization
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            xavier_uniform_(m.weight)  # Glorot initialization for weights
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x):
        # Transform input
        b = x.shape[0]
        x = x.view(b, -1)
        shared_out = self.shared_layers(x)
        value_out = self.value_head(shared_out)
        policy_out = self.policy_head(shared_out)
        return value_out, policy_out


In [143]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

# Assuming ValuePolicyNetwork is defined
model1 = ValuePolicyNetwork()
model2 = ValuePolicyNetwork()
model3 = ValuePolicyNetwork()
model4 = ValuePolicyNetwork()
model_combined = ValuePolicyNetwork()

# Load the state dictionaries
model1.load_state_dict(torch.load('model1.pth'))
model2.load_state_dict(torch.load('model2.pth'))
model3.load_state_dict(torch.load('model3.pth'))
model4.load_state_dict(torch.load('model4.pth'))

# Combine the weights by averaging
state_dict1 = model1.state_dict()
state_dict2 = model2.state_dict()
state_dict3 = model3.state_dict()

combined_state_dict = {}
for key in state_dict1.keys():
    combined_state_dict[key] = (
                                       state_dict1[key] + state_dict2[key] + state_dict3[key]
                               ) / 3

# Load the averaged weights into the combined model
model_combined.load_state_dict(combined_state_dict)

# Set the model to evaluation mode if needed
model1.eval()
model2.eval()
model3.eval()
model4.eval()
model_combined.eval()

model1.to(device)
model2.to(device)
model3.to(device)
model4.to(device)
model_combined.to(device)


  model1.load_state_dict(torch.load('model1.pth'))
  model2.load_state_dict(torch.load('model2.pth'))
  model3.load_state_dict(torch.load('model3.pth'))
  model4.load_state_dict(torch.load('model4.pth'))


ValuePolicyNetwork(
  (shared_layers): Sequential(
    (0): Linear(in_features=480, out_features=4096, bias=True)
    (1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=4096, out_features=2048, bias=True)
    (4): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (5): ELU(alpha=1.0)
  )
  (value_head): Sequential(
    (0): Linear(in_features=2048, out_features=512, bias=True)
    (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=512, out_features=1, bias=True)
  )
  (policy_head): Sequential(
    (0): Linear(in_features=2048, out_features=512, bias=True)
    (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=512, out_features=12, bias=True)
  )
)

In [144]:
class MetaModel(nn.Module):
    def __init__(self, models):
        super(MetaModel, self).__init__()
        self.models = models
        self.shared_layers = nn.Sequential(
            nn.Linear(len(self.models) * 13, 2048),
            nn.ELU(),
            nn.Linear(2048, 4096),
            nn.LayerNorm(4096),
            nn.ELU()
        )
        self.value_head = nn.Sequential(
            nn.Linear(4096, 512),  # 2048 -> 512
            nn.LayerNorm(512),
            nn.ELU(),
            nn.Linear(512, 1)  # 512 -> 1 (scalar value)
        )

        # Define policy head
        self.policy_head = nn.Sequential(
            nn.Linear(4096, 512),  # 2048 -> 512
            nn.LayerNorm(512),
            nn.ELU(),
            nn.Linear(512, 12)  # 512 -> 12 (policy logits)
        )
    def forward(self, x):
        # Transform input
        b = x.shape[0]
        x = x.view(b, -1)
        
        vs, ps = [], []
        for model in self.models:
            model.eval()
            with torch.no_grad():
                v, p = model(x)
            vs.append(v)
            ps.append(p)
        vs_tensor = torch.stack(vs)
        ps_tensor = torch.stack(ps)
        x = torch.cat((vs_tensor, ps_tensor), dim=-1).view(-1).to(device)
        shared_out = self.shared_layers(x)
        value_out = self.value_head(shared_out)
        policy_out = self.policy_head(shared_out)
        return value_out, policy_out

In [156]:
import torch
import torch.nn as nn

class MeanVectorModel(nn.Module):
    def __init__(self, models):
        super(MeanVectorModel, self).__init__()
        self.models = models

    def forward(self, x):
        vs, ps = [], []
        for model in self.models:
            model.eval()
            with torch.no_grad():
                v, p = model(x)
                vs.append(v)
                ps.append(p.softmax(dim=-1))

        # Compute the mean vectors for v and p
        v = torch.stack(vs, dim=0).mean(dim=0)
        p = torch.stack(ps, dim=0).mean(dim=0)

        return v, p

In [146]:
from copy import deepcopy
import random

# Assuming Cube class and move definitions are already provided

def generate_samples(k: int, l: int):
    samples = []
    for _ in range(l):
        cube = Cube()
        actions = []
        # Make random moves up to depth k
        for _ in range(k):
            appended = False
            move_index = 0
            while not appended:
                move_index = random.randint(0, 11)
                if len(actions) == 0:
                    appended = True
                    continue
                if move_index%2 == 0 and actions[-1] == move_index + 1:
                    continue
                if move_index%2 == 1 and actions[-1] == move_index - 1:
                    continue
                appended = True
            cube.move(move_index)
            actions.append(move_index)
            state = (deepcopy(cube.get_state()), deepcopy(actions))
            samples.append(state)

    # Transform samples to dictionary format
    samples_dict = []
    for state, actions in samples:
        sample_dict = {
            "state": [state[0], state[1]],
            "actions": actions
        }
        samples_dict.append(sample_dict)

    return samples_dict


def reward(cube: Cube, action):
    new_cube = deepcopy(cube)
    new_cube.move(action)
    if new_cube.is_solved():
        return 1
    else:
        return -1


def custom_loss(y_vi_pred, y_pi_pred, y_vi, y_pi, weight, alpha=1):
    # Compute per-sample losses for value and policy
    loss_v = nn.MSELoss(reduction='none')(y_vi_pred, y_vi)  # Shape: (batch_size,)
    loss_p = nn.CrossEntropyLoss(reduction='none')(y_pi_pred, y_pi)  # Shape: (batch_size,)

    # Apply weights to the per-sample losses
    weighted_loss_v = (loss_v * weight).mean()
    weighted_loss_p = (loss_p * weight).mean()

    # Combine losses
    return weighted_loss_v + alpha * weighted_loss_p

In [147]:
from tqdm import tqdm

def train(samples, model, batch_size, epochs, optimizer, loss_fn, device):
    for epoch in range(epochs):
        bar = tqdm(range(0, len(samples), batch_size), desc="Training")
        running_loss = 0
        for step, i in enumerate(bar):
            batch = samples[i:i + batch_size]

            values = torch.zeros(len(batch)).to(device)
            policies = torch.zeros((len(batch), 12)).to(device)
            data = torch.zeros((len(batch), 20, 24)).to(device)
            weights = torch.zeros(len(batch)).to(device)

            for j, sample in enumerate(batch):
                children_vi = torch.zeros(12).to(device)
                children_pi = []

                cube = Cube(sample['state'][0], sample['state'][1])

                for action in range(12):
                    child_cube = deepcopy(cube)
                    r = reward(child_cube, action)

                    child_cube.move(action)
                    corners, edges = child_cube.get_state()
                    cube_representation = torch.concat([corners, edges], dim=0)
                    cube_representation_encoded = F.one_hot(cube_representation, num_classes=24).float().to(device)
                    with torch.no_grad():
                        v, p = model(cube_representation_encoded.unsqueeze(0))

                    children_vi[action] = v + r
                    children_pi.append(p)

                target_vi = torch.max(children_vi)
                target_pi = F.one_hot(torch.argmax(children_vi), num_classes=12).float()

                values[j] = target_vi
                policies[j] = target_pi
                corners, edges = cube.get_state()
                cube_representation = torch.concat([corners, edges], dim=0)
                cube_representation_encoded = F.one_hot(cube_representation, num_classes=24).float().to(device)
                data[j] = cube_representation_encoded

                weights[j] = 1 / len(sample['actions'])

            vi_pred, pi_pred = model(data)
            pi_pred = pi_pred.view(-1, 12)
            loss = loss_fn(vi_pred, pi_pred, values, policies, weights)
            optimizer.zero_grad()
            loss.backward()

            running_loss += loss.item()
            bar.set_description(f'loss: {running_loss / (step + 1):4f}, {loss.item()}')
            optimizer.step()


In [157]:
ensemble_models = [model1, model2, model3]
meta_model = MetaModel(ensemble_models).to(device)
mean_vector_model = MeanVectorModel(ensemble_models).to(device)

In [149]:
from torch.optim import RMSprop, AdamW

optimizer = AdamW(meta_model.parameters(), lr=2e-4, weight_decay=2e-4)
# optimizer = RMSprop(model.parameters(), lr=2e-4)
loss_fn = custom_loss


samples = generate_samples(20, 10)
# Shake the samples
random.shuffle(samples)

batch_size = 1

models = [model1, model2, model3, model_combined]

train(samples, meta_model,batch_size, 1, optimizer, loss_fn, device)

loss: 0.666881, 0.21742305159568787: 100%|██████████| 200/200 [00:15<00:00, 12.96it/s]


In [154]:
import random
from tqdm import tqdm

def benchmark(models: list, device, iterations=1000, max_scrambles=20):
    solved_times = np.array([0] * (len(models)))
    for model in models:
        model.to(device)

    # Add tqdm progress bar for iterations
    for iter in tqdm(range(iterations), desc="Benchmarking Progress"):
        steps = [random.randint(0, 11) for _ in range(random.randint(1, max_scrambles))]
        # All models
        for m_idx, model in enumerate(models):
            cube = Cube()
            for step in steps:
                cube.move(step)

            i = 0
            solution = []
            while not cube.is_solved():
                model.eval()
                corners, edges = cube.get_state()
                cube_representation = torch.concat([corners, edges], dim=0)
                cube_representation_encoded = F.one_hot(cube_representation, num_classes=24).float().to(device)
                with torch.no_grad():
                    v, p = model(cube_representation_encoded.unsqueeze(0))
                action = torch.argmax(p).item()
                solution.append(action)
                cube.move(action)
                i += 1

                if i >= 30:
                    break

            if cube.is_solved():
                solved_times[m_idx] += 1

    return solved_times / iterations

In [158]:
models = [model1, model2, model3, model4, mean_vector_model]#, meta_model]
accuracies = benchmark(models, device)
print(accuracies)

Benchmarking Progress: 100%|██████████| 1000/1000 [02:54<00:00,  5.72it/s]

[0.339 0.343 0.344 0.115 0.353]



