In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pl_data_module import QRTDataModule
from pl_module import QRTChallengeRegressor
from ray.tune.integration.pytorch_lightning import TuneReportCallback
import pytorch_lightning as pl
from ray import tune

In [None]:
def train_qrt(config, data_dir, num_gpus=0):
    model = QRTChallengeRegressor(config)
    dm = QRTDataModule(
        data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
    metrics = {"mse_loss": "ptl/mse_loss", "cel_loss": "ptl/cel_loss"}
    trainer = pl.Trainer(
        max_epochs=config["num_epochs"],
        gpus=num_gpus,
        progress_bar_refresh_rate=0,
        callbacks=[TuneReportCallback(metrics, on="validation_end")])
    trainer.fit(model, dm)

In [None]:
config = {
 "dropout": tune.choice([0, 0.2, 0.5]),
 "num_epochs": tune.choice([20, 40, 60, 80, 100]),
 "lr": tune.loguniform(1e-5, 1e-1),
 "batch_size": tune.choice([32, 64, 128])
}

In [None]:
# Execute the hyperparameter search
analysis = tune.run(
 tune.with_parameters(train_qrt, data_dir = "/content/QRT_DataChallenge/data", num_gpus=0),
 config=config,
 num_samples=10, checkpoint_at_end=True, metric="cel_loss", mode="min")

In [None]:
config_to_use = analysis.best_config

In [None]:
model = QRTChallengeRegressor(config_to_use)
dm = QRTDataModule(
    data_dir="/content/QRT_DataChallenge/data", num_splits= 1, num_workers=1, batch_size=config_to_use["batch_size"])
metrics = {"mse_loss": "ptl/mse_loss", "cel_loss": "ptl/cel_loss"}
trainer = pl.Trainer(
    max_epochs=config_to_use["num_epochs"],
    progress_bar_refresh_rate=0)
trainer.fit(model, dm)

In [None]:
from utils import extract_AB, createSubmission

A, beta = extract_AB(model.model)
df_out = createSubmission(A, beta, 1, "full_train_data_no_val")

In [None]:
config_to_use