In [None]:
import lmdb
from typing import Union, List, Tuple, Sequence, Dict, Any, Optional, Collection
from torch.utils.data import Dataset
from pathlib import Path
import pickle as pkl

class LMDBDataset(Dataset):
    """Creates a dataset from an lmdb file.
    Args:
        data_file (Union[str, Path]): Path to lmdb file.
        in_memory (bool, optional): Whether to load the full dataset into memory.
            Default: False
    """

    def __init__(self,
                 data_file: Union[str, Path],
                 in_memory: bool = False):

        data_file = Path(data_file)
        if not data_file.exists():
            raise FileNotFoundError(data_file)

        env = lmdb.open(str(data_file), max_readers=1, readonly=True,
                        lock=False, readahead=False, meminit=False)

        with env.begin(write=False) as txn:
            num_examples = pkl.loads(txn.get(b'num_examples'))

        if in_memory:
            cache = [None] * num_examples
            self._cache = cache

        self._env = env
        self._in_memory = in_memory
        self._num_examples = num_examples

    def __len__(self) -> int:
        return self._num_examples

    def __getitem__(self, index: int):
        if not 0 <= index < self._num_examples:
            raise IndexError(index)

        if self._in_memory and self._cache[index] is not None:
            item = self._cache[index]
        else:
            with self._env.begin(write=False) as txn:
                item = pkl.loads(txn.get(str(index).encode()))
                if 'id' not in item:
                    item['id'] = str(index)
                if self._in_memory:
                    self._cache[index] = item
        return item
def dataset_factory(data_file: Union[str, Path], *args, **kwargs) -> Dataset:
    data_file = Path(data_file)
    if not data_file.exists():
        raise FileNotFoundError(data_file)
    if data_file.suffix == '.lmdb':
        return LMDBDataset(data_file, *args, **kwargs)
    else:
        raise ValueError(f"Unrecognized datafile type {data_file.suffix}")

In [None]:

import torch
import numpy as np
class StabilityDataset(Dataset):

    def __init__(self,
                 data_path: Union[str, Path],
                 split: str,
                 in_memory: bool = False):

        if split not in ('train', 'valid', 'casp12', 'ts115', 'cb513','test'):
            raise ValueError(f"Unrecognized split: {split}. Must be one of "
                             f"['train', 'valid', 'casp12', "
                             f"'ts115', 'cb513']")

        data_path = Path(data_path)
        data_file = f'stability/stability_{split}.lmdb'
        self.data = dataset_factory(data_path / data_file, in_memory)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        item = self.data[index]
        token_ids = item['primary']
        input_mask = np.ones_like(token_ids)

        # pad with -1s because of cls/sep tokens
        labels = np.asarray(item['stability_score'], np.float16)
        output = {'input_ids': token_ids,
                'attention_mask': input_mask,
                'labels': labels}
        return output
    

In [None]:
    
data_dir = r'./data'
check_point = r'./esm2_t33_650M_UR50D'
train_dataset = StabilityDataset(data_dir,'train')
valid_dataset = StabilityDataset(data_dir,'valid')
test_dataset = StabilityDataset(data_dir,'test')

train_sequences=[]
train_labels=[]
for seq in range(train_dataset.__len__()):
    train_sequences.append(train_dataset[seq]['input_ids'])
    train_labels.append(train_dataset[seq]['labels'])
valid_sequences=[]
valid_labels=[]
for seq in range(valid_dataset.__len__()):
    valid_sequences.append(valid_dataset[seq]['input_ids'])
    valid_labels.append(valid_dataset[seq]['labels'])
test_sequences=[]
test_labels=[]
for seq in range(test_dataset.__len__()):
    test_sequences.append(test_dataset[seq]['input_ids'])
    test_labels.append(test_dataset[seq]['labels'])
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(check_point)

train_tokenized = tokenizer(train_sequences)
valid_tokenized = tokenizer(valid_sequences)
test_tokenized = tokenizer(test_sequences)
from datasets import Dataset

