In [2]:
# MatchJointerRaw test
import torch

from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel


config = Configuration.createOrLoad('configs/matchjointer1-test.yaml')
model = loadModel(config['model'])


c_time, c_pitch, c_velocity = torch.rand(1, 10) * 10000, torch.randint(128, (1, 10)), torch.rand(1, 10) * 127
s_time, s_pitch, s_velocity = torch.rand(1, 6) * 10000, torch.randint(128, (1, 6)), torch.rand(1, 6) * 127

matching, c_vec, s_vec = model(c_time, c_pitch, c_velocity, s_time, s_pitch, s_velocity)

print('matching:', matching.shape)
print('c_vec:', c_vec.shape)
print('s_vec:', s_vec.shape)


matching: torch.Size([1, 6, 10])
c_vec: torch.Size([1, 6, 128])
s_vec: torch.Size([1, 128, 10])


In [2]:
# dataset
import os
from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset


DATA_DIR = os.environ.get('DATA_DIR')

config = Configuration.create('configs/matchjointer1-test.yaml')
train, val = loadDataset(config, data_dir=DATA_DIR, device='cpu')

print('n_examples:', train.dataset.n_examples, val.dataset.n_examples)

it = iter(train)
batch = next(it)
print('tensors:', [(tensor.shape, tensor.dtype) for tensor in batch['criterion']], [(tensor.shape, tensor.dtype) for tensor in batch['sample']], batch['ci'].shape)

next(it)
next(it)
next(it)
batch = next(it)
print('cis:', batch['ci'][:, -15:])
print('sample_mask:', batch['sample'][1] > 0)
print('pitch-c', batch['criterion'][1][:, :20])
print('pitch-s', batch['sample'][1][:, -20:])


n_examples: 14997 3809
tensors: [(torch.Size([4, 64]), torch.float32), (torch.Size([4, 64]), torch.int64), (torch.Size([4, 64]), torch.float32)] [(torch.Size([4, 64]), torch.float32), (torch.Size([4, 64]), torch.int64), (torch.Size([4, 64]), torch.float32)] torch.Size([4, 64])
cis: tensor([[ 3,  4,  5,  7,  6,  8,  9, 10, 11, 12, 13, 15, 14, 16, 16],
        [ 4,  5,  7,  6,  8,  9, 10, 11, 12, 13, 15, 14, 16, 16, 17],
        [ 5,  7,  6,  8,  9, 10, 11, 12, 13, 15, 14, 16, 16, 17, 18],
        [ 7,  6,  8,  9, 10, 11, 12, 13, 15, 14, 16, 16, 17, 18, 19]])
sample_mask: tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False,  True,  True,  True,
          True,  True,  True

In [3]:
# MatchJointer1Loss
import torch

from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel


config = Configuration.createOrLoad('configs/matchjointer1-test.yaml')
model = loadModel(config['model'], postfix='Loss')

loss, metric = model(batch)

print('loss:', loss)
print('metric:', metric)


loss: tensor(0.2318, grad_fn=<BinaryCrossEntropyBackward0>)
metric: {'accuracy': 0.0}
