In [1]:
# dataset
import os
from starry.utils.config import Configuration
from starry.utils.dataset_factory import loadDataset


DATA_DIR = os.environ.get('DATA_DIR')

config = Configuration.create('configs/evtopo-test.yaml')
train, val = loadDataset(config, data_dir=DATA_DIR, device='cpu')

it = iter(val)
tensors = next(it)

for k, v in tensors.items():
	print(f'{k}:\n', v[0].shape)


type:
 torch.Size([19])
staff:
 torch.Size([19])
feature:
 torch.Size([19, 15])
x:
 torch.Size([19])
y1:
 torch.Size([19])
y2:
 torch.Size([19])
matrixH:
 torch.Size([324])
tick:
 torch.Size([19])
division:
 torch.Size([19])
dots:
 torch.Size([19])
beam:
 torch.Size([19])
stemDirection:
 torch.Size([19])
grace:
 torch.Size([19])
timeWarped:
 torch.Size([19])
fullMeasure:
 torch.Size([19])
fake:
 torch.Size([19])


In [2]:
# RectifySieveJointer
import torch
from starry.topology.models.rectifyJointer import RectifySieveJointer


model = RectifySieveJointer(1, 1)

with torch.no_grad():
	rec, j = model(tensors)

assert tensors['matrixH'][0].shape == j[0].shape

print('types:', tensors['type'].shape)
print('rec:', {k: v[0, 0] for k, v in rec.items()})
print('j:', len(j), j[0].shape)


types: torch.Size([64, 19])
rec: {'tick': tensor(0.7731), 'division': tensor([0.3088, 0.2697, 0.0939, 0.1620, 0.0754, 0.0462, 0.0438]), 'dots': tensor([0.2978, 0.5361, 0.1661]), 'beam': tensor([0.1714, 0.2564, 0.1247, 0.4474]), 'stemDirection': tensor([0.2084, 0.3240, 0.4676]), 'grace': tensor(0.4681), 'timeWarped': tensor(0.4583), 'fullMeasure': tensor(0.7600), 'fake': tensor(0.5281)}
j: 64 torch.Size([324])


In [6]:
# RectifySieveJointerLoss
import torch
from starry.topology.models.rectifyJointer import RectifySieveJointerLoss


model = RectifySieveJointerLoss(n_trunk_layers=1, n_rectifier_layers=1)

with torch.no_grad():
	loss, metrics = model(tensors)

print('loss:', loss)
print('metrics:', {k: str(v) for k, v in metrics.items()})


loss: tensor(47.3543)
metrics: {'acc_topo': '0.0 (1)', 'loss_topo': '3.637711524963379 (1)', 'error_tick': '1067.55615234375 (1024)', 'acc_division': '0.205078125 (1024)', 'acc_dots': '0.6318359375 (1024)', 'acc_beam': '0.21153846383094788 (832)', 'acc_stemDirection': '0.21995192766189575 (832)', 'acc_grace': '0.0012019231216982007 (832)', 'acc_timeWarped': '0.7607421875 (1024)', 'acc_fullMeasure': '0.8125 (64)', 'acc_fake': '0.208984375 (1024)'}


In [5]:
tensors = next(it)

with torch.no_grad():
	loss, metrics = model(tensors)

print('loss:', loss)
print('metrics:', {k: str(v) for k, v in metrics.items()})


loss: tensor(41.0790)
metrics: {'acc_topo': '0.0 (1)', 'loss_topo': '3.0107696056365967 (1)', 'error_tick': '1068.025634765625 (1024)', 'acc_division': '0.0009765625 (1024)', 'acc_dots': '0.1201171875 (1024)', 'acc_beam': '0.2788461446762085 (832)', 'acc_stemDirection': '0.1430288404226303 (832)', 'acc_grace': '0.9387019276618958 (832)', 'acc_timeWarped': '0.0693359375 (1024)', 'acc_fullMeasure': '0.109375 (64)', 'acc_fake': '0.8427734375 (1024)'}
