Skip to content

Using the trained model #2

@fteufel

Description

@fteufel

Hi @zatchwu I want to use your trained model to sample/score signal peptides.

The following is what I came up with by going through the provided notebooks and trying to get a more straightforward sequence -> model -> prediction worflow independent of the datasets you were using. It would be great to get some feedback whether what I'm doing here is correct.

  1. Encoding amino acid data for the transformer
SPGEN_AA_TO_ID = {
 ' ': 0,
 '$': 1,
 '.': 2,
 'A': 3,
 'C': 4,
 'D': 5,
 'E': 6,
 'F': 7,
 'G': 8,
 'H': 9,
 'I': 10,
 'K': 11,
 'L': 12,
 'M': 13,
 'N': 14,
 'P': 15,
 'Q': 16,
 'R': 17,
 'S': 18,
 'T': 19,
 'U': 20,
 'V': 21,
 'W': 22,
 'X': 23,
 'Y': 24,
 'Z': 25,
}

sp = [SPGEN_AA_TO_ID['$']] + [SPGEN_AA_TO_ID[x] for x in sp] + [SPGEN_AA_TO_ID['.']]
prot = [SPGEN_AA_TO_ID['$']] + [SPGEN_AA_TO_ID[x] for x in prot] + [SPGEN_AA_TO_ID['.']]
  1. Loading the model
def load_spgen_model():
    # the weights were extracted from the .chkpt file with the same name
    state_dict = torch.load('../../SPGen/remote_generation/signal_peptide/outputs/SIM99_550_12500_64_6_5_0.1_64_100_0.0001_-0.03_99_weightsonly.pt')
    model = Models.Transformer(
        27,
        27,
        107,
        proj_share_weight=True,
        embs_share_weight=True,
        d_k=64,
        d_v=64,
        d_model=550,
        d_word_vec=550,
        d_inner_hid=1100,
        n_layers=6,
        n_head=5,
        dropout=0.1)

    model.load_state_dict(state_dict)
    model.eval()

    return model
  1. Making predictions (logits) and scoring the perplexity. I encode the data as shown in step 1, and make prot_positions, sp_positions masks that are 0 at true positions and 1 at masked positions.
def get_perplexity_batch(transformer, src_seq, src_positions, tgt_seq, tgt_positions):
    '''Adapted from Translator()._epoch().'''
    ppls = []

    loss_fn = torch.nn.CrossEntropyLoss()

    pred = transformer((src_seq, src_positions), (tgt_seq, tgt_positions))

    # process each seq in batch
    for idx in range(len(src_seq)):
        loss = loss_fn(pred[idx].view(-1, 27), tgt_seq[idx,1:].view(-1))
        ppls.append(torch.exp(loss).item())

    return ppls


def predict_spgen(model, loader):
        
    with torch.no_grad():
        ppl = []
        for idx, batch in tqdm(enumerate(loader), total=len(loader)):

            proteins, prot_positions, sps, sp_positions = batch
            proteins, prot_positions, sps, sp_positions = proteins.to(device), prot_positions.to(device), sps.to(device), sp_positions.to(device)

            aa_logits = model((proteins,prot_positions), (sps, sp_positions))

            ppls = get_perplexity_batch(model, proteins, prot_positions, sps, sp_positions)

            ppl.extend(ppls)

    return np.array(ppl)

My code is running, but it is a bit hard to tell whether everything is in place or there's an error somewhere. Would be great to get some feedback - also open to any other way to make the model run on new data.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions