# Import Modules

In [1]:
from functools import partial

import datetime
import pandas as pd
from ray import tune
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.schedulers import ASHAScheduler
from ray.tune.integration.wandb import WandbLogger
from sklearn import preprocessing

import utils

In [2]:
start_date = datetime.datetime.now()
start_date = start_date.strftime('%Y-%m-%d-%H-%M-%S')

# Prepare Data

In [3]:
source_X = pd.read_csv("./deep_occupancy_detection/data/1_X_train.csv").values
target_X = pd.read_csv("./deep_occupancy_detection/data/2_X_train.csv").values
source_y_task = pd.read_csv("./deep_occupancy_detection/data/1_Y_train.csv").values.reshape(-1)
target_y_task = pd.read_csv("./deep_occupancy_detection/data/2_Y_train.csv").values.reshape(-1)

scaler = preprocessing.StandardScaler()
source_X = scaler.fit_transform(source_X)
target_X = scaler.fit_transform(target_X)

In [None]:
source_loader, target_loader, source_y_task, source_X, target_X, target_y_task = utils.get_loader(source_X, target_X, source_y_task, target_y_task)

# Raytune TPE

In [3]:
options = {
    "source_loader": source_loader,
    "target_loader": target_loader,
    "source_X": source_X,
    "target_X": target_X,
    "target_y_task": target_y_task
}

In [None]:
config = {
    "hidden_size": tune.randint(5, 500),
    "domain_fc1_size": tune.randint(5, 500),
    "domain_fc2_size": tune.randint(5, 500),
    "task_fc1_size": tune.randint(5, 500),
    "task_fc2_size": tune.randint(5, 500),

    "feature_learning_rate": tune.loguniform(1e-4, 1e-1),
    "feature_weight_decay": tune.loguniform(1e-10, 1e-1),
    "feature_eps": tune.loguniform(1e-10, 1e-1),

    "domain_learning_rate": tune.loguniform(1e-4, 1e-1),
    "domain_weight_decay": tune.loguniform(1e-10, 1e-1),
    "domain_eps": tune.loguniform(1e-10, 1e-1),

    "task_learning_rate": tune.loguniform(1e-4, 1e-1),
    "task_weight_decay": tune.loguniform(1e-10, 1e-1),
    "task_eps": tune.loguniform(1e-10, 1e-1),

    "wandb": {
        "project": f"project_{start_date}",
        "api_key_file": "./wandb_api_key.txt"
    }
}

In [None]:
hyperopt = HyperOptSearch(metric="loss", mode="min")
scheduler = ASHAScheduler(metric='loss', mode='min')

In [None]:
analysis = tune.run(
    partial(utils.raytune_trainer, options=options),
    # TODO: Understand partial
    config=config,
    num_samples=50,
    search_alg=hyperopt,
    resources_per_trial={'cpu':4, 'gpu':1},
    scheduler=scheduler,
    loggers=[WandbLogger]
)