In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
from graphviz import Digraph

In [53]:
class Node:
    def __init__(self, depth, index, parent=None) -> None:
        self.depth = depth
        self.index = index
        self.parent = parent
        self.left = None
        self.right = None
        self.active = True
        self.mean_estimate = 0.0
    
    def subdivide(self):
        self.left = Node(self.depth + 1, 2 * self.index - 1, parent=self)
        self.right = Node(self.depth + 1, 2 * self.index, parent=self)

    def evict(self):
        self.active = False
        self.left = None
        self.right = None

    def contains(self, x):
        left = (self.index-1) / (2**(self.depth-1))
        right = (self.index) / (2 ** (self.depth-1))
        print(left, right)
        return left <= x <= right

    def __repr__(self) -> str:
        return f"Node({self.depth}, {self.index}, {self.active})"

class Tree:

    def __init__(self, max_depth) -> None:
        self.max_depth = max_depth
        self.root = Node(1, 1)
        
        self.active_depths = {max_depth : list(np.arange(1, 2**max_depth))} ### A modifier le type

        if max_depth is not None:
            self.build_full_tree(node=self.root, max_depth=max_depth)

    def build_full_tree(self, node, max_depth):
        """ Build tree of size max_depth recursively with node as root"""
        if node.depth >= max_depth:
            return
        node.subdivide()
        self.build_full_tree(node.left, max_depth)
        self.build_full_tree(node.right, max_depth)
    
    def find_node(self, depth, index):
        node = self.root
        for d in range(1, depth):
            if node is None or node.active==False:
                return None
            bit = ((index - 1) >> (depth - d - 1)) & 1
            node = node.right if bit else node.left
        return node if node and node.depth == depth and node.index == index else None
        
    def evict(self, depth, index):
        node = self.find_node(depth, index)
        if node:
            node.evict()
    
    def collect_active_nodes(self, depth):
        def helper_collect_active_nodes(node, depth):
        # base
            if node is None or node.active == False:
                return []
            if node.depth == depth:
                return [node]
            return helper_collect_active_nodes(node.left, depth) + helper_collect_active_nodes(node.right, depth) 
    
        return helper_collect_active_nodes(self.root, depth)
    
    def count_active_bins(self, depth):
        return len(self.collect_active_nodes(depth))

    def compute_proba(self, node):
        proba = 1.0
        current = node

        while current.depth > min(self.active_depths):
            parent = current.parent
            if not parent :
                break
            # Count active children at that depth
            active_children = [child for child in [parent.left, parent.right] if child and child.active]
            if not active_children :
                return 0.0
            proba *= 1.0 / len(active_children)
            current = parent
        
        # Initial proba at minimum depth
        min_depth_nodes = [n for n in self.active_depths[min(self.active_depths)] if n.active]
        if not min_depth_nodes:
            return 0.0
        proba *= 1.0 / len(min_depth_nodes)
        return proba
    
            
    def update_estimates(self, x_t, y_t):
        def update_node(node):
            if node is None or not node.active:
                return        
            if node.contains(x_t):
                p = self.compute_proba(node)
                if p > 0:
                    node.mean_esimate = (y_t/p)
            update_node(node.left)
            update_node(node.right)
        update_node(self.root)

    def visualize_tree(self, filename="dyadic_tree"):
            dot = Digraph()

            def add_nodes_edges(node):
                if node is None:
                    return
                # Color node based on activity
                color = "green" if node.active else "red"
                label = f"{node.depth}-{node.index}"
                dot.node(name=str(id(node)), label=label, style="filled", fillcolor=color)

                for child in [node.left, node.right]:
                    if child:
                        dot.edge(str(id(node)), str(id(child)))
                        add_nodes_edges(child)

            add_nodes_edges(self.root)
            dot.render(filename, format="png", cleanup=True)
            print(f"Tree image saved as {filename}.png")

In [71]:
list(np.arange(1, 2**5 + 1), int)

TypeError: list expected at most 1 argument, got 2

In [66]:
2**5

32

In [54]:
tree = Tree(5)
tree.find_node(5, 6)
tree.collect_active_nodes(3)
tree.count_active_bins(3)
tree.evict(4, 2)

0
1
0
1
0
0
1


In [55]:
tree.visualize_tree()

Tree image saved as dyadic_tree.png


In [59]:
tree.find_node(5, 7).parent

0
1
1
0


Node(4, 4, True)

In [56]:
tree.update_estimates(0.1, 0.65)

0.0 1.0


TypeError: '>' not supported between instances of 'NoneType' and 'int'

In [None]:
def sample_replay_prob(m, τ_l_m):
    # Schedule replays for all the block
    Replays = np.zeros((m-1, 8**m - 1 ))
    for s in range(τ_l_m + 1, τ_l_m + 8 ** m):
        for d in range(m):
            if s - τ_l_m % (8 ** d) == 0:
                p_s_d = np.sqrt(8 ** d / (s - τ_l_m))
                R_s_d = int(np.random.random() < p_s_d)
                Replays[s, d] = R_s_d
    return Replays

def MDBE(T):
    # Initialization
    l = 1
    t = 1
    m = 1
    
    while t <= T:
        # (★) Block handling
        if t == 1 or t == τ_l_m + 8 ** m: # if ending of the block without significant shift
            m += 1
            B_MASTER = set(range(1, 1 + 8**m)) # Reset MASTER set
            D_t = {m}
            τ_l_m = t

            # Initialize StoreActive dictionary
            StoreActive = {}
            StoreActive[m] = [τ_l_m, τ_l_m + 8**m]
            for d in range(m) :
                StoreActive[d] = []

            B_t = {}
            Replays = sample_replay_prob(m, τ_l_m)

        # Check if there are replays that are beginning in this precise round t
        for d in range(m) :
            if Replays[t, d] == 1 :
                D_t.add(d)
                StoreActive[d] = [t, t + 8**d]
                B_t[d] = set(range(1, 1 + 8**d))
        
        # Check if a replay ends
        for d in D_t:
            if StoreActive[d][1] == t :
                D_t.discard(d)
                B_t[d] = set()

        # Hierarchical sampling
        d_min = min(D_t)
        B_parent = random.choice(B_t[d_min])

        # faudrait creer une classe d'arbre...
        




        # Replay deactivation
        for d in range(m):
            if StoreActive.get((l, d), {}).get(t, None):
                D[l].discard(d)
                B_t[(l, d)] = set()

        # Bin selection
        d0 = min(D[l])
        B_parent = random.choice(list(B_t[(l, d0)]))
        xt = B_parent  # If no children

        children = get_children(B_parent, d0, B_t)
        if children:
            xt = random.choice(children)
            B_parent = xt

        # Bin eviction
        for d in D[l]:
            for B in B_t[(l, d)]:
                for d_ in [d_ for d_ in range(d + 1, m)]:
                    for B_prime in B_t[(l, d_)]:
                        if bin_condition(B_prime, B, StoreActive, l, d):
                            B_t[(l, d_)] -= {B_prime}

        # MASTER set update
        B_MASTER[(l, m)] = B_MASTER[(l, m)].intersection(B_t.get((l, m), set()))
        if not B_MASTER[(l, m)]:
            τ[(l + 1, 0)] = t + 1
            l += 1
            t = τ[(l, 0)]
            m = 0
            continue

        t += 1


In [15]:
np.log(np.power(10, 6))

np.float64(13.815510557964274)