In [1]:
%load_ext autoreload
%autoreload 2

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import torch
from torch import nn
from torch.nn import (
    KLDivLoss,
)
from transformers import (
    PreTrainedModel,
)

from src.model.configuration_md_pssm import MDPSSMConfig
from src.model.modeling_outputs import PSSMOutput

from plms import ProstT5, PLMConfig

In [2]:
class PSSMHead(nn.Module):
    """Head for PSSM generation from T5 embeddings. based on https://github.com/hefeda/PGP/blob/master/prott5_batch_predictor.py#L144"""

    def __init__(self):
        """
        Args:
            config (MDPSSMConfig): Configuration object for the model
        """
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Conv1d(1024, 32, kernel_size=7, padding=3),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Conv1d(32, 20, kernel_size=7, padding=3),
        )

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.classifier(x)
        x = x.transpose(1, 2)
        pssm = torch.softmax(x, dim=2)
        return pssm


class T5EncoderModelForPssmGeneration(PreTrainedModel):
    def __init__(self, config: MDPSSMConfig):
        super().__init__(config=config)
        device_map = config.device if hasattr(config, "device") else "auto"
        plm_config = PLMConfig(
            name_or_path=config.model_name,
            device=device_map,
        )

        self.protein_encoder = ProstT5(config=plm_config)
        self.pssm_head = PSSMHead()
        self.loss_fct = KLDivLoss(reduction="batchmean")

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        encoder_outputs = self.protein_encoder.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=return_dict,
        )

        # [batch_size, seq_len, hidden_dim]
        hidden_states = encoder_outputs["last_hidden_state"]

        # print(attention_mask.shape)
        # print(hidden_states.shape)

        # display(attention_mask)
        # display(hidden_states)

        seq_lengths = attention_mask.sum(dim=1) - 1
        batch_indices = torch.arange(attention_mask.size(0), device=attention_mask.device)
        attention_mask[batch_indices, seq_lengths] = 0

        hidden_states = hidden_states * attention_mask.unsqueeze(-1)

        # [batch_size, seq_len, 20]
        pssm = self.pssm_head(hidden_states)

        loss = None
        if labels is not None:
            # [batch_size * seq_len, 20]
            target = labels.flatten(end_dim=1)
            pred = pssm.flatten(end_dim=1)

            mask = ~torch.any(target == -100, dim=1)

            pred = pred[mask]
            target = target[mask]

            loss = self.loss_fct(torch.log(pred), target)

        if not return_dict:
            output = (pssm, encoder_outputs[2:-1])
            return ((loss,) + output) if loss is not None else output

        return PSSMOutput(
            loss=loss,
            pssms=pssm,
            hidden_states=encoder_outputs["last_hidden_state"] if output_hidden_states else None,
            masks=attention_mask,
        )


In [None]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

config = MDPSSMConfig(device=device)
model = T5EncoderModelForPssmGeneration(config)

In [None]:
from transformers import T5EncoderModel

model = T5EncoderModel.from_pretrained(
    "Rostlab/prot_t5_xl_uniref50",
    torch_dtype="auto",
)

model.safetensors:   0%|          | 0.00/11.3G [00:00<?, ?B/s]