In [1]:
import sqlite3
from Bio import SeqIO
import os
from torch.utils.data import Dataset,DataLoader
import torch
from dataclasses import dataclass

Create databases with sqlite 3, and call it something like 'spike_prot.db'

In [2]:
conn = sqlite3.connect("spike_prot.db")
db_cursor = conn.cursor()

Now lets create a data table for trainign sequences with it's simple data structure and one for test sequences.

In [3]:
#create train sequences table
db_cursor.execute('''CREATE TABLE train_sequences
             (id INTEGER PRIMARY KEY,
              header TEXT,
              sequence TEXT)''')

#create test sequences table
db_cursor.execute('''CREATE TABLE test_sequences
             (id INTEGER PRIMARY KEY,
              header TEXT,
              sequence TEXT)''')

<sqlite3.Cursor at 0x7f23c987ab20>

Read the fasta files in and distrubute them to their correct collections 

In [4]:
training_seqs = SeqIO.parse(open(os.path.abspath('../data/spikeprot0203.clean.uniq.training.fasta')),'fasta')

for i, fasta in enumerate(training_seqs):
    header, seq = fasta.id, str(fasta.seq)
    db_cursor.execute("INSERT INTO train_sequences (header, sequence) VALUES (?,?)", (header,seq))

conn.commit()
conn.close()

In [5]:
conn = sqlite3.connect("spike_prot.db")
db_cursor = conn.cursor()

testing_seqs = SeqIO.parse(open(os.path.abspath('../data/spikeprot0203.clean.uniq.testing.fasta')), 'fasta')

for i, fasta in enumerate(testing_seqs):
    header, seq = fasta.id, str(fasta.seq)
    db_cursor.execute("INSERT INTO test_sequences (header, sequence) VALUES (?,?)", (header,seq))
    
conn.commit()
conn.close()

Test query:

In [6]:
conn = sqlite3.connect("spike_prot.db")
db_cursor = conn.cursor()

db_cursor.execute("SELECT sequence FROM train_sequences")
train_result = db_cursor.fetchone()
train_sequence = train_result[0]
print(train_sequence)


MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHAIHVSGTNGTKRFDNPVLPFNDGVYFASTEKSNIIRGWIFGTTLDSKTQSLLIVNNATNVVIKVCEFQFCNDPFLGVYYHKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLEGKQGNFKNLREFVFKNIDGYFKIYSKHTPINLVRDLPQGFSALEPLVDLPIGINITRFQTLLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFNFNGLTGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITPGTNTSNQVAVLYQDVNCTEVPVAIHADQLTPTWRVYSTGSNVFQTRAGCLIGAEHVNNSYECDIPIGAGICASYQTQTNSPRRARSVASQSIIAYTMSLGAENSVAYSNNSIAIPTNFTISVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLNRALTGIAVEQDKNTQEVFAQVKQIYKTPPIKDFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDCLGDIAARDLICAQKFNGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGR

In [7]:
db_cursor.execute("SELECT sequence FROM test_sequences")
test_result = db_cursor.fetchone()
test_sequence = test_result[0]
print(test_sequence)

MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHAIHVSGTNGTKRFDNPVLPFNDGVYFASTEKSNIIRGWIFGTTLDSKTQSLLIVNNATNVVIKVCEFQFCNDPFLGVYYHKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLEGKQGNFKNLREFVFKNIDGYFKIYSKHTPINLVRDLPQGFSVLEPLVDLPIGINITRFQTLLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFNFNGLTGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITPGTNTSNQVAVLYQGVNCTEVPVAIHADQLTPTWRVYSTGSNVFQTRAGCLIGAEHVNNSYECDIPIGAGICASYXTQTNSPRRARSVASQSIIAYTMSLGAENSVAYSNNSIAIPTNFTISVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLNRALTGIAVEQDKNTQEVFAQVKQIYKTPPIKDFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDCLGDIAARDLICAQKFNGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGR

Now let's see how to use this as way to load in our data. 

In [8]:
class FastaDataset(Dataset):
    """Create Dataset compatible indexing of fasta file
    """
    def __init__(self, db_file: str, table_name: str, encoding_fn) -> None:
        self.db_file = db_file
        self.table_name = table_name
        self.encoding_fn = encoding_fn
        
        conn = sqlite3.connect(self.db_file)
        cursor = conn.cursor()
        cursor.execute("SELECT sequence FROM {} ORDER BY id".format(self.table_name))
        self.sequences = [row[0] for row in cursor.fetchall()]
        conn.close()
        
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        sequence = sequence.replace("*", "")
        encoding = self.encoding_fn(sequence)
        return encoding
    
    def __len__(self):
        return len(self.sequences)


