In [1]:
import lmdb
from typing import Union, List, Tuple, Sequence, Dict, Any, Optional, Collection
from torch.utils.data import Dataset
from pathlib import Path
import pickle as pkl

class LMDBDataset(Dataset):
    """Creates a dataset from an lmdb file.
    Args:
        data_file (Union[str, Path]): Path to lmdb file.
        in_memory (bool, optional): Whether to load the full dataset into memory.
            Default: False.
    """

    def __init__(self,
                 data_file: Union[str, Path],
                 in_memory: bool = False):

        data_file = Path(data_file)
        if not data_file.exists():
            raise FileNotFoundError(data_file)

        env = lmdb.open(str(data_file), max_readers=1, readonly=True,
                        lock=False, readahead=False, meminit=False)

        with env.begin(write=False) as txn:
            num_examples = pkl.loads(txn.get(b'num_examples'))

        if in_memory:
            cache = [None] * num_examples
            self._cache = cache

        self._env = env
        self._in_memory = in_memory
        self._num_examples = num_examples

    def __len__(self) -> int:
        return self._num_examples

    def __getitem__(self, index: int):
        if not 0 <= index < self._num_examples:
            raise IndexError(index)

        if self._in_memory and self._cache[index] is not None:
            item = self._cache[index]
        else:
            with self._env.begin(write=False) as txn:
                item = pkl.loads(txn.get(str(index).encode()))
                if 'id' not in item:
                    item['id'] = str(index)
                if self._in_memory:
                    self._cache[index] = item
        return item

In [2]:
def dataset_factory(data_file: Union[str, Path], *args, **kwargs) -> Dataset:
    data_file = Path(data_file)
    if not data_file.exists():
        raise FileNotFoundError(data_file)
    if data_file.suffix == '.lmdb':
        return LMDBDataset(data_file, *args, **kwargs)
    else:
        raise ValueError(f"Unrecognized datafile type {data_file.suffix}")

In [3]:
import torch
import numpy as np
class FluorescenceDataset(Dataset):

    def __init__(self,
                 data_path: Union[str, Path],
                 split: str,
                 in_memory: bool = False):

        if split not in ('train', 'valid', 'casp12', 'ts115', 'cb513','test'):
            raise ValueError(f"Unrecognized split: {split}. Must be one of "
                             f"['train', 'valid', 'casp12', "
                             f"'ts115', 'cb513']")

        data_path = Path(data_path)
        data_file = f'fluorescence/fluorescence_{split}.lmdb'
        self.data = dataset_factory(data_path / data_file, in_memory)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        item = self.data[index]
        token_ids = item['primary']
        input_mask = np.ones_like(token_ids)

        # pad with -1s because of cls/sep tokens
        labels = np.asarray(item['log_fluorescence'], np.float16)
        output = {'input_ids': token_ids,
                'attention_mask': input_mask,
                'labels': labels}
        return output

In [4]:
data_dir = r'./data'
train_dataset = FluorescenceDataset(data_dir,'train')
valid_dataset = FluorescenceDataset(data_dir,'valid')
test_dataset = FluorescenceDataset(data_dir,'test')

In [5]:
train_sequences=[]
train_labels=[]
for seq in range(train_dataset.__len__()):
    train_sequences.append(train_dataset[seq]['input_ids'])
    train_labels.append(train_dataset[seq]['labels'])

In [6]:
valid_sequences=[]
valid_labels=[]
for seq in range(valid_dataset.__len__()):
    valid_sequences.append(valid_dataset[seq]['input_ids'])
    valid_labels.append(valid_dataset[seq]['labels'])

In [7]:
test_sequences=[]
test_labels=[]
for seq in range(test_dataset.__len__()):
    test_sequences.append(test_dataset[seq]['input_ids'])
    test_labels.append(test_dataset[seq]['labels'])

In [8]:
print(train_labels[0:100])

