# Inference with METL-Global

In [10]:
import torch
import torchextractor as tx
import torchinfo

import metl

# Load a METL-G model

In [5]:
model, data_encoder = metl.get_from_ident("METL-G-20M-1D")

In [7]:
summary = torchinfo.summary(model, depth=4, verbose=1, row_settings=["var_names"])

Layer (type (var_name))                                                Param #
AttnModel (AttnModel)                                                  --
├─SequentialWithArgs (model)                                           --
│    └─ScaledEmbedding (embedder)                                      --
│    │    └─Embedding (embedding)                                      10,752
│    └─RelativeTransformerEncoder (tr_encoder)                         --
│    │    └─ModuleList (layers)                                        --
│    │    │    └─RelativeTransformerEncoderLayer (0)                   3,154,560
│    │    │    └─RelativeTransformerEncoderLayer (1)                   3,154,560
│    │    │    └─RelativeTransformerEncoderLayer (2)                   3,154,560
│    │    │    └─RelativeTransformerEncoderLayer (3)                   3,154,560
│    │    │    └─RelativeTransformerEncoderLayer (4)                   3,154,560
│    │    │    └─RelativeTransformerEncoderLayer (5)                

# Set up representation extraction
For METL-Global models, I recommend using the representation immediately after the GlobalAveragePooling (avg_pooling) layer. For METL-Local models, I recommend using the representation immediately after the final fully connected layer (fc1). 

In [12]:
return_layers = [
    "model.avg_pooling",
]

extractor = tx.Extractor(model.eval(), return_layers)

# Test a couple sequences

In [25]:
# note: make sure all the sequences in a batch are the same length
amino_acid_sequences = ["SMART", "MAGIC"]
encoded_seqs = data_encoder.encode_sequences(amino_acid_sequences)

with torch.no_grad():
    model_out, intermediate_out = extractor(torch.tensor(encoded_seqs))

In [29]:
# model_out contains the final output of the model (Rosetta energy term predictions)
# there are 55 energy terms, the first one is total_score 
# they are listed in order on the main README
model_out.shape

torch.Size([2, 55])

In [32]:
# intermediate_out is a dictionary containing intermediate outputs 
# for all the return_layers specified above
# METL-G has an embedding dimension of 512, thus outputs will be 512
intermediate_out["model.avg_pooling"].shape

torch.Size([2, 512])

# Additional notes
The above will retrieve a length 512 sequence-level representation immediately following the global average pooling layer, which takes the average of residue-level representations. 

If you want, you can also get the residue-representations. You can also play around with the sequence-level representation from after the FC layer, although I haven't had as much success with this representation for my tasks (too specific to the Rosetta energies?). You may have more luck with it, though. 

In [34]:
# the above will retrieve a length 512 sequence-level representation
# you can also get a representation for each residue

return_layers = [
    "model.tr_encoder", # residue-level representation
    "model.avg_pooling", # sequence-level representation following avg pooling
    "model.fc1", # sequence-level representation following the final fully connected layer
]

extractor = tx.Extractor(model.eval(), return_layers)

amino_acid_sequences = ["SMART", "MAGIC"]
encoded_seqs = data_encoder.encode_sequences(amino_acid_sequences)

with torch.no_grad():
    model_out, intermediate_out = extractor(torch.tensor(encoded_seqs))

In [42]:
for k, v in intermediate_out.items():
    print("Layer: {}\nOutput shape: {}\n".format(k, v.shape))

Layer: model.tr_encoder
Output shape: torch.Size([2, 5, 512])

Layer: model.avg_pooling
Output shape: torch.Size([2, 512])

Layer: model.fc1
Output shape: torch.Size([2, 512])
