In [1]:
import torch
import numpy
import sys

sys.path.append('../')
sys.path.append('../model')

In [2]:
from classifiers import EncoderClassifier, EncoderClassifierConfig
from training import run_training, LightningWrapper
from encoder import create_model, add_arguments
from esm_train import get_esm, device, parser
from data_loading import prepare_datasets
from torchmetrics import F1Score, MatthewsCorrCoef, Precision, Recall, AUROC, \
MeanMetric, AveragePrecision, PrecisionRecallCurve, MetricCollection
from torch.utils.data import DataLoader
from functools import partial
from data_loading import prep_batch

  from .autonotebook import tqdm as notebook_tqdm


cuda:0


In [3]:
chkpt_path = '../model/new_logs/encoder_S_60_focal/fold_0/chkpt.ckpt'
checkpoint = torch.load(chkpt_path)
args = parser.parse_args(args=[])

for k, v in checkpoint['hyper_parameters'].items():
    args.__setattr__(k, v)

model, tokenizer = create_model(args)


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EncoderClassifier(
  (base): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 480, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 480, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-11): 12 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=480, out_features=480, bias=True)
              (key): Linear(in_features=480, out_features=480, bias=True)
              (value): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((480,), eps=1e-05, element

In [4]:
full_dataset = prepare_datasets(args, ignore_label=args.ignore_label)

step_metrics = MetricCollection({
    'f1' : F1Score(task='binary', ignore_index=args.ignore_label),
    'precision' : Precision(task='binary',ignore_index=args.ignore_label),
    'recall' : Recall(task='binary', ignore_index=args.ignore_label),
})

epoch_metrics = MetricCollection({
    'f1' : F1Score(task='binary', ignore_index=args.ignore_label),
    'precision' : Precision(task='binary',ignore_index=args.ignore_label),
    'recall' : Recall(task='binary', ignore_index=args.ignore_label),
    'auroc' : AUROC('binary', ignore_index=args.ignore_label),
    'auprc' : AveragePrecision('binary', ignore_index=args.ignore_label),
    'mcc' : MatthewsCorrCoef('binary', ignore_index=args.ignore_label)
})

In [5]:
train_ds, dev_ds, test_ds = full_dataset.get_fold(0)
train = DataLoader(train_ds, args.batch_size, shuffle=True,
                    collate_fn=partial(prep_batch, tokenizer=tokenizer, ignore_label=args.ignore_label),
                    persistent_workers=True if args.num_workers > 0 else False, 
                    num_workers=args.num_workers )
dev = DataLoader(dev_ds, args.batch_size, shuffle=False,
                    collate_fn=partial(prep_batch, tokenizer=tokenizer, ignore_label=args.ignore_label),
                    persistent_workers=True if args.num_workers > 0 else False,
                    num_workers=args.num_workers)

test = DataLoader(test_ds, args.batch_size, shuffle=False,
                    collate_fn=partial(prep_batch, tokenizer=tokenizer, ignore_label=args.ignore_label),
                    persistent_workers=True if args.num_workers > 0 else False,
                    num_workers=args.num_workers)

Train size: 7180
Dev size: 1796
Test size: 2244


In [6]:
    model = LightningWrapper(args, model, step_metrics=step_metrics, epoch_metrics=epoch_metrics, ds_size=len(train), logdir='.')
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model = model.classifier

In [7]:
class OutputGatherModule(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        self.last_batch = None
        
    def forward(self, *args, **kwargs):
        out = self.module(*args, **kwargs)
        self.last_batch = out
        return out
    

In [8]:
model.encoder = OutputGatherModule(model.encoder)

In [9]:
embed_df = {}
all_embeds = []
with torch.no_grad():
    for batch in test:
        loss, batch_preds = model.predict(**batch.to(model.device))
        batch_labels = batch['labels']
        embeds = model.encoder.last_batch
        mask = batch_labels != model.ignore_index
        ids = test_ds.data.iloc[batch['indices'].cpu().numpy()]['id']

        # Save the predictions
        for i, id in enumerate(ids):
            embed_df[id] = {}
            embed_df[id]['prot_embed'] = embeds[i][0].cpu().numpy()
            embed_df[id]['embeds'] = embeds[i][1:][mask[i]].cpu().numpy()
            all_embeds.extend(embed_df[id]['embeds'].tolist())
            embed_df[id]['mask'] = mask[i].cpu().numpy()


  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


In [10]:
from sklearn.manifold import TSNE
import pandas as pd
import numpy as np

In [11]:
tsne = TSNE(n_jobs=15)

In [12]:
embed_df = pd.DataFrame.from_dict(embed_df, orient='index')

In [13]:
all_embeds = np.asarray(all_embeds)

In [None]:
tsne.fit(all_embeds)