### **Prerequesites**: Download uniprot data and import python libraries

In [21]:
!wget https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.dat.gz
!gunzip uniprot_sprot.dat.gz
import re
from collections import defaultdict
import torch
import pickle
import torch.nn as nn

In [3]:
# Let's see what we're working with.

with open('uniprot_sprot.dat') as f_in:
    for l_i, line in enumerate(f_in):
        print(line.strip())

        if l_i > 500:
            break

ID   001R_FRG3G              Reviewed;         256 AA.
AC   Q6GZX4;
DT   28-JUN-2011, integrated into UniProtKB/Swiss-Prot.
DT   19-JUL-2004, sequence version 1.
DT   09-APR-2025, entry version 45.
DE   RecName: Full=Putative transcription factor 001R;
GN   ORFNames=FV3-001R;
OS   Frog virus 3 (isolate Goorha) (FV-3).
OC   Viruses; Varidnaviria; Bamfordvirae; Nucleocytoviricota; Megaviricetes;
OC   Pimascovirales; Iridoviridae; Alphairidovirinae; Ranavirus; Frog virus 3.
OX   NCBI_TaxID=654924;
OH   NCBI_TaxID=30343; Dryophytes versicolor (chameleon treefrog).
OH   NCBI_TaxID=8404; Lithobates pipiens (Northern leopard frog) (Rana pipiens).
OH   NCBI_TaxID=45438; Lithobates sylvaticus (Wood frog) (Rana sylvatica).
OH   NCBI_TaxID=8316; Notophthalmus viridescens (Eastern newt) (Triturus viridescens).
RN   [1]
RP   NUCLEOTIDE SEQUENCE [LARGE SCALE GENOMIC DNA].
RX   PubMed=15165820; DOI=10.1016/j.virol.2004.02.019;
RA   Tan W.G., Barkman T.J., Gregory Chinchar V., Essani K.;
RT   "Compara

### Specify UniprotEntry class
Stores information about uniprot entry sequence and features


In [4]:
def feature_is_header(line_data):
    c1_check = line_data[0] == 'FT'
    c2_check = re.match(r"[A-Z]+", line_data[1])
    c3_check = re.match(r"<?\d+(?:\.\.\>?d+)?", line_data[2])
    return c1_check and c2_check and c3_check

def get_pos(pos_data):
    pos_data = pos_data.replace('<', '')
    pos_data = pos_data.replace('>', '')

    try:
        pos_data = [int(x)-1 for x in pos_data.split('..')]
    except ValueError:
        return None

    if len(pos_data) == 1:
        pos_data.append(pos_data[0]+1)

    return pos_data

def get_subfeature(feature_string):
    feature_search = re.search(r'^\/(.+)="(.+)"', feature_string)
    if feature_search:
        return feature_search.groups()
    else:
        return None


class UniprotEntry:
    def __init__(self, entry_ID):
        self.entry_ID = entry_ID
        self.features = []
        self.sequence = ''

    def handle_FT(self, line):
        line_data = (line[:2], line[2:21].strip(), line[21:].strip())

        if feature_is_header(line_data):
            if len(self.features) > 0:
                if self.features[-1] == 'ERR':
                    self.features.pop()

            feature_pos = get_pos(line_data[2])
            if feature_pos:
                new_feature = [line_data[1], feature_pos, {}]
                self.features.append(new_feature)
            else:
                self.features.append('ERR')
        else:
            if self.features[-1] == 'ERR':
                return

            subfeature = get_subfeature(line_data[2])

            if subfeature:
                if subfeature[0] not in ['evidence', 'id']:
                    self.features[-1][-1][subfeature[0]] = subfeature[1]


### Process Uniprot Data
Iterate through `uniprot_sprot.dat` and collect all features

