# LSTM for Sales Prediction
LSTM over weekly sales + tabular features.

In [None]:
import torch
torch.manual_seed(42)
from utils.dataset import get_dataloaders, CATEGORICAL_COLS, NUMERIC_COLS
from utils.training import train_model
from utils.metrics import mae, rmse, r2, plot_losses, plot_predictions
from torch import nn
train_loader, valid_loader, test_loader, stats = get_dataloaders('data')
class SalesLSTM(nn.Module):
    def __init__(self, stats):
        super().__init__()
        self.embs = nn.ModuleList([nn.Embedding(len(stats.cat_maps[c]),4) for c in CATEGORICAL_COLS])
        self.num_mlp = nn.Sequential(nn.Linear(len(NUMERIC_COLS),32), nn.ReLU())
        self.lstm = nn.LSTM(input_size=2, hidden_size=64, batch_first=True)
        self.head = nn.Sequential(nn.Linear(64+4*len(CATEGORICAL_COLS)+32,32), nn.ReLU(), nn.Linear(32,1))
    def forward(self, x_num, x_cat, x_seq):
        emb = torch.cat([emb(x_cat[:,i]) for i,emb in enumerate(self.embs)], dim=1)
        num = self.num_mlp(x_num)
        seq,_ = self.lstm(x_seq)
        seq = seq[:,-1,:]
        x = torch.cat([emb,num,seq], dim=1)
        return self.head(x).squeeze(1)
model = SalesLSTM(stats)
history, best_path = train_model(model, train_loader, valid_loader, epochs=8, patience=2)
plot_losses(history)
model.load_state_dict(torch.load(best_path))
model.eval()
preds, targets = [], []
with torch.no_grad():
    for x_num, x_cat, x_seq, y in test_loader:
        preds.append(model(x_num, x_cat, x_seq))
        targets.append(y)
y_true = torch.cat(targets)
y_pred = torch.cat(preds)
print('MAE', mae(y_true,y_pred).item())
print('RMSE', rmse(y_true,y_pred).item())
print('R2', r2(y_true,y_pred).item())
plot_predictions(y_true, y_pred)