<img src="./SGLD_algo.png" width="800" height="800"/>

SGLD optimizer updates parameters like this:
$$\Delta w_t = \frac{\epsilon}{2}\left(\frac{\beta n}{m} \sum_{i=1}^m \nabla \log p\left(y_{l_i} \mid x_{l_i}, w_t\right)+\gamma\left(w_0-w_t\right) - \lambda w_t\right) + N(0, \epsilon\sigma^2)$$

In [18]:
a = torch.tensor([[1,0],[2,1]])
b = torch.tensor([[0,1],[1,0]])

print(torch.matmul(a,a))

tensor([[1, 0],
        [4, 1]])


In [8]:
from devinterp.optim.sgld import SGLD
from devinterp.slt.callback import SamplerCallback
from typing import Callable, List, Union

import torch

In [7]:
class BIFEstimator(SamplerCallback):
    def __init__(
        self,
        num_chains: int,
        num_draws: int,

        num_data: int,
        num_obs: int,
        
        init_loss: torch.Tensor,
        device: Union[torch.device, str] = "cpu",
        eval_field: List[str] = ["loss", "obs"]
        nbeta: float = None,
        temperature: float = None,
    ):
        self.num_chains = num_chains
        self.num_draws = num_draws
        self.num_data = num_data
        self.num_obs = num_obs
        
        self.losses = torch.zeros((num_data, num_chains*num_draws), dtype=torch.float32).to(
            device
        )
        self.observables = torch.zeros((num_obs, num_chains*num_draws), dtype=torch.float32).to(
            device
        )
        self.init_loss = init_loss

        assert nbeta is not None or temperature is not None, (
            "Please provide a value for nbeta."
        )
        if nbeta is None and temperature is not None:
            nbeta = temperature
            warnings.warn("Temperature is deprecated. Please use nbeta instead.")
            
        self.nbeta = torch.tensor(nbeta, dtype=torch.float32).to(device)
        self.temperature = temperature

        self.device = device
        self.eval_field = eval_field

    def update(self, chain: int, draw: int, loss_vec: torch.tensor, obs_vec: torch.tensor):
        if torch.isnan(loss).any():
            raise RuntimeError(f"NaN detected in loss at chain {chain}, draw {draw}")

        col = (chain - 1) * self.num_chains + draw
        self.losses[:, col] = loss_vec.to(self.device)
        self.observables[:, col] = obs_vec.to(self.device)


    def get_results(self):
        """
        :returns: A dict :python:`
        {"llc/mean": llc_mean, "llc/std": llc_std, "llc-chain/{i}": llc_trace_per_chain, "loss/trace": loss_trace_per_chain}`. 
        (Only after running :python:`devinterp.slt.sampler.sample(..., [llc_estimator_instance], ...)`).
        """

        init_loss = (
            self.init_loss.item()
            if isinstance(self.init_loss, torch.Tensor)
            else self.init_loss
        )


        # calculating BIF matrix where BIF_ij = BIF(z_i, phi_j)
        # returns matrix of (num_data x num_obs)
        CT = self.num_chains * self.num_draws
        multiplier_mat = (torch.eye(CT) - torch.full((CT,CT), 1/CT)).to(self.device)
        multiplier_mat = torch.matmul(multiplier_mat, multiplier_mat.T)
        BIF = 1/(CT - 1) * torch.matmul(torch.matmul(self.losses, multiplier_mat), self.observables.T)

        return {
            "init_loss": init_loss,
            "BIF": BIF
            **{
                f"llc-chain/{i}": self.llc_per_chain[i].cpu().numpy().item()
                for i in range(self.num_chains)
            },
            "loss/trace": self.losses.cpu().numpy(),
            "obs/trace": self.observables.cpu().numpy(),
        }

    def __call__(self, chain: int, draw: int, **kwargs):
        # eval_field = ["loss", "obs"]
        self.update(chain, draw, kwargs[self.eval_field[0]], kwargs[self.eval_field[1]])

    def finalize(self):
        if os.environ.get("USE_SPMD", "0") == "1" and not str(self.device).startswith(
            "cpu:"
        ):
            if str(self.device).startswith("cuda") and torch.cuda.device_count() > 1:
                if torch.distributed.is_initialized():
                    torch.distributed.barrier()
                    torch.distributed.all_reduce(
                        self.losses, op=torch.distributed.ReduceOp.AVG
                    )
            else:
                pass

        elif USE_TPU_BACKEND and str(self.device).startswith("xla:"):
            import torch_xla.core.xla_model as xm

            if TPU_TYPE == "v4":
                self.losses = xm.all_reduce(xm.REDUCE_SUM, self.losses)
            elif TPU_TYPE == "v2/v3":
                self.losses = self.losses.cpu()
                if torch.distributed.is_initialized():
                    torch.distributed.all_reduce(self.losses)
                else:
                    warnings.warn(
                        "torch.distributed has not been initialized. If running on TPU v2/v3, and you want to run chains in parallel, you need to initialize torch.distributed after calling xmp.spawn() as follows:"
                        ">>> import torch_xla.runtime as xr"
                        ">>> store = torch.distributed.TCPStore('127.0.0.1', 12345, 4, xr.global_ordinal() == 0)"
                        ">>> torch.distributed.init_process_group(backend='gloo', store=store, rank=xr.global_ordinal()//2, world_size=xr.world_size()//2)"
                    )

            else:
                raise NotImplementedError(f"TPU type {TPU_TYPE} not supported")
        elif str(
            self.device
        ).startswith(
            "cuda"
        ):  # if we've ran on multi-GPU, we should do a reduce as well. see above for how this would work
            try:
                torch.distributed.all_reduce(self.losses)
            except ValueError:
                pass
        avg_losses = self.losses.mean(axis=1)
        # bypass automatic bfloat16 issues
        if os.environ.get("XLA_USE_BF16", "0") == "1" and str(self.device).startswith(
            "xla:"
        ):
            self.llc_per_chain = self.nbeta.to(device="cpu", dtype=torch.float32) * (
                avg_losses.to(device="cpu", dtype=torch.float32)
                - self.init_loss.to(device="cpu", dtype=torch.float32)
            )
        elif (
            str(self.device).startswith("cuda")
            and os.environ.get("USE_SPMD", "0") == "1"
        ):
            self.llc_per_chain = self.nbeta.to(device="cpu", dtype=torch.float32) * (
                avg_losses.to(device="cpu", dtype=torch.float32)
                - self.init_loss.to(device="cpu", dtype=torch.float32)
            )
        else:
            self.llc_per_chain = self.nbeta * (avg_losses - self.init_loss)
        
        self.llc_mean = self.llc_per_chain.mean()
        self.llc_std = self.llc_per_chain.std()
        

In [1]:
model = None

In [4]:
a = {"one":1}
print(a.setdefault("two",6))

6
