In [None]:
import torch


def generate_solutions(n, batch_size, device='cpu'):
    r"""
    Generates random arrays with inputs {-1, 1}.

    Args:
        n:              length of each array
        batch_size:     number of generated arrays
    """
    return torch.randint(low=0, high=2, size=(batch_size,n), device=device) * 2 - 1


def gauge_problems(s, J) -> None:
    r"""
    Given problems with {1}^n solution,
    turn them into problems with s solutions.

    Before gauge, each matrix from J has solution [1,1,...,1]. If solution
    has -1 at position i, i-th row and column of matrix is multiplied by -1. 
    This transforms problems with solution [1,1,...,1] to problems with given 
    solutions s. 

    Energy of configuration s with coupling matrix J is

    H(s) = -1/2 s^T J s

    Args:
        s:      solutions to Ising's problems
        J:      coupling matrices
    """
    masks = ((-s + 1)//2).bool()
    for i in range(masks.shap[0]):
        J[i, masks[i], ::] *= -1
        J[i, ::, masks[i]] *= -1


def generate_problem(s, m: int, device='cpu'):
    """
    Generation of Ising problem using Wishart planted 
    solutions. 

    Generate Ising's prolems with known solutions s. 
    H(s) = -1/2 s^T J s

    Attributes:
        s:              array of {-1, 1} of length n
        m:              hardness parameter
        batch_size:     number of generated problems
    """
    batch_size, n = s.shape
    W = torch.normal(mean=0, std=1, size=(batch_size, n,m), device=device)
    J = (-1/2) *  W @ torch.transpose(W, 1, 2) 
    J -= torch.diag_embed(torch.diagonal(J, dim1=-2, dim2=-1))
    gauge_problems(s, J)

    return J

In [14]:
n = 100
m = 100
batch_size = 1000
device = 'cpu'


In [15]:
W = torch.normal(mean=0, std=1, size=(batch_size, n,m), device=device)
W

tensor([[[ 0.1818, -0.2128, -0.5857,  ..., -1.2006,  0.7204, -1.5179],
         [-0.5381, -0.0901,  0.9798,  ...,  0.9043, -1.0803,  0.2611],
         [-1.1213,  1.9463,  0.8505,  ...,  1.3700, -1.1342, -1.5091],
         ...,
         [-0.0234, -0.4382, -0.9228,  ...,  1.7390, -0.6246,  0.2231],
         [-0.2651,  0.3002,  1.7054,  ..., -1.4700, -0.3047, -1.6840],
         [-0.1358,  0.1617, -2.3575,  ...,  0.4891,  0.9848,  0.7390]],

        [[ 0.7575,  0.2173,  0.7474,  ..., -1.0344, -0.5404, -0.3037],
         [-0.7029,  1.1119, -1.4521,  ...,  0.2124,  1.2849,  0.1539],
         [ 0.0774, -1.4229, -1.9466,  ..., -0.2085, -1.3353,  0.8630],
         ...,
         [ 0.1956,  0.9682, -0.5685,  ...,  0.8469, -0.4411,  2.0474],
         [-1.1767,  0.4950,  1.4246,  ..., -2.0444, -0.9921, -1.5306],
         [-1.3622, -1.1847,  0.7522,  ..., -1.1981,  0.1977, -1.3593]],

        [[-1.9488, -0.9965,  0.9755,  ...,  0.1735, -0.4670,  2.7776],
         [ 0.9732,  1.0471, -1.6237,  ..., -0