In [3]:
# save a default checkpoint
import torch

from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel


#config = Configuration.createOrLoad('configs/testencoder.yaml')
config = Configuration.createOrLoad('training/melody/20220601-matchjointer-raw')
model = loadModel(config['model'])

# initialize parameters
for p in model.parameters():
	if p.dim() > 1:
		torch.nn.init.xavier_uniform_(p)

torch.save({'model': model.state_dict()}, config.localPath('test.chkpt'))


In [22]:
# deduce test
import torch

from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel


config = Configuration.createOrLoad('training/melody/20220526-testencoder')
model = loadModel(config['model'])

checkpoint = torch.load(config.localPath(config['best']), map_location='cpu')
model.load_state_dict(checkpoint['model'])

model.eval()

with torch.no_grad():
	input = torch.tensor([1, 2, 3, 4], dtype=torch.float32).reshape((1, 1, 4))
	print('input:', input)
	output = model(input)
	print('out:', output)


input: tensor([[[1., 2., 3., 4.]]])
out: tensor([[[-1.2304, -0.1907, -0.1379,  1.5591]]])


In [4]:
# MatchJointerRaw test
import torch

from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel


config = Configuration.createOrLoad('training/melody/20220601-matchjointer-raw')
model = loadModel(config['model'])

checkpoint = torch.load(config.localPath(config['best']), map_location='cpu')
model.load_state_dict(checkpoint['model'])

model.eval()

with torch.no_grad():
	c = torch.tensor([[[1, -1], [2, -2]]], dtype=torch.float32)
	s = torch.tensor([[[0, 0], [0, 0]]], dtype=torch.float32)
	matching = model(c, s)
	print('matching:', matching)


matching: (tensor([[[1., 1.],
         [1., 1.]]]), tensor([[[ 0.7071, -0.7071],
         [ 0.7071, -0.7071]]]), tensor([[[ 0.7071,  0.7071],
         [-0.7071, -0.7071]]]))


In [2]:
# MatchJointer1 test
import torch

from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel


config = Configuration.createOrLoad('training/melody/20220609-matchjointer1-test')
model = loadModel(config['model'])

checkpoint = torch.load(config.localPath(config['best']), map_location='cpu')
model.load_state_dict(checkpoint['model'])

model.eval()

with torch.no_grad():
	c = (torch.tensor([[0, 1, 2]], dtype=torch.float32), torch.tensor([[60, 62, 64]], dtype=torch.long), torch.tensor([[0.4, 0.5, 0.6]], dtype=torch.float32))
	s = (torch.tensor([[0, 1, 2]], dtype=torch.float32), torch.tensor([[60, 62, 64]], dtype=torch.long), torch.tensor([[0.4, 0.5, 0.6]], dtype=torch.float32))
	matching, vc, vs = model(*c, *s)
	print('matching:', matching)


matching: tensor([[[0.2726, 0.2509, 0.2567],
         [0.2726, 0.2509, 0.2566],
         [0.2725, 0.2509, 0.2566]]])


In [22]:
# unsqueeze operation
import torch


class TestModel(torch.nn.Module):
	def __init__(self):
		super().__init__()

	def forward(self, x):
		#return x.unsqueeze(-1).repeat(1, 1, 3)
		#return x.repeat(1, 3)
		#return x.unsqueeze(-1).tile((1, 1, 3))
		#return torch.tile(x.unsqueeze(-1), (1, 1, 3))
		return torch.cat([x]*3, dim=-1)


model = TestModel()

scriptedm = torch.jit.script(model)
scriptedm.save('./test.pt')

torch.onnx.export(model, (torch.zeros(1,3, dtype=torch.float32),), './test.onnx',
	verbose=True,
	input_names=('x',),
	output_names=('y',),
	opset_version=12)


graph(%x : Float(1, 3, strides=[3, 1], requires_grad=0, device=cpu)):
  %y : Float(1, 9, strides=[9, 1], requires_grad=0, device=cpu) = onnx::Concat[axis=-1](%x, %x, %x) # <ipython-input-22-b754483d1ad1>:14:0
  return (%y)

