<a href="https://colab.research.google.com/github/makrez/BioinformaticsTools/blob/master/VariationalAutoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from sklearn.preprocessing import OneHotEncoder
!pip install biopython
from Bio import SeqIO
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
!pip install umap-learn
import umap
import matplotlib.pyplot as plt
import os
from sklearn.decomposition import PCA
!pip install torchviz
from torchviz import make_dot



Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
# Mount drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
class hot_dna:
    def __init__(self, sequence):
        sequence = sequence.upper()
        self.sequence = self._preprocess_sequence(sequence)
        self.category_mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'U': 3, '-': 4, 'N': 5}
        self.onehot = self._onehot_encode(self.sequence)

    def _preprocess_sequence(self, sequence):
        ambiguous_bases = {'R', 'Y', 'S', 'W', 'K', 'M', 'B', 'D', 'H', 'V'}
        new_sequence = ""
        for base in sequence:
            if base in ambiguous_bases:
                new_sequence += 'N'
            else:
                new_sequence += base
        return new_sequence

    def _onehot_encode(self, sequence):
        integer_encoded = np.array([self.category_mapping[char] for char in sequence]).reshape(-1, 1)
        onehot_encoder = OneHotEncoder(sparse=False, categories='auto', handle_unknown='ignore')
        onehot_encoded = onehot_encoder.fit_transform(integer_encoded)
        
        full_onehot_encoded = np.zeros((len(sequence), 6))
        full_onehot_encoded[:, :onehot_encoded.shape[1]] = onehot_encoded
        
        return full_onehot_encoded

flatted_sequence = list()
sequence_labels = list()

alignment_length = 4000

with open('/content/drive/MyDrive/Colab Notebooks/autoencoder_data/bacillus.aln') as handle:
    for record in SeqIO.parse(handle, 'fasta'):
        label = str(record.description).rsplit(';', 1)[-1]
        seq_hot = hot_dna(str(record.seq)[10:alignment_length+10]).onehot
        
        if len(seq_hot) == alignment_length:
            flatted_sequence.append(seq_hot)
            sequence_labels.append(label)


