In [None]:
from tts.visualize import simple_interactive_plot
from tts.data import  _get_tumor_feature_ranges, synthetic_tumor_data, TTSDataset, create_train_val_test_dataloaders, create_dataloader
from tts.config import Config
from tts.model import TTS
from tts.lit_module import LitTTS
import pytorch_lightning as pl
import numpy as np

seed = 0


config = Config(n_features=4,n_basis=5,T=1,seed=seed)
config.training.batch_size = 3000

X, ts, ys = synthetic_tumor_data(2000,20,1.0,0.0,seed=seed,equation='wilkerson')
dataset = TTSDataset(config, (X, ts, ys))
train_dataloader, val_dataloader, test_dataloader = create_train_val_test_dataloaders(config, dataset)
# Or use train_dataloader = create_dataloader(config, dataset, indices, shuffle=True)

tts = TTS(config)
litmodel = LitTTS(config, tts)


tb_logger = pl.loggers.TensorBoardLogger(save_dir='logs/', name='tts')
best_val_checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, filename='best_val')

trainer_dict = {
    'deterministic': True,
    'devices': 1,
    'auto_lr_find': True,
    'enable_model_summary': False,
    'enable_progress_bar': False,
    'accelerator': 'cpu',
    'max_epochs': 200,
    'logger': tb_logger,
    'check_val_every_n_epoch': 10,
    'log_every_n_steps':1,
    'callbacks': [best_val_checkpoint]
}


trainer = pl.Trainer(**trainer_dict)
        
trainer.tune(litmodel,train_dataloaders=train_dataloader)
        
trainer.fit(model=litmodel,train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

litmodel = LitTTS.load_from_checkpoint(best_val_checkpoint.best_model_path, config=config, model=tts)



In [None]:
def trajectory(t, **x):
    feature_names = ["age","weight","initial_tumor_volume","dosage"]
    features = np.array([x[feature_name] for feature_name in feature_names])
    return litmodel.model.forecast_trajectory(features,t)


simple_interactive_plot(trajectory, 1, (0,2), _get_tumor_feature_ranges("age","weight","initial_tumor_volume","dosage"), n_points=100)

In [None]:
# Ground truth
from tts.visualize import simple_interactive_plot
from tts.data import _tumor_volume_2,  _get_tumor_feature_ranges
simple_interactive_plot(_tumor_volume_2, 1, (0,2), _get_tumor_feature_ranges("age","weight","initial_tumor_volume","dosage"), n_points=1000)