# Protein Sequence Generation with DeepChem and ProtBERT

In this tutorial, we will walk through how to generate protein sequences using pre-trained protein language models (PLMs) integrated into DeepChem. Specifically, we will demonstrate how to use the ProtBERT model—a transformer-based model trained on millions of protein sequences—to generate plausible amino acid sequences conditioned on learned representations.


This tutorial is part of DeepChem’s ongoing effort to bring cutting-edge protein language modeling tools to the open-source drug discovery and bioinformatics communities.


### Setup

#### Install necessary libraries


In [None]:
!pip install --pre deepchem
import deepchem
deepchem.__version__

## Import libraries

In [21]:
import deepchem as dc
from deepchem.models.torch_models.prot_bert import ProtBERT
import torch
import torch.nn as nn
from deepchem.models.torch_models.torch_model import TorchModel
from typing import Tuple, List, Any
from transformers import AutoTokenizer
import torch.nn.functional as F
from deepchem.data import NumpyDataset



## Load Dataset

### Tokenizer Wrapper for ProtBERT

In this section, we define a wrapper class `ProtBERTTokenizerWrapper` for the ProtBERT tokenizer. This wrapper allows us to conveniently set the maximum sequence length and handle the necessary tokenization steps (like truncation, padding, and conversion to PyTorch tensors) when calling the tokenizer. 

We use the `AutoTokenizer.from_pretrained()` method to load the pre-trained ProtBERT tokenizer from HuggingFace, and then customize it with additional parameters such as truncation, padding, and max length.


In [2]:
# Temporary workaround in sending kwargs to HuggingFaceFeaturizer
class ProtBERTTokenizerWrapper:
    def __init__(self, model_name, max_len=128):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.max_len = max_len

    def __call__(self, sequence):
        return self.tokenizer(sequence, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')

temp_tokenizer = ProtBERTTokenizerWrapper('Rostlab/prot_bert',max_len=512)
protbert_featurizer = dc.feat.HuggingFaceFeaturizer(temp_tokenizer)

In [3]:
loader = dc.data.CSVLoader([],
                               feature_field="protein",
                               featurizer=protbert_featurizer)
dataset = loader.create_dataset("/home/shivasankaran/deepchem/plastic_degrading_train.csv")

## Models

### 1. Variational Autoencoder (VAE) for Protein Generation

In this section, we define a **Variational Autoencoder (VAE)** model, which is designed to generate protein sequences based on a latent space. The architecture of the VAE consists of several key components:

- **Latent-to-Hidden Mapping:** A fully connected layer to map the latent vector `z` to a hidden space.
- **Positional Embeddings:** An embedding layer to add positional information to the input sequence, enabling the model to learn dependencies across sequence positions.
- **Transformer Decoder:** A Transformer decoder layer (with multiple layers) that processes the sequence and latent vector for sequence generation.
- **Output Linear Layer:** A fully connected layer that generates token probabilities for the vocabulary.

The forward pass uses the latent vector `z` to generate a hidden representation, which is then passed through the Transformer decoder to produce the output sequence.


In [4]:

class VAE(nn.Module):
    def __init__(self, latent_dim, hidden_dim, vocab_size, num_heads, num_layers, max_seq_len):
        super(VAE, self).__init__()
        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim)

        # Positional embeddings for input sequences
        self.position_embeddings = nn.Embedding(max_seq_len, hidden_dim)

        # Transformer Decoder components
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=4 * hidden_dim,
            dropout=0.1,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)

        # Linear layer for output token probabilities
        self.fc_out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, z, seq_len):
        # Map latent vector to the hidden dimension
        hidden = self.latent_to_hidden(z)  # [batch_size, hidden_dim]

        # Repeat hidden state across sequence length
        hidden_seq = hidden.unsqueeze(1).repeat(1, seq_len, 1)  # [batch_size, seq_len, hidden_dim]

        # Add positional embeddings
        positions = torch.arange(seq_len, device=z.device).unsqueeze(0).expand(z.size(0), seq_len)
        hidden_seq += self.position_embeddings(positions)  # [batch_size, seq_len, hidden_dim]

        # Transformer decoder requires inputs in shape [seq_len, batch_size, hidden_dim]
        hidden_seq = hidden_seq.permute(1, 0, 2)

        # Create a causal mask to ensure the decoder only attends to previous tokens
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=z.device), diagonal=1).bool()

        # Decode
        outputs = self.decoder(
            tgt=hidden_seq,
            memory=hidden_seq,  # No encoder; latent is used as memory
            tgt_mask=causal_mask,
        )  # [seq_len, batch_size, hidden_dim]

        # Project outputs to vocabulary size
        outputs = self.fc_out(outputs.permute(1, 0, 2))  # [batch_size, seq_len, vocab_size]

        return outputs


