<a href="https://colab.research.google.com/github/nongiga/HyperSmorf/blob/main/NeuralEmbedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Reference paper: https://github.com/geomstats/challenge-iclr-2021/blob/main/gcorso__Neural-Sequence-Distance-Embeddings/Neural_SEED.ipynb

In [None]:
# INITIAL SET UP: ONLY NEED TO RUN ONCE PER RUNTIME

!pip3 install geomstats
!apt install clustalw
!pip install biopython
!pip install python-Levenshtein
!pip install Cython
!pip install networkx
!pip install tqdm
!pip install gdown
!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
!git clone https://github.com/gcorso/neural_seed.git
import os
os.chdir("neural_seed")
!cd hierarchical_clustering/relaxed/mst; python setup.py build_ext --inplace; cd ../unionfind; python setup.py build_ext --inplace; cd ..; cd ..; cd ..;
os.environ['GEOMSTATS_BACKEND'] = 'pytorch'

In [None]:
# INITIAL SET UP CONTINUED: ONLY NEED ONCE PER RUNTIME
import torch
import os 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time

# from geomstats.geometry.poincare_ball import PoincareBall
#!TODO: figure out if above step is necessary

from edit_distance.train import load_edit_distance_dataset
from util.data_handling.data_loader import get_dataloaders
from util.ml_and_math.loss_functions import AverageMeter

In [None]:
# IMPORT DATASET HERE
# commented-out portion is reference from neuralseed

# [IMPORT THINGS HERE]

# !gdown --id 1yZTOYrnYdW9qRrwHSO5eRc8rYIPEVtY2 # for edit distance approximation
# !gdown --id 1hQSHR-oeuS9bDVE6ABHS0SoI4xk3zPnB # for closest string retrieval
# !gdown --id 1ukvUI6gUTbcBZEzTVDpskrX8e6EHqVQg # for hierarchical clustering

In [None]:
class LinearEncoder(nn.Module):
    """  Linear model which simply flattens the sequence and applies a linear transformation. """

    def __init__(self, len_sequence, embedding_size, alphabet_size=4):
        super(LinearEncoder, self).__init__()
        self.encoder = nn.Linear(in_features=alphabet_size * len_sequence, 
                                 out_features=embedding_size)

    def forward(self, sequence):
        # flatten sequence and apply layer
        B = sequence.shape[0]
        sequence = sequence.reshape(B, -1)
        emb = self.encoder(sequence)
        return emb


class PairEmbeddingDistance(nn.Module):
    """ Wrapper model for a general encoder, computes pairwise distances and applies projections """

    def __init__(self, embedding_model, embedding_size, scaling=False):
        super(PairEmbeddingDistance, self).__init__()
        self.hyperbolic_metric = PoincareBall(embedding_size).metric.dist
        self.embedding_model = embedding_model
        self.radius = nn.Parameter(torch.Tensor([1e-2]), requires_grad=True)
        self.scaling = nn.Parameter(torch.Tensor([1.]), requires_grad=True)

    def normalize_embeddings(self, embeddings):
        """ Project embeddings to an hypersphere of a certain radius """
        min_scale = 1e-7
        max_scale = 1 - 1e-3
        return F.normalize(embeddings, p=2, dim=1) * self.radius.clamp_min(min_scale).clamp_max(max_scale)

    def encode(self, sequence):
        """ Use embedding model and normalization to encode some sequences. """
        enc_sequence = self.embedding_model(sequence)
        enc_sequence = self.normalize_embeddings(enc_sequence)
        return enc_sequence

    def forward(self, sequence):
        # flatten couples
        (B, _, N, _) = sequence.shape
        sequence = sequence.reshape(2 * B, N, -1)

        # encode sequences
        enc_sequence = self.encode(sequence)

        # compute distances
        enc_sequence = enc_sequence.reshape(B, 2, -1)
        distance = self.hyperbolic_metric(enc_sequence[:, 0], enc_sequence[:, 1])
        distance = distance * self.scaling

        return distance

In [None]:
def train(model, loader, optimizer, loss, device):
    avg_loss = AverageMeter()
    model.train()

    for sequences, labels in loader:
        # move examples to right device
        sequences, labels = sequences.to(device), labels.to(device)

        # forward propagation
        optimizer.zero_grad()
        output = model(sequences)

        # loss and backpropagation
        loss_train = loss(output, labels)
        loss_train.backward()
        optimizer.step()

        # keep track of average loss
        avg_loss.update(loss_train.data.item(), sequences.shape[0])

    return avg_loss.avg


def test(model, loader, loss, device):
    avg_loss = AverageMeter()
    model.eval()

    for sequences, labels in loader:
        # move examples to right device
        sequences, labels = sequences.to(device), labels.to(device)

        # forward propagation and loss computation
        output = model(sequences)
        loss_val = loss(output, labels).data.item()
        avg_loss.update(loss_val, sequences.shape[0])

    return avg_loss.avg

In [None]:
EMBEDDING_SIZE = 128

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(2021)
if device == 'cuda':
    torch.cuda.manual_seed(2021)

# load data
datasets = load_edit_distance_dataset('./edit_qiita_large.pkl')
loaders = get_dataloaders(datasets, batch_size=128, workers=1)

# model, optimizer and loss
encoder = LinearEncoder(152, EMBEDDING_SIZE)
model = PairEmbeddingDistance(embedding_model=encoder, embedding_size=EMBEDDING_SIZE)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
loss = nn.MSELoss()

# training
for epoch in range(0, 21):
    t = time.time()
    loss_train = train(model, loaders['train'], optimizer, loss, device)
    loss_val = test(model, loaders['val'], loss, device)

    # print progress
    if epoch % 5 == 0:
        print('Epoch: {:02d}'.format(epoch),
              'loss_train: {:.6f}'.format(loss_train),
              'loss_val: {:.6f}'.format(loss_val),
              'time: {:.4f}s'.format(time.time() - t))
      
# testing
for dset in loaders.keys():
    avg_loss = test(model, loaders[dset], loss, device)
    print('Final results {}: loss = {:.6f}'.format(dset, avg_loss))

In [None]:
from hierarchical_clustering.unsupervised.unsupervised import hierarchical_clustering_testing

hierarchical_clustering_testing(encoder_model=model, data_path='./hc_qiita_large_extr.pkl',
                                batch_size=128, device=device, distance='hyperbolic')

In [None]:
from multiple_alignment.guide_tree.guide_tree import approximate_guide_trees

# performs neighbour joining algorithm on the estimate of the pairwise distance matrix
approximate_guide_trees(encoder_model=model, dataset=datasets['test'],
                        batch_size=128, device=device, distance='hyperbolic')

# Command line clustalw using the tree generated with the previous command. 
# The substitution matrix and gap penalties are set to simulate the classical edit distance used to train the model 
!clustalw -infile="sequences.fasta" -dnamatrix=multiple_alignment/guide_tree/matrix.txt -transweight=0 -type='DNA' -gapopen=1 -gapext=1 -gapdist=10000 -usetr