In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch
from torch import nn, optim
import torch.nn.init as init
from torch.nn import functional as F
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
from scipy.optimize import linear_sum_assignment
from sklearn.metrics.cluster import normalized_mutual_info_score

import os

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP

import pytorch_lightning as pl

In [None]:
from typing import Tuple
import math

import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as distributed

from einops import rearrange, repeat

def is_protein(seq, aa_list):
    """
    Check if a str corresponds to a protein sequence
    return bool
    """
    for aa in seq:
        if aa not in aa_list:
            return False
    return True

def l_out_cnn1d(L_in:int,K:int,S:int,P:int,D:int=1) -> float:
    '''Formula to find the L_out dimension of an input (dim=L_in)
    in cnn_1d.'''
    return (L_in+2*P-D*(K-1)-1)/S + 1

def find_optimal_cnn1d_padding(L_in:int,K,S:int) -> Tuple[int,int]:
    '''Find the minimal padding giving the kernel size K and stride S 
    for a CNN1D without losing any piece of information.'''
    P=0
    L_out = l_out_cnn1d(L_in,K,S,P)

    assert L_in>=K, 'Kernel size higher than input dimension, the conv1d will not work'

    while not L_out.is_integer() and 2*P<=S:
        L_out = l_out_cnn1d(L_in,K,S,P)
        P+=1

    if 2*P>=S: P-=1
    return math.floor(L_out), P

def l_out_cnn1d_transpose(L_in:int,K:int,S:int,P:int,D:int=1) -> int:
    '''Formula to find the L_out dimension of an input (dim=L_in)
    in cnn_1d.'''
    return (L_in-1)*S -2*P + D*(K-1) + 1

def find_out_padding_cnn1d_transpose(L_obj:int,L_in:int,K:int,S:int,P:int) -> int:
    '''Find the minimal output padding giving the kernel size K and stride S 
    to add after a CNN1D transpose layer to reach L_obj (objective).'''
    L_out = l_out_cnn1d_transpose(L_in,K,S,P)
    assert L_obj>=L_out, 'Make sure the padding is correct, the ouput \
            of the CNN1D transpose is larger than expeceted'
    return L_obj-L_out

# From the enhancing VQ (https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py)
# Copyright (c) 2020 Phil Wang (MIT Licenced)

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def noop(*args, **kwargs):
    pass

def l2norm(t):
    return F.normalize(t, p = 2, dim = -1)

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def uniform_init(*shape):
    t = torch.empty(shape)
    nn.init.kaiming_uniform_(t)
    return t

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
    if temperature == 0:
        return t.argmax(dim = dim)

    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

def ema_inplace(moving_avg, new, decay):
    moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))

def laplace_smoothing(x, n_categories, eps = 1e-5):
    return (x + eps) / (x.sum() + n_categories * eps)

def sample_vectors(samples, num):
    num_samples, device = samples.shape[0], samples.device
    if num_samples >= num:
        indices = torch.randperm(num_samples, device = device)[:num]
    else:
        indices = torch.randint(0, num_samples, (num,), device = device)

    return samples[indices]

def batched_sample_vectors(samples, num):
    return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0)

def pad_shape(shape, size, dim = 0):
    return [size if i == dim else s for i, s in enumerate(shape)]

def sample_multinomial(total_count, probs):
    device = probs.device
    probs = probs.cpu()

    total_count = probs.new_full((), total_count)
    remainder = probs.new_ones(())
    sample = torch.empty_like(probs, dtype = torch.long)

    for i, p in enumerate(probs):
        s = torch.binomial(total_count, p / remainder)
        sample[i] = s
        total_count -= s
        remainder -= p

    return sample.to(device)

def all_gather_sizes(x, dim):
    size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device)
    all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
    distributed.all_gather(all_sizes, size)
    return torch.stack(all_sizes)

def all_gather_variably_sized(x, sizes, dim = 0):
    rank = distributed.get_rank()
    all_x = []

    for i, size in enumerate(sizes):
        t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
        distributed.broadcast(t, src = i, async_op = True)
        all_x.append(t)

    distributed.barrier()
    return all_x

def sample_vectors_distributed(local_samples, num):
    local_samples = rearrange(local_samples, '1 ... -> ...')

    rank = distributed.get_rank()
    all_num_samples = all_gather_sizes(local_samples, dim = 0)

    if rank == 0:
        samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
    else:
        samples_per_rank = torch.empty_like(all_num_samples)

    distributed.broadcast(samples_per_rank, src = 0)
    samples_per_rank = samples_per_rank.tolist()

    local_samples = sample_vectors(local_samples, samples_per_rank[rank])
    all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0)
    out = torch.cat(all_samples, dim = 0)

    return rearrange(out, '... -> 1 ...')

def batched_bincount(x, *, minlength):
    batch, dtype, device = x.shape[0], x.dtype, x.device
    target = torch.zeros(batch, minlength, dtype = dtype, device = device)
    values = torch.ones_like(x)
    target.scatter_add_(-1, x, values)
    return target

