In [2]:
from torch_geometric.data import Data
import torch_geometric.nn as gnn
import torch
import torch.nn as nn
import torch_geometric
import networkx as nx
import pyvis
import matplotlib.pyplot as plt
import numpy as np

from pathlib import Path
from typing import Protocol, cast
import enum
import dataclasses
import pickle
import random
import json
import uuid
import logging

In [3]:
@dataclasses.dataclass
class Hyperparameters:
    batch_size: int = 64
    runs_per_episode: int = 64
    epochs: int = 10
    learning_rate: float = 1e-5
    eps_clip: float = 0.2
    entropy_coef: float = 0.01
    weight_decay: float = 1e-3
    value_weight: float = 0.5
    policy_weight: float = 1.0
    gae_gamma: float = 0.95
    gae_lambda: float = 0.8
    penalty_per_conflict: float = 5e-5
    temperature: float = 4.0


HP = Hyperparameters()


In [4]:
logger = logging.getLogger("notebook")
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)
np.seterr(all='raise')
DEV = "cuda"
torch.set_float32_matmul_precision("medium")

@dataclasses.dataclass
class Counters:
    episodes: int = 0 
    epochs: int = 0
    runs: int = 0
    steps: int = 0
    batches: int = 0
    train_steps: int = 0

    def from_dict(self, d):
        for k, v in d.items():
            setattr(self, k, v)

COUNTERS = Counters()

In [5]:
class ExpectedValueNormalizationLogits(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits: torch.Tensor, ex: torch.Tensor):
        ex = torch.as_tensor(ex)

        b = torch.zeros(logits.shape[:-1], device=logits.device)

        for _ in range(100):
            normalized = torch.sigmoid(logits + b.unsqueeze(-1))
            f_gamma = normalized.sum(dim=-1) - ex
            f_prime_gamma = (normalized * (1 - normalized)).sum(dim=-1)
            diff = torch.clamp(f_gamma / f_prime_gamma, -2, 2)
            if torch.all(diff.abs() < 1e-6):
                break
            b = b - diff

        normalized = torch.sigmoid(logits + b.unsqueeze(-1))
        ctx.save_for_backward(normalized)
        return normalized

    @staticmethod
    def backward(ctx, g):
        normalized, = ctx.saved_tensors
        p_grad = normalized * (1 - normalized)
        denom = p_grad.sum(dim=-1)
        coordwise = p_grad * g

        grad = coordwise - p_grad * coordwise.sum(axis=-1).unsqueeze(-1) / denom.unsqueeze(-1)

        return grad, None


probs = torch.tensor([
    [0.999, 0.5, 0.5, 0.5, 0.1],
    [0.3, 0.5, 0.5, 0.8, 0.2],
], requires_grad=True)
x = -(1 / probs - 1).log()
y = ExpectedValueNormalizationLogits.apply(x, torch.tensor([2.0, 1.0]))
# print(x, y, y.sum(axis=-1), sep="\n")
y.sum().backward()
# print(probs.grad)

optim = torch.optim.SGD([probs], lr=0.1)
for _ in range(100):
    optim.zero_grad()
    x = -(1 / probs - 1).log()
    y = ExpectedValueNormalizationLogits.apply(x, torch.tensor([2.0, 1.0]))
    loss = y.pow(3.0).sum()
    loss.backward()
    optim.step()
    # print(probs)

In [6]:
class GraphProblem(Protocol):
    global_data: torch.Tensor
    x: torch.Tensor
    edge_index: torch.Tensor
    edge_attr: torch.Tensor
    reducible: torch.Tensor

In [7]:
def compute_returns_advantages(rewards: list[float], values: list[float]) -> tuple[list[float], list[float]]:
    returns = []
    advantages = []
    gae = 0
    for i in reversed(range(len(rewards))):
        delta = rewards[i] + HP.gae_gamma * (values[i + 1] if i + 1 < len(values) else 0) - values[i]
        gae = delta + HP.gae_gamma * HP.gae_lambda * gae
        returns.append(gae + values[i])
        advantages.append(gae)
    return returns[::-1], advantages[::-1]

