# Part 1: Build CpG Detector

Here we have a simple problem, given a DNA sequence (of N, A, C, G, T), count the number of CpGs in the sequence (consecutive CGs).

We have defined a few helper functions / parameters for performing this task.

We need you to build a LSTM model and train it to complish this task in PyTorch.

A good solution will be a model that can be trained, with high confidence in correctness.

In [1]:
from typing import Sequence
from functools import partial
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random

In [2]:
# DO NOT CHANGE HERE
def set_seed(seed=13):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(13)

# Use this for getting x label
def rand_sequence(n_seqs: int, seq_len: int=128) -> Sequence[int]:
    for i in range(n_seqs):
        yield [random.randint(0, 4) for _ in range(seq_len)]

# Use this for getting y label
def count_cpgs(seq: str) -> int:
    cgs = 0
    for i in range(0, len(seq) - 1):
        dimer = seq[i:i+2]
        # note that seq is a string, not a list
        if dimer == "CG":
            cgs += 1
    return cgs

# Alphabet helpers   
alphabet = 'NACGT'
dna2int = { a: i for a, i in zip(alphabet, range(5))}
int2dna = { i: a for a, i in zip(alphabet, range(5))}

intseq_to_dnaseq = partial(map, int2dna.get)
dnaseq_to_intseq = partial(map, dna2int.get)

In [3]:
def prepare_data(num_samples=100):
    sequences = list(rand_sequence(num_samples))
    temp = ["".join(list(intseq_to_dnaseq(seq))) for seq in sequences]
    labels = [count_cpgs(seq) for seq in temp]
    return sequences, labels
    
train_x, train_y = prepare_data(2048)
test_x, test_y = prepare_data(512)

In [4]:
# Config
VOCAB_SIZE = 5
LSTM_HIDDEN = 32
LSTM_LAYER = 4
batch_size = 16
learning_rate = 1e-3
epoch_num = 100

In [5]:
# Data loader and shape check
from torch.utils.data import Dataset, DataLoader

class DnaDataset(Dataset):
    def __init__(self, sequences, counts):
        self.sequences = sequences
        self.counts = counts

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        return torch.LongTensor(self.sequences[index]), self.counts[index]


class Collater:
    def __call__(self, batch):
        sequences, labels = zip(*batch)
        sequences = torch.stack(sequences)
        sequences = F.one_hot(sequences)
        labels = torch.tensor(labels, dtype=torch.float)
        return sequences.to(torch.float32), labels.to(torch.float32)
        

collate_fn = Collater()
training_data = DnaDataset(train_x, train_y)
testing_data = DnaDataset(test_x, test_y)
train_data_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_data_loader = DataLoader(testing_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

for i, batch in enumerate(train_data_loader):
    x, y = batch
    print(x.shape, y.shape)
    break

torch.Size([16, 128, 5]) torch.Size([16])


In [6]:
# Model
class CpGPredictor(torch.nn.Module):
    ''' Simple model that uses a LSTM to count the number of CpGs in a sequence '''
    def __init__(self):
        super(CpGPredictor, self).__init__()
        self.lstm = nn.LSTM(VOCAB_SIZE, LSTM_HIDDEN, LSTM_LAYER, batch_first=True)
        self.classifier = nn.Linear(LSTM_HIDDEN, 1)

    def forward(self, x):
        encoded, _ = self.lstm(x)
        encoded = torch.sum(encoded, dim=1)
        logits = self.classifier(encoded).squeeze(-1)
        return logits

In [7]:
# init model / loss function / optimizer etc.
model = CpGPredictor()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
# Training loop

def train_one_epoch(model, train_data_loader, optimizer, loss_fn):
    t_loss = .0
    for j, batch in enumerate(train_data_loader):
        optimizer.zero_grad()
        
        x, y = batch
        y_hat = model(x)
        loss = loss_fn(y_hat, y)
        
        t_loss += loss.item()
        loss.backward()
        optimizer.step()
    return t_loss / (j+1)

def train(model, train_data_loader, optimizer, loss_fn):
    model.train()
    for i, epoch in enumerate(range(epoch_num)):
        t_loss = train_one_epoch(model, train_data_loader, optimizer, loss_fn)
        if i % 10 == 9:
            print(f"{epoch+1}: {t_loss}")

In [9]:
train(model, train_data_loader, optimizer, loss_fn)

10: 0.767889107693918
20: 0.10669196597154951
30: 0.02465420945736696
40: 0.01395535550909699
50: 0.0045007479457126465
60: 0.0027540832485328792
70: 0.002481040032080273
80: 0.00042396134654154594
90: 0.0003836447443461566
100: 0.0007991211390390163


In [10]:
# Evaluation loop

def eval(model, test_data_loader):
    model.eval()
    res_gs = []
    res_pred = []
    
    for i, batch in enumerate(test_data_loader):
        x, y = batch
        y_hat = model(x)
        res_pred.append(y_hat)
        if i == 0:
            print(y.to(torch.int32).tolist())
            print([round(item, 2) for item in y_hat.tolist()])
        
        gs = torch.sum((y_hat - y) ** 2)
        res_gs.append(gs)
    print(round((sum(res_gs) / (batch_size * len(res_gs))).item(), 8))

In [11]:
# TODO complete evaluation of the model
eval(model, test_data_loader)

[2, 5, 8, 7, 4, 7, 6, 6, 3, 3, 4, 4, 1, 5, 5, 4]
[2.03, 4.99, 8.01, 7.03, 4.01, 7.01, 6.01, 6.03, 3.0, 3.01, 4.01, 4.03, 1.02, 4.99, 5.01, 4.0]
0.00023319


# Part 2: what if the DNA sequences are not the same length

In [12]:
# hint we will need following imports
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

In [13]:
# DO NOT CHANGE HERE
random.seed(13)

# Use this for getting x label
def rand_sequence_var_len(n_seqs: int, lb: int=16, ub: int=128) -> Sequence[int]:
    for i in range(n_seqs):
        seq_len = random.randint(lb, ub)
        yield [random.randint(1, 5) for _ in range(seq_len)]


# Use this for getting y label
def count_cpgs(seq: str) -> int:
    cgs = 0
    for i in range(0, len(seq) - 1):
        dimer = seq[i:i+2]
        # note that seq is a string, not a list
        if dimer == "CG":
            cgs += 1
    return cgs


# Alphabet helpers   
alphabet = 'NACGT'
dna2int = {a: i for a, i in zip(alphabet, range(1, 6))}
int2dna = {i: a for a, i in zip(alphabet, range(1, 6))}
dna2int.update({"<pad>": 0})
int2dna.update({0: "<pad>"})

intseq_to_dnaseq = partial(map, int2dna.get)
dnaseq_to_intseq = partial(map, dna2int.get)

In [14]:
def prepare_data(num_samples=100, min_len=16, max_len=128):
    sequences = list(rand_sequence_var_len(num_samples, min_len, max_len))
    temp = ["".join(list(intseq_to_dnaseq(seq))) for seq in sequences]
    labels = [count_cpgs(seq) for seq in temp]
    return sequences, labels
    
    
min_len, max_len = 64, 128
train_x, train_y = prepare_data(2048, min_len, max_len)
test_x, test_y = prepare_data(512, min_len, max_len)

In [15]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, sequences, counts) -> None:
        self.sequences = sequences
        self.counts = counts

    def __getitem__(self, index):
        return torch.LongTensor(self.sequences[index]), self.counts[index]

    def __len__(self):
        return len(self.sequences)

    