In [5]:
with open('uniprot_sprot.dat') as uniprot_in:
    uniprot_data = []

    for l_i, line in enumerate(uniprot_in):
        # We are doing this for demo purposes
        # so cap the number of entries at 100
        if len(uniprot_data) > 100:
            break
            
        c1 = line[:2]

        if c1 == 'ID':
            if len(uniprot_data) > 0:
                if uniprot_data[-1].features[-1] == 'ERR':
                    uniprot_data[-1].features.pop()

            entry_ID = re.split(r'\s+', line)[1]
            uniprot_data.append(UniprotEntry(entry_ID))

        if c1 == 'FT':
            uniprot_data[-1].handle_FT(line)

        if c1 == '  ':
            seq_data = re.split(r'\s+', line)
            uniprot_data[-1].sequence += ''.join(seq_data)


In [6]:
# Sanity check!
# Check that our uniprot data entries behave as expected
for e_i, entry in enumerate(uniprot_data):
    print(entry.entry_ID)
    print(entry.sequence)
    print(entry.features)
    print('-----')


001R_FRG3G
MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL
[['CHAIN', [0, 255], {'note': 'Putative transcription factor 001R'}]]
-----
002L_FRG3G
MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQTCASGFCTSQPLCARIKKTQVCGLRYSSKGKDPLVSAEWDSRGAPYVRCTYDADLIDTQAQVDQFVSMFGESPSLAERYCMRGVKNTAGELVSRVSSDADPAGGWCRKWYSAHRGPDQDAALGSFCIKNPGAADCKCINRASDPVYQKVKTLHAYPDQCWYVPCAADVGELKMGTQRDTPTNCPTQVCQIVFNMLDDGSVTMDDVKNTINCDFSKYVPPPPPPKPTPPTPPTPPTPPTPPTPPTPPTPRPVHNRKVMFFVAGAVLVAILISTVRW
[['CHAIN', [0, 319], {'note': 'Uncharacterized protein 002L'}], ['TRANSMEM', [300, 317], {'note': 'Helical'}], ['REGION', [260, 293], {'note': 'Disordered'}], ['COMPBIAS', [261, 293], {'note': 'Pro residues'}]]
-----
002R_IIV3
MASNTVSAQGGSNRPVRDFSNIQDVAQFLLFDPIWNEQPGSIVPWKMNREQALAERYPELQTSEPSEDYSGPVESLELLPLEIKLDIMQYLSWEQIS

### Filter and Format
Identify the samples we are interested in and convert sequence data and labels to one-hot values.

In [9]:
feature_labels = [
 'DNA_BIND',
 'REGION-Disordered',
 'TOPO_DOM-Cytoplasmic',
 'TOPO_DOM-Extracellular',
 'TOPO_DOM-Lumenal',
 'TRANSIT-Mitochondrion',
 'TRANSMEM-Helical',
 'ZN_FING'
]

seq_vocab = ['-','A','B','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','U','V','W','X','Y','Z']

def get_label_index(feature):
    for k,v in feature[-1].items():
        feature_string = f"{feature[0]}-{v}"
        for idx, label in enumerate(feature_labels):
            if label in feature_string:
                return idx
    return None

def get_samples(entry, feature_labels, max_l=20):
    samples = []
    for feature in entry.features:
        label_idx = get_label_index(feature)

        if label_idx is not None:
            seq_slice = entry.sequence[slice(*feature[1])][:max_l]

            if len(seq_slice) < max_l:
                seq_slice += '-'*(max_l-len(seq_slice))

            samples.append((seq_slice, label_idx))

    return samples

def generate_onehot(size, idx):
    onehot = [0 for _ in range(size)]
    onehot[idx] = 1
    return onehot

def seq_to_onehot(seq, vocab):
    seq_onehot = []
    for c in seq:
        symbol_onehot = generate_onehot(len(vocab), vocab.index(c))
        seq_onehot += symbol_onehot
    return seq_onehot

# Store one-hot featured samples in these!
sample_seqs = []
sample_labels = []

for ei, entry in enumerate(uniprot_data):
    samples = get_samples(entry, feature_labels)

    if samples is not None:
        for seq, label in samples:
            sample_seqs.append(seq_to_onehot(seq, seq_vocab))
            sample_labels.append(generate_onehot(len(feature_labels), label))

In [8]:
for seq, label in zip(sample_seqs, sample_labels):
    print(seq)
    print(len(seq), len(seq_vocab))
    print(label)
    break

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [11]:
sample_seqs = torch.tensor(sample_seqs)
sample_labels = torch.tensor(sample_labels)