In [6]:
class VAE(nn.Module):
    def __init__(self, alignment_length=4000, latent_dim=50, fc_hidden=512):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv1d(6, 12, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(12, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(alignment_length*32//8, fc_hidden),
            nn.ReLU(),
            nn.Linear(fc_hidden, fc_hidden),
            nn.ReLU(),
            nn.Linear(fc_hidden, fc_hidden),
            nn.ReLU(),
        )

        self.fc_mu = nn.Linear(fc_hidden, latent_dim)
        self.fc_var = nn.Linear(fc_hidden, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, fc_hidden),
            nn.ReLU(),
            nn.Linear(fc_hidden, fc_hidden),
            nn.ReLU(),
            nn.Linear(fc_hidden, fc_hidden),
            nn.ReLU(),
            nn.Linear(fc_hidden, ((alignment_length//2)//2)*32),
            nn.ReLU()
        )

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose1d(32, 12, kernel_size=4, stride=2, padding=1), 
            nn.ReLU(),
            nn.ConvTranspose1d(12, 6, kernel_size=4, stride=2, padding=1),
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, log_var = self.fc_mu(h), self.fc_var(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h = self.decoder(z)
        h = h.view(h.shape[0], 32, -1)
        out = self.decoder_conv(h)
        out = F.softmax(out, dim=1)
        return out

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

class SequenceDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        return self.sequences[index]

# def vae_loss(x, x_recon, mu, log_var):
#     BCE = F.binary_cross_entropy(x_recon, x, reduction='sum')
#     KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
#     return BCE + KLD

loss_function = torch.nn.CosineEmbeddingLoss(reduction='none')

def vae_loss(x, x_recon, mu, log_var, target):
    cos_loss = loss_function(x_recon.contiguous().view(x_recon.size(0), -1), x.contiguous().view(x.size(0), -1), target)
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return cos_loss.mean() + KLD



    
def train_model(model, optimizer, train_dataloader, val_dataloader, device, n_epochs):
    train_losses = []
    val_losses = []

    # for epoch in range(1, n_epochs + 1):
    #     train_loss = 0.0
    #     val_loss = 0.0

    #     # Training loop
    #     model.train()
    #     for batch_idx, data in enumerate(train_dataloader):
    #         # Transform the data tensor to have the shape (batch_size, n_channels, sequence_length)
    #         data = data.permute(0, 2, 1)
    #         data = data.float().to(device)

    #         # Zero the gradients
    #         optimizer.zero_grad()

    #         # Forward pass
    #         x_recon, mu, log_var = model(data)

    #         # Compute the loss
    #         loss = vae_loss(data, x_recon, mu, log_var)

    #         # Backward pass and optimization
    #         loss.backward()
    #         optimizer.step()

    #         # Update the training loss
    #         train_loss += loss.item() * data.size(0)

    for batch_idx, data in enumerate(train_dataloader):
        # Transform the data tensor to have the shape (batch_size, n_channels, sequence_length)
        data = data.permute(0, 2, 1)
        data = data.float().to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        x_recon, mu, log_var = model(data)

        # Define the target tensor for the cosine embedding loss
        target = torch.ones(data.size(0)).to(device)

        # Compute the loss
        loss = vae_loss(data, x_recon, mu, log_var, target)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update the training loss
        train_loss += loss.item() * data.size(0)
        # Print the average training loss for the epoch
        train_loss /= len(train_dataloader.dataset)

        # Validation loop
        model.eval()
        with torch.no_grad():
            for batch_idx, data in enumerate(val_dataloader):
              # Transform the data tensor to have the shape (batch_size, n_channels, sequence_length)
              data = data.permute(0, 2, 1)
              data = data.float().to(device)

              # Forward pass
              x_recon, mu, log_var = model(data)

              # Compute the loss
              loss = vae_loss(data, x_recon, mu, log_var)

              # Update the validation loss
              val_loss += loss.item() * data.size(0)

        val_loss /= len(val_dataloader.dataset)

        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(epoch, train_loss, val_loss))
        
        # Append the training loss to the list
        train_losses.append(train_loss)
        val_losses.append(val_loss)

    return train_losses, val_losses



def plot_latent_space(model, device, flatted_sequence, sequence_labels, latent_space_dim, hyperparameters):
    # Set the model to evaluation mode
    model.eval()

    # Create a list to store the latent vectors
    latent_vectors = []
    labels = []

    # Iterate over the sequences in the dataset and obtain their latent vectors
    for sequence, label in zip(flatted_sequence, sequence_labels):
        # Transform the sequence tensor to have the shape (1, n_channels, sequence_length)
        sequence = torch.from_numpy(sequence).unsqueeze(0).permute(0, 2, 1)
        sequence = sequence.float().to(device)

        # Obtain the latent vector (mu)
        mu, _ = model.encode(sequence)
        mu = mu.detach().cpu().numpy()

        # Add the latent vector to the list
        latent_vectors.append(mu)
        labels.append(label)

    # Convert the list of latent vectors to a 2D numpy array
    latent_vectors_array = np.array(latent_vectors).squeeze()

    if latent_space_dim > 2:
        umap_model = umap.UMAP(n_neighbors=100, min_dist=0.1, random_state=42)
        coords = umap_model.fit_transform(latent_vectors_array)
    else:
        coords = latent_vectors_array

    # Create a scatter plot for each unique label
    unique_labels = set(labels)
    color_dict = {label: plt.cm.tab20(i) for i, label in enumerate(unique_labels)}

    fig, ax = plt.subplots()
    for label in unique_labels:
        # Get the indices of the data points with this label
        indices = [i for i, x in enumerate(labels) if x == label]

        # Get the corresponding latent vectors or UMAP coordinates and colors
        coords_subset = coords[indices]
        color = color_dict[label]

        # Add the scatter plot for this label to the axes
        ax.scatter(coords_subset[:, 0], coords_subset[:, 1], color=color, label=label)

    # Add the legend to the axes
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

    ax.set_title(f"Latent Dim: {hyperparameters['latent_dim']} | Learning Rate: {hyperparameters['learning_rate']} | Epochs: {hyperparameters['n_epochs']}")

    # Save the plot in the 'plots' subdirectory
    os.makedirs(os.path.join(os.getcwd(), "plots"), exist_ok=True)
    filename = f"latent_dim_{hyperparameters['latent_dim']}_lr_{hyperparameters['learning_rate']}_n_epochs_{hyperparameters['n_epochs']}.png"
    plt.savefig(os.path.join(os.getcwd() , "plots", filename), bbox_inches="tight")

    plt.show()


def plot_loss(train_losses, val_losses, hyperparameters):
    plt.figure()
    plt.plot(train_losses, label="Training Loss")
    plt.plot(val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Latent Dim: {hyperparameters['latent_dim']} | Learning Rate: {hyperparameters['learning_rate']} | Epochs: {hyperparameters['n_epochs']}")
    plt.legend()
    os.makedirs(os.path.join(os.getcwd(), "plots"), exist_ok=True)
    filename = f"loss_latent_dim_{hyperparameters['latent_dim']}_lr_{hyperparameters['learning_rate']}_n_epochs_{hyperparameters['n_epochs']}.png"
    plt.savefig(os.path.join(os.getcwd() , "plots", filename), bbox_inches="tight")

    plt.show()


def main(alignment_length, latent_dim, learning_rate, n_epochs): #, architecture_name):
    # Instantiate the model with the chosen architecture
    model = VAE(alignment_length=alignment_length, latent_dim=latent_dim) #, architecture=architectures[architecture_name])
    model = model.to(device)

    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Define the dataloader
    dataset = SequenceDataset(flatted_sequence)
    
    # Split the dataset into training and validation sets
    train_set, val_set = train_test_split(dataset, test_size=0.2, random_state=42)
    
    train_dataloader = DataLoader(train_set, batch_size=32, shuffle=True)
    val_dataloader = DataLoader(val_set, batch_size=32, shuffle=False)

    # Train the model
    train_losses, val_losses = train_model(model, optimizer, train_dataloader, val_dataloader, device, n_epochs)

    # Define the hyperparameters dictionary
    hyperparameters = {
        'latent_dim': latent_dim,
        'learning_rate': learning_rate,
        'n_epochs': n_epochs
    }

    # Plot the loss curves
    plot_loss(train_losses, val_losses, hyperparameters)
    
    # Plot the latent space
    plot_latent_space(model, device, flatted_sequence, sequence_labels, latent_dim, hyperparameters)
    
    # Calculate the latent vectors for the PCA plot
    latent_vectors_array, _ = encode_latent_vectors(model, device, flatted_sequence)

    # Plot the PCA
    try:
        plot_pca(latent_vectors_array.cpu().numpy(), sequence_labels, hyperparameters)
    except Exception as e:
        print("Failed to plot PCA:", e)

    # Plot the model
    try:
        plot_model(model)
    except Exception as e:
        print("Failed to plot model:", e)

def encode_latent_vectors(model, device, sequences):
    model.eval() # Set the model to evaluation mode
    with torch.no_grad(): # Deactivate gradients for the following code
        sequences_tensor = torch.Tensor(sequences).to(device)
        latent_vectors, _ = model.encode(sequences_tensor)
    return latent_vectors


if __name__ == '__main__':

    print(f"alignment length is {alignment_length}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_epochs = 10

    latent_dims = [2,4]
    learning_rates = [0.01]

    for latent_dim in latent_dims:
        for learning_rate in learning_rates:
              main(alignment_length, latent_dim, learning_rate, n_epochs)



alignment length is 4000


UnboundLocalError: ignored