### 2. Protein Generator Model

The ProtGenerator class is a higher-level wrapper that integrates the VAE with an encoder and decoder. The encoder processes the input protein sequences, and the decoder generates reconstructed sequences based on the latent representation produced by the encoder. This structure helps in training the model to generate sequences in an unsupervised manner.


In [17]:

class ProtGenerator(nn.Module):
    def __init__(self,encoder,decoder,tokenizer):
        super(ProtGenerator, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.tokenizer = tokenizer
    
    def forward(self,inputs):
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        z = self.encoder.get_last_hidden_state(input_ids,attention_mask)
        recon_x = self.decoder(z, seq_len = input_ids.size(1))
        return recon_x

    def generate(self,protein_sequence , seq_len, noise_level=0.0, temperature=1.0,exclude_tokens=None):
        encoded = self.tokenizer(protein_sequence, truncation=True, padding='max_length', max_length=512, return_tensors='pt')
        input_ids = encoded['input_ids']
        attention_mask = encoded['attention_mask']
        z = self.encoder.get_last_hidden_state(input_ids,attention_mask)

        if noise_level > 0:
            noise = torch.randn_like(z) * noise_level  # Generate Gaussian noise
            z = z + noise  # Add noise to the latent representation
        
        with torch.no_grad():
            outputs = self.decoder(z, seq_len)  # Shape: (1, seq_len, vocab_size)

        outputs = outputs.squeeze(0)  # Remove the batch dimension
        generated_sequence = []

        for i in range(seq_len):
        # Apply softmax to get probabilities, adjusted by temperature
            probabilities = F.softmax(outputs[i] / temperature, dim=-1)


            # Impose a penalty on [PAD] tokens if a pad_token_id is provided
            if exclude_tokens:
                for token_id in exclude_tokens:
                    # print(token_id)
                    probabilities[token_id] = 0

            # Normalize probabilities after zeroing out [PAD]

            probabilities = probabilities / probabilities.sum()

            # Sample a token from the probability distribution
            token_index = torch.multinomial(probabilities, num_samples=1).item()

            # Append the corresponding amino acid to the generated sequence
            token = self.tokenizer.decode([token_index]).strip()
            generated_sequence.append(token)

        # Join the tokens to form a complete protein sequence, excluding [PAD]
        return ' '.join([token for token in generated_sequence if token not in exclude_tokens])



    

### 3. ProtGenerator Wrapper for DeepChem

Here, we define the ProtGeneratorDCWrapper class, which wraps the ProtGenerator model into DeepChem's framework. The wrapper integrates a loss function (SparseSoftmaxCrossEntropy) and ensures the model adheres to DeepChem's expected format for training and evaluation.

In this wrapper, the model is expected to receive a batch of inputs containing input_ids and attention_mask, and the outputs are processed by the VAE model for protein sequence generation.

In [34]:
from deepchem.models.losses import SparseSoftmaxCrossEntropy
import numpy as np
class ProtGeneratorDCWrapper(TorchModel):
    def __init__(self,model,**kwargs):
        self.model = model
        loss = SparseSoftmaxCrossEntropy()
        output_types = ['loss', 'predict']
        super(ProtGeneratorDCWrapper, self).__init__(model,
                                      loss=loss,
                                      output_types=output_types,
                                      **kwargs)

    
    def predict(self, dataset,seq_len = 512, noise_level=0.0, temperature=1.0, exclude_tokens=None):
        """
        Generates protein sequences for input protein sequences.

        Args:
            dataset (deepchem.Dataset): DeepChem dataset containing input sequences.
            noise_level (float): Noise to add to latent representations.
            temperature (float): Sampling temperature.
            exclude_tokens (list): Tokens to exclude from generated sequences.

        Returns:
            list: Generated protein sequences.
        """
        inputs = dataset.X  # Assuming dataset contains raw protein sequences
        generated_sequences = []

        for protein_sequence in inputs:
            generated_seq = self.model.generate(
                protein_sequence,
                seq_len=seq_len,  # Adjust as needed
                noise_level=noise_level,
                temperature=temperature,
                exclude_tokens=exclude_tokens
            )
            generated_sequences.append(generated_seq)

        return np.array(generated_sequences)
    
    def _prepare_batch(
            self,
            batch) -> Tuple[List[Any], List[torch.Tensor], List[torch.Tensor]]:

        inputs, labels, _ = batch
        inputs = inputs[0]
        input_ids = torch.stack([
            x['input_ids'][0] 
            for x in inputs
        ])
        attention_mask = torch.stack([
            x['attention_mask'][0]
            for x in inputs
        ])
        labels = torch.stack([
                x['input_ids'][0]
                for x in inputs
            ])
        inputs = {'input_ids':input_ids, "attention_mask":attention_mask}
        weights = torch.stack([
            torch.tensor(1.0)
            for x in inputs
        ])
        return (inputs, [labels], [weights])



In [35]:
tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert')
prot_bert_feat_extractor = ProtBERT(task="feature_extractor",model_path='Rostlab/prot_bert', device = 'cpu')
vae_decoder = VAE(latent_dim=prot_bert_feat_extractor.config.hidden_size, hidden_dim=512, vocab_size=prot_bert_feat_extractor.tokenizer.vocab_size,num_heads=8,num_layers=4,max_seq_len=1024 )
gen_model_torch = ProtGenerator(prot_bert_feat_extractor,vae_decoder, tokenizer)
gen_model_dc = ProtGeneratorDCWrapper(gen_model_torch,device = 'cpu',batch_size = 2,wandb= True)

`wandb` argument is deprecated. Please use `wandb_logger` instead. This argument will be removed in a future release of DeepChem.
You set wandb to True but W&B is not installed. To use wandb logging, run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface.


## Training the Protein Generation Model

Now that we've constructed the model and loaded our protein dataset, it's time to train the model. We'll run a single epoch to demonstrate the training loop and ensure everything is working end-to-end.

This step will fine-tune the ProtBERT-based generator on the provided dataset.



In [36]:
gen_model_dc.fit(dataset,nb_epoch=1)

## Generating Protein Sequences

With the model trained, we can now generate new protein sequences. These sequences are sampled based on the learned representations from the fine-tuned ProtBERT model.

This step demonstrates how the model can be used to create plausible protein-like sequences that follow the patterns learned from the training data.


In [37]:
input_sequences = [
    "M S A S P R L G F V Q C I S P A G L H R M A Y H E W G D P A N P R V L V C A H G L T R G R D F D T V A S A L C G D Y R V V C P D V A G R G R S E W L A D A N G Y V V P Q Y V S D M V T L I A R L N V E K V D W F G T S M G G L I G M G L A G L P K S P V R N V L L N D V G P K L A P S A V E R I G A Y L G L P V R F K T F E E G L A Y L Q T I S A S F G R H T P E Q W R E L N A A I L K P V Q G T D G L E W G L H Y D P Q L A V P F R K S T P E A I A A G E A A L W R T F E A I E G P V L V V R G A Q S D L L L R E T V A E M V A R G K H V S S V E V P D V G H A P T F V D P A Q I A I A P Q F F T G A",
    "M L A K Q I K K A N S R S T L L R K S L L F A A P I I L A V S S S S V Y A L T Q V S N F G T N P G N L Q M F K H V P S G M P A N A P L V V A L H G C T Q T A A A Y E A S G W S A L G N T H K F Y V V Y P Q Q Q S G N N S N K C F N W F E P G D I T R G Q G E A L S I K Q M V D N M K A N H S I D P S R V Y V T G L S A G A F M T T V M A A T Y P D V F A G A A P I A G G P Y K C A T S M T S A F T C M S P G V D K T P A A W G D L A R G G Y S G Y N G P K P K I S I W H G S S D Y T V A P A N Q N E T V E Q F T N Y H G I D Q T P D V S D T V G G F P H K V Y K S A N G T P L V E T Y T I T G M G H G T P V D P G T G A N Q C G T A G A Y I L D V N V C S S Y Y I G Q F F G I I G G G G T T T T T T S G N V T T T T A A T T T T T T A T Q G Y T Q T T S A T V T N H Y V A G R I N V T Q Y N V L G A R Y G Y V T T I P L Y Y C P S L S G W T D K A N C S P I"
]

dc_dataset = NumpyDataset(np.array(input_sequences))

In [38]:

seq_len = 300  # Length of the desired generated sequence
temperature = 1  # Adjust temperature for randomness control
noise_level = 0.1
protein_seed = "M S A S P R L G F V Q C I S P A G L H R M A Y H E W G D P A N P R V L V C A H G L T R G R D F D T V A S A L C G D Y R V V C P D V A G R G R S E W L A D A N G Y V V P Q Y V S D M V T L I A R L N V E K V D W F G T S M G G L I G M G L A G L P K S P V R N V L L N D V G P K L A P S A V E R I G A Y L G L P V R F K T F E E G L A Y L Q T I S A S F G R H T P E Q W R E L N A A I L K P V Q G T D G L E W G L H Y D P Q L A V P F R K S T P E A I A A G E A A L W R T F E A I E G P V L V V R G A Q S D L L L R E T V A E M V A R G K H V S S V E V P D V G H A P T F V D P A Q I A I A P Q F F T G A"
exclude_tokens = [
    tokenizer.pad_token_id,
    tokenizer.cls_token_id,
    tokenizer.mask_token_id,
    tokenizer.unk_token_id,
    tokenizer.sep_token_id
]
gen_model_dc.predict(dc_dataset,seq_len,noise_level,temperature,exclude_tokens)

array(['D E M M Q Y S E W E P T G B U T N O H C F D H D K P A F I L Q T Y U S Y G W O H T L G O M K A R N H U F L U D N B S G A F B X Z M Y U C I U E Y X R X Q L S T G H G E D G O G K Y A A W Z K Z N Q Q P C E B M A E Y N K S E R I O P A A B N C P G G Y T A Y B Z E R Q A U C I F X I D K G A M Y S M U X T B C L O Q Y V T P B B K K F O H Y C O T G G Q T L C D I I X S H N O O G Q L C F O G A I U O F P L G G F T K O H S K H G Q Q G O U A W I G M U B G X A P E P Y T E A T E X C C O U Z G S R Z S R N F O T U B S D I L N H K U G Y H Z K S Q K G P T L Q P L X F I O N M T E G G L E O O E D D A O Q L W B D C Z',
       'X Q A B T M L Z M Z E G S P H K G P U A Q H C O D D S H A L N K B S E X S A H O X T Z M C T K K E R X D E U K T K O S W O S Q Z K D O A V Y E Z B O Y I Q G N X K T Q X I O N L H O M C B T K W F G E G H H L H O Z D H O P G N Y V T E B Z A E M A X M K Q Z T N T A H C G Y D K Z D Z E H O F S I T A P Z T Z E E C A K O B Z P C X N S G D G M Z P S O D L W E O N M X G C V Q S M K E Q Q 

# Congratulations! Time to join the Community!
Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:

## Star DeepChem on [GitHub](https://github.com/deepchem/deepchem)
This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.

## Join the DeepChem Discord
The DeepChem [Discord](https://discord.gg/cGzwCdrUqS) hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!

## Citing this Tutorial

If you use this tutorial or the underlying tools in your work, please consider citing our paper:
```
@article{vanaja2024open,
  title={Open-Source Protein Language Models for Function Prediction and Protein Design},
  author={Vanaja Pandi, Shivasankaran and Ramsundar, Bharath},
  journal={arXiv e-prints},
  pages={arXiv--2412},
  year={2024}
}
```