def kmeans(
    samples,
    num_clusters,
    num_iters = 10,
    use_cosine_sim = False,
    sample_fn = batched_sample_vectors,
    all_reduce_fn = noop
):
    num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device

    means = sample_fn(samples, num_clusters)

    for _ in range(num_iters):
        if use_cosine_sim:
            dists = samples @ rearrange(means, 'h n d -> h d n')
        else:
            dists = -torch.cdist(samples, means, p = 2)

        buckets = torch.argmax(dists, dim = -1)
        bins = batched_bincount(buckets, minlength = num_clusters)
        all_reduce_fn(bins)

        zero_mask = bins == 0
        bins_min_clamped = bins.masked_fill(zero_mask, 1)

        new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype)

        new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples)
        new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1')
        all_reduce_fn(new_means)

        if use_cosine_sim:
            new_means = l2norm(new_means)

        means = torch.where(
            rearrange(zero_mask, '... -> ... 1'),
            means,
            new_means
        )

    return means, bins

def batched_embedding(indices, embeds):
    batch, dim = indices.shape[1], embeds.shape[-1]
    indices = repeat(indices, 'h b n -> h b n d', d = dim)
    embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
    return embeds.gather(2, indices)

# regularization losses

def orthogonal_loss_fn(t):
    # eq (2) from https://arxiv.org/abs/2112.00384
    h, n = t.shape[:2]
    normed_codes = l2norm(t)
    cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
    return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n)


In [None]:
from einops import rearrange
import pandas as pd
import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.distributed as distributed
from torch.cuda.amp import autocast

class CosineSimCodebook(nn.Module):
    def __init__(
        self,
        dim,
        codebook_size,
        num_codebooks = 1,
        kmeans_init = False,
        kmeans_iters = 10,
        sync_kmeans = True,
        decay = 0.8,
        eps = 1e-5,
        threshold_ema_dead_code = 3,
        use_ddp = False,
        learnable_codebook = False,
        sample_codebook_temp = 0.
    ):
        super().__init__()
        self.decay = decay

        if not kmeans_init:
            embed = l2norm(uniform_init(num_codebooks, codebook_size, dim))
        else:
            embed = torch.zeros(num_codebooks, codebook_size, dim)

        self.codebook_size = codebook_size
        self.num_codebooks = num_codebooks

        self.kmeans_iters = kmeans_iters
        self.eps = eps
        self.threshold_ema_dead_code = threshold_ema_dead_code
        self.sample_codebook_temp = sample_codebook_temp

        self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
        self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
        self.all_reduce_fn = distributed.all_reduce if use_ddp else noop

        self.register_buffer('initted', torch.Tensor([not kmeans_init]))
        self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))

        self.learnable_codebook = learnable_codebook
        if learnable_codebook:
            self.embed = nn.Parameter(embed)
        else:
            self.register_buffer('embed', embed)

    @torch.jit.ignore
    def init_embed_(self, data):
        if self.initted:
            return

        embed, cluster_size = kmeans(
            data,
            self.codebook_size,
            self.kmeans_iters,
            use_cosine_sim = True,
            sample_fn = self.sample_fn,
            all_reduce_fn = self.kmeans_all_reduce_fn
        )

        self.embed.data.copy_(embed)
        self.cluster_size.data.copy_(cluster_size)
        self.initted.data.copy_(torch.Tensor([True]))

    def replace(self, batch_samples, batch_mask):
        batch_samples = l2norm(batch_samples)

        for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))):
            if not torch.any(mask):
                continue

            sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
            self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...')

    def expire_codes_(self, batch_samples):
        if self.threshold_ema_dead_code == 0:
            return

        expired_codes = self.cluster_size < self.threshold_ema_dead_code

        if not torch.any(expired_codes):
            return

        batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
        self.replace(batch_samples, batch_mask = expired_codes)

    @autocast(enabled = False)
    def forward(self, x):
        needs_codebook_dim = x.ndim < 4

        x = x.float()

        if needs_codebook_dim:
            x = rearrange(x, '... -> 1 ...')

        shape, dtype = x.shape, x.dtype

        flatten = rearrange(x, 'h ... d -> h (...) d')
        flatten = l2norm(flatten)

        self.init_embed_(flatten)

        embed = self.embed if not self.learnable_codebook else self.embed.detach()
        embed = l2norm(embed)

        dist = einsum('h n d, h c d -> h n c', flatten, embed)
        embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
        embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
        embed_ind = embed_ind.view(*shape[:-1])

        quantize = batched_embedding(embed_ind, self.embed)

        if self.training:
            bins = embed_onehot.sum(dim = 1)
            self.all_reduce_fn(bins)

            ema_inplace(self.cluster_size, bins, self.decay)

            zero_mask = (bins == 0)
            bins = bins.masked_fill(zero_mask, 1.)

            embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
            self.all_reduce_fn(embed_sum)

            embed_normalized = embed_sum / rearrange(bins, '... -> ... 1')
            embed_normalized = l2norm(embed_normalized)

            embed_normalized = torch.where(
                rearrange(zero_mask, '... -> ... 1'),
                embed,
                embed_normalized
            )

            ema_inplace(self.embed, embed_normalized, self.decay)
            self.expire_codes_(x)

        if needs_codebook_dim:
            quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))

        return quantize, embed_ind


