In [None]:
import torch
import numpy as np
from tqdm import tqdm

In [2]:
model_path = "m01.chkpt"
chkpt = torch.load(model_path)
opt = chkpt["settings"]
epoch = chkpt["epoch"]
model_state = chkpt["model"]

In [3]:
import transformer.Models
from transformer.Optim import ScheduledOptim
import torch.optim as optim
import torch.utils.data
from dataset import ProteinDataset, paired_collate_fn
from train import cal_performance

In [4]:
device = torch.device('cuda' if opt.cuda else 'cpu')
the_model = transformer.Models.Transformer(opt.max_token_seq_len,
        d_k=opt.d_k,
        d_v=opt.d_v,
        d_model=opt.d_model,
        d_inner=opt.d_inner_hid,
        n_layers=opt.n_layers,
        n_head=opt.n_head,
        dropout=opt.dropout)
the_model.load_state_dict(model_state)

In [5]:
data = torch.load(opt.data)
data_loader = torch.utils.data.DataLoader(
        ProteinDataset(
            seqs=data['test']['seq'],
            angs=data['test']['ang']),
        num_workers=2,
        batch_size=1,#opt.batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True)

In [15]:
with torch.no_grad():
    for batch in tqdm(
            data_loader, mininterval=2,
            desc='  - (Evaluation) ', leave=False):

        # prepare data
        src_seq, src_pos, tgt_seq, tgt_pos = batch
        gold = tgt_seq[:, 1:]

        # forward
        pred = the_model(src_seq, src_pos, tgt_seq, tgt_pos)
        loss = cal_performance(pred, gold, device)
        print("Loss: {0:.2f}, Predshape: {1}".format(loss, pred.shape))

  - (Evaluation) :   0%|          | 0/16 [00:00<?, ?it/s]

Loss: 1.76, Predshape: torch.Size([48, 11])
Loss: 1.93, Predshape: torch.Size([15, 11])
Loss: 1.24, Predshape: torch.Size([43, 11])
Loss: 1.32, Predshape: torch.Size([45, 11])
Loss: 1.46, Predshape: torch.Size([47, 11])
Loss: 1.60, Predshape: torch.Size([51, 11])
Loss: 1.57, Predshape: torch.Size([50, 11])
Loss: 1.59, Predshape: torch.Size([23, 11])
Loss: 1.21, Predshape: torch.Size([45, 11])


                                                         

Loss: 1.55, Predshape: torch.Size([37, 11])
Loss: 1.73, Predshape: torch.Size([35, 11])
Loss: 1.39, Predshape: torch.Size([45, 11])
Loss: 1.72, Predshape: torch.Size([55, 11])
Loss: 2.04, Predshape: torch.Size([25, 11])
Loss: 1.26, Predshape: torch.Size([45, 11])
Loss: 1.52, Predshape: torch.Size([16, 11])




In [10]:
np.array(pred)

array([[-1.20497656e+00, -5.47692180e-01,  3.97164971e-01,
         1.88373792e+00,  1.96541476e+00,  2.04401517e+00,
        -6.39292061e-01, -2.67679784e-02,  4.13794518e-02,
         9.07431170e-02, -1.42316369e-03],
       [-1.16540074e+00, -5.31542122e-01,  4.18082595e-01,
         1.87687898e+00,  1.93631840e+00,  2.02516270e+00,
        -6.18998051e-01, -3.15383337e-02,  4.82610576e-02,
         6.15231209e-02, -1.89916547e-02],
       [-1.24943113e+00, -6.18891120e-01,  3.16177011e-01,
         1.79587376e+00,  1.89624989e+00,  1.97840679e+00,
        -6.66306376e-01, -1.18638165e-01, -8.54163840e-02,
         2.04268023e-01,  1.60053857e-02],
       [-1.20393419e+00, -5.47202289e-01,  3.98404777e-01,
         1.88185394e+00,  1.96413124e+00,  2.04177070e+00,
        -6.40191197e-01, -2.51256172e-02,  3.96777391e-02,
         9.15276781e-02, -1.86934171e-03],
       [-1.20396948e+00, -5.47154009e-01,  3.97958577e-01,
         1.88195574e+00,  1.96435142e+00,  2.04204178e+00,
  

In [None]:
opt.cuda