class FastaDataLoader:
    """Wrapper for fasta dataloader
    """
    def __init__(self, db_file: str, table_name: str, encoding_fn, batch_size: int, shuffle=True):
        self.dataset = FastaDataset(db_file, table_name, encoding_fn)
        self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle)

    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        return iter(self.dataloader)

In [9]:
ALL_AAS = 'ACDEFGHIKLMNPQRSTUVWXY'
ADDITIONAL_TOKENS = ['<OTHER>', '<START>', '<END>', '<PAD>']

# Each sequence is added <START> and <END>. "<PAD>" are added to sequence shorten than max_len.
ADDED_TOKENS_PER_SEQ = 2

n_aas = len(ALL_AAS)
aa_to_token_index = {aa: i for i, aa in enumerate(ALL_AAS)}
additional_token_to_index = {token: i + n_aas for i, token in enumerate(ADDITIONAL_TOKENS)}
token_to_index = {**aa_to_token_index, **additional_token_to_index}
index_to_token = {index: token for token, index in token_to_index.items()}
n_tokens = len(token_to_index)

def tokenize_seq(seq: str, max_len:int=1500) -> torch.IntTensor:
    """
    Tokenize a sequence.

    It is the caller's responsibility to infer the maximum length of the input. In case of
    tokenizing a batch of sequences, the maximum length shall be assigned to the lenght of
    the longest sequence in the same batch. 


    seq: input insquence
    max_len: maximum number of tokens, including the special tokens such as <START>, <END>.
    
    """
    seq = seq.upper()   # All in upper case.
    other_token_index = additional_token_to_index['<OTHER>']
    token_seq = [additional_token_to_index['<START>']] + [aa_to_token_index.get(aa, other_token_index) for aa in seq]
    if len(token_seq) < max_len - 1: # -1 is for the <END> token
        n_pads = max_len -1 - len(token_seq)
        token_seq.extend(token_to_index['<PAD>'] for _ in range(n_pads))
    token_seq += [additional_token_to_index['<END>']]
    return torch.IntTensor(token_seq)

In [10]:
@dataclass
class training_config:
    batch_size: int = 32
    shuffle: bool = True
    table_name: str = 'train_sequences'
    db_file: str = 'spike_prot.db'
    
@dataclass
class testing_config:
    batch_size: int = 32
    shuffle: bool = True
    table_name: str = 'test_sequences'
    db_file: str = 'spike_prot.db'
    

In [11]:
#load in training data
train_loader = FastaDataLoader(db_file=training_config.db_file,
                               table_name=training_config.table_name, 
                               encoding_fn = tokenize_seq, 
                               batch_size=training_config.batch_size,
                               shuffle=training_config.shuffle)

for i, batch in enumerate(train_loader):
    print(batch)
    if i == 1:
        break

tensor([[23, 10,  4,  ..., 25, 25, 24],
        [23, 10,  4,  ..., 25, 25, 24],
        [23, 10,  4,  ..., 25, 25, 24],
        ...,
        [23, 10,  4,  ..., 25, 25, 24],
        [23, 10,  4,  ..., 25, 25, 24],
        [23, 10,  4,  ..., 25, 25, 24]], dtype=torch.int32)
tensor([[23, 10,  4,  ..., 25, 25, 24],
        [23, 10,  4,  ..., 25, 25, 24],
        [23, 10,  4,  ..., 25, 25, 24],
        ...,
        [23, 20, 20,  ..., 25, 25, 24],
        [23, 10,  4,  ..., 25, 25, 24],
        [23, 10,  4,  ..., 25, 25, 24]], dtype=torch.int32)


In [13]:
test_loader = FastaDataLoader(db_file=testing_config.db_file,
                               table_name=testing_config.table_name, 
                               encoding_fn = tokenize_seq, 
                               batch_size=testing_config.batch_size,
                               shuffle=testing_config.shuffle)

for i, batch in enumerate(test_loader):
    print(batch)
    if i == 1:
        break

OperationalError: no such table: testing_sequences