## Load dataset

In [1]:
import json

with open('data/data.json', 'r') as file:
    data = json.load(file)

## Define model

In [68]:
import numpy as np
from sklearn.tree import DecisionTreeClassifier

amino_acids = 'ARNDCQEGHILKMFPSTWYV'
empty_acid = np.zeros(20)

def one_hot_encode_sequence(sequence):
    indices = [amino_acids.index(a) for a in sequence]
    array = np.zeros((len(sequence), 20))
    array[np.arange(len(sequence)), indices] = 1
    return array


class Model:
    
    def __init__(self):
        self.tree = DecisionTreeClassifier()
    
    def _preprocess_sequence(self, sequence):
        n_acids = len(sequence)
        sequence = one_hot_encode_sequence(sequence)
        X = []
        for i in range(n_acids):
            indices = np.arange(i-4, i+5)
            x = np.array([empty_acid if j < 0 or j >= n_acids else sequence[j] for j in indices])
            X.append(x.flatten())
        return X
        
    
    def _preprocess_sequences(self, data):
        X = []
        for protein_dict in data:
            sequence = protein_dict['sequence']
            _X = self._preprocess_sequence(sequence)
            X.extend(_X)
        return np.array(X)
    
    
    def _preprocess_labels(self, data):
        joined_labels = ''.join([p['labels'] for p in data])
        y = [l for l in joined_labels]
        return np.array(y)
    
    
    def fit(self, data):
        X = self._preprocess_sequences(data)
        y = self._preprocess_labels(data)
        self.tree.fit(X, y)
    
    def predict(self, sequence):
        X = self._preprocess_sequence(sequence)
        preds = self.tree.predict(X)
        return ''.join(preds)

## Perform LOOCV

In [73]:
model = Model()

for i, protein_dict in enumerate(data):
    sequence = protein_dict['sequence']
    training_data = data[:i] + data[i+1:]
    model.fit(training_data)
    preds = model.predict(sequence)
    data[i]['preds'] =  preds

## Store the predictions in the JSON file

In [75]:
import json

with open('data/data.json', 'w', encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=4)