### Load Data

In [None]:
def load_protein_datasets(dataset_pathes, fields_column='protein'):

    datasets = dict()

    for dataset_name, dataset_pathes in dataset_pathes.items():

        dataset = ProteinsDataset(fields_column=fields_column)

        for dataset_path in dataset_pathes:

            dataset.extend(dataset_path)

        datasets[dataset_name] = dataset
        
    print("Datasets loaded:", *datasets.keys())
        
    return datasets

In [None]:
def load_regressors(regressors_path):

    with open(regressors_path, 'rb') as handle:
        regressors = pickle.load(handle)
    
    print("Regressors loaded:", *regressors.keys())
    
    return regressors

### Decode and Enrich

In [None]:
def decode_sampling_results(qubo, dataset, sampling_results):

    results = []

    for sampling_result in sampling_results.data():

        samples = np.array(list(sampling_result.sample.values()))
        
        bits = qubo.one_hot_to_bits(samples)
        
        symbols = dataset.decode(bits)
        
        # decimal = sum(number * 2 ** position for position, number in enumerate(reversed(bits)))

        reversed_indices = np.arange(len(bits))[::-1]

        powers_of_two = 1 << reversed_indices
        
        decimal = bits.dot(powers_of_two)

        result = {'decimal': decimal,
                  'samples': samples,
                  'bits': bits,
                  'symbols': symbols,
                  'qubo_energy': sampling_result.energy,
                  'num_occurrences': sampling_result.num_occurrences}
        
#         if 'X' in symbols:
            
#             # print(result)
            
#             continue
        
        results.append(result)

    results = pd.DataFrame(results)
    
    return results

In [None]:
def one_cold_decode_sampling_results(qubo, dataset, sampling_results):

    results = []

    for sampling_result in sampling_results.data():

        samples = np.array(list(sampling_result.sample.values()))
        
        # bits = qubo.one_hot_to_bits(samples)
        
        bits = samples
        
        symbols = dataset.decode(bits)
        
        # decimal = sum(number * 2 ** position for position, number in enumerate(reversed(bits)))

        reversed_indices = np.arange(len(bits))[::-1]

        powers_of_two = 1 << reversed_indices
        
        decimal = bits.dot(powers_of_two)

        result = {'decimal': decimal,
                  'samples': samples,
                  'bits': bits,
                  'symbols': symbols,
                  'qubo_energy': sampling_result.energy,
                  'num_occurrences': sampling_result.num_occurrences}
        
        # print(symbols)
        
        # if 'X' in symbols:
            
            # print(result)
            
            # continue
        
        results.append(result)

    results = pd.DataFrame(results)
    
    return results

In [None]:
def add_energies(results):
    
    bits = np.vstack(results['bits'])
    
    target_binding_energies = target_regressor.predict(bits)
    
    offtarget_binding_energies = [offtarget_regressor.predict(bits) 
                                  for offtarget_regressor in offtarget_regressors]
    
    results['target_binding_energy'] = target_binding_energies
    results['offtarget_binding_energy'] = np.vstack(offtarget_binding_energies).sum(axis=0)
    results['offtarget_binding_energy'] = results['offtarget_binding_energy'] / len(offtarget_proteins)
    
    
    X_PENALTY = 0.5

    x_penalties = results['symbols'].str.count('X') * X_PENALTY

    results['target_binding_energy'] += x_penalties
    
    print("x_penalties:", x_penalties.to_numpy(), x_penalties.mean())    
    
    
    
    results['binding_energy'] = results['target_binding_energy'] - results['offtarget_binding_energy']

    return results

### Proteins Dataset

In [None]:
class ProteinsDataset(torch.utils.data.Dataset):
    
    # Encoding

    AMINOACIDS = 'ACDEFGHIKLMNPQRSTVWY'

    default_dna_encoding = {'A': (0, 0), 'C': (0, 1), 'G': (1, 0), 'T': (1, 1)}

    default_amino_encoding = {aminoacid: tuple(int(bit) for bit in f"{index:05b}")
                              for index, aminoacid in enumerate(AMINOACIDS)}
    
    default_encodings = {'seq': default_dna_encoding,
                         'protein': default_amino_encoding}
    
    
    def __init__(self, fields_column, target_column='energy', encoding=None):
        
        self.fields_column = fields_column
        self.target_column = target_column
        
        self.encoding = encoding or self.default_encodings[fields_column]
            
        self.decoding = {bits: symbol for symbol, bits in self.encoding.items()}

        self.bits_per_symbol = len(list(self.encoding.values())[0])
               
        self.fields = None
        self.targets = np.array([])        
        self.data = pd.DataFrame()
        
    
    def extend(self, dataset_path):        
        
        data = pd.read_csv(dataset_path, index_col=0)
        
        column_mapping = self.unify_column_names(data.columns)
        
        data.rename(columns=column_mapping, inplace=True)
        
        
        # Fields
        
        field_series = data[self.fields_column]
        
        field_bits_count = len(field_series[0]) * self.bits_per_symbol
        
        if self.fields is None:            
            self.fields = np.array([]).reshape(0, field_bits_count).astype(int)
        
        new_fields = np.vstack(field_series.apply(self.encode))
        new_targets = data[self.target_column].to_numpy()
        
        self.fields = np.vstack([self.fields, new_fields])
        self.targets = np.concatenate([self.targets, new_targets])
        
        self.data = pd.concat([self.data, data])
        
        
        # Dimensions
        
        self.field_dimensions = np.max(self.fields, axis=0).astype(int) + 1
        
        self.field_dimensions[self.field_dimensions < 2] = 2
        
    
    def unify_column_names(self, column_names):
        
        new_column_names = []
        
        for column_name in column_names:
                        
            if '_' in column_name:
                
                column_name_parts = column_name.split('_')
                
                new_column_name = '_'.join(column_name_parts[1:])
                
            else:
                
                new_column_name = column_name
                
            new_column_names.append(new_column_name)
        
        column_mapping = dict(zip(column_names, new_column_names))
        
        return column_mapping
    
    
    def encode(self, symbols):
        
        # binary = [bit for nucleotide in dna_sequence 
        #           for bit in self.dna_encoding[nucleotide]]
        
        bits = []

        for symbol in symbols:

            bits.extend(self.encoding[symbol])

        return bits
    
    
    def decode(self, bits):
        
        symbols = ''
        
        # bit_chunks = zip(bits[0::2], bits[1::2])
        
        bit_chunks = np.array(bits).reshape(-1, self.bits_per_symbol)
        
        for bit_chunk in bit_chunks:
        
            symbols += self.decoding.get(tuple(bit_chunk), 'X')
            
        return symbols
        
        
    def save(self, file_path):
        
        self.data.to_csv(file_path, index=False)
        
    
    def __len__(self):
        
        return self.fields.shape[0]
    

    def __getitem__(self, index):
        
        fields = self.fields[index]
        target = self.targets[index].squeeze()
        
        return fields, target
    
    
    def append_records(self, new_fields, new_targets, record_repetitions_count):
        
        new_rows_count = new_fields.shape[0] * record_repetitions_count
    
        new_fields_array = np.tile(new_fields, (record_repetitions_count, 1))
        new_targets_array = np.tile(new_targets, (record_repetitions_count, 1)).ravel()

        self.fields = np.vstack((self.fields, new_fields_array))
        self.targets = np.concatenate((self.targets, new_targets_array))
        
        # TODO: append to self.data