In [1]:
import torch
import datasets

from torch.utils.flop_counter import FlopCounterMode
from torch.utils.data import DataLoader, Dataset, TensorDataset, RandomSampler

from torch import nn

from transformers import (
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    SchedulerType,
    HfArgumentParser,
    is_torch_tpu_available,
    get_scheduler,
    set_seed,
    TrainingArguments
)

In [2]:
dataset_dir = "/home/jarekk/datasets/uniref50"

dataset = datasets.load_from_disk(dataset_dir)
dataset

Loading dataset from disk:   0%|          | 0/62 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'special_tokens_mask', 'attention_mask'],
        num_rows: 10000000
    })
    validation: Dataset({
        features: ['input_ids', 'special_tokens_mask', 'attention_mask'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['input_ids', 'special_tokens_mask', 'attention_mask'],
        num_rows: 50000
    })
})

In [3]:
model_id = "facebook/esm2_t48_15B_UR50D"
mlm_probability = 0.15

tokenizer_kwargs = {}

tokenizer = AutoTokenizer.from_pretrained(
    model_id, **tokenizer_kwargs)

data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm_probability=mlm_probability
    )

loader = DataLoader(
        dataset["train"],
        collate_fn=data_collator,
        batch_size=1,
    )

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [4]:
loader_iter = iter(loader)
batch = next(loader_iter)

In [5]:
batch

{'input_ids': tensor([[ 0, 20,  4, 19,  7, 32, 11, 16, 32, 17,  9, 15, 19, 20,  6, 13,  5,  4,
          4, 16, 15, 11, 19, 11,  7, 32, 32,  4,  8, 15, 32, 10,  7, 15,  0, 16,
          6, 12,  7, 14, 32, 19, 19, 12,  9, 32, 17, 21,  9,  5, 12, 32, 14, 10,
          9,  4, 18, 32, 10,  5, 16,  9,  9, 15,  5, 32, 10,  5, 32, 12, 19, 10,
         10,  8, 11,  9, 15, 15,  8, 15, 32,  9, 32,  8, 15, 19,  8,  8, 15, 19,
          8, 32, 29, 13, 12, 20,  7, 23,  6,  9, 23,  6, 16, 14, 19, 10, 10, 16,
         11, 32,  8, 32, 19,  6, 16, 15,  8,  5,  7, 22, 10, 23, 13, 17, 10,  4,
         32, 17,  6, 11, 15, 17, 23, 15, 32,  8, 14, 32, 32, 15,  9, 16, 14,  4,
         19,  9,  5, 32, 20, 11,  5, 12, 17,  8,  7,  7,  9, 17, 16,  6, 32, 18,
          7,  6,  5, 18, 10,  9, 25,  7, 12, 10,  7, 12,  6,  8, 19,  8, 11, 10,
         21, 12, 14,  8,  9, 19, 32, 32, 16, 12,  9, 15,  4, 16,  6,  9, 20,  4,
         11,  4, 12,  9,  9, 17,  5, 15, 16,  6,  8,  7, 17,  9, 17,  8, 13,  9,
         16, 1

In [6]:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']

In [7]:

config_kwargs = {}
config = AutoConfig.from_pretrained(
        model_id, **config_kwargs
    )

config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

In [8]:
config

EsmConfig {
  "_name_or_path": "facebook/esm2_t48_15B_UR50D",
  "architectures": [
    "EsmForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "emb_layer_norm_before": false,
  "esmfold_config": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 5120,
  "initializer_range": 0.02,
  "intermediate_size": 20480,
  "is_folding_model": false,
  "layer_norm_eps": 1e-05,
  "mask_token_id": 32,
  "max_position_embeddings": 1026,
  "model_type": "esm",
  "num_attention_heads": 40,
  "num_hidden_layers": 48,
  "pad_token_id": 1,
  "position_embedding_type": "rotary",
  "token_dropout": true,
  "torch_dtype": "float32",
  "transformers_version": "4.39.3",
  "use_cache": true,
  "vocab_list": null,
  "vocab_size": 33
}

In [9]:
input_ids = input_ids.numpy()
attention_mask = attention_mask.numpy()
labels=labels.numpy()

In [10]:
with torch.device("meta"): 
    model = AutoModelForMaskedLM.from_config(config)
    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels = torch.tensor(labels)
    

In [12]:
print(model)

EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 5120, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 5120, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-47): 48 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=5120, out_features=5120, bias=True)
              (key): Linear(in_features=5120, out_features=5120, bias=True)
              (value): Linear(in_features=5120, out_features=5120, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=5120, out_features=5120, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((5120,), eps=1e-05, 

In [13]:
flop_counter = FlopCounterMode(model, display=True)

In [14]:
with flop_counter:
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        labels=labels
    )
    outputs.loss.backward()

Module                         FLOP    % Total
-----------------------  ----------  ---------
EsmForMaskedLM           47239.791B    100.00%
 - aten.addmm            15488.726B     32.79%
 - aten.bmm                773.094B      1.64%
 - aten.mm               30977.971B     65.58%
 EsmForMaskedLM.esm      47158.741B     99.83%
  - aten.addmm           15461.882B     32.73%
  - aten.bmm               773.094B      1.64%
  - aten.mm              30923.765B     65.46%
 EsmForMaskedLM.lm_head     81.050B      0.17%
  - aten.addmm              26.844B      0.06%
  - aten.mm                 54.206B      0.11%


In [18]:
flops_per_sequence = flop_counter.get_total_flops()
flops_per_sequence

47239790592000

In [19]:
model_flops = flops_per_sequence/512
model_flops

92265216000.0