# Hybrid Stock Prediction Model Training
In the "HybridStockPredictionModel" notebook we created our model that can be used to make efficient stock prediction for new business ideas.

At first we create our Dataset class that will be used to train the model:

In [3]:
import torch
from torch.utils.data import Dataset

from Model import HybridStockPredictionModel


class StockPerformanceDataset(Dataset):
    def __init__(self, data, text_encoder, forecast_steps=12):
        """
        Args:
            data: A list of dictionaries where each dictionary contains:
                  - 'idea_text': A string of the business idea.
                  - 'static_features': A list or array of static features.
                  - 'targets': A list or array of target stock performance values.
            text_encoder: The Sentence-BERT model for text encoding.
            forecast_steps: Number of months to forecast.
        """
        self.data = data
        self.text_encoder = text_encoder
        self.forecast_steps = forecast_steps

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        idea_text = item['idea_text']
        static_features = torch.tensor(item['static_features'], dtype=torch.float32)

        # Encode the text description using Sentence-BERT
        text_embedding = torch.tensor(self.text_encoder.encode(idea_text), dtype=torch.float32)

        # Targets (12-month forecast sequence)
        targets = torch.tensor(item['targets'][:self.forecast_steps], dtype=torch.float32)

        return text_embedding, static_features, targets


### Training the model
Here we import the model and set it up for training

In [4]:
from torch.utils.data import DataLoader
from Model.HybridStockPredictionModel import StockPerformancePredictionModel
 
train_data = [] # import from csv
model = StockPerformancePredictionModel(text_embedding_dim=384, static_feature_dim=10, hidden_dim=128, forecast_steps=12)

# Assuming 'train_data' is a list of dictionaries with the required keys
dataset = StockPerformanceDataset(train_data, model.text_encoder)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)


ValueError: num_samples should be a positive integer value, but got num_samples=0

#### We can also create a a custom loss function and an optimizer

In [None]:
import torch.optim as optim
import torch.nn as nn

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


You can also possible create custom loss functions:

In [None]:
# Example for mixed loss
regression_loss_fn = nn.MSELoss()
classification_loss_fn = nn.CrossEntropyLoss()

# Suppose the last indicator is categorical
def custom_loss(predictions, targets):
    # Separate out indicators
    reg_targets, class_targets = targets[:, :, :-1], targets[:, :, -1].long()
    reg_preds, class_preds = predictions[:, :, :-1], predictions[:, :, -1]

    # Compute losses for each part
    reg_loss = regression_loss_fn(reg_preds, reg_targets)
    class_loss = classification_loss_fn(class_preds.view(-1, class_preds.size(-1)), class_targets.view(-1))

    # Combine losses
    return reg_loss + class_loss


### Start of training

Here we provide a custom training loop function:

In [None]:
num_epochs = 10
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for text_embedding, static_features, targets in train_loader:
        # Move data to device (CPU or GPU)
        text_embedding = text_embedding.to(device)
        static_features = static_features.to(device)
        targets = targets.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        predictions = model(text_embedding, static_features)

        # Compute the loss
        loss = criterion(predictions, targets)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate the loss for monitoring
        epoch_loss += loss.item()

    # Average loss for the epoch
    average_epoch_loss = epoch_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {average_epoch_loss:.4f}')


### Evaluation
After training we are going to evaluate our model:

In [None]:
from torch.utils.data import DataLoader

# Assuming 'val_data' is formatted like 'train_data'
val_dataset = StockPerformanceDataset(val_data, model.text_encoder)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

def evaluate(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for text_embedding, static_features, targets in val_loader:
            text_embedding = text_embedding.to(device)
            static_features = static_features.to(device)
            targets = targets.to(device)

            predictions = model(text_embedding, static_features)
            loss = criterion(predictions, targets)
            val_loss += loss.item()

    return val_loss / len(val_loader)

# Add early stopping
best_val_loss = float('inf')
patience = 3
trigger_times = 0

for epoch in range(num_epochs):
    # Training step (from above)
    model.train()
    epoch_loss = 0.0
    for text_embedding, static_features, targets in train_loader:
        text_embedding, static_features, targets = (
            text_embedding.to(device),
            static_features.to(device),
            targets.to(device)
        )
        optimizer.zero_grad()
        predictions = model(text_embedding, static_features)
        loss = criterion(predictions, targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    # Validation step
    val_loss = evaluate(model, val_loader, criterion)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss/len(train_loader):.4f}, Validation Loss: {val_loss:.4f}')

    # Early stopping condition
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        trigger_times = 0  # Reset patience counter
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print("Early stopping triggered!")
            break
