In [1]:
import deepchem as dc
from deepchem.models.torch_models.prot_bert import ProtBERT
import torch
import torch.nn as nn
from deepchem.models.torch_models.torch_model import TorchModel
from typing import Tuple, List, Any
from transformers import AutoTokenizer



No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
  from .autonotebook import tqdm as notebook_tqdm
Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch.utils._import_utils'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'MXMNet' from 'deepchem.models.torch_models' (/home/shivasankaran/deepchem/deepchem/models/torch_models/__init__.py)
Skipped loading some Jax models, missing a dependency. No module named 'haiku'
Skipped loading some PyTorch models, missing a dependency. No module named 'tensorflow'


## Load Dataset

### Tokenizer Wrapper for ProtBERT

In this section, we define a wrapper class `ProtBERTTokenizerWrapper` for the ProtBERT tokenizer. This wrapper allows us to conveniently set the maximum sequence length and handle the necessary tokenization steps (like truncation, padding, and conversion to PyTorch tensors) when calling the tokenizer. 

We use the `AutoTokenizer.from_pretrained()` method to load the pre-trained ProtBERT tokenizer from HuggingFace, and then customize it with additional parameters such as truncation, padding, and max length.


In [3]:
# Temporary workaround in sending kwargs to HuggingFaceFeaturizer
class ProtBERTTokenizerWrapper:
    def __init__(self, model_name, max_len=128):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.max_len = max_len

    def __call__(self, sequence):
        return self.tokenizer(sequence, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')

temp_tokenizer = ProtBERTTokenizerWrapper('Rostlab/prot_bert',max_len=512)
protbert_featurizer = dc.feat.HuggingFaceFeaturizer(temp_tokenizer)

In [4]:
loader = dc.data.CSVLoader([],
                               feature_field="protein",
                               featurizer=protbert_featurizer)
dataset = loader.create_dataset("/home/shivasankaran/deepchem/plastic_degrading_train.csv")

## Models

### 1. Variational Autoencoder (VAE) for Protein Generation

In this section, we define a **Variational Autoencoder (VAE)** model, which is designed to generate protein sequences based on a latent space. The architecture of the VAE consists of several key components:

- **Latent-to-Hidden Mapping:** A fully connected layer to map the latent vector `z` to a hidden space.
- **Positional Embeddings:** An embedding layer to add positional information to the input sequence, enabling the model to learn dependencies across sequence positions.
- **Transformer Decoder:** A Transformer decoder layer (with multiple layers) that processes the sequence and latent vector for sequence generation.
- **Output Linear Layer:** A fully connected layer that generates token probabilities for the vocabulary.

The forward pass uses the latent vector `z` to generate a hidden representation, which is then passed through the Transformer decoder to produce the output sequence.


In [5]:

class VAE(nn.Module):
    def __init__(self, latent_dim, hidden_dim, vocab_size, num_heads, num_layers, max_seq_len):
        super(VAE, self).__init__()
        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim)

        # Positional embeddings for input sequences
        self.position_embeddings = nn.Embedding(max_seq_len, hidden_dim)

        # Transformer Decoder components
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=4 * hidden_dim,
            dropout=0.1,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)

        # Linear layer for output token probabilities
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, z, seq_len):
        # Map latent vector to the hidden dimension
        hidden = self.latent_to_hidden(z)  # [batch_size, hidden_dim]

        # Repeat hidden state across sequence length
        hidden_seq = hidden.unsqueeze(1).repeat(1, seq_len, 1)  # [batch_size, seq_len, hidden_dim]

        # Add positional embeddings
        positions = torch.arange(seq_len, device=z.device).unsqueeze(0).expand(z.size(0), seq_len)
        hidden_seq += self.position_embeddings(positions)  # [batch_size, seq_len, hidden_dim]

        # Transformer decoder requires inputs in shape [seq_len, batch_size, hidden_dim]
        hidden_seq = hidden_seq.permute(1, 0, 2)

        # Create a causal mask to ensure the decoder only attends to previous tokens
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=z.device), diagonal=1).bool()

        # Decode
        outputs = self.decoder(
            tgt=hidden_seq,
            memory=hidden_seq,  # No encoder; latent is used as memory
            tgt_mask=causal_mask,
        )  # [seq_len, batch_size, hidden_dim]

        # Project outputs to vocabulary size
        outputs = self.fc_out(outputs.permute(1, 0, 2))  # [batch_size, seq_len, vocab_size]

        return outputs


### 2. Protein Generator Model

The ProtGenerator class is a higher-level wrapper that integrates the VAE with an encoder and decoder. The encoder processes the input protein sequences, and the decoder generates reconstructed sequences based on the latent representation produced by the encoder. This structure helps in training the model to generate sequences in an unsupervised manner.


In [6]:

class ProtGenerator(nn.Module):
    def __init__(self,encoder,decoder):
        super(ProtGenerator, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self,inputs):
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        z = self.encoder.get_feat(input_ids,attention_mask)
        recon_x = self.decoder(z, seq_len = input_ids.size(1))
        return recon_x

    

### 3. ProtGenerator Wrapper for DeepChem

Here, we define the ProtGeneratorDCWrapper class, which wraps the ProtGenerator model into DeepChem's framework. The wrapper integrates a loss function (SparseSoftmaxCrossEntropy) and ensures the model adheres to DeepChem's expected format for training and evaluation.

In this wrapper, the model is expected to receive a batch of inputs containing input_ids and attention_mask, and the outputs are processed by the VAE model for protein sequence generation.

In [7]:
from deepchem.models.losses import SparseSoftmaxCrossEntropy
class ProtGeneratorDCWrapper(TorchModel):
    def __init__(self,model,**kwargs):
        self.model = model
        loss = SparseSoftmaxCrossEntropy()
        output_types = ['loss', 'predict']
        super(ProtGeneratorDCWrapper, self).__init__(model,
                                      loss=loss,
                                      output_types=output_types,
                                      **kwargs)
    
    def _prepare_batch(
            self,
            batch) -> Tuple[List[Any], List[torch.Tensor], List[torch.Tensor]]:

        inputs, labels, _ = batch
        inputs = inputs[0]
        input_ids = torch.stack([
            x['input_ids'][0] 
            for x in inputs
        ])
        attention_mask = torch.stack([
            x['attention_mask'][0]
            for x in inputs
        ])
        labels = torch.stack([
                x['input_ids'][0]
                for x in inputs
            ])
        inputs = {'input_ids':input_ids, "attention_mask":attention_mask}
        weights = torch.stack([
            torch.tensor(1.0)
            for x in inputs
        ])
        return (inputs, [labels], [weights])

In [8]:
prot_bert_feat_extractor = ProtBERT(task="feature_extractor",model_path='Rostlab/prot_bert', device = 'cpu')
vae_decoder = VAE(latent_dim=prot_bert_feat_extractor.config.hidden_size, hidden_dim=512, vocab_size=prot_bert_feat_extractor.tokenizer.vocab_size,num_heads=8,num_layers=4,max_seq_len=1024 )
gen_model_torch = ProtGenerator(prot_bert_feat_extractor,vae_decoder)
gen_model_dc = ProtGeneratorDCWrapper(gen_model_torch,device = 'cpu',batch_size = 2,wandb= True)

## Train

In [None]:
gen_model_dc.fit(dataset,nb_epoch=1)