In [2]:
import re
import urllib
import warnings
from argparse import Namespace
from pathlib import Path

import torch


model_name = "esm2_t6_8M_UR50D"
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
fn = Path(url).name
model_data = torch.load(
    f"{torch.hub.get_dir()}/checkpoints/{fn}",
    map_location="cpu",
)


url = f"https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt"
fn = Path(url).name
regression_data = torch.load(
    f"{torch.hub.get_dir()}/checkpoints/{fn}",
    map_location="cpu",
)

  from .autonotebook import tqdm as notebook_tqdm


In [60]:
def upgrade_state_dict(state_dict):
    """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
    prefixes = ["encoder.sentence_encoder.", "encoder."]
    pattern = re.compile("^" + "|".join(prefixes))
    state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
    return state_dict

cfg = model_data["cfg"]["model"]
state_dict = model_data["model"]
state_dict = upgrade_state_dict(state_dict)

regression_dict = regression_data["model"]

PATH = "dataset/checkpoints/esm2_t6_8M_UR50D.pt"
# torch.save({
#             'cfg_model': cfg,
#             'model': state_dict,
#             'regression': regression_dict,
#             }, PATH)
checkpoint = torch.load(PATH)
for k, v in checkpoint['model'].items():
    print(k, '      ', v.shape)

embed_tokens.weight        torch.Size([33, 320])
layers.0.self_attn.k_proj.weight        torch.Size([320, 320])
layers.0.self_attn.k_proj.bias        torch.Size([320])
layers.0.self_attn.v_proj.weight        torch.Size([320, 320])
layers.0.self_attn.v_proj.bias        torch.Size([320])
layers.0.self_attn.q_proj.weight        torch.Size([320, 320])
layers.0.self_attn.q_proj.bias        torch.Size([320])
layers.0.self_attn.out_proj.weight        torch.Size([320, 320])
layers.0.self_attn.out_proj.bias        torch.Size([320])
layers.0.self_attn.rot_emb.inv_freq        torch.Size([8])
layers.0.self_attn_layer_norm.weight        torch.Size([320])
layers.0.self_attn_layer_norm.bias        torch.Size([320])
layers.0.fc1.weight        torch.Size([1280, 320])
layers.0.fc1.bias        torch.Size([1280])
layers.0.fc2.weight        torch.Size([320, 1280])
layers.0.fc2.bias        torch.Size([320])
layers.0.final_layer_norm.weight        torch.Size([320])
layers.0.final_layer_norm.bias        torch

In [25]:
from modeling.esm.data import Alphabet
from modeling.esm.esm2 import ESM2

In [26]:
alphabet = Alphabet.from_architecture("ESM-1b")

cfg = checkpoint['cfg_model']
model_state = checkpoint['model']
model_state.update(checkpoint["regression"])

model = ESM2(
    num_layers=cfg.encoder_layers, # 6
    embed_dim=cfg.encoder_embed_dim, # 320
    attention_heads=cfg.encoder_attention_heads, # 20
    alphabet=alphabet,
    token_dropout=cfg.token_dropout, # True
)

model.load_state_dict(model_state, strict=regression_data is not None)

<All keys matched successfully>

In [30]:
from modeling.globalnet import GlobalNet
import utils
import argparse

In [41]:
parser = argparse.ArgumentParser()

utils.add_argument(parser)
args: utils.Args = parser.parse_args(args=[])

    
model = GlobalNet(args, cfg, alphabet)

def update_state_dict(state_dict):
    state_dict = {'esm2.' + name : param for name, param in state_dict.items()}
    return state_dict

model_state2 = update_state_dict(model_state)

model.load_state_dict(model_state2, strict=False)

<All keys matched successfully>

In [54]:
import numpy as np

f = 'output/output_debug/prediction/validation_metrics.npz'
npzf = np.load(f, allow_pickle=True)
metrics = npzf['metrics']

array({'save': True, 'min_eval_loss': 8.818276030950903, 'val_info': {0: {'loss': 8.939831887387541, 'rvalue': 0.015984963448895943, 'pvalue': 0.8618493171786743, 'rrmse': 92.21702783787866}, 1: {'loss': 8.818276030950903, 'rvalue': 0.032543108429948835, 'pvalue': 0.7230748759612504, 'rrmse': 89.93632310925378}, 2: {'loss': 7.694267730076003, 'rvalue': 0.08236897820645024, 'pvalue': 0.3690888765880832, 'rrmse': 70.24695992667573}}, 'output': {'epoch': 2, 'preds': array([-1.2985408, -1.3079586, -1.2835433, -1.2802734, -1.3124374,
       -1.259903 , -1.3035116, -1.2745309, -1.2891798, -1.2968376,
       -1.2796082, -1.2910616, -1.3048191, -1.2624491, -1.2900293,
       -1.2273464, -1.2623799, -1.3147895, -1.2796191, -1.2620436,
       -1.2622919, -1.2624171, -1.2796676, -1.2623378, -1.2619572,
       -1.2626233, -1.26233  , -1.2623463, -1.2619588, -1.2621503,
       -1.2622385, -1.2624427, -1.2624762, -1.2458231, -1.2985127,
       -1.2985487, -1.2618921, -1.2985145, -1.2985339, -1.26633

In [55]:
import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

ModuleNotFoundError: No module named 'transformers'

In [3]:
# checkpoint = torch.load('dataset/checkpoints/6-new-12w-0/pytorch_model.bin')
checkpoint = torch.load('dataset/checkpoints/5-new-12w-0/pytorch_model.bin')

In [4]:
def update_state_dict2(state_dict):
    state_dict = {name.replace('bert.', '') : param for name, param in state_dict.items()}
    return state_dict
cp = update_state_dict2(checkpoint)


PATH = 'dataset/checkpoints/dnabert_t12.pt'
torch.save({
            'model': cp,
            }, PATH)
checkpoint = torch.load(PATH)
for k, v in checkpoint['model'].items():
    print(k, '      ', v.shape)

embeddings.word_embeddings.weight        torch.Size([1029, 768])
embeddings.position_embeddings.weight        torch.Size([512, 768])
embeddings.token_type_embeddings.weight        torch.Size([2, 768])
embeddings.LayerNorm.weight        torch.Size([768])
embeddings.LayerNorm.bias        torch.Size([768])
encoder.layer.0.attention.self.query.weight        torch.Size([768, 768])
encoder.layer.0.attention.self.query.bias        torch.Size([768])
encoder.layer.0.attention.self.key.weight        torch.Size([768, 768])
encoder.layer.0.attention.self.key.bias        torch.Size([768])
encoder.layer.0.attention.self.value.weight        torch.Size([768, 768])
encoder.layer.0.attention.self.value.bias        torch.Size([768])
encoder.layer.0.attention.output.dense.weight        torch.Size([768, 768])
encoder.layer.0.attention.output.dense.bias        torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.weight        torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.bias        t

In [5]:
for k, v in cp.items():
    print(k, '      ', v.shape)

embeddings.word_embeddings.weight        torch.Size([1029, 768])
embeddings.position_embeddings.weight        torch.Size([512, 768])
embeddings.token_type_embeddings.weight        torch.Size([2, 768])
embeddings.LayerNorm.weight        torch.Size([768])
embeddings.LayerNorm.bias        torch.Size([768])
encoder.layer.0.attention.self.query.weight        torch.Size([768, 768])
encoder.layer.0.attention.self.query.bias        torch.Size([768])
encoder.layer.0.attention.self.key.weight        torch.Size([768, 768])
encoder.layer.0.attention.self.key.bias        torch.Size([768])
encoder.layer.0.attention.self.value.weight        torch.Size([768, 768])
encoder.layer.0.attention.self.value.bias        torch.Size([768])
encoder.layer.0.attention.output.dense.weight        torch.Size([768, 768])
encoder.layer.0.attention.output.dense.bias        torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.weight        torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.bias        t