class VectorQuantize(nn.Module):
    def __init__(
        self,
        dim,
        codebook_size,
        codebook_dim,
        heads = 1,
        separate_codebook_per_head = False,
        decay = 0.8,
        eps = 1e-5,
        kmeans_init = True,
        kmeans_iters = 10,
        sync_kmeans = True,
        threshold_ema_dead_code = 3,
        commitment_weight = 1.,
        orthogonal_reg_weight = 0.,
        orthogonal_reg_active_codes_only = False,
        orthogonal_reg_max_codes = None,
        sample_codebook_temp = 0.,
        sync_codebook = False
    ):
        super().__init__()
        self.heads = heads
        self.separate_codebook_per_head = separate_codebook_per_head

        codebook_dim = default(codebook_dim, dim)
        codebook_input_dim = codebook_dim * heads

        requires_projection = codebook_input_dim != dim
        self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
        self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()

        self.eps = eps
        self.commitment_weight = commitment_weight

        has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
        self.orthogonal_reg_weight = orthogonal_reg_weight
        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes

        codebook_class = CosineSimCodebook

        self._codebook = codebook_class(
            dim = codebook_dim,
            num_codebooks = heads if separate_codebook_per_head else 1,
            codebook_size = codebook_size,
            kmeans_init = kmeans_init,
            kmeans_iters = kmeans_iters,
            sync_kmeans = sync_kmeans,
            decay = decay,
            eps = eps,
            threshold_ema_dead_code = threshold_ema_dead_code,
            use_ddp = sync_codebook,
            learnable_codebook = has_codebook_orthogonal_loss,
            sample_codebook_temp = sample_codebook_temp
        )

        self.codebook_size = codebook_size

    @property
    def codebook(self):
        codebook = self._codebook.embed
        if self.separate_codebook_per_head:
            return codebook

        return rearrange(codebook, '1 ... -> ...')

    def forward(self, x,):
        shape, device, heads, is_multiheaded, codebook_size = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size

        x = self.project_in(x)

        if is_multiheaded:
            ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d'
            x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads)

        quantize, embed_ind = self._codebook(x)

        if self.training:
            quantize = x + (quantize - x).detach()

        
        detached_inputs = x.detach()
        loss = F.mse_loss(quantize, detached_inputs, reduction='none')
        loss_pbe = torch.mean(loss, dim=(1,2)) # (batch_size)

        if self.commitment_weight > 0:
            detached_quantize = quantize.detach()
            commit_loss = F.mse_loss(detached_quantize, x, reduction='none')

            loss_pbe = loss_pbe + torch.mean(commit_loss * self.commitment_weight, dim=(1,2)) # (batch_size)

        if is_multiheaded:
            if self.separate_codebook_per_head:
                quantize = rearrange(quantize, 'h b n d -> b n (h d)', h = heads)
                embed_ind = rearrange(embed_ind, 'h b n -> b n h', h = heads)
            else:
                quantize = rearrange(quantize, '1 (b h) n d -> b n (h d)', h = heads)
                embed_ind = rearrange(embed_ind, '1 (b h) n -> b n h', h = heads)

        quantize_latent = quantize.detach().clone()
        quantize = self.project_out(quantize)

        avg_probs = torch.mean(F.one_hot(embed_ind, self.codebook_size).type(torch.float32).view((-1, self.codebook_size)), 0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return {
        'quantize_projected_in': x, # (batch_size, l_r, codebook_dim)
        'quantize_latent': quantize_latent, # (batch_size, l_r, codebook_dim)
        'quantize_projected_out': quantize, # (batch_size, l_r, dim)
        'loss_vq_commit_pbe': loss_pbe, # (batch_size)
        'perplexity': perplexity, # (batch_size)
        'encoding_indices': embed_ind # (batch_size, l_r)
    } 



In [None]:
from typing import Tuple
import numpy as np
import math
import random
import pandas as pd
from pandas.api.types import CategoricalDtype

from Bio import SeqIO
import torch

alphabet = ['A', 'C', 'D', 'E', 'F', 'G','H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '-']

def data_loader_masking_bert_onehot_fasta(seq: list, batch_size: int, perc_masked_residues: float, 
                                          is_masking: bool) -> torch.utils.data.DataLoader:
    ''' 
    Generate a Torch dataloader iterator from fp_data.

    Parameters
    ----------
    fp_data: str
        Filepath to the fasta file with the sequences of interest.
    batch_size: int
    perc_masked_residues: float
        Ratio of residues to apply the BERT masking on (between 0 and 1).
    is_masking: bool

    '''
    iterator = IterableMaskingBertOnehotDatasetFasta(seq, perc_masked_residues=perc_masked_residues, is_masking=is_masking)
    loader = torch.utils.data.DataLoader(iterator, batch_size=batch_size, num_workers=0, shuffle=is_masking)
    return loader 


class IterableMaskingBertOnehotDatasetFasta(torch.utils.data.IterableDataset):
    '''
    BERT-style masking onehot generator for all sequences given a fasta file.
    '''
    def __init__(self, seq, perc_masked_residues=0.0, is_masking=False):
        self.seq = seq
        self.perc_masked_residues = perc_masked_residues
        self.is_masking = is_masking

    def __iter__(self) -> torch.utils.data.IterableDataset:
        for sequence in self.seq:
            yield torch_masking_BERT_onehot(sequence, perc_masked_residues=self.perc_masked_residues,is_masking=self.is_masking)


        
def torch_masking_BERT_onehot(seq: str, perc_masked_residues: float=0.0, 
                              is_masking: bool=False, alphabet: list=alphabet) -> Tuple[torch.Tensor, torch.Tensor]:
    '''
    BERT-style masking on a one-hot encoding input. When a residue is masked, it is replaced 
    by the dummie vector [1/21,...,1/21] of size 21. 80% of perc_masked_residues are masked, 
    10% are replaced by another residue, 10% are left as they are.

    Parameters
    ----------
    seq: str
    perc_masked_residues: float
        Ratio of residues to apply the BERT masking on (between 0 and 1).
    is_masking: bool
        False for evaluation.
    alphabet: list 
        List of string of the alphabet of residues used in the one hot encoder

    Returns
    -------
    onehot_seq: tensor
        One hot encoded input.
    m_tf_onehot_seq: tensor 
        BERT masked one hot encoded input.

    '''

    alphabet = ['A', 'C', 'D', 'E', 'F', 'G','H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '-']

    # One Hot Encoding
    onehot_seq = np.array((pd.get_dummies(pd.Series(list(seq)).astype(CategoricalDtype(categories=alphabet))))).astype(float)
    onehot_seq = torch.tensor(onehot_seq, dtype=torch.float32)
    ln_seq = len(onehot_seq)

    m_tf_onehot_seq = onehot_seq.clone().detach()

    if is_masking:
        if perc_masked_residues > 1:
            raise NotImplementedError('Masking percentage should be between 0 and 1.')

        # the onehot vector of the masked residue
        len_alphabet = len(alphabet)
        masked_letter = [1/len_alphabet]*len_alphabet

        # MASKING
        nb_masking = math.floor(ln_seq * perc_masked_residues)
        nb_to_mask = math.floor(nb_masking*0.8) #80% replace with mask token
        nb_to_replace = math.floor(nb_masking*0.1) #10% replace with random residue

        if nb_to_mask != 0:

            rd_ids = torch.Tensor(random.sample(range(ln_seq),ln_seq)[:nb_to_mask+nb_to_replace]).type(torch.int64)

            rd_alphabet_selection_to_replace = random.choices(alphabet, k=nb_to_replace)
            dummies_to_replace =  np.array((pd.get_dummies(pd.Series(rd_alphabet_selection_to_replace).astype(CategoricalDtype(categories=alphabet)))))

            updates = np.array([masked_letter]*nb_to_mask)
            updates = torch.Tensor(np.concatenate((updates,dummies_to_replace)))

            m_tf_onehot_seq[rd_ids] = updates

    return onehot_seq, m_tf_onehot_seq             


In [None]:
from einops.layers.torch import Rearrange

class PositionalEncoding(nn.Module):
    def __init__(self, d_embedding, max_len):
        super(PositionalEncoding, self).__init__()

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_embedding, 2) * (-math.log(10000.0) / d_embedding))
        pe = torch.zeros(max_len, d_embedding)

        # apply sin to even indices in the array; 2i
        pe[:, 0::2] = torch.sin(position * div_term)

        # apply cos to odd indices in the array; 2i+1
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, input_seq_len, d_embedding]
        """
        x = x + self.pe[:x.size(1)]
        return x

class MHAEncoderBlock(nn.Module):
  def __init__(self, d_embedding, num_heads, d_ff, dropout):
    super(MHAEncoderBlock, self).__init__()

    self.self_MHA = torch.nn.MultiheadAttention(d_embedding, num_heads, batch_first=True)

    self.MLperceptron = nn.Sequential(
            nn.Linear(d_embedding, d_ff),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(d_ff, d_embedding))

    self.layernorm1 = nn.LayerNorm(d_embedding, eps=1e-6)
    self.layernorm2 = nn.LayerNorm(d_embedding, eps=1e-6)

    self.dropout = nn.Dropout(dropout)

  def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
      x: Tensor, shape [batch_size, input_seq_len, d_embedding]
    """
    # Attention
    attn_output, attn_output_weights = self.self_MHA(x, x, x)  # (batch_size, input_seq_len, d_embedding)
    x = x + self.dropout(attn_output)
    x = self.layernorm1(x)

    # MLP 
    linear_output = self.MLperceptron(x) 
    x = x + self.dropout(linear_output)
    x = self.layernorm2(x) # (batch_size, input_seq_len, d_embedding) + residual 

    return x, attn_output_weights
    

