In [1]:
import torch
import datasets

from torch.utils.flop_counter import FlopCounterMode
from torch.utils.data import DataLoader, Dataset, TensorDataset, RandomSampler

from torch import nn

from transformers import (
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    SchedulerType,
    HfArgumentParser,
    is_torch_tpu_available,
    get_scheduler,
    set_seed,
    TrainingArguments
)

In [2]:
dataset_dir = "/home/jarekk/datasets/uniref50"

dataset = datasets.load_from_disk(dataset_dir)
dataset

Loading dataset from disk:   0%|          | 0/62 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'special_tokens_mask', 'attention_mask'],
        num_rows: 10000000
    })
    validation: Dataset({
        features: ['input_ids', 'special_tokens_mask', 'attention_mask'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['input_ids', 'special_tokens_mask', 'attention_mask'],
        num_rows: 50000
    })
})

In [3]:
model_id = "facebook/esm2_t36_3B_UR50D"
mlm_probability = 0.15

tokenizer_kwargs = {}

tokenizer = AutoTokenizer.from_pretrained(
    model_id, **tokenizer_kwargs)

data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm_probability=mlm_probability
    )

loader = DataLoader(
        dataset["train"],
        collate_fn=data_collator,
        batch_size=1,
    )

In [4]:
loader_iter = iter(loader)
batch = next(loader_iter)

In [5]:
batch

{'input_ids': tensor([[ 0, 20,  4, 19,  7,  4, 11, 32,  5, 17,  9, 15, 19, 20,  6, 13, 32,  4,
          4, 16, 15, 11, 19, 32,  7, 13, 18,  4, 32, 15, 15, 10,  7, 15, 17, 16,
          6, 12,  7, 32, 16, 19, 19, 12,  9, 13, 17, 21,  9,  5, 12, 32, 14, 32,
          9,  4, 18, 19, 10,  5, 16,  9,  9, 15,  5, 10, 32,  5,  5, 12, 19, 32,
         14,  8, 32, 32, 15, 15,  8, 15, 11,  9, 15,  8, 18, 27,  8,  8, 15, 19,
          8,  4,  8, 32, 12, 20,  7, 23,  6, 32,  5,  6, 16, 14, 32, 10, 32, 16,
         11, 22,  8, 15, 19,  6, 32, 15,  8,  5,  7, 22, 10, 23, 13, 17, 10,  4,
         32, 17,  6, 11, 15, 17, 23, 15, 21,  8, 14, 11,  4, 32,  9, 16, 14,  4,
         19,  9,  5, 12, 20, 11,  5, 12, 17,  8,  7,  7,  9, 32, 16,  6,  9, 18,
          7,  6,  5, 32, 10,  9, 17,  7, 12, 28, 32, 12,  6,  8, 19,  8, 11, 10,
         21, 12, 11,  8,  9, 19, 13,  9, 16, 12, 29, 15,  4, 16, 32, 32, 20,  4,
         11,  4, 12,  9,  9, 17,  5, 15, 16,  6,  8, 32, 17,  9, 17,  8, 13,  9,
         16, 1

In [6]:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']

In [7]:

config_kwargs = {}
config = AutoConfig.from_pretrained(
        model_id, **config_kwargs
    )

In [8]:
config

EsmConfig {
  "_name_or_path": "facebook/esm2_t36_3B_UR50D",
  "architectures": [
    "EsmForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "emb_layer_norm_before": false,
  "esmfold_config": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 2560,
  "initializer_range": 0.02,
  "intermediate_size": 10240,
  "is_folding_model": false,
  "layer_norm_eps": 1e-05,
  "mask_token_id": 32,
  "max_position_embeddings": 1026,
  "model_type": "esm",
  "num_attention_heads": 40,
  "num_hidden_layers": 36,
  "pad_token_id": 1,
  "position_embedding_type": "rotary",
  "token_dropout": true,
  "torch_dtype": "float32",
  "transformers_version": "4.39.3",
  "use_cache": true,
  "vocab_list": null,
  "vocab_size": 33
}

In [9]:
input_ids = input_ids.numpy()
attention_mask = attention_mask.numpy()
labels=labels.numpy()

In [10]:
with torch.device("meta"): 
    model = AutoModelForMaskedLM.from_config(config)
    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels = torch.tensor(labels)
    

In [11]:
for param in model.parameters():
    print(param)

Parameter containing:
tensor(..., device='meta', size=(33, 2560), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(1026, 2560), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560, 2560), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560, 2560), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560, 2560), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560, 2560), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560,), requires_grad=True)
Parameter containing:
tensor(..., device='meta', size=(2560

In [11]:
flop_counter = FlopCounterMode(model, display=True)

In [12]:
with flop_counter:
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        labels=labels
    )
    outputs.loss.backward()

Module                        FLOP    % Total
-----------------------  ---------  ---------
EsmForMaskedLM           9007.611B    100.00%
 - aten.addmm            2905.814B     32.26%
 - aten.bmm               289.910B      3.22%
 - aten.mm               5811.887B     64.52%
 EsmForMaskedLM.esm      8987.219B     99.77%
  - aten.addmm           2899.103B     32.19%
  - aten.bmm              289.910B      3.22%
  - aten.mm              5798.206B     64.37%
 EsmForMaskedLM.lm_head    20.392B      0.23%
  - aten.addmm              6.711B      0.07%
  - aten.mm                13.681B      0.15%


In [29]:
flop_counter.get_total_flops()

2131171737600

In [14]:
model.num_parameters()

2841632194

In [15]:
model

EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 2560, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 2560, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-35): 36 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=2560, out_features=2560, bias=True)
              (key): Linear(in_features=2560, out_features=2560, bias=True)
              (value): Linear(in_features=2560, out_features=2560, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=2560, out_features=2560, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((2560,), eps=1e-05, 