# Demand Forecasting Model Training
This notebook trains a demand forecasting model using historical sales and external factors.

In [None]:
import pandas as pd
import numpy as np
import torch
from backend.synthetic_data import generate_fake_demand_data, create_training_dataset
from transformers import TimeSeriesTransformerForPrediction, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns

# Make plots look nice
sns.set_context('talk')
sns.set_style('whitegrid')

# Generate synthetic data as a placeholder
sales_df, external_df = generate_fake_demand_data()
X, y = create_training_dataset(sales_df, external_df)

# Split data
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Load pre-trained TimeSformer and replace it with the actual model path
model = TimeSeriesTransformerForPrediction.from_pretrained('huggingface/timeseries-base')

# Prepare Training
training_args = TrainingArguments(
    output_dir='./models/demand_predictor',
    evaluation_strategy='epoch',
    logging_dir='./logs',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5
)

# Create PyTorch dataset
class TimeSeriesDataset(torch.utils.data.Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels
    def __len__(self):
        return len(self.sequences)
    def __getitem__(self, idx):
        return {'input_ids': torch.tensor(self.sequences[idx], dtype=torch.float32),
                'labels': torch.tensor(self.labels[idx], dtype=torch.float32)}

train_dataset = TimeSeriesDataset(X_train, y_train)
val_dataset = TimeSeriesDataset(X_val, y_val)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

trainer.train()

# Plot predictions vs actuals for validation set
val_predictions = trainer.predict(val_dataset)
sns.scatterplot(x=val_predictions.predictions.flatten(), y=y_val)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Predicted vs Actual')
plt.show()