[array([3.824], dtype=float16), array([3.752], dtype=float16), array([3.541], dtype=float16), array([3.691], dtype=float16), array([3.688], dtype=float16), array([3.188], dtype=float16), array([1.301], dtype=float16), array([1.301], dtype=float16), array([3.504], dtype=float16), array([3.719], dtype=float16), array([3.754], dtype=float16), array([3.613], dtype=float16), array([3.268], dtype=float16), array([1.301], dtype=float16), array([1.301], dtype=float16), array([3.186], dtype=float16), array([3.725], dtype=float16), array([1.534], dtype=float16), array([3.64], dtype=float16), array([3.137], dtype=float16), array([3.756], dtype=float16), array([1.301], dtype=float16), array([3.678], dtype=float16), array([1.506], dtype=float16), array([3.607], dtype=float16), array([3.701], dtype=float16), array([3.318], dtype=float16), array([3.463], dtype=float16), array([3.605], dtype=float16), array([3.26], dtype=float16), array([3.492], dtype=float16), array([3.719], dtype=float16), array([3.

In [9]:
check_point = r"./esm2_t30_150M_UR50D"

In [10]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(check_point)

train_tokenized = tokenizer(train_sequences)
valid_tokenized = tokenizer(valid_sequences)
test_tokenized = tokenizer(test_sequences)

In [11]:
print(train_tokenized.keys())

dict_keys(['input_ids', 'attention_mask'])


In [12]:
from datasets import Dataset

train_dataset = Dataset.from_dict(train_tokenized)
valid_dataset = Dataset.from_dict(valid_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
valid_dataset = valid_dataset.add_column("labels",valid_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

In [13]:
print(train_dataset)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 21446
})


In [14]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, EsmForSequenceClassification
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch import nn as nn
from transformers import AutoModel
from transformers.modeling_outputs import SequenceClassifierOutput
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class EsmClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        # logger.info(f"x.shape:{x.shape}")
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x
class EsmForSequenceClassificationMLP(EsmForSequenceClassification):
    def __init__(self, config):
        super(EsmForSequenceClassificationMLP, self).__init__(config)
        for param in self.esm.parameters():
            param.requires_grad = False
        self.classifier = EsmClassificationHead(config)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)

            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [15]:
num_labels = 1
model = EsmForSequenceClassificationMLP.from_pretrained(check_point,problem_type="regression", num_labels=num_labels)

Some weights of the model checkpoint at ./esm2_t30_150M_UR50D were not used when initializing EsmForSequenceClassificationMLP: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmForSequenceClassificationMLP from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassificationMLP from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassificationMLP were not initialized from the model checkpoint at ./esm2_t30_150M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias', 

In [16]:
# model_name = model_checkpoint.split("/")[-1]
model_name="esm2_t30_150M_UR50D"
batch_size = 30

args = TrainingArguments(
    f"{model_name}-Fluorescence-mlp-0720",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=50,
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="mse",
    push_to_hub=False,
    fp16=True,
    fp16_full_eval=True,
)

In [17]:
# from evaluate import load
# import numpy as np

# metric = load("mae")

# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     # predictions = np.argmax(predictions, axis=1)
#     return metric.compute(predictions=predictions, references=labels)

In [18]:
def mse(outputs, labels):
    loss = torch.square(outputs - labels)
    return torch.sum(loss) / len(labels)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = torch.tensor(predictions)
    labels = torch.tensor(labels)
    return {"mse":mse(predictions,labels)}

In [19]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [20]:
trainer.train()



Epoch,Training Loss,Validation Loss,Mse
1,1.9931,0.671636,0.671875
2,0.6669,0.648415,0.648438
3,0.642,0.639091,0.63916
4,0.638,0.626722,0.626465
5,0.6284,0.625319,0.625
6,0.6304,0.614774,0.614746
7,0.6147,0.610322,0.610352
8,0.6101,0.607309,0.607422
9,0.6077,0.604133,0.604492
10,0.6024,0.598005,0.598145


In [None]:
trainer.evaluate()

{'eval_loss': 1.056227684020996,
 'eval_mse': 1.056640625,
 'eval_runtime': 55.738,
 'eval_samples_per_second': 96.2,
 'eval_steps_per_second': 6.028,
 'epoch': 50.0}

In [None]:
predictions,labels,metric = trainer.predict(test_dataset)

In [None]:
metric

{'test_loss': 0.5903475284576416,
 'test_mse': 0.59033203125,
 'test_runtime': 271.8118,
 'test_samples_per_second': 100.132,
 'test_steps_per_second': 6.262}