In [None]:
from FLIP.baselines.esm_next.models.esmc import ESMC
import torch
from build_dgl_graph import normalize_adj
import numpy as np
import dgl
from utils import Config
from models import build_model
from utils_flip_self import *

## Regression Task Example
We provide the example model weights trained on HP-S. Please download it from https://zenodo.org/records/15736743.

⚠️Attention: Please train a new model on custom datasets.

In [51]:
seq = 'MNPLIKRIIFIVVVLGLWELGSQLQIWSELILPAPSSVGVALYEGFANMTLVYDFVASFRRLIIGLAIALFIGTSIGLLIAKSKTADDTIGSMLLAFQSVPSIVWLPLAMMWFGLNDKAVIFVVVLGGTFVMAMNVRSGIKNVSPDFIRAARTMGAKGFDLFIRVIFPASIPYFVTGSRLAWAFAWRALIAGELLSTGPGLGYTLSYASDFGEMELVIGVMIIIGVIGLIVDQVIFQRIEKSVAKRWGLEL'  # sampled from HP-S-test
config = 'configs/S_esmc/model3.py'
cfg = Config.fromfile(config)
ckpt_path = 'results/S_esmc/seed-101/model3/epoch_best.pth'

if len(seq) > 800:
	seq = seq[:800]
encoder = ESMC.from_pretrained('esmc_600m')
encoder.eval()
if not torch.cuda.is_available():
	raise RuntimeError('This script requires a GPU to run.')

encoder = encoder.cuda()

with torch.no_grad():
	toks = encoder._tokenize([seq])
	toks = toks.to('cuda')
	_, embedding, _, attention = encoder(toks)
	embedding = embedding[0, 1: len(seq) + 1].float().cpu().numpy()
	attention = attention[0, :len(seq), :len(seq)].float().cpu().numpy()

	src, dst = np.nonzero(attention)
	edge_feats = normalize_adj(attention)
	edge_feats = edge_feats[np.nonzero(edge_feats)]
	graph = dgl.graph((src, dst), num_nodes=len(seq))
	graph.ndata['x'] = torch.from_numpy(embedding).float()
	graph.edata['x'] = torch.from_numpy(edge_feats).float()
	graph = graph.to('cuda')

	model = build_model(cfg.model)
	model.eval()
	model = model.cuda()
	model, test_epoch = load_ckpt(model, ckpt_path)
	pred_tm = model.forward_test(graph)
	print(f'Predicted Tm: {pred_tm.item()}')


==> load test checkpoint..
test epoch:  10
Predicted Tm: 51.55893325805664


## Classification Task Example
We provide the example model weights trained on HP-S with 5 classes. Please download it from https://zenodo.org/records/15736743.

⚠️Attention: Please train a new model on custom datasets.


In [48]:
seq = 'MNPLIKRIIFIVVVLGLWELGSQLQIWSELILPAPSSVGVALYEGFANMTLVYDFVASFRRLIIGLAIALFIGTSIGLLIAKSKTADDTIGSMLLAFQSVPSIVWLPLAMMWFGLNDKAVIFVVVLGGTFVMAMNVRSGIKNVSPDFIRAARTMGAKGFDLFIRVIFPASIPYFVTGSRLAWAFAWRALIAGELLSTGPGLGYTLSYASDFGEMELVIGVMIIIGVIGLIVDQVIFQRIEKSVAKRWGLEL'  # sampled from HP-S-test
config = 'configs/S_esmc_cls/model0.py'
cfg = Config.fromfile(config)
ckpt_path = 'results/S_esmc_cls/seed-101/model0/epoch_best.pth'

if len(seq) > 800:
	seq = seq[:800]
encoder = ESMC.from_pretrained('esmc_600m')
encoder.eval()
if not torch.cuda.is_available():
	raise RuntimeError('This script requires a GPU to run.')

encoder = encoder.cuda()

with torch.no_grad():
	toks = encoder._tokenize([seq])
	toks = toks.to('cuda')
	_, embedding, _, attention = encoder(toks)
	embedding = embedding[0, 1: len(seq) + 1].float().cpu().numpy()
	attention = attention[0, :len(seq), :len(seq)].float().cpu().numpy()

	src, dst = np.nonzero(attention)
	edge_feats = normalize_adj(attention)
	edge_feats = edge_feats[np.nonzero(edge_feats)]
	graph = dgl.graph((src, dst), num_nodes=len(seq))
	graph.ndata['x'] = torch.from_numpy(embedding).float()
	graph.edata['x'] = torch.from_numpy(edge_feats).float()
	graph = graph.to('cuda')

	model = build_model(cfg.model)
	model.eval()
	model = model.cuda()
	model, test_epoch = load_ckpt(model, ckpt_path)
	logits = model.forward_test(graph)
	pred_cls = logits.argmax(dim=1).cpu().numpy()
	print(f'Predicted class: {pred_cls[0]}')


==> load test checkpoint..
test epoch:  9
Predicted class: 3