class Encoder(nn.Module):
  def __init__(self, d_embedding, kernel, stride, num_heads, num_mha_layers, d_ff,
              length_seq, alphabet_size, dropout=0):
    super(Encoder, self).__init__()

    # CNN1d embedding
    self.l_red, self.padding = find_optimal_cnn1d_padding(L_in=length_seq, K=kernel, S=stride)
    self.cnn_embedding =  nn.Sequential(Rearrange('b l r -> b r l'),
                nn.Conv1d(alphabet_size, d_embedding, kernel_size=kernel, stride=stride, padding=self.padding),
                Rearrange('b r l -> b l r'))

    # Positional encoding
    self.en_pos_encoding = PositionalEncoding(d_embedding, max_len=self.l_red)
    self.en_dropout = nn.Dropout(dropout)

    # MHA blocks
    self.en_MHA_blocks = nn.ModuleList([MHAEncoderBlock(d_embedding, num_heads, d_ff, dropout)
                       for _ in range(num_mha_layers)])

  def forward(self, x) -> torch.Tensor: 
    """
    Args:
      x: Tensor, shape [batch_size, input_seq_len, alphabet_size]
    """
    # CNN1d Embedding
    h = self.cnn_embedding(x) # (batch_size, l_red, d_embedding)

    # Positional encoding
    h = self.en_pos_encoding(h) 
    h = self.en_dropout(h) 

    # MHA blocks
    for i, l in enumerate(self.en_MHA_blocks):
      h, attn_enc_weights = self.en_MHA_blocks[i](h) # (batch_size, l_red, d_embedding)
    
    return h