uniprot_corpus = (sample_seqs, sample_labels, seq_vocab, feature_labels)

# Don't mess with this unless you want to overwrite precious data!
# torch.save((sample_seqs, sample_labels, seq_vocab, feature_labels), 'uniprot_corpus.pt')

  sample_seqs = torch.tensor(sample_seqs)
  sample_labels = torch.tensor(sample_labels)


In [14]:
# Let's convert our one-hot data back into text data
# (just another check)

fun_hot_seqs = sample_seqs.reshape((-1, 20, len(seq_vocab))).argmax(2)
fun_hot_labels = sample_labels.argmax(1)

In [13]:
for seq_index, label_index in zip(fun_hot_seqs, fun_hot_labels):
    print(''.join([seq_vocab[i] for i in seq_index]), feature_labels[label_index])

VMFFVAGAVLVAILIST--- TRANSMEM-Helical
YVPPPPPPKPTPPTPPTPPT REGION-Disordered
GTGYESDSDPENEHFDDESF REGION-Disordered
YNSEDEDFEYDSDSEDDDSD REGION-Disordered
MLFLGTIGLAVVVGGLMAYG TRANSMEM-Helical
GKTPSSGTSFHTASPSFSSR REGION-Disordered
MQNPLPEVMSPEHDKRTTT- REGION-Disordered
RFTKPSSSVAKSTSPSLRNS REGION-Disordered
TVLAFKGEGALALAGLLVMA TRANSMEM-Helical
FGVVHSHTPKKKYTSRDSDS REGION-Disordered
VKTIAMLAMLVIVAALIYMG TRANSMEM-Helical
YKMWFLYALILALIFGVFMW TRANSMEM-Helical
APEGMGPHHAASSSHHSAQH REGION-Disordered
AVTGSSSNVKIRKSAPARNE REGION-Disordered
LVTYGGKDGPSDNEDGPSDD REGION-Disordered
MATNYCDEFERNPTRNPRTG REGION-Disordered
PRVAAASPCPEFARDPTRNP REGION-Disordered
GGASPRRVSPARAFPNRRVS REGION-Disordered
GLSPFRSHMRKSPARRSPAR REGION-Disordered
SRPSGVSRTSGTSGSSGSSA REGION-Disordered
MLQNYAIVLGMAVAVAIWYF TRANSMEM-Helical
APPGPNPPKPDPPKPDPPKM REGION-Disordered
RQPRVVPVTSSDPEVVDDED REGION-Disordered
SFVHKLPTFYTAGVGAIIGG TRANSMEM-Helical
HWAGIALYCVGWVTLASVIY TRANSMEM-Helical
ILKGSILSCIVISAVWSILE TRANSMEM-Hel

### Construct a dataloader
Specify a class which partitions the data into train-test splits and retrieves randomized batches. 

In [15]:
uniprot_corpus = torch.load('uniprot_corpus.pt')

In [16]:
class UniprotLoader:
    def __init__(self, uniprot_corpus, batch_size=64):
        self.sequences = uniprot_corpus[0]
        self.labels = uniprot_corpus[1]
        self.sequence_vocab = uniprot_corpus[2]
        self.label_vocab = uniprot_corpus[3]

        self.seq_length = self.sequences.shape[1]
        self.label_length = self.labels.shape[1]

        self.train_index, self.test_index = self.get_train_test_splits()

        self.batch_size = batch_size

        self.train_iterator = self.new_batch_iterator(self.train_index)
        self.test_iterator = self.new_batch_iterator(self.test_index)

        self.train_batch_count = len(self.train_index) // self.batch_size
        self.test_batch_count = len(self.test_index) // self.batch_size

    def new_batch_iterator(self, index):
        random_index = index[torch.randperm(len(index))]

        for i in range(0, len(random_index), self.batch_size):
            yield random_index[i : i + self.batch_size]

    def get_train_test_splits(self, r=0.2):
        train_index = []
        test_index = []

        for l_idx in range(len(self.label_vocab)):
            class_index = torch.where(self.labels[:, l_idx]==1)[0]
            class_size = len(class_index)
            class_index_random = class_index[torch.randperm(class_size)]

            test_cutoff = int(class_size*r)

            test_index.append(class_index_random[:test_cutoff])
            train_index.append(class_index_random[test_cutoff:])

        return(torch.hstack(train_index).sort()[0], torch.hstack(test_index).sort()[0])

    def get_batch(self, dataset='train'):
        iterator = self.train_iterator if dataset == 'train' else self.test_iterator
        random_index = next(iterator, None)

        if random_index is None:
            if dataset == 'train':
                iterator = self.new_batch_iterator(self.train_index)
                self.train_iterator = iterator
            else:
                iterator = self.new_batch_iterator(self.test_index)
                self.test_iterator = iterator

            random_index = next(self.train_iterator, None)

        return self.sequences[random_index].float(), self.labels[random_index].float()