In [8]:
@dataclasses.dataclass
class EpisodeResult:
    states: list[GraphProblem]
    dists: list[torch.distributions.Bernoulli]
    actions: list[torch.Tensor]
    rewards: list[float]
    values: list[float]
    returns: list[float]
    advantages: list[float]

    stats: list[dict]

    @staticmethod
    def empty() -> 'EpisodeResult':
        return EpisodeResult([], [], [], [], [], [], [], [])

    def merge_with(self, other: 'EpisodeResult') -> 'EpisodeResult':
        assert (
            len(self.states) == len(self.dists) == len(self.actions) ==
            len(self.values) == len(self.rewards) == len(self.returns) ==
            len(self.advantages)
        )
        assert (
            len(other.states) == len(other.dists) == len(other.actions) ==
            len(other.values) == len(other.rewards) == len(other.returns) ==
            len(other.advantages)
        )
        return EpisodeResult(
            self.states + other.states,
            self.dists + other.dists,
            self.actions + other.actions,
            self.rewards + other.rewards,
            self.values + other.values,
            self.returns + other.returns,
            self.advantages + other.advantages,
            self.stats + other.stats,
        )

    @staticmethod
    def merge_all(results: list['EpisodeResult']) -> 'EpisodeResult':
        result = results[0]
        for other in results[1:]:
            result = result.merge_with(other)
        return result

    def add(self, *, state, dist, action, reward, value):
        assert (
            len(self.states) == len(self.dists) == len(self.actions) ==
            len(self.values) == len(self.rewards)
        )
        self.states.append(state)
        self.dists.append(dist)
        self.actions.append(action)
        if reward is not None:
            self.rewards.append(reward)
        self.values.append(value)

    def add_reward(self, reward):
        assert (
            len(self.states) == len(self.dists) == len(self.actions) ==
            len(self.values) == len(self.rewards) + 1
        )
        self.rewards.append(reward)

    def complete(self, stats: dict):
        assert (
            len(self.states) == len(self.dists) == len(self.actions) ==
            len(self.values) == len(self.rewards)
        )
        assert len(self.stats) == 0
        self.returns, self.advantages = compute_returns_advantages(self.rewards, self.values)
        self.stats = [stats]

In [9]:
class Agent:
    def __init__(self, strategy: str):
        self.strategy = strategy

    def act(self, graph: GraphProblem, ex: float):
        values = graph.x[:, 0]
        reducible = graph.reducible
        reducible_values = values[reducible] + torch.randn_like(values[reducible]) * 1e-3
        n = len(values)
        m = len(reducible_values)
        if self.strategy == "uniform":
            return torch.distributions.Bernoulli(torch.full((m,), ex / m)), 0.0
        elif self.strategy == "max":
            return torch.distributions.Bernoulli((reducible_values > torch.quantile(reducible_values, ex / n)).float()), 0.0
        elif self.strategy == "min":
            return torch.distributions.Bernoulli((reducible_values < torch.quantile(reducible_values, 1 - ex / n)).float()), 0.0
        elif self.strategy == "none":
            return torch.distributions.Bernoulli(torch.zeros((m,))), 0.0
        else:
            raise ValueError(f"Unknown strategy {self.strategy}")

    def update(self, results: EpisodeResult, silent: bool = False):
        pass

In [10]:
@dataclasses.dataclass
class GameConfig:
    number_of_bits: int = 4
    number_of_cards_to_add: tuple[int] = (16, 8, 12, 16, 24, 32, 48, 48, 48, 48)
    fraction_to_remove: float = 0.5
    fraction_to_make_reducible: float = 0.2