class Decoder(nn.Module):
  def __init__(self, d_embedding, kernel, stride, num_heads, num_mha_layers, d_ff,
                  length_seq, alphabet_size, dropout=0):
    super(Decoder, self).__init__()

    # Positional encoding
    self.l_red, self.padding = find_optimal_cnn1d_padding(L_in=length_seq, K=kernel, S=stride)
    self.de_pos_encoding = PositionalEncoding(d_embedding, max_len=self.l_red)
    self.de_dropout = nn.Dropout(dropout)

    # MHA blocks
    self.de_MHA_blocks = nn.ModuleList([MHAEncoderBlock(d_embedding, num_heads, d_ff, dropout)
                       for _ in range(num_mha_layers)])

    # Dense reconstruction
    self.dense_to_alphabet = nn.Linear(d_embedding, alphabet_size)
    self.dense_reconstruction = nn.Linear(alphabet_size*self.l_red, length_seq*alphabet_size)

    # CNN1d reconstruction
    self.out_pad = find_out_padding_cnn1d_transpose(L_obj=length_seq, L_in=self.l_red, K=kernel, S=stride, P=self.padding)
    self.cnn_reconstruction =  nn.Sequential(Rearrange('b l r -> b r l'),
                nn.ConvTranspose1d(d_embedding, alphabet_size, kernel_size=kernel, stride=stride, 
                              padding=self.padding, output_padding=self.out_pad),
                Rearrange('b r l -> b l r'))
    
  
  def forward(self, q) -> torch.Tensor:
    """
    Args:
      q: Tensor, shape [batch_size, l_red, d_embedding]
    """
    # Positional encoding
    z = self.de_pos_encoding(q) 
    z = self.de_dropout(z) 

    # MHA blocks
    for i, l in enumerate(self.de_MHA_blocks):
      z, attn_dec_weights = self.de_MHA_blocks[i](z) # (batch_size, l_red, d_embedding)
      
    # CNN reconstruction 
    z = self.cnn_reconstruction(z) # (batch_size, input_seq_len, alphabet_size)
    z_recon = F.softmax(z, dim=-1)

    return z_recon