uniprot_loader = UniprotLoader(uniprot_corpus, batch_size=2048)

### Build your model
Construct your sequence classifier and test it in battle

**TODO #1:** Implement the SequenceClassifier class. This could look very much like the MNIST classifier.

In [32]:
class SequenceClassifier(nn.Module):
    '''
    There's nothing here! :O
    Consult the MNIST classifier model 
    or some other resource to construct this class.
    '''

**TODO #2:** Instantiate your sequence classifier. You will need to identify the correct hyperparameters, or have hard-coded them directly into the SequenceClassifier class. 

Be sure to load your model onto the GPU! 

In [29]:
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(sequence_classifier.parameters(), lr=0.001)

device = "How do you access the GPU?"
sequence_classifier = "Replace this string!"

NameError: name 'sequence_classifier' is not defined

In [88]:
def train(dataloader, model, loss_function, optimizer, test_freq=100):
    device = next(model.parameters()).device
    loss_vals = []

    for batch_index in range(dataloader.train_batch_count):
        model.train()
        batch_samples, batch_labels = dataloader.get_batch()

        optimizer.zero_grad()

        outputs = model(batch_samples.to(device))
        loss = loss_function(outputs, torch.argmax(batch_labels.to(device), dim=1))

        loss.backward()
        optimizer.step()

        loss_vals.append(loss.item())

        if (batch_index + 1) % test_freq == 0: 
            test_accuracy = test(dataloader, model)
            accuracy_report = ' | '.join([f'{x.item():0.2f}' for x in test_accuracy]) 
            avg_train_loss = sum(loss_vals) / len(loss_vals)
            
            report_string = [f'|-----Epoch [%d/%d], Batch [{batch_index+1}/{uniprot_loader.train_batch_count}]-----|\n',
                             f'| Train Loss: {avg_train_loss:.4f}\n',
                             f'| Per-label test accuracy: {accuracy_report}\n']

            report_string.append('|' + ''.join(['-' for _ in range(len(report_string[0])-3)]) + '|\n')
            
            yield ''.join(report_string)
            


def test(dataloader, model):
    with torch.no_grad():
        model.eval()
        totals = torch.zeros(len(dataloader.label_vocab))
        correct = torch.zeros(len(dataloader.label_vocab))
    
        for batch_index in range(dataloader.test_batch_count):
            batch_samples, batch_labels = dataloader.get_batch(dataset='test')
    
            outputs = model(batch_samples.to(device))
            predictions = outputs.argmax(dim=1)
            true_labels = batch_labels.argmax(dim=1).to(device)
            hits = (predictions == true_labels)
            correct_label_index, correct_counts = true_labels[hits].cpu().unique(return_counts=True)
    
            totals += torch.sum(batch_labels, dim=0)
            correct[correct_label_index] += correct_counts
    
        return (correct / totals)


|-----Epoch [1/100], Batch [100/298]-----|
| Train Loss: 0.0152
| Per-label test accuracy: 0.92 | 0.98 | 0.80 | 0.76 | 0.52 | 0.79 | 1.00 | 0.98
|----------------------------------------|

|-----Epoch [1/100], Batch [200/298]-----|
| Train Loss: 0.0145
| Per-label test accuracy: 0.92 | 0.97 | 0.80 | 0.76 | 0.51 | 0.83 | 1.00 | 0.98
|----------------------------------------|

