# Import Modules

In [128]:
import numpy as np
import pandas as pd

import utils

# Load Data

In [189]:
target_vegetables = [
    'だいこん', 'にんじん', 'キャベツ', 'レタス',
    'はくさい', 'こまつな', 'ほうれんそう', 'ねぎ',
    'きゅうり', 'トマト', 'ピーマン', 'じゃがいも',
    'なましいたけ', 'セルリー', 'そらまめ', 'ミニトマト'
]
target_vegetable = target_vegetables[15]

In [190]:
train = pd.read_csv("./data/train.csv")
test = pd.read_csv("./data/test.csv")
train_test = pd.concat([train, test])
train_test["date"] = pd.to_datetime(train_test["date"], format="%Y%m%d")
train_test = train_test.reset_index(drop=True)

train_test["year"] = train_test.date.dt.year
years = pd.get_dummies(train_test["year"])
train_test = train_test.drop(columns="year")
train_test = pd.concat([train_test, years], axis=1)

train_test["month"] = train_test.date.dt.month
months = pd.get_dummies(train_test["month"])
train_test = train_test.drop(columns="month")
train_test = pd.concat([train_test, months], axis=1)

train_test["weekday"] = train_test.date.dt.weekday
weekdays = pd.get_dummies(train_test["weekday"])
train_test = train_test.drop(columns="weekday")
train_test = pd.concat([train_test, weekdays], axis=1)

areas = pd.get_dummies(train_test["area"])
train_test = train_test.drop(columns="area")
train_test = pd.concat([train_test, areas], axis=1)

train = train_test[:-test.shape[0]]
test = train_test[-test.shape[0]:]

train = utils.get_target_values(train, target_vegetable)
test = utils.get_target_values(test, target_vegetable)

train_loader, train, test, ss = utils.preprocess_data(train, test, T=10)

# Training

In [191]:
future = test.shape[0]
pred_y = utils.pipeline_rnn(train_loader, train, test, future=future, num_epochs=100)

pred_y = pred_y.cpu().detach().numpy()
pred_y = np.concatenate([pred_y.reshape(-1, 1), test[:, 1:]], axis=1)
pred_y[:, :1] = ss.inverse_transform(pred_y[:, :1])

In [193]:
pred_y[:, 0].tolist()

[110.43846757812832,
 102.73139228754266,
 96.86388481662581,
 90.19372679303862,
 82.15694767972803,
 78.18960686845598,
 72.59586877586456,
 66.91223467254113,
 62.32391640930072,
 57.30478844221243,
 55.78295278034459,
 53.01857490925758,
 50.09077295957809,
 48.110966852467115,
 45.725724332043896,
 46.02804185409845,
 45.01168878365618,
 43.59525314080972,
 42.836454715895044,
 41.544536304596974]