class AbNatiV_Model(pl.LightningModule):
  def __init__(self, hparams: dict):
    super(AbNatiV_Model, self).__init__()

    self.encoder = Encoder(hparams['d_embedding'], hparams['kernel'], hparams['stride'], hparams['num_heads'], 
                            hparams['num_mha_layers'], hparams['d_ff'], hparams['length_seq'], 
                            hparams['alphabet_size'], dropout=hparams['drop'])

    self.decoder = Decoder(hparams['d_embedding'], hparams['kernel'], hparams['stride'], hparams['num_heads'], 
                            hparams['num_mha_layers'], hparams['d_ff'], hparams['length_seq'], 
                            hparams['alphabet_size'], dropout=hparams['drop'])

    self.vqvae = VectorQuantize(
            dim=hparams['d_embedding'],
            codebook_size=hparams['num_embeddings'],
            codebook_dim=hparams['embedding_dim_code_book'],
            decay=hparams['decay'],
            kmeans_init=True,
            commitment_weight=hparams['commitment_cost']
            )

    self.learning_rate = hparams['learning_rate']
    self.save_hyperparameters()


  def forward(self, data) -> dict:
    inputs = data[:][0][:][:]
    m_inputs = data[:][1][:][:]


    x = self.encoder(m_inputs)
    vq_outputs = self.vqvae(x)
    x_recon = self.decoder(vq_outputs['quantize_projected_out'])

    # Loss computing 
    recon_error_pres_pposi = F.mse_loss(x_recon, inputs, reduction='none')
    recon_error_pposi = torch.mean(recon_error_pres_pposi, dim=-1)
    recon_error_pbe = torch.mean(recon_error_pposi, dim=1)

    loss_pbe = torch.add(recon_error_pbe, vq_outputs['loss_vq_commit_pbe'])

    return {
        'inputs': inputs, # (batch_size, input_seq_len, alphabet_size)
        'x_recon': x_recon, # (batch_size, input_seq_len, alphabet_size)
        'recon_error_pres_pposi': recon_error_pres_pposi, # (batch_size, input_seq_len, alphabet_size)
        'recon_error_pposi': recon_error_pposi, # (batch_size, input_seq_len)
        'recon_error_pbe': recon_error_pbe, # (batch_size)
        'loss_pbe': loss_pbe, # (batch_size)
        **vq_outputs
    }

  def configure_optimizers(self):
    optim_groups = list(self.encoder.parameters()) + \
                    list(self.decoder.parameters()) + \
                    list(self.vqvae.parameters()) 

    return torch.optim.AdamW(optim_groups, lr=self.learning_rate)

  def training_step(self, batch, batch_idx) -> torch.float32:
    vqvae_output = self(batch)

    loss_vqvae = torch.mean(vqvae_output['loss_pbe'])
    self.log("train_loss_vqvae", loss_vqvae, on_step=True, prog_bar=True, logger=True)

    loss_vq_commit = torch.mean(vqvae_output['loss_vq_commit_pbe'])
    self.log("train_loss_vq_commit", loss_vq_commit, on_step=True, prog_bar=True, logger=True)

    nmse_accuracy = torch.mean(vqvae_output['recon_error_pbe'])
    self.log("train_loss_nmse_recons", nmse_accuracy, on_step=True, prog_bar=True, logger=True)

    perplexity = vqvae_output['perplexity']
    self.log("train_perplexity", perplexity, on_step=True, prog_bar=True, logger=True)

    return loss_vqvae

  def validation_step(self, batch, batch_idx) -> dict:
    model_output = self(batch)
    return {'val_loss': torch.mean(model_output['loss_pbe']), 'model_output': model_output}

  def on_validation_epoch_end(self, outputs) -> dict:

    val_losses = torch.Tensor([out['val_loss'] for out in outputs])
    total_val_loss = torch.mean(val_losses)
    self.log('val_loss', total_val_loss, on_epoch=True, logger=True)

    val_accuracies = torch.Tensor([torch.mean(out['model_output']['recon_error_pbe']) for out in outputs])
    total_val_accuracy = torch.mean(val_accuracies)
    self.log('val_nmse_accuracy', total_val_accuracy, on_epoch=True, logger=True)

    val_perplexities = torch.Tensor([out['model_output']['perplexity'] for out in outputs])
    total_val_perplexity = torch.mean(val_perplexities)
    self.log('val_perplexity', total_val_perplexity, on_epoch=True, logger=True)
    
    return {'val_loss': total_val_loss, 'val_nmse_accuracy': total_val_accuracy, 'val_perplexity': total_val_perplexity}




In [None]:
class HParams(object):
    def __init__(self):
        self.alphabet_size=21
        self.batch_size=64
        self.commitment_cost=2 # In the loss function
        self.d_embedding=768 # assert d_embedding % num_heads == 0
        self.d_ff=128 # Hidden layer dimension of point wise feed forward network
        self.decay=0.90 # This is only used for EMA updates.
        self.drop=0
        self.embedding_dim_code_book=64
        self.kernel=8
        self.learning_rate=4.0e-05
        self.length_seq=125
        self.limit_val_batches=400
        self.max_epochs=15
        self.num_embeddings=512
        self.num_heads=8
        self.num_mha_layers=1
        self.perc_masked_residues=0.15
        self.run_name='abnativ_v1'
        self.stride=8


In [None]:
from itertools import chain

