# Import Modules

In [14]:
from functools import partial

import pandas as pd
import numpy as np
from ray import tune
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.schedulers import ASHAScheduler
from ray.tune.integration.wandb import WandbLogger
import torch
from torch import nn
from torch import optim

import utils

In [15]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Utils

In [16]:
def pipeline_rnn_raytune(config, options):    
    # Assign Config, Options
    training_size = options["training_size"]
    target_values = options["target_values"]
    
    lr = config["lr"]
    weight_decay = config["weight_decay"]
    eps = config["eps"]
    num_epochs = config["num_epochs"]
    batch_size = config["batch_size"]
    
    # Preprocess Data
    train_loader, test_y, train, test, ss = utils.preprocess_data(target_values, train_size=training_size, batch_size=batch_size)

    # Instantiate Model, Optimizer, Criterion, EarlyStopping
    model = utils.RNN(input_size=train.shape[2]).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, eps=eps)
    criterion = nn.MSELoss()

    # Training & Test Loop
    for _ in range(num_epochs):
        model.train()

        for _, (batch_x, batch_y) in enumerate(train_loader):
            # Forward
            out = model(batch_x)
            loss = criterion(out, batch_y)

            # Backward
            optimizer.zero_grad()
            loss.backward()

            # Update Params
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()

        # Test
        with torch.no_grad():
            model.eval()
            pred_y = model.predict(train, test, test.shape[0])
            pred_y = pred_y.reshape(-1)
            loss = criterion(pred_y, test_y)
            tune.report(loss=loss.item())

def pipeline_raytune(options, config, trainer=pipeline_rnn_raytune):
    
    # Instantiate HyperOptSearch, ASHAScheduler
    hyperopt = HyperOptSearch(metric="loss", mode="min")
    scheduler = ASHAScheduler(
        metric='loss', mode='min', max_t=1000,
        grace_period=12, reduction_factor=2
    )
    
    # Optimization
    analysis = tune.run(
        partial(trainer, options=options),
        config=config,
        num_samples=100,
        search_alg=hyperopt,
        resources_per_trial={'cpu':4, 'gpu':1},
        scheduler=scheduler,
        loggers=[WandbLogger]
    )

# Load Data

In [17]:
target_vegetable = "トマト"

In [18]:
train_test = pd.read_csv("./data/mapped_train_test.csv")
train_test["date"] = pd.to_datetime(train_test["date"], format="%Y-%m-%d")
weather = pd.read_csv("./data/sorted_mapped_adjusted_weather.csv")
train_test = pd.concat([train_test, weather], axis=1)

train_test["year"] = train_test.date.dt.year
years = pd.get_dummies(train_test["year"])
train_test = train_test.drop(columns="year")
train_test = pd.concat([train_test, years], axis=1)

train_test["month"] = train_test.date.dt.month
months = pd.get_dummies(train_test["month"])
train_test = train_test.drop(columns="month")
train_test = pd.concat([train_test, months], axis=1)

# train_test["weekday"] = train_test.date.dt.weekday
# weekdays = pd.get_dummies(train_test["weekday"])
# train_test = train_test.drop(columns="weekday")
# train_test = pd.concat([train_test, weekdays], axis=1)

areas = pd.get_dummies(train_test["area"])
train_test = train_test.drop(columns="area")
train_test = pd.concat([train_test, areas], axis=1)

train = train_test[:pd.read_csv("./data/train.csv").shape[0]]

target_values = utils.get_target_values(train, target_vegetable)

# Set Config, Params

In [19]:
rnn_options = {
    "target_values": target_values,
    "training_size": 4000,
}

rnn_config = {
    "lr": tune.loguniform(1e-4, 1e-1),
    "weight_decay": tune.choice([0, 1e-10, 1e-7, 1e-5, 1e-3]),
    "eps": tune.choice([1e-11, 1e-8, 1e-5, 1e-3, 1e-1]),
    "num_epochs": tune.choice([25, 50, 75, 100, 150]),
    "batch_size": tune.choice([16, train.shape[0]]),
    "wandb": {
        "project": f"{target_vegetable}",
        "api_key_file": "./wandb_api_key.txt"
    }
}

# Raytune

In [None]:
pipeline_raytune(rnn_options, rnn_config, trainer=pipeline_rnn_raytune)