### **Prerequesites**: Download uniprot data and import python libraries

In [None]:
# !wget https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.dat.gz
# !gunzip uniprot_sprot.dat.gz
# shutil.copyfile('uniprot_sprot.dat', '/content/drive/My Drive/uniprot_sprat.dat')

import re
from collections import defaultdict
import torch
import pickle
from google.colab import drive
drive.mount('/content/drive')
import shutil

### Specify UniprotEntry class
Stores information about uniprot entry sequence and features


In [None]:
def feature_is_header(line_data):
    c1_check = line_data[0] == 'FT'
    c2_check = re.match(r"[A-Z]+", line_data[1])
    c3_check = re.match(r"<?\d+(?:\.\.\>?d+)?", line_data[2])
    return c1_check and c2_check and c3_check

def get_pos(pos_data):
    pos_data = pos_data.replace('<', '')
    pos_data = pos_data.replace('>', '')

    try:
        pos_data = [int(x)-1 for x in pos_data.split('..')]
    except ValueError:
        return None

    if len(pos_data) == 1:
        pos_data.append(pos_data[0]+1)

    return pos_data

def get_subfeature(feature_string):
    feature_search = re.search(r'^\/(.+)="(.+)"', feature_string)
    if feature_search:
        return feature_search.groups()
    else:
        return None


class UniprotEntry:
    def __init__(self, entry_ID):
        self.entry_ID = entry_ID
        self.features = []
        self.sequence = ''

    def handle_FT(self, line):
        line_data = (line[:2], line[2:21].strip(), line[21:].strip())

        if feature_is_header(line_data):
            if len(self.features) > 0:
                if self.features[-1] == 'ERR':
                    self.features.pop()

            feature_pos = get_pos(line_data[2])
            if feature_pos:
                new_feature = [line_data[1], feature_pos, {}]
                self.features.append(new_feature)
            else:
                self.features.append('ERR')
        else:
            if self.features[-1] == 'ERR':
                return

            subfeature = get_subfeature(line_data[2])

            if subfeature:
                if subfeature[0] not in ['evidence', 'id']:
                    self.features[-1][-1][subfeature[0]] = subfeature[1]


### Process Uniprot Data
Iterate through `uniprot_sprot.dat` and collect all features

In [None]:
with open('/content/drive/My Drive/uniprot_sprat.dat') as uniprot_in:
    uniprot_data = []

    for l_i, line in enumerate(uniprot_in):
        c1 = line[:2]

        if c1 == 'ID':
            if len(uniprot_data) > 0:
                if uniprot_data[-1].features[-1] == 'ERR':
                    uniprot_data[-1].features.pop()

            entry_ID = re.split(r'\s+', line)[1]
            uniprot_data.append(UniprotEntry(entry_ID))

        if c1 == 'FT':
            uniprot_data[-1].handle_FT(line)

        if c1 == '  ':
            seq_data = re.split(r'\s+', line)
            uniprot_data[-1].sequence += ''.join(seq_data)
            # uniprot_data[-1].sequence += line.strip()




In [None]:
# Save or load uniprot_data.pkl
#
# with open('/content/drive/My Drive/uniprot_data.pkl', 'wb') as f:
#     pickle.dump(uniprot_data, f)

# with open('/content/drive/My Drive/uniprot_data.pkl', 'rb') as f:
#     uniprot_data = pickle.load(f)

In [None]:
# Sanity check!

for e_i, entry in enumerate(uniprot_data):
    print(entry.entry_ID)
    print(entry.sequence)
    print(entry.features)
    print('-----')

    if e_i > 10:
        break

### Filter and Format
Identify the samples we are interested in and convert sequence data and labels to one-hot values.

In [None]:
feature_labels = [
 'DNA_BIND',
 'REGION-Disordered',
 'TOPO_DOM-Cytoplasmic',
 'TOPO_DOM-Extracellular',
 'TOPO_DOM-Lumenal',
 'TRANSIT-Mitochondrion',
 'TRANSMEM-Helical',
 'ZN_FING'
]

seq_vocab = ['-','A','B','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','U','V','W','X','Y','Z']

def get_label_index(feature):
    for k,v in feature[-1].items():
        feature_string = f"{feature[0]}-{v}"
        for idx, label in enumerate(feature_labels):
            if label in feature_string:
                return idx
    return None

def get_samples(entry, feature_labels, max_l=20):
    samples = []
    for feature in entry.features:
        label_idx = get_label_index(feature)

        if label_idx is not None:
            seq_slice = entry.sequence[slice(*feature[1])][:max_l]

            if len(seq_slice) < max_l:
                seq_slice += '-'*(max_l-len(seq_slice))

            samples.append((seq_slice, label_idx))

    return samples

def generate_onehot(size, idx):
    onehot = [0 for _ in range(size)]
    onehot[idx] = 1
    return onehot

def seq_to_onehot(seq, vocab):
    seq_onehot = []
    for c in seq:
        symbol_onehot = generate_onehot(len(vocab), vocab.index(c))
        seq_onehot += symbol_onehot
    return seq_onehot

sample_seqs = []
sample_labels = []

for ei, entry in enumerate(uniprot_data):
    samples = get_samples(entry, feature_labels)

    if samples is not None:
        for seq, label in samples:
            sample_seqs.append(seq_to_onehot(seq, seq_vocab))
            sample_labels.append(generate_onehot(len(feature_labels), label))

In [None]:
for seq, label in zip(sample_seqs, sample_labels):
    print(seq)
    print(len(seq), len(seq_vocab))
    print(label)
    break