In [11]:
class ToyEnv:
    def __init__(self, game_cfg: GameConfig, runs_per_episode: int = 16):
        self.runs_per_episode = runs_per_episode
        self.game_cfg = game_cfg

    def run_instance(self, agent: Agent):
        COUNTERS.runs += 1
        result = EpisodeResult.empty()

        max_card = 2 ** self.game_cfg.number_of_bits
        deck = np.random.randint(0, max_card, size=self.game_cfg.number_of_cards_to_add[0])
        reducible = np.ones_like(deck, dtype=bool)
        total_reward = 0

        for step in range(len(self.game_cfg.number_of_cards_to_add) - 1):
            COUNTERS.steps += 1
            logger.info("Running instance step %d", step)

            problem = Data(
                global_data=torch.tensor([
                    len(deck) / sum(self.game_cfg.number_of_cards_to_add),
                ], dtype=torch.float32),
                x=torch.tensor(np.array([
                    deck / (max_card - 1),
                    *[np.bitwise_and(deck, 1 << i) >> i for i in range(self.game_cfg.number_of_bits)],
                ]), dtype=torch.float32).permute(1, 0),
                edge_index=torch.tensor(np.array([
                    np.arange(len(deck)),
                    np.arange(1, len(deck) + 1) % len(deck),
                ]), dtype=torch.long),
                edge_attr=torch.zeros((len(deck), 0), dtype=torch.float32),
                reducible=torch.tensor(reducible, dtype=torch.bool),
            )

            ex = self.game_cfg.fraction_to_remove * sum(reducible)
            dist, value = agent.act(problem, ex)
            action = dist.sample()

            indices = np.arange(len(deck))[reducible][action.cpu().numpy() == 1]
            deck = np.delete(deck, indices)
            reducible = np.delete(reducible, indices)

            reducible = np.random.rand(*reducible.shape) < self.game_cfg.fraction_to_make_reducible

            reward = deck.sum() + np.bitwise_xor(deck[1:], deck[:-1]).sum() + np.bitwise_xor(deck[-1], deck[0])
            reward /= max_card * sum(self.game_cfg.number_of_cards_to_add)
            total_reward += reward

            result.add(state=(problem, ex), dist=dist, action=action, reward=reward, value=value)

            deck = list(deck)
            reducible = list(reducible)

            for _ in range(self.game_cfg.number_of_cards_to_add[step + 1]):
                card = random.randrange(max_card)
                index = random.randrange(len(deck) + 1) 
                deck.insert(index, card)
                reducible.insert(index, True)

            deck = np.array(deck)
            reducible = np.array(reducible)

            logger.info("Finished instance step %d", step)

        result.complete({"total_reward": total_reward})
        return result

    def run_episode(self, agent: Agent) -> EpisodeResult:
        results = []
        for _ in range(self.runs_per_episode):
            results.append(self.run_instance(agent))

        return EpisodeResult.merge_all(results)

In [16]:
logging.basicConfig(level=logging.WARNING)
logger.setLevel(logging.WARNING)

game_cfg = GameConfig()

agent = Agent(strategy="uniform")

env = ToyEnv(
    game_cfg=game_cfg,
    runs_per_episode=HP.runs_per_episode,
)

while True:
    print(f"Episode {COUNTERS.episodes}")
    results = env.run_episode(agent)
    logger.info("Finished episode, starting training")
    agent.update(results)

    print(f"Rewards: {sum(results.rewards)}")
    COUNTERS.episodes += 1
    del results

Episode 247
Rewards: 76.5054166666666
Episode 248
Rewards: 75.53895833333331
Episode 249
Rewards: 75.85958333333345
Episode 250
Rewards: 75.59041666666666
Episode 251
Rewards: 74.77791666666674
Episode 252
Rewards: 75.49187500000004
Episode 253
Rewards: 75.82125
Episode 254
Rewards: 74.77145833333331
Episode 255
Rewards: 75.58125000000004
Episode 256
Rewards: 74.77395833333331
Episode 257
Rewards: 77.010625
Episode 258
Rewards: 75.09770833333339
Episode 259
Rewards: 77.62312499999992
Episode 260
Rewards: 76.50333333333329
Episode 261
Rewards: 76.58354166666676
Episode 262
Rewards: 76.91020833333327
Episode 263
Rewards: 74.75062500000006
Episode 264
Rewards: 76.25249999999987
Episode 265
Rewards: 75.76166666666654
Episode 266
Rewards: 75.26458333333336
Episode 267
Rewards: 74.69291666666668
Episode 268
Rewards: 74.3166666666666
Episode 269


KeyboardInterrupt: 