class VDJ_dataset(torch.utils.data.Dataset):
    def __init__(self, vdj, labels, task, max_len = None):
        self.task = task
        self.max_len = max_len

        if max_len is None:
            self.max_len = self.get_max_len(vdj)
        else:
            self.max_len = max_len
        vdj = self.add_mask(vdj)
        self.vdj = self.encode_sequences(vdj)

        if labels is not None:
            if self.task == 'binary.classification' or self.task == 'multiclass.classification':
                self.label_to_ix = dict(zip(set(labels), range(len(set(labels)))))
                self.ix_to_label =  dict(zip(range(len(set(labels))), set(labels)))
                self.labels = [self.label_to_ix[label] for label in labels]
                self.n_classes = len(np.unique(self.labels))
                print('Found {} classes for {}'.format(self.n_classes, self.task))

                #self.class_weights = self.make_weights_for_balanced_classes(self.labels)
                #self.labels = [F.one_hot(torch.tensor(label), self.n_classes).float() for label in self.labels]
                self.labels = [torch.tensor(label).int() for label in self.labels]

            elif self.task == 'gex':
                self.label_to_ix = None
                self.ix_to_label = None
                self.labels = np.array(labels)
                self.n_classes = self.labels.shape[-1]        #To do: Add gex normalization!
                print('Found {} genes for {}'.format(self.n_classes, self.task))
                
                self.labels = [torch.Tensor(label) for label in self.labels]


            elif self.task == 'regression':
                self.label_to_ix = None
                self.ix_to_label = None
                self.labels = np.array(labels)
                self.labels = self.z_score(self.min_max(self.labels))
                
        else:
            self.label_to_ix = None
            self.ix_to_label = None
            self.labels = None
            self.n_classes =  None

    def __len__(self):
        return len(self.vdj)

    def __getitem__(self, item):
        vdj = self.vdj[item]
        label = self.labels[item]

        return (vdj, label)
      
    def get_max_len(self, sequences):
        current_len = 0
        for seq in sequences:
            current_len = max(current_len, len(seq))

        return current_len
    
    def add_mask(self, sequences, mask_token = '-'):
        padded = []
        for sequence in sequences:
            padding = [mask_token] * (self.max_len - len(sequence))
            sequence += ''.join(padding)
            padded.append(sequence)
        
        return padded
        
    def encode_sequences(self, sequences):
        ids = [list(set(i)) for i in sequences]
        ids = set(chain(*ids))
        self.seq_to_ids = dict(zip( ids, list(range(len(ids)))))
        self.ids_to_seq = dict(zip( list(range(len(ids))), ids ))
        self.n_residues = len(ids)

        sequences = [F.one_hot(torch.Tensor([self.seq_to_ids[residue] for residue in sequence]).to(torch.int64), num_classes = self.n_residues).float() for sequence in sequences]
        sequences = [F.pad(sequence, (0,0,0,self.max_len - sequence.size(0))) for sequence in sequences]

        return sequences
    
    def min_max(self, scores):
        return (scores - scores.min()) / (scores.max() - scores.min())

    def z_score(self, scores):
        scores = np.array(scores)
        out = (scores - scores.mean()) / scores.std()
        return scores.tolist()

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
VDJ = pd.read_csv('./data/VDJ_IgG_all.csv')
VDJ = VDJ.drop_duplicates(['pasted_cdr3'])
seq = VDJ['pasted_cdr3'].tolist()
label = VDJ['antigen'].tolist()
dataset = VDJ_dataset(seq, label, task = 'multiclass.classification')
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False)
seq = dataset.add_mask(seq)
seq_len = dataset.max_len
vocab = dataset.n_residues

In [None]:
from Bio import SeqIO
import os 
files = os.listdir('./data/fasta')
sequences = []
labels = []
numbers = ['1','2','3','4','5','6','7','8','9','0'] 
for input_file in files:
    fasta_sequences = SeqIO.parse(open('./data/fasta/'+input_file),'fasta')
    for fasta in fasta_sequences:
        sequence = str(fasta.seq)
        if(any(ext in sequence for ext in numbers) is False):
            sequences.append(sequence)
        
        if input_file.startswith('antig1'):
            labels.append('TNFR2')
        else:
            labels.append('OVA')

_, ids = np.unique(sequences, return_index =True)
sequences = np.array(sequences)[ids]
labels = np.array(labels)[ids]
lens = np.array([len(s) for s in sequences])
ids = np.where(lens <= 300)
sequences = np.array(sequences)[ids].tolist()
labels = np.array(labels)[ids].tolist()

In [None]:
def select_unique(df, scores, labels):
    score_df = df.groupby('sequence')[scores].mean().reset_index()
    label_df = df.drop_duplicates(['sequence'])[labels + ['sequence']]
    df = pd.merge(score_df, label_df, on='sequence')

    return df

df = pd.read_csv('/cluster/home/tcotet/fitness_landscapes/data/all_scores_pooled.csv')
scores = ['interface_score', 'total_score', 'catalytic_score', 'interface_potential', 'total_potential', 'catalytic_potential', 'generation', 'mutations']
labels = ['score_taken_from', 'design_method', 'cat_resn', 'cat_resi', 'parent_index']
df = df[df['sequence'].notnull()]
df = select_unique(df, scores, labels)
sequences = df['sequence'].tolist()
labels = df['cat_resi'].tolist()

In [None]:
dataset = VDJ_dataset(sequences, labels, task = 'multiclass.classification')
sequences = dataset.add_mask(sequences)


In [None]:
hp = HParams()
hparams = hp.__dict__

In [None]:
train_loader = data_loader_masking_bert_onehot_fasta(sequences, hparams['batch_size'],
                        hparams['perc_masked_residues'], is_masking=False)
val_loader = data_loader_masking_bert_onehot_fasta(sequences, hparams['batch_size'],
                        perc_masked_residues=0, is_masking=False)


In [None]:
device = torch.device("cuda:0")
device = torch.device('cpu')

In [None]:
model = AbNatiV_Model(hparams).to(device)


In [None]:
optimizer = model.configure_optimizers()

In [None]:
import warnings
warnings.simplefilter("ignore")