# this will be a collate_fn for dataloader to pad sequence  
class PadSequence:
    def __call__(self, batch):
        sequences, labels = zip(*batch)
        sequences = [torch.tensor(seq) for seq in sequences]
        lengths = torch.tensor([len(seq) for seq in sequences])
        padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0)
        padded_sequences = F.one_hot(padded_sequences)

        lengths, sort_indices = lengths.sort(descending=True)
        padded_sequences = padded_sequences[sort_indices]
        labels = torch.tensor(labels)[sort_indices]

        packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True)
        return packed_sequences.to(torch.float32), labels.to(torch.float32)

collate_fn = PadSequence()
training_data = MyDataset(train_x, train_y)
testing_data = MyDataset(test_x, test_y)
train_data_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_data_loader = DataLoader(testing_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

for i, batch in enumerate(train_data_loader):
    x, y = batch
    print(pad_packed_sequence(x, batch_first=True)[0].shape, y.shape)
    break

torch.Size([16, 125, 6]) torch.Size([16])


  sequences = [torch.tensor(seq) for seq in sequences]


In [16]:
# Config
VOCAB_SIZE = 6
LSTM_HIDDEN = 32
LSTM_LAYER = 4
batch_size = 16
learning_rate = 1e-3
epoch_num = 100

In [17]:
# Model

class CpGPredictor(torch.nn.Module):
    ''' Simple model that uses a LSTM to count the number of CpGs in a sequence '''
    def __init__(self):
        super(CpGPredictor, self).__init__()
        self.lstm = nn.LSTM(VOCAB_SIZE, LSTM_HIDDEN, LSTM_LAYER, batch_first=True)
        self.classifier = nn.Linear(LSTM_HIDDEN, 1)

    def forward(self, x):
        packed_output, _ = self.lstm(x)
        encoded, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
        encoded_sum = torch.sum(encoded, dim=1)
        logits = self.classifier(encoded_sum).squeeze(-1)
        return logits

In [18]:
# init model / loss function / optimizer etc.
model = CpGPredictor()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [19]:
# training (you can modify the code below)
train(model, train_data_loader, optimizer, loss_fn)

  sequences = [torch.tensor(seq) for seq in sequences]


10: 0.9518940644338727
20: 0.06461269300780259
30: 0.016101358594823978
40: 0.0043148350664523605
50: 0.002812138138779119
60: 0.0012343386067641404
70: 0.0009755911191859923
80: 0.0015113188141526734
90: 0.0002691288632092892
100: 0.0005508342567850377


In [20]:
eval(model, test_data_loader)

[1, 5, 4, 9, 6, 3, 7, 5, 5, 7, 1, 3, 3, 2, 4, 3]
[1.01, 5.04, 4.01, 9.06, 6.04, 3.02, 7.03, 5.03, 5.03, 7.03, 1.01, 3.02, 3.02, 2.02, 4.03, 3.02]


  sequences = [torch.tensor(seq) for seq in sequences]


0.00051393


In [21]:
torch.save(model.state_dict(), 'autonomize_lstm.pt')