In [None]:
from ortools.sat.python import cp_model

In [None]:
class SolutionCollector(cp_model.CpSolverSolutionCallback):

    def __init__(self, variables):
        
        super().__init__()
        
        self.__variables = variables
        self.solutions = []
        self.energies = []

    def on_solution_callback(self):
        
        # print([v.Name() for v in variables])
        # print(list(map(self.Value, variables)))
        
        solution = list(map(self.Value, self.__variables))        
        energy = self.ObjectiveValue()
    
        self.solutions.append(solution)
        self.energies.append(energy)

In [None]:
class ORToolsSampler(dimod.core.sampler.Sampler):
    
    properties = None
    parameters = None
    
    def __init__(self):
        
        self.properties = {}
        self.parameters = {}

    def sample_qubo(self, Q, 
                    verbose=False,
                    time_limit=10.0,
                    num_search_workers=16, **kwargs):
        
        model, variables = self.get_model_from_qubo(Q)
        
        collector = SolutionCollector(variables)
        
        print(model.ModelStats()) if verbose else None
        
        samples = self.solve_model(model, collector, verbose,
                                   num_search_workers, time_limit)
        
        return samples
        
        
    def get_model_from_qubo(self, qubo_coefficients):
        
        model = cp_model.CpModel()
        model.SetName('QUBO CP-SAT Model')
        
        variables_count = max(index for key in qubo_coefficients.keys()
                              for index in key) + 1

        variable_indices = list(range(variables_count))

        variables = [model.NewBoolVar(f'linear_{index}') 
                     for index in variable_indices]

        objective_variables = []
        objective_coefficients = []

        for (i, j), coefficient in qubo_coefficients.items():

            if coefficient == 0.0:
                continue

            if i == j:

                variable = variables[i]

            else:

                variable = model.NewBoolVar(f'quadratic_{i}_{j}')

                model.AddBoolOr([variables[i].Not(), 
                                 variables[j].Not(), 
                                 variable])
                model.AddImplication(variable, variables[i])
                model.AddImplication(variable, variables[j])        

            objective_variables.append(variable)
            objective_coefficients.append(coefficient)


        model.Minimize(sum(objective_variables[i] * objective_coefficients[i]
                           for i in range(len(objective_variables))))
        
        return model, variables
        
        
    def solve_model(self, model, collector, verbose,
                    num_search_workers, time_limit):

        solver = cp_model.CpSolver()
        
        solver.parameters.num_search_workers = num_search_workers
        solver.parameters.log_search_progress = verbose
        solver.parameters.max_time_in_seconds = time_limit

        status = solver.Solve(model, collector)

        # Statistics
        
        if verbose:

            print()
            print('Statistics:')
            print(f'  status   : {solver.StatusName(status)}')
            print(f'  conflicts: {solver.NumConflicts()}')
            print(f'  branches : {solver.NumBranches()}')
            print(f'  wall time: {solver.WallTime()} s')
        
        solutions = np.array(collector.solutions)
        
        samples = dimod.SampleSet.from_samples(solutions,
                                               energy=collector.energies,
                                               vartype='BINARY')
        return samples

In [None]:
class OneHotRandomSampler(dimod.core.sampler.Sampler):
    
    properties = None
    parameters = None
    
    def __init__(self):
        
        self.properties = {}
        self.parameters = {}
        
        self.random_sampler = np.random.default_rng()
        self.sampler = dimod.IdentitySampler()
        

    def sample_qubo(self, Q, num_reads, one_hot_keys=None, **kwargs):
        
        DEFAULT_ONE_HOT_KEYS = ((0, 1), (1, 0))
        
        one_hot_keys = one_hot_keys or DEFAULT_ONE_HOT_KEYS
        
        one_hot_bits = len(one_hot_keys[0])
        
        variables_count = max(index for key in Q.keys()
                              for index in key) + 1 

        bits_count = variables_count // one_hot_bits
        
        sample_pairs = self.random_sampler.choice(one_hot_keys, size=(num_reads, 
                                                                      bits_count))
        initial_states = sample_pairs.reshape(num_reads, -1)       

        samples = self.sampler.sample_qubo(Q, initial_states=initial_states)
        
        return samples