In [None]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import WandbLogger

from fpl_engineering.data.lit_data_module import FPLDataModule
from fpl_engineering.models.lit_models import FPLLSTMRegressor

from fpl_engineering.utils import get_project_root
project_root = str(get_project_root())
log_dir = project_root+'/logs'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

sweep_config = {
  "method": "random",   # random search
  "metric": {           # We want to maximize val_accuracy
      "name": "val_loss",
      "goal": "minimize"
  },
  "parameters": {
        "dropout": {
            # Choose from pre-defined values
            "values": [0.3, 0.4, 0.5, 0.75, 0.85]
        },
        "hidden_size": {
            # Choose from pre-defined values
            "values": [64, 128, 256, 512]
        },
        "batch_size": {
            # Choose from pre-defined values
            "values": [64, 128, 256, 512]
        },
        "lr": {
        # a flat distribution between 0 and 0.1
            'distribution': 'uniform',
            'min': 0,
            'max': 0.1
        },
        "seq_len": {
            # Choose from pre-defined values
            "values": [4,5] 
        },
        "n_layers": {
            # Choose from pre-defined values
            "values": [1,2] 
        }
    },
'early_terminate' :{
    'type': 'hyperband',
    's': 2,
    'eta': 3,
    'max_iter': 40
}
}

static_config = {'n_features': 35, 'epochs':40}

def sweep_iteration():

    # set up W&B logger
    wandb.init(config=sweep_config)    # required to have access to `wandb.config`
    config = wandb.config

    # setup data
    data_module = FPLDataModule(batch_size=config.batch_size, seq_len=config.seq_len, download_data=False)


    # setup model - note how we refer to sweep parameters with wandb.config
    model = FPLLSTMRegressor(n_features= static_config['n_features'], hidden_size=config.hidden_size, seq_len=config.seq_len, 
                            batch_size=config.batch_size, num_layers=config.n_layers, dropout=config.dropout,learning_rate=config.lr)


    # early_stopping_callback = EarlyStopping(monitor='val_loss', patience=4, verbose=True)
    progress_bar = TQDMProgressBar(refresh_rate=1)

    logger = WandbLogger(save_dir=log_dir, project='FPL')
    callbacks = [progress_bar]

    # setup Trainer
    trainer = pl.Trainer(accelerator = device,
                        max_epochs = static_config['epochs'],
                        logger= logger,
                        callbacks = callbacks,
                        log_every_n_steps=20,
                        precision=16)

    # train
    trainer.fit(model, datamodule=data_module)



if __name__ == '__main__':

    '''
    All parameters are aggregated in one place.
    This is useful for reporting experiment params to experiment tracking software
    '''

    sweep_id = wandb.sweep(sweep_config, project="FPL")
    wandb.agent(sweep_id, sweep_iteration,project='FPL' ,count=100)
    wandb.finish()