In [1]:
import argparse
import pickle
import numpy as np

import torch
import torch.utils.data as data
import torch.nn as nn
import pytorch_lightning as pl

from lib.utils import build_model_tabular_suhan
import lib.layers.odefunc_suhan as odefunc

SOLVERS = ["dopri5"]
parser = argparse.ArgumentParser('NodeIK')
parser.add_argument(
    '--data', choices=['2spirals_1d','2spirals_2d', 'swissroll_1d','swissroll_2d', 'circles_1d', 'circles_2d', '2sines_1d', 'target_1d'],
    type=str, default='2spirals_1d'
)
parser.add_argument("--layer_type", type=str, default="concatsquash", choices=["concatsquash"])
parser.add_argument('--dims', type=str, default='64-64-64')
parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.')
parser.add_argument('--time_length', type=float, default=0.5)
parser.add_argument('--train_T', type=eval, default=True)
parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"])
parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES)

parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS)
parser.add_argument('--atol', type=float, default=1e-5)
parser.add_argument('--rtol', type=float, default=1e-5)

parser.add_argument('--residual', type=eval, default=False, choices=[True, False])
parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False])
parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False])
parser.add_argument('--niters', type=int, default=36000)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=1e-5)

# for the proposed method
parser.add_argument('--std_min', type=float, default=0.0)
parser.add_argument('--std_max', type=float, default=0.1)
parser.add_argument('--std_weight', type=float, default=2)

parser.add_argument('--viz_freq', type=int, default=100)
parser.add_argument('--val_freq', type=int, default=400)
parser.add_argument('--log_freq', type=int, default=10)
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args([])

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

class Learner(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.model = model
        self.iters = 0

model = build_model_tabular_suhan(args, 7).to(device)
learn = Learner.load_from_checkpoint('model/panda_sample_model.ckpt',model=model)
model.eval()
model.chain[0].odefunc.odefunc.calc_density = False



In [2]:
import time

input_pose = np.array([6.1946e-01, -1.6464e-02,  8.6722e-01,  4.7658e-01,  4.9979e-01,  7.2251e-01, -3.2554e-02])

max_len = 40960
z = torch.normal(0, 1, size=(max_len, 7)).to(device)
c = torch.from_numpy(input_pose).float().to(device)
print('origin_c',c)
cc = torch.stack([c]*max_len).to(device)
zero = torch.zeros(z.shape[0], 1).to(z)

start_2 = time.time()
model.chain[0].odefunc.odefunc.calc_density = True
xx, delta_logp = model(z, cc, zero,reverse=True)
end_2 = time.time()
evals = model.chain[0].num_evals()
print('evals',evals)
print('after q', xx[:max_len,:])
print(delta_logp[:max_len,0])
print('time',(end_2 - start_2) * 1000,'ms')

origin_c tensor([ 0.6195, -0.0165,  0.8672,  0.4766,  0.4998,  0.7225, -0.0326],
       device='cuda:0')
evals 68.0
after q tensor([[ 1.4650e-01,  2.4423e-01, -2.8553e-01,  ...,  1.3234e+00,
          2.1234e+00, -1.1380e+00],
        [ 2.3675e-01,  1.0548e+00, -2.1947e+00,  ...,  3.0969e+00,
          1.6803e+00, -6.7985e-01],
        [ 9.4373e-01,  2.5037e-03, -9.7939e-01,  ...,  1.4871e+00,
          2.2285e+00, -1.4273e+00],
        ...,
        [ 1.3677e-01,  9.7758e-01, -2.2528e+00,  ...,  3.1838e+00,
          1.7515e+00, -6.0714e-01],
        [ 1.3311e+00,  4.8811e-01, -1.9605e+00,  ...,  2.0299e+00,
          1.8115e+00, -1.2990e+00],
        [-2.5089e+00, -6.1415e-01,  1.5235e+00,  ...,  2.1621e+00,
          1.6336e+00, -1.1419e+00]], device='cuda:0', grad_fn=<SliceBackward0>)
tensor([17.5565, 17.4163, 16.1689,  ..., 15.6779, 17.7838, 13.4034],
       device='cuda:0', grad_fn=<SelectBackward0>)
time 1099.3585586547852 ms


In [3]:

sorted, idx = torch.sort(-delta_logp[:,0], axis=0)

print('sorted',-sorted[:max_len])
print('idx',idx)
xxx = xx[idx]
print('sorted xxx',xxx[:max_len,:])

sorted tensor([19.7294, 19.6539, 19.6259,  ..., -0.8144, -1.5827, -1.6842],
       device='cuda:0', grad_fn=<NegBackward0>)
idx tensor([31872, 29048,  6142,  ..., 10457, 31825, 39546], device='cuda:0')
sorted xxx tensor([[-0.1419,  0.2324,  0.0453,  ...,  1.3363,  2.1065, -1.1745],
        [-0.1302,  0.2391,  0.0374,  ...,  1.3529,  2.1084, -1.1991],
        [ 0.2052,  0.2530, -0.4256,  ...,  1.4929,  2.0424, -1.1877],
        ...,
        [ 0.9368,  0.7388, -1.9846,  ..., -0.3629,  3.4094,  4.2433],
        [-4.1963, -0.1746, -2.2052,  ...,  5.8940,  1.7983,  3.0739],
        [-2.1688,  0.0978, -0.9324,  ..., -3.2472,  2.1744, -0.0463]],
       device='cuda:0', grad_fn=<SliceBackward0>)
