In [None]:
# dataset preparation - 1
from starry.topology.data import *


file = open('./test/test.json', 'rb')
data = loadClusterSet(file)

seq_id, seq_position, masks, matrixH, matrixV = exampleToTensors(data['clusters'][0], 0x100, 0x200)
#print(seq_id, seq_position, masks, matrixH, matrixV)
print(matrixH, matrixV)


In [None]:
# dataset preparation - 2
examples = list(map(lambda ex: exampleToTensors(ex, 0x100, 0x200), data['clusters']))
dataset = batchizeTensorExamples(examples, 4)

print('dataset:', dataset[0]['seq_id'].shape, dataset[0]['seq_position'].shape, dataset[0]['mask'].shape, dataset[0]['matrixH'].shape, dataset[0]['matrixV'].shape)
#print('seq_id.0', dataset[0]['seq_id'][0])
#print('seq_position.0', dataset[0]['seq_position'][0][1])


In [None]:
# dataset scatter
from starry.topology.data import *


dataset1, dataset2 = DatasetScatter.loadPackage('zip://./test/test.zip', batch_size=2, splits='1/3:2/3', device='cpu')

print(len(dataset1.entries), dataset1.entries[0]['filename'])
print(len(dataset2.entries), dataset2.entries[0]['filename'])

it = iter(dataset1)
batch = next(it)

print(batch['seq_id'].shape, batch['seq_position'].shape, batch['mask'].shape, batch['matrixH'].shape, batch['matrixV'].shape)


In [None]:
# sequence masking
import torch
from starry.transformer.models import get_subsequent_mask, get_pad_mask


seq = torch.tensor([[3,2,1], [4,5,6]])
mask1 = get_pad_mask(seq, 1)
mask2 = get_subsequent_mask(seq)
print(mask1)
print(mask2)
print(mask1 & mask2)


In [None]:
# model test
from starry.topology.models.jointers import *


model = TransformJointer()
batch = dataset[0]
pred = model(batch['seq_id'], batch['seq_position'], batch['mask'])


In [None]:
import torch


t1 = torch.tensor([1,2,3])
t2 = torch.tensor([4,5,6])

t1 = t1.unsqueeze(0).repeat(2, 1, 1)
t2 = t2.unsqueeze(1)
print(t1, t2)

p = t1.matmul(t2)
print(p)


In [None]:
# predictor test
from ipywidgets import interact_manual
import json
from starry.utils.config import Configuration
from starry.topology.data import *
from starry.topology.predictor import TopologyPredictorHV


def test (config, clusters):
	predictor = TopologyPredictorHV(config)
	results = [*predictor.predict(clusters)]
	text = json.dumps(results)

	print('results:', text)


def setConfig(config_dir, topology_file):
	config = Configuration(config_dir)
	print('config loaded:', config.id)

	file = open(topology_file, 'r')
	data = loadClusterSet(file)
	clusters = data.get('clusters')[:4]
	test(config, clusters)


interact_manual(setConfig, config_dir='', topology_file='')


In [None]:
# dump eval
import os
import json
import torch
import dill as pickle
from starry.utils.config import Configuration
from starry.topology.data import Dataset
from starry.topology.trainer import Trainer


config = Configuration(r'./training/score-topology/20210809-chopin-20210807-l6+1')

data_file = open(os.path.join(os.environ.get('DATA_DIR'), config['data.file_name']), 'rb')
meta = pickle.load(data_file)
print(meta['ids'])

config['model.args.d_model'] = 0x200
val, = Dataset.loadPackage(data_file, batch_size=1, splits='0/12', device='cpu')
batch = next(iter(val))
#print('batch:', batch['seq_id'])

trainer = Trainer(config)
trainer.model.eval()
deducer = trainer.model.deducer
with torch.no_grad():
    pred = deducer(batch['seq_id'], batch['seq_position'], batch['mask'])
print('pred:', pred)

output = {
    'seq_id': batch['seq_id'].tolist(),
    'seq_position': batch['seq_position'].tolist(),
    'mask': batch['mask'].tolist(),
    'pred': pred[0].tolist(),
}
json.dump(output, open('./test/test.json', 'w'))


In [None]:
# test eval
import os
import torch
import dill as pickle
from starry.utils.config import Configuration
from starry.topology.data import Dataset
from starry.topology.trainer import Trainer


config = Configuration(r'.\training\20210809-chopin-20210807-l6+1')

data_file = open(os.path.join(os.environ.get('DATA_DIR'), config['data.file_name']), 'rb')
meta = pickle.load(data_file)
print(meta['ids'][11])

config['model.args.d_model'] = 0x200
val, = Dataset.loadPackage(data_file, batch_size=2, splits='11/12', device='cpu')


trainer = Trainer(config)
trainer.model.eval()
with torch.no_grad():
    for i, batch in enumerate(val):
        loss, acc = trainer.model(batch)
        #print('loss:', loss)
        print('acc:', acc)

        if i > 10:
            break



In [None]:
# remove deducer in checkpoint state_dict
from ipywidgets import interact_manual
import torch
from starry.utils.config import Configuration
from starry.utils.model_factory import loadModel


def run(config_dir):
	config = Configuration(config_dir)
	if config['best'] is None:
		print('no best field found')

	config['model.type'] += 'Loss'
	model = loadModel(config['model'])

	cp_path = config.localPath(config['best'])
	checkpoint = torch.load(cp_path, map_location='cpu')
	model.load_state_dict(checkpoint['model'])

	checkpoint = {'epoch': checkpoint['epoch'], 'model': model.deducer.state_dict()}
	torch.save(checkpoint, cp_path)

	print('checkpoint saved:', cp_path)


interact_manual(run, config_dir='')