train_dataset = Dataset.from_dict(train_tokenized)
valid_dataset = Dataset.from_dict(valid_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
valid_dataset = valid_dataset.add_column("labels",valid_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

In [None]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, EsmForSequenceClassification, EsmModel
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch import nn as nn
from transformers import AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class EsmClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config,n_hidden=1024):
        super().__init__()
        self.dense = nn.Linear(n_hidden*2,n_hidden*2)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(n_hidden*2, config.num_labels)
    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        # logger.info(f"x.shape:{x.shape}")
        # x.shape:(20,640)
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x
class EsmForSequenceClassificationMLP(EsmForSequenceClassification):
    def __init__(self, config,n_hidden=1024):
        super(EsmForSequenceClassificationMLP, self).__init__(config)
        self.esm = EsmModel(config, add_pooling_layer=False)
        self.lstm = nn.LSTM(input_size=config.hidden_size+64, hidden_size=n_hidden, batch_first=True, num_layers=2, bidirectional=True, dropout=0.5)
        self.lstm_dropout = nn.Dropout(p=0.5)
        self.conv1_1d = nn.Sequential(*[
            nn.Dropout(p=0.5),
            nn.Conv1d(in_channels=config.hidden_size, out_channels=32, kernel_size=129, padding=64),
            nn.ReLU(),
        ])
        self.conv2_1d = nn.Sequential(*[
            nn.Dropout(p=0.5),
            nn.Conv1d(in_channels=config.hidden_size, out_channels=32, kernel_size=257, padding=128),
            nn.ReLU(),
        ])
        self.batch_norm = nn.BatchNorm1d(config.hidden_size+64, track_running_stats=False)
        for param in self.esm.parameters():
            param.requires_grad = False
        self.classifier = EsmClassificationHead(config)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        # logger.info(f'sequence_output:{sequence_output.shape}')
        _, length, _ = sequence_output.size()
        # logger.info(f"attention_mask.shape:{attention_mask.shape}")
        lengths = torch.sum(attention_mask, dim=1).cpu().long()
        # logger.info(f"attention_mask:{attention_mask}")
        # logger.info(f"lengths:{lengths}")
        # sequence_output:torch.Size([8, 517, 640])
        x = sequence_output.permute(0,2,1)
        # x:torch.Size([8, 640, 517])
        # logger.info(f'x:{x.shape}')
        r1 = self.conv1_1d(x)
        # logger.info(f'r1:{r1.shape}')
        r2 = self.conv2_1d(x)
        # logger.info(f'r2:{r2.shape}')
        # concatenate channels from residuals and input
        x = torch.cat([x, r1, r2], dim=1)

        x = self.batch_norm(x)
        x = x.permute(0,2,1)
        x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        x, _ = self.lstm(x)
        x, _ = pad_packed_sequence(x, total_length=length, batch_first=True)
        
        x = self.lstm_dropout(x)
        # logger.info(f"x.shape:{x.shape}")
        logits = self.classifier(x)
        labels = labels.to(torch.float32)
        loss = None
        if labels is not None:
            labels = labels.to(logits.device)

            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
num_labels = 1
model = EsmForSequenceClassificationMLP.from_pretrained(check_point,problem_type="regression", num_labels=num_labels)
model_name="esm2_t33_650M_UR50D"
batch_size = 30

args = TrainingArguments(
    f"{model_name}-stability-cnn_lstm_230731",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    # learning_rate=2e-5,
    learning_rate=5e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=100,
    weight_decay=0.001,
    save_total_limit=10,
    load_best_model_at_end=True,
    metric_for_best_model="mse",
    push_to_hub=False,
    fp16=False,
    fp16_full_eval=False,
)
def mse(outputs, labels):
    loss = torch.square(outputs - labels)
    return torch.sum(loss) / len(labels)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = torch.tensor(predictions)
    labels = torch.tensor(labels)

    return {"mse":mse(predictions,labels)}
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()


In [None]:
# trainer.evaluate()
predictions,labels,metric = trainer.predict(test_dataset)
torch.save(labels,"./result/sta-train/labels.pt")
torch.save(predictions,"./result/sta-train/predictions.pt")
torch.save(metric,"./result/sta-train/metric.pt")