In [None]:
torch.set_grad_enabled(True)
model.train()
losses = []
print('Started training')
for epoch in range(hparams['max_epochs']):
    for idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        
        batch[0] = batch[0].to(device)
        batch[1] = batch[1].to(device)

        # train step
        loss = model.training_step(batch, idx)

        # clear gradients
        # backward
        loss.backward()

        # update parameters
        optimizer.step()

        losses.append(loss.detach().item())
    print(f"epoch: {epoch + 1} loss {np.array(losses).mean()}", flush=True)

In [None]:
data_loader = train_loader

In [None]:
batch = next(iter(data_loader))

In [None]:
model.eval()
with torch.no_grad():
    encoded = model.encoder(batch[0])
    vq_outputs = model.vqvae(encoded)



In [None]:
alphabet = ['A', 'C', 'D', 'E', 'F', 'G','H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', '-']



In [None]:
o = [out.reshape(-1, dataset.max_len, 21) for out in outs]
idd = [torch.max(torch.tensor(out), dim = -1)[1].detach().tolist()[0] for out in o]
recons = [''.join([alphabet[j] for j in idd[i]]).replace('-','') for i in range(len(idd))]


In [None]:
outs.shape

In [None]:
out = batch[0]

In [None]:
o = [out.reshape(-1, dataset.max_len, 21) for out in outs]
idd = [torch.max(torch.tensor(out), dim = -1)[1].detach().tolist()[0] for out in o]
recons1 = [''.join([alphabet[j] for j in idd[i]]).replace('-','') for i in range(len(idd))]


In [None]:
recons1

In [None]:
o = [data.reshape(-1, dataset.max_len, 21) for out in outs]
idd = [torch.max(torch.tensor(out), dim = -1)[1].detach().tolist()[0] for out in o]
recons2 = [''.join([alphabet[j] for j in idd[i]]).replace('-','') for i in range(len(idd))]


In [None]:
recons2

In [None]:
model.eval()
encoded = torch.randn((10, D,768))
vq_outputs = model.vqvae(encoded)
outs = model.decoder(vq_outputs['quantize_projected_out'])


In [None]:
import Levenshtein as lv
dists = [lv.distance(recons[i], recons1[i]) for i in range(len(recons))]

In [None]:

outs = model.decoder(vq_outputs['quantize_projected_out'])

In [None]:
model.eval()
encoded = model.encoder(data)
vq_outputs = model.vqvae(encoded)
outs = model.decoder(vq_outputs['quantize_projected_out'])


In [None]:
encoded.shape

In [None]:
encoded.shape

In [None]:
model.eval()
N = len(sequences)
D = 16
encoded_features = np.zeros((N, D * 768))
proj_in_features = np.zeros((N, D * 64))
latent_features = np.zeros((N, D * 64))
proj_out_features = np.zeros((N, D * 768))
decoded = []
start_ind = 0
with torch.no_grad():
    for (data, labels) in data_loader:
        encoded = model.encoder(data)
        encoded_feat = encoded.reshape(-1, D*768).cpu().detach().numpy()  
        vq_outputs = model.vqvae(encoded)
        
        end_ind = min(start_ind + data.size(0), N+1)
    
        encoded_features[start_ind:end_ind] = encoded_feat
        proj_in_features[start_ind:end_ind] = vq_outputs['quantize_projected_in'].reshape(-1, D*64).cpu().detach().numpy()  
        latent_features[start_ind:end_ind] = vq_outputs['quantize_latent'].reshape(-1, D*64).cpu().detach().numpy()  
        proj_out_features[start_ind:end_ind] = vq_outputs['quantize_projected_out'].reshape(-1, D*768).cpu().detach().numpy()  

        start_ind += data.size(0)
        


In [None]:
labels = [dataset.label_to_ix[label] for label in labels]

In [None]:
#features, labs = gmvae.latent_features(train_loader, return_labels = True)
pca = PCA(n_components = 2)
tsne = TSNE(n_components = 2)
umap = UMAP()
#features = tsne.fit_transform(latent_features)
features = pca.fit_transform(proj_out_features)
#features = umap.fit_transform(latent_features)
p = ['#F75C55','#F9DA7A','#ADA59E','#FE6E34','#2219D1','#F5D7BC','#D5DCF2','#590925','#1AFFD5','#007FFF','#7D83FF']
cols = [p[lab] for lab in labels]
# plot only the first 2 dimensions
fig = plt.figure(figsize=(8, 6))
plt.scatter(features[:, 0], features[:, 1], marker='o', c = cols,
        edgecolor='none', s = 1)
plt.colorbar()

In [None]:
pl.seed_everything(11)
# Logging

# Checkpointing
logger = MLFlowLogger(experiment_name='vqvae', run_name='test1')


ckpt_root_dir = os.path.join('checkpoints', 'test1')
ckpt_callback = ModelCheckpoint(ckpt_root_dir, save_top_k=-1) # to save every epoch

trainer = pl.Trainer(max_epochs=hparams['max_epochs'], 
                     deterministic=True, accelerator='auto') 


# Training 
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
