In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
class DenoisingAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(100, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 100),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [4]:
# Create the network and the optimizer
net = DenoisingAutoencoder()
optimizer = optim.Adam(net.parameters())

# Define the loss function and the metric for evaluation
criterion = nn.MSELoss()

In [7]:

# Generate some random DNA sequences and add noise to them
dna_sequences = torch.randint(0, 4, (100, 100))
noisy_dna_sequences = dna_sequences + torch.randint(-1, 2, (100, 100))
noisy_dna_sequences = noisy_dna_sequences.clamp(0, 4)

# Convert the DNA sequences to one-hot encoded vectors
dna_sequences_1hot = nn.functional.one_hot(dna_sequences, num_classes=5).float()
noisy_dna_sequences_1hot = nn.functional.one_hot(noisy_dna_sequences, num_classes=5).float()


In [8]:
dna_sequences

tensor([[1, 1, 3,  ..., 3, 2, 3],
        [1, 0, 2,  ..., 1, 0, 2],
        [0, 2, 3,  ..., 3, 1, 2],
        ...,
        [0, 2, 0,  ..., 3, 0, 2],
        [2, 1, 0,  ..., 1, 1, 0],
        [3, 3, 2,  ..., 0, 0, 2]])

In [9]:
dna_sequences_1hot

tensor([[[0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0.],
         ...,
         [0., 0., 0., 1., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.]],

        [[0., 1., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         ...,
         [0., 1., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0.]],

        [[1., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.],
         ...,
         [0., 0., 0., 1., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.]],

        ...,

        [[1., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [1., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 1., 0.],
         [1., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0.]],

        [[0., 0., 1., 0., 0.],
         [0., 1., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         ...,
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         