# Import Modules

In [1]:
import numpy as np
import pandas as pd

import utils

# Load Data

In [2]:
target_vegetables = [
    'だいこん', 'にんじん', 'キャベツ', 'レタス',
    'はくさい', 'こまつな', 'ほうれんそう', 'ねぎ',
    'きゅうり', 'トマト', 'ピーマン', 'じゃがいも',
    'なましいたけ', 'セルリー', 'そらまめ', 'ミニトマト'
]
target_vegetable = target_vegetables[15]

In [3]:
train = pd.read_csv("./data/train.csv")
test = pd.read_csv("./data/test.csv")
train_test = pd.concat([train, test])
train_test["date"] = pd.to_datetime(train_test["date"], format="%Y%m%d")
train_test = train_test.reset_index(drop=True)

train_test["year"] = train_test.date.dt.year
years = pd.get_dummies(train_test["year"])
train_test = train_test.drop(columns="year")
train_test = pd.concat([train_test, years], axis=1)

train_test["month"] = train_test.date.dt.month
months = pd.get_dummies(train_test["month"])
train_test = train_test.drop(columns="month")
train_test = pd.concat([train_test, months], axis=1)

train_test["weekday"] = train_test.date.dt.weekday
weekdays = pd.get_dummies(train_test["weekday"])
train_test = train_test.drop(columns="weekday")
train_test = pd.concat([train_test, weekdays], axis=1)

areas = pd.get_dummies(train_test["area"])
train_test = train_test.drop(columns="area")
train_test = pd.concat([train_test, areas], axis=1)

train = train_test[:-test.shape[0]]
test = train_test[-test.shape[0]:]

train = utils.get_target_values(train, target_vegetable)
test = utils.get_target_values(test, target_vegetable)

train_loader, train, test, ss = utils.preprocess_data(train, test, T=10)

# Training

In [4]:
future = test.shape[0]
pred_y = utils.pipeline_rnn(train_loader, train, test, future=future, num_epochs=100)

pred_y = pred_y.cpu().detach().numpy()
pred_y = np.concatenate([pred_y.reshape(-1, 1), test[:, 1:]], axis=1)
pred_y[:, :1] = ss.inverse_transform(pred_y[:, :1])

In [5]:
pred_y[:, 0].tolist()

[115.43459701082304,
 115.55757150204619,
 118.76982923269031,
 119.45430800436384,
 119.68139086857053,
 121.79715602064907,
 121.37559637499965,
 122.94709550783455,
 122.898602391597,
 122.76255610200414,
 124.56504495592458,
 124.06931248335461,
 125.36195438666728,
 125.11162827753265,
 124.8035994243981,
 126.16671796160378,
 125.60636914045907,
 126.53416891623172,
 126.12322757985542,
 125.69948043482067]