# Import Modules

In [1]:
import numpy as np
import pandas as pd

import utils

# Load Data

In [2]:
target_vegetables = utils.VEGETABLES
target_vegetable = target_vegetables[15]

In [3]:
train = pd.read_csv("./data/train.csv")
test = pd.read_csv("./data/test.csv")
train_test = pd.concat([train, test], axis=0)
train_test["date"] = pd.to_datetime(train_test["date"], format="%Y%m%d")
train_test = train_test.reset_index(drop=True)

train_test["year"] = train_test.date.dt.year
years = pd.get_dummies(train_test["year"])
train_test = train_test.drop(columns="year")
train_test = pd.concat([train_test, years], axis=1)

train_test["month"] = train_test.date.dt.month
months = pd.get_dummies(train_test["month"])
train_test = train_test.drop(columns="month")
train_test = pd.concat([train_test, months], axis=1)

train_test["weekday"] = train_test.date.dt.weekday
weekdays = pd.get_dummies(train_test["weekday"])
train_test = train_test.drop(columns="weekday")
train_test = pd.concat([train_test, weekdays], axis=1)

areas = pd.get_dummies(train_test["area"])
train_test = train_test.drop(columns="area")
train_test = pd.concat([train_test, areas], axis=1)

train = train_test[:-test.shape[0]]
test = train_test[-test.shape[0]:]

train = utils.get_target_values(train, target_vegetable)
test = utils.get_target_values(test, target_vegetable)

train_loader, train, test, ss = utils.preprocess_data(train, test)

# Set Free Params

In [4]:
future = test.shape[0]
num_epochs = 100
learning_rate = 0.005
weight_decay = 1e-3

# Training

In [6]:
pred_y = utils.pipeline_rnn(train_loader, train, test, future=future, num_epochs=num_epochs,
                            lr=learning_rate, weight_decay=weight_decay)

pred_y = pred_y.cpu().detach().numpy()
pred_y = np.concatenate([pred_y.reshape(-1, 1), test[:, 1:]], axis=1)
pred_y[:, :1] = ss.inverse_transform(pred_y[:, :1])

In [7]:
pred_y[:, 0].tolist()

[111.74423447775294,
 105.72466600870825,
 104.16115295096395,
 101.00860033486602,
 97.83318537110665,
 98.1556074594964,
 96.78717763973647,
 96.51854079887933,
 95.44881955277171,
 94.22062802045394,
 96.1551945974754,
 95.94487972466595,
 96.26377798136103,
 95.45969746800301,
 94.31725250692176,
 96.29108342115919,
 96.0912549267261,
 96.40418650221841,
 95.57208556871839,
 94.3935766586269]