<a href="https://colab.research.google.com/github/olive004/Toy_SNP_Caller/blob/master/Toy_SNP_caller_part_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import random
import numpy as np
NUCLEOTIDES = "ACGT"

# Function to simulate multiple sequence alignments with errors and either no SNP or heterozygous SNP

In [0]:
def simulate_alignments(reference_length=200, num_alignments = 20000, 
                        coverage = 100, p_sequencing_error=0.03):
    alignments = []
    mutation_types = []
    snp_index = reference_length // 2
    
    for i in range(num_alignments):
        if (i % 400 == 0):
          print("Computing alignment ", i)
        reference = [random.choice(NUCLEOTIDES) for _ in range(reference_length)]
        reference_base_at_snp = reference[snp_index]
        snp_base = random.choice([i for i in NUCLEOTIDES if i != reference_base_at_snp])
        mutation_type=random.choice([0,1,2]) # 0 -> no SNP; 1 -> Homozygous SNP; 2 -> Heterozygous SNP
        mutation_types.append(mutation_type)
        
        alignment = [reference]#first read is always the reference
        for _ in range(coverage):
            new_read = [reference[i] if random.random() > p_sequencing_error else random.choice(NUCLEOTIDES) for i in range(reference_length)]
            
            if mutation_type == 1: #homozygous SNP
                new_read[snp_index] = snp_base
            if mutation_type == 2 and random.random() > 0.5: # heterozygous SNP
                new_read[snp_index] = snp_base

            if random.random() < p_sequencing_error: #Add errors to SNP region also
                new_read[snp_index] =  random.choice(NUCLEOTIDES)
            alignment.append(new_read)
        alignments.append(alignment)
    return alignments, mutation_types

In [0]:
# Compute 2000 alignments
alignments, mutation_types = simulate_alignments(num_alignments=2000)
alignments = np.array(alignments)

# Visualise the alignments

In [0]:
import matplotlib.pyplot as plt
%matplotlib inline


In [0]:
transdict = {"A":0, "C": 1, "G":2, "T":3}
alignments_ints = np.vectorize(transdict.get)(alignments)
plt.rcParams['figure.dpi'] = 200

In [0]:
alignment_idx = 13
plt.imshow(alignments_ints[alignment_idx],cmap='jet')
mutation_type_names = {0: "No mutation",
                 1: "Homozygous SNP",
                 2: "Heterozygous SNP"}
print ("Mutation type: ", mutation_type_names[mutation_types[alignment_idx]])

In [0]:
# Write the basecaller


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [0]:
#print(alignments)

In [0]:
print(mutation_types)

In [0]:
transdict = {"A":0, "C": 1, "G":2, "T":3}

def char2int(alignments):
  return np.vectorize(transdict.get)(alignments)

def int2onehot(x):
  
  x = torch.from_numpy(x)
  return F.one_hot(x, 4).to(torch.float32)


In [0]:
alignment_ints = char2int(alignments)
print(alignment_ints)

onehot = int2onehot(alignment_ints).permute(0, 3, 1, 2)
print(onehot)

In [0]:

np.shape(onehot[1, 1:, :, :]) # exclude reference

In [0]:
cuda = torch.device('cuda')

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__() # 200 x 100 x 4
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=6, kernel_size=5, stride=1, padding=2) # 200 x 100 x 6
        self.pool1 = nn.MaxPool2d(2, 2) # 100 x 50 x 6
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=2) # 200 x 100 x 6
        self.pool2 = nn.MaxPool2d(2, 2) # 50 x 25 x 16
        # view # 20,000 ( x 1 x 1)
        self.fc1 = nn.Linear(50*25*16, 3) # 3

    def forward(self, x):
        #print(x.size())
        x = self.pool1(F.relu(self.conv1(x)))
        #print(x.size())
        x = self.pool2(F.relu(self.conv2(x)))
        #print(x.size())
        x = x.view(-1, 50*25*16)
        #print(x.size())
        x = F.relu(self.fc1(x))
        #print(x.size())
        return x
      


In [0]:

with torch.cuda.device(0):
  model = Net()
  
  x_in = onehot[:20, :, :, :]
  print('x_in:', x_in.shape)
  y = model.forward(x_in) # NCHW
  print(y)

In [0]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [0]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i in range(int(2000 / 20)):
        # get the inputs; data is a list of [inputs, labels]
        #inputs, labels = data #  one_hot, mutation_types
        inputs = onehot[20*i:20*(i+1), :, :, :]
        #print(np.shape(inputs))
        labels = torch.tensor(mutation_types[20*i:20*(i+1)], dtype=torch.long)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        #print(inputs.shape)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

In [0]:
test_inputs, test_outputs = simulate_alignments()

Computing alignment  0


In [0]:
x_in = onehot[:20, :, :, :]
y_in = mutation_types[:20]
y = model.forward(x_in) # NCHW
print('\n'.join(['{}\t{}'.format(a, b) for (a, b) in zip(y, y_in)]))

In [0]:
import torch
import torch.utils.data

In [0]:
testset = simulate_alignments(reference_length=200, num_alignments = 20, coverage = 100, p_sequencing_error=0.03)

test_alignment_ints = char2int(testset)
print(test_alignment_ints)

onehot_test = int2onehot(test_alignment_ints).permute(0, 3, 1, 2)
print(onehot_test)

In [0]:

testloader = torch.utils.data.DataLoader(testset, batch_size=20,
                                         shuffle=False, num_workers=2)     # Output: alignments, mutation_types

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        alignments, mutation_types = data
        outputs = model(alignments)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))