<a href="https://colab.research.google.com/github/monikagulia1/STATE-SPACE-MODEL/blob/main/SSM_(State_Space_Model).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [96]:
import torch
import torch.nn as nn

class SSM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SSM, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        # Define the matrices for the state transition and output
        self.A = nn.Parameter(torch.randn(hidden_dim, hidden_dim))  # State transition matrix
        self.B = nn.Parameter(torch.randn(hidden_dim, input_dim))  # Input matrix
        self.C = nn.Parameter(torch.randn(output_dim, hidden_dim))  # Output matrix
        self.D = nn.Parameter(torch.randn(output_dim, input_dim))  # Direct input matrix

    def forward(self, x, state):
        # x: input tensor of shape (batch_size, input_dim)
        # state: previous state tensor of shape (batch_size, hidden_dim)

        # Calculate the next state
        next_state = torch.matmul(state, self.A) + torch.matmul(x, self.B.t())

        # Calculate the output
        output = torch.matmul(next_state, self.C.t()) + torch.matmul(x, self.D.t())

        return output, next_state

# Example parameters
input_dim = 10
hidden_dim = 20
output_dim = 5
batch_size = 32
seq_length = 20
learning_rate = 0.001
num_epochs = 10

# Initialize the SSM model
model = SSM(input_dim, hidden_dim, output_dim)

# Define a loss function and optimizer
criterion = nn.MSELoss()  # Mean Squared Error Loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Generate some dummy data
input_data = torch.randn(batch_size, seq_length, input_dim)
target_data = torch.randn(batch_size, seq_length, output_dim)

# Training loop
for epoch in range(num_epochs):
    # Initialize the hidden state
    hidden_state = torch.zeros(batch_size, hidden_dim)

    # Forward pass
    outputs = []
    for t in range(seq_length):
        # Get the input for the current time step
        x = input_data[:, t, :]

        # Pass the input and previous state to the model
        output, hidden_state = model(x, hidden_state)
        outputs.append(output)

    # Stack the outputs
    outputs = torch.stack(outputs, dim=1)

    # Calculate the loss
    loss = criterion(outputs, target_data)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")


Epoch 1, Loss: 1.001483881470932e+25
Epoch 2, Loss: 1.001484112055233e+25
Epoch 3, Loss: 1.0014836508866312e+25
Epoch 4, Loss: 1.0014842273473835e+25
Epoch 5, Loss: 1.0014837661787816e+25
Epoch 6, Loss: 1.0014837661787816e+25
Epoch 7, Loss: 1.001483881470932e+25
Epoch 8, Loss: 1.001484112055233e+25
Epoch 9, Loss: 1.0014837661787816e+25
Epoch 10, Loss: 1.0014836508866312e+25
