In [1]:
class OneHotQubo:
    
    def __init__(self, trainer,
                 one_hot_scheme=None, maximize=True,
                 penalty_value=1000):
        
        default_one_hot_scheme = {0: (0, 1), 1: (1, 0)}
        
        self.one_hot_scheme = one_hot_scheme or default_one_hot_scheme 
        
        self.one_hot_decoding_scheme = {value: key for key, value in 
                                        self.one_hot_scheme.items()}
        
        self.trainer = trainer
        self.maximize = maximize
        self.penalty_value = penalty_value
        
        
        # Load weights
        
        state_dict = trainer.model.state_dict()

        factorization_matrix = state_dict['embedding.embedding.weight'].numpy().squeeze()
        
        
        # Swap factorization matrix
        
        swapped_factorization_matrix = np.zeros_like(factorization_matrix)
        
        swapped_factorization_matrix[::2] = factorization_matrix[1::2]
        swapped_factorization_matrix[1::2] = factorization_matrix[::2]
        
                     
        self.factorization_matrix = swapped_factorization_matrix
        
        self.one_hot_penalty = self.get_one_hot_penalty()
        
        self.qubo_coefficients = self.get_qubo_coefficients()

        # print("self.one_hot_penalty:", self.one_hot_penalty)
        # print("self.qubo_coefficients", self.qubo_coefficients)
        

    def get_one_hot_penalty(self):

        one_hot_pair_penalty = self.penalty_value * np.array([[-1,  2],
                                                              [ 2, -1]])
        
        bits_count = len(self.factorization_matrix) // 2

        one_hot_penalty_diagonal = np.ones(bits_count)

        one_hot_penalty_matrix = np.diag(one_hot_penalty_diagonal)

        one_hot_penalty = np.kron(one_hot_penalty_matrix, one_hot_pair_penalty)


        # print("one_hot_pair_penalty:\n", one_hot_pair_penalty)
        # print("one_hot_penalty_diagonal:\n", one_hot_penalty_diagonal)
        # print("one_hot_penalty_matrix:\n", one_hot_penalty_matrix)
        # print("one_hot_penalty:\n", one_hot_penalty, one_hot_penalty.shape)
        
        return one_hot_penalty
    
    
    def get_qubo_coefficients(self):
        
        coupling_matrix = self.factorization_matrix @ self.factorization_matrix.T
        
        np.fill_diagonal(coupling_matrix, 0)

        if self.maximize:

            coupling_matrix = -coupling_matrix

        penalized_coupling_matrix = coupling_matrix + self.one_hot_penalty

        qubo_coefficients = self.get_coupling_coefficients(penalized_coupling_matrix)
        
        self.coupling_matrix = coupling_matrix
        
        self.penalized_coupling_matrix = penalized_coupling_matrix

#         print("self.factorization_matrix:\n", self.factorization_matrix.shape)
#         print("coupling_matrix:\n", coupling_matrix.shape)
#         print("penalized_coupling_matrix part:\n", penalized_coupling_matrix[-3:, -3:])
#         print("qubo_coefficients part:\n", qubo_coefficients.__repr__()[-60:])
        
        return qubo_coefficients


    def bits_to_one_hot(self, bits):
        
        one_hot_bits = [self.one_hot_scheme[bit] for bit in bits]

        one_hot_array = np.array(one_hot_bits, dtype='i1').ravel()

        return one_hot_array
    
    
    def one_hot_to_bits(self, one_hot):    

        bits = []

        for one_hot_pair in one_hot.reshape(-1, 2):

            bit = self.one_hot_decoding_scheme[tuple(one_hot_pair)]

            bits.append(bit)

        bits_array = np.array(bits, dtype='i1')    

        return bits_array
    
    
    def bit_to_spin(self, binary):    
        return binary * 2 - 1

    def spin_to_bit(self, spin):    
        return (spin + 1) // 2
    
    
    def get_coupling_coefficients(self, coupling_matrix):

        coupling_coefficients = dict()

        for row_index, row in enumerate(coupling_matrix):

            for column_index, cell in enumerate(row[row_index:], row_index):

                coupling_coefficients[row_index, column_index] = cell

        return coupling_coefficients
    
    
    def is_correct_one_hot(self, record):

        sample = record.sample

        fields_spins = np.array(list(sample.values()))

        fields_one_hot = self.spin_to_bit(fields_spins)

        for one_hot_pair in fields_one_hot.reshape(-1, 2):

            one_hot_tuple = tuple(one_hot_pair)

            if one_hot_tuple not in self.one_hot_scheme.values():            

                return False

        return True
    
    
    def get_coupling_matrix(qubo_coefficients):

        coupling_matrix_size = max(index for key in qubo_coefficients.keys() 
                                   for index in key) + 1

        coupling_matrix = np.zeros((coupling_matrix_size, coupling_matrix_size))

        for key, weight in qubo_coefficients.items():

            coupling_matrix[key] = weight

        return coupling_matrix

In [None]:
class OneColdQubo:
    
    def __init__(self, trainer, maximize=True):
               
        self.trainer = trainer
        self.maximize = maximize
        
        state_dict = trainer.model.state_dict()

        factorization_matrix = state_dict['weights'].numpy()
        
        coupling_matrix = factorization_matrix @ factorization_matrix.T
        
        np.fill_diagonal(coupling_matrix, 0)
        
        # TODO: Add Linear Weights

        if self.maximize:

            coupling_matrix = -coupling_matrix

        qubo_coefficients = self.get_coupling_coefficients(coupling_matrix)
        
        self.factorization_matrix = factorization_matrix
        self.coupling_matrix = coupling_matrix
        self.qubo_coefficients = qubo_coefficients
        
        # print("factorization_matrix:", factorization_matrix.shape, factorization_matrix)
        # print("coupling_matrix", coupling_matrix.shape, coupling_matrix)
        # print("qubo_coefficients", qubo_coefficients)        
          
    
    def bit_to_spin(self, binary):    
        return binary * 2 - 1

    def spin_to_bit(self, spin):    
        return (spin + 1) // 2
    
    
    def get_coupling_coefficients(self, coupling_matrix):

        coupling_coefficients = dict()

        for row_index, row in enumerate(coupling_matrix):

            for column_index, cell in enumerate(row[row_index:], row_index):

                coupling_coefficients[row_index, column_index] = cell

        return coupling_coefficients