In [16]:
import esm
model, vocab = esm.pretrained.esm2_t33_650M_UR50D()

In [2]:
modules = []
for module in model.modules():
    modules.append(module)

In [17]:
from typing import Union
import torch
from esm.model.esm2 import ESM2


class ESMModel(ESM2):

    def __init__(self,  num_layers: int = 33,
        embed_dim: int = 1280,
        attention_heads: int = 20,
        alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
        token_dropout: bool = True,) -> None:
        super().__init__(num_layers, embed_dim, attention_heads, alphabet, token_dropout)

    def forward(self, x):
        modules = []
        for module in self.modules():
            modules.append(module)
        
        gpus = list(range(torch.cuda.device_count()))
        gpus = [f"cuda:{i}" for i in range(len(gpus))]
        for i, module in enumerate(modules):
            gpu = gpus[i % len(gpus)]
            x.to(gpu)
            module.to(gpu)
            x = module(x)
        return x

In [18]:
esm_model = ESMModel()
esm_model.load_state_dict(model.state_dict())
model = esm_model

In [20]:
model

ESMModel(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280,