In [1]:
import os
! git clone https://git.wur.nl/bioinformatics/grs34806-deep-learning-project-data.git
os.chdir("grs34806-deep-learning-project-data")

Cloning into 'grs34806-deep-learning-project-data'...
remote: Enumerating objects: 21, done.[K
remote: Total 21 (delta 0), reused 0 (delta 0), pack-reused 21 (from 1)[K
Receiving objects: 100% (21/21), 8.74 MiB | 13.58 MiB/s, done.


In [13]:
! pip install biopython --quiet
! pip install torch --quiet
from Bio import SeqIO

In [23]:
import torch


def read(seqfile,posfile):
    # datalist contains sequences, labellist contains labels
    # seqfile: file with sequences
    # posfile: file with positive cases (annotated with function)
    idlist = []
    datalist = []
    labellist = []
    with open(seqfile, 'r') as f:
        for line in f.readlines():
            line = line.rstrip().split('\t')
            idlist.append(line[0])
            datalist.append(line[1])
            labellist.append(False)
    with open(posfile, 'r') as f:
        for line in f.readlines():
            id = line.rstrip()
            try:
                i = idlist.index(id)
                labellist[i] = True
            except ValueError:
                continue
    return datalist, labellist



def generate_train_test(datalist, labellist):
    # Split up dataset in training set and test set
    i = len(datalist) // 4 * 3
    traindatalist = datalist[:i]
    trainlabellist = labellist[:i]
    testdatalist = datalist[i:]
    testlabellist = labellist[i:]
    return traindatalist, trainlabellist,testdatalist,testlabellist


def load_array(data_arrays, batch_size, is_train=True):
    """Construct a PyTorch data iterator.

    Defined in :numref:`sec_utils`"""
    dataset = torch.utils.data.TensorDataset(*data_arrays)
    return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)


def load_data(batch_size, num_steps, dataset):
    mapaa2num = {aa: i for (i, aa) in enumerate(list("ACDEFGHIKLMNPQRSTVWY"))}
    seq,lab = dataset
    seq = tokenize(seq, mapaa2num)
    seq_array = build_seq_array(seq, num_steps)
    data_arrays = (seq_array, torch.tensor(lab))
    data_iter = load_array(data_arrays, batch_size)
    return data_iter


def tokenize(data, map2num, non_aa_num=20):
    seq = []
    for count, i in enumerate(data):
        seq.append([map2num.get(j,non_aa_num) for j in list(i)])
    return seq


def build_seq_array(lines, num_steps, non_aa_num=20):
    return torch.tensor([truncate_pad(l, num_steps, non_aa_num) for l in lines])


def truncate_pad(line, num_steps, padding_token):
    if len(line) > num_steps:
        return line[:num_steps] # Truncate
    return line + [padding_token] * (num_steps - len(line)) # Pad



In [18]:
batch_size = 5
num_steps = 10

# Example for one of the simulated datasets
datalist, labellist = read("len100_200_n1000.seq", "len100_200_n1000.pos")
traindatalist, trainlabellist, testdatalist, testlabellist = generate_train_test(datalist, labellist)
traindataset = [traindatalist, trainlabellist]
testdataset = [testdatalist, testlabellist]

# Set batch_size and num_steps (maximum sequence length)
train_iter = load_data(batch_size, num_steps, traindataset)
test_iter = load_data(batch_size, num_steps, testdataset)

print(next(iter(train_iter)))

# # Define MYMODEL yourself - we do not give details about it here
# net = MYMODEL
# # trainfunction will have additional arguments;
# # This is yours to make - we do not give details about it
# trainfunction(net, train_iter, test_iter, ....)

[tensor([[10, 19,  0,  8, 14, 19,  3, 13, 15, 19],
        [10, 13, 13, 17,  4,  2,  2, 14, 17,  4],
        [10,  6, 13,  3,  7, 18,  1, 18,  3,  0],
        [10, 19, 11, 18,  5, 18, 10,  3, 15, 13],
        [10,  2,  5, 17, 14, 16,  9, 14, 12, 13]]), tensor([ True, False, False,  True, False])]


In [24]:
datalist, labellist = read('expr5Tseq_filtGO_100-1000.lis', 'GO_3A0005739.annotprot')
traindatalist, trainlabellist, testdatalist, testlabellist = generate_train_test(datalist, labellist)
print(len(traindatalist))
print(len(trainlabellist))
print(len(testdatalist))
print(len(testlabellist))

5088
5088
1696
1696
