In [1]:

import os
import time
import pickle
import numpy as np
import torch
from tqdm import tqdm

from src.torch_erg import load_pglib_opf as lp
from src.torch_erg.samplers import (
    MHSampler_Hard,
    GWGSampler,
    DLMC_Sampler,
    DLP_Sampler,
)

#############################################################
#  Observables
#############################################################

def basic_observables(mtx: torch.Tensor) -> torch.Tensor:
    edges = torch.sum(mtx) / 2
    triangles = torch.trace(mtx @ mtx @ mtx) / 6
    return torch.stack([edges, triangles])

In [2]:
ordmat, ordlist, buslist, countlist = lp.pow_parser("30_ieee")
graph = torch.tensor(ordmat, dtype=torch.float32)

# Initial parameters for EE
#betas0 = torch.tensor([0.0, 0.0], dtype=torch.float32)
betas0 = torch.tensor([-2., -0.1], dtype=torch.float32)
obs0 = basic_observables(graph)


In [3]:
class S(DLP_Sampler):
    def __init__(self, backend): super().__init__(backend, stepsize_alpha=0.3)
    def observables(self, mtx): return basic_observables(mtx)
mtx = torch.tensor(ordmat, dtype=torch.float32)

In [9]:
sampler = S(backend="cpu")

mtx, acc = sampler.proposal(torch.tensor(ordmat, dtype=torch.float32), obs0, betas0, torch.dot(betas0,obs0))

#mtx

tensor([0.8972, 0.5000, 0.5000, 0.5000, 0.5000, 0.8972, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.8972, 0.5000, 0.5000, 0.5000, 0.5000, 0.8995, 0.8995,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.8972, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.8995, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.8995, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 