In [None]:
sample_seqs = torch.tensor(sample_seqs)
sample_labels = torch.tensor(sample_labels)

uniprot_corpus = (sample_seqs, sample_labels, seq_vocab, feature_labels)
# torch.save((sample_seqs, sample_labels, seq_vocab, feature_labels), '/content/drive/My Drive/uniprot_corpus.pt')
#uniprot_corpus = torch.load('/content/drive/My Drive/uniprot_corpus.pt')

In [None]:
fun_hot_seqs = sample_seqs.reshape((-1, 20, len(seq_vocab))).argmax(2)
fun_hot_labels = sample_labels.argmax(1)

In [None]:
for r_i, data_row in enumerate(zip(fun_hot_seqs, fun_hot_labels)):
    seq_index, label_index = data_row
    print(''.join([seq_vocab[i] for i in seq_index]), feature_labels[label_index])

    if r_i > 20:
        break

In [None]:
class UniprotLoader:
    def __init__(self, uniprot_corpus, batch_size=64):
        self.sequences = uniprot_corpus[0]
        self.labels = uniprot_corpus[1]
        self.sequence_vocab = uniprot_corpus[2]
        self.label_vocab = uniprot_corpus[3]

        self.seq_length = self.sequences.shape[1]
        self.label_length = self.labels.shape[1]

        self.train_index, self.test_index = self.get_train_test_splits()

        self.batch_size = batch_size

        self.train_iterator = self.new_batch_iterator(self.train_index)
        self.test_iterator = self.new_batch_iterator(self.test_index)

        self.train_batch_count = len(self.train_index) // self.batch_size
        self.test_batch_count = len(self.test_index) // self.batch_size

    def new_batch_iterator(self, index):
        random_index = index[torch.randperm(len(index))]

        for i in range(0, len(random_index), self.batch_size):
            yield random_index[i : i + self.batch_size]

    def get_train_test_splits(self, r=0.2):
        train_index = []
        test_index = []

        for l_idx in range(len(self.label_vocab)):
            class_index = torch.where(self.labels[:, l_idx]==1)[0]
            class_size = len(class_index)
            class_index_random = class_index[torch.randperm(class_size)]

            test_cutoff = int(class_size*r)

            test_index.append(class_index_random[:test_cutoff])
            train_index.append(class_index_random[test_cutoff:])

        return(torch.hstack(train_index).sort()[0], torch.hstack(test_index).sort()[0])

    def get_batch(self, dataset='train'):
        iterator = self.train_iterator if dataset == 'train' else self.test_iterator
        random_index = next(iterator, None)

        if random_index is None:
            if dataset == 'train':
                iterator = self.new_batch_iterator(self.train_index)
                self.train_iterator = iterator
            else:
                iterator = self.new_batch_iterator(self.test_index)
                self.test_iterator = iterator

            random_index = next(self.train_iterator, None)

        return self.sequences[random_index].float(), self.labels[random_index].float()


uniprot_loader = UniprotLoader(uniprot_corpus, batch_size=2048)

In [None]:
class SequenceClassifier(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SequenceClassifier, self).__init__()

        self.hidden_size = hidden_size

        self.input_layer = torch.nn.Linear(input_size, hidden_size)
        self.hidden_layers = self.generate_hidden_layers(4, torch.nn.ReLU())
        self.output_layer = torch.nn.Linear(hidden_size, output_size)

    def generate_hidden_layers(self, layer_count, activation):
        layers = []

        for _ in range(layer_count):
            layers.append(torch.nn.Linear(in_features=self.hidden_size,
                                          out_features=self.hidden_size))
            layers.append(activation)

        return torch.nn.Sequential(*layers)


    def forward(self, x):
        x = self.input_layer(x)
        x = self.hidden_layers(x)
        x = self.output_layer(x)
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
sequence_classifier = SequenceClassifier(uniprot_loader.seq_length, 1024, len(uniprot_loader.label_vocab)).to(device)


In [None]:
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(sequence_classifier.parameters(), lr=0.001)

In [None]:
# prompt: pytorch training loop boilerplate
epochs = 100

for epoch in range(epochs):
    sequence_classifier.train()
    running_loss = 0.0

    for batch_index in range(uniprot_loader.train_batch_count):
        batch_samples, batch_labels = uniprot_loader.get_batch()

        optimizer.zero_grad()

        outputs = sequence_classifier(batch_samples.to(device))
        loss = loss_function(outputs, torch.argmax(batch_labels.to(device), dim=1))

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (batch_index + 1) % 100 == 0:    # Print every 10 mini-batches
            print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch_index+1}/{uniprot_loader.train_batch_count}], Loss: {running_loss/10:.4f}')
            running_loss = 0.0

            with torch.no_grad():
                sequence_classifier.eval()
                totals = torch.zeros(len(uniprot_loader.label_vocab))
                correct = torch.zeros(len(uniprot_loader.label_vocab))

                for batch_index in range(uniprot_loader.test_batch_count):
                    batch_samples, batch_labels = uniprot_loader.get_batch(dataset='test')

                    outputs = sequence_classifier(batch_samples.to(device))
                    predictions = outputs.argmax(dim=1)
                    true_labels = batch_labels.argmax(dim=1).to(device)
                    hits = (predictions == true_labels)
                    correct_label_index, correct_counts = true_labels[hits].cpu().unique(return_counts=True)

                    totals += torch.sum(batch_labels, dim=0)
                    correct[correct_label_index] += correct_counts

                print(correct / totals)
                sequence_classifier.train()



print('Finished Training')