|-----Epoch [2/100], Batch [100/298]-----|
| Train Loss: 0.0166
| Per-label test accuracy: 0.92 | 0.97 | 0.81 | 0.75 | 0.53 | 0.83 | 1.00 | 0.98
|----------------------------------------|

|-----Epoch [2/100], Batch [200/298]-----|
| Train Loss: 0.0160
| Per-label test accuracy: 0.93 | 0.97 | 0.80 | 0.75 | 0.53 | 0.81 | 1.00 | 0.97
|----------------------------------------|

|-----Epoch [3/100], Batch [100/298]-----|
| Train Loss: 0.0156
| Per-label test accuracy: 0.92 | 0.97 | 0.80 | 0.76 | 0.53 | 0.79 | 1.00 | 0.98
|----------------------------------------|

|-----Epoch [3/100], Batch [200/298]-----|
| Train Loss


KeyboardInterrupt



In [80]:
epochs = 100

for epoch in range(100):
    for output in train(uniprot_loader, sequence_classifier, loss_function, optimizer):
        print(output % (epoch+1, epochs))

|-----Epoch [1/100], Batch [100/298]-----|
Train Loss: 0.0175
Per-label test accuracy: 0.925 0.969 0.809 0.720 0.519 0.862 0.996 0.982
--------------

|-----Epoch [1/100], Batch [200/298]-----|
Train Loss: 0.0174
Per-label test accuracy: 0.922 0.971 0.810 0.753 0.513 0.797 0.996 0.978
--------------

|-----Epoch [2/100], Batch [100/298]-----|
Train Loss: 0.0153
Per-label test accuracy: 0.910 0.972 0.812 0.738 0.531 0.815 0.996 0.979
--------------

|-----Epoch [2/100], Batch [200/298]-----|
Train Loss: 0.0160
Per-label test accuracy: 0.937 0.973 0.811 0.733 0.519 0.770 0.997 0.974
--------------

|-----Epoch [3/100], Batch [100/298]-----|
Train Loss: 0.0163
Per-label test accuracy: 0.912 0.977 0.797 0.752 0.504 0.848 0.997 0.975
--------------

|-----Epoch [3/100], Batch [200/298]-----|
Train Loss: 0.0162
Per-label test accuracy: 0.922 0.967 0.818 0.740 0.508 0.798 0.997 0.977
--------------

|-----Epoch [4/100], Batch [100/298]-----|
Train Loss: 0.0153
Per-label test accuracy: 0.921 0

KeyboardInterrupt: 

### Reframing the problem
In this exercise, we were tasked with classifying a fixed length sequence with a well established label. While this makes for an instructive toy problem, it takes little imagination to appreciate how limited in utility our model is.

Consider a revised version of this problem: Suppose that instead of receiving a fixed length of sequence to be classified, we are given a full protein sequence. In this case, we are tasked not only with classifying our regions of interest, but identifying the actual bounds of those regions as well. 

This would make for a much more useful tool than what we have constructed here. However, updating our model to tackle this task is not a trivial matter. 

Your final **TODO** is to spare some thought about how to approach this problem. Feel free to even take a stab at it here, though I certainly won't hold you to it!

## Instructor's Manual
All TODO solutions available here

**TODO #1**

In [27]:
class SequenceClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, hidden_depth=4):
        super(SequenceClassifier, self).__init__()

        self.hidden_size = hidden_size

        self.input_layer = nn.Linear(input_size, hidden_size)
        self.hidden_layers = nn.Sequential(*[nn.Linear(hidden_size, hidden_size), 
                                            nn.ReLU()] * hidden_depth)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.hidden_layers(x)
        x = self.output_layer(x)
        return x

**TODO #2**

Note: You may have chosen to hard code the hyperparameters into your model, in which case you will not need to include arguments at all when instantiating your SequenceClassifier.

In [30]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
sequence_classifier = SequenceClassifier(uniprot_loader.seq_length, 1024, len(uniprot_loader.label_vocab)).to(device)

**TODO #3**

Just have fun!