In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.optim import Adam

In [38]:
device  = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'
print(f"using device {device}")

using device mps


In [59]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

class LSTM_Scratch(nn.Module):
    def __init__(self):
        super().__init__()
        mean, std = 0.0, 1.0

        # Forget Gate
        self.wa1 = nn.Parameter(torch.normal(mean=torch.tensor(mean), std=torch.tensor(std)))
        self.wa2 = nn.Parameter(torch.normal(mean=torch.tensor(mean), std=torch.tensor(std)))
        self.ba1 = nn.Parameter(torch.tensor(0.0))

        # Input Gate
        self.wb1 = nn.Parameter(torch.normal(mean=torch.tensor(mean), std=torch.tensor(std)))
        self.wb2 = nn.Parameter(torch.normal(mean=torch.tensor(mean), std=torch.tensor(std)))
        self.bb1 = nn.Parameter(torch.tensor(0.0))

        # Candidate Memory
        self.wc1 = nn.Parameter(torch.normal(mean=torch.tensor(mean), std=torch.tensor(std)))
        self.wc2 = nn.Parameter(torch.normal(mean=torch.tensor(mean), std=torch.tensor(std)))
        self.bc1 = nn.Parameter(torch.tensor(0.0))

        # Output Gate
        self.wd1 = nn.Parameter(torch.normal(mean=torch.tensor(mean), std=torch.tensor(std)))
        self.wd2 = nn.Parameter(torch.normal(mean=torch.tensor(mean), std=torch.tensor(std)))
        self.bd1 = nn.Parameter(torch.tensor(0.0))

    def lstm_unit(self, input_value, long_memory, short_memory):
        long_memory_percent = torch.sigmoid((short_memory * self.wa1) + (input_value * self.wa2) + self.ba1)
        potential_memory_percent = torch.sigmoid((short_memory * self.wb1) + (input_value * self.wb2) + self.bb1)
        potential_long_memory = torch.tanh((short_memory * self.wc1) + (input_value * self.wc2) + self.bc1)
        updated_long_memory = (long_memory * long_memory_percent) + (potential_memory_percent * potential_long_memory)
        output_short_memory = torch.sigmoid((short_memory * self.wd1) + (input_value * self.wd2) + self.bd1)
        updated_short_memory = torch.tanh(updated_long_memory) * output_short_memory
        return updated_long_memory, updated_short_memory

    def forward(self, input):
        batch_size, seq_len = input.shape
        long_memory = torch.zeros(batch_size, device=input.device)
        short_memory = torch.zeros(batch_size, device=input.device)

        for t in range(seq_len):
            long_memory, short_memory = self.lstm_unit(input[:, t], long_memory, short_memory)

        return short_memory


In [63]:
model =LSTM_Scratch()
print('prediction of the values')
print("Company A og_value: 0, pred_value: ", model(torch.tensor([[0.0, 0.5, 0.25, 1.0]])).detach())
print("Company B og_value: 1, pred_value: ", model(torch.tensor([[1.0, 0.5, 0.25, 1.0]])).detach())


prediction of the values
Company A og_value: 0, pred_value:  tensor([-0.0936])
Company B og_value: 1, pred_value:  tensor([-0.0938])


In [76]:
def train_model(model, dataloader, num_epochs=500, lr=0.01):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for epoch in range(num_epochs):
        total_loss = 0.0

        for inputs, labels in dataloader:
            outputs = model(inputs)  # Pass entire batch
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        if epoch % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}")

    print("Training complete!")


# Prepare data
inputs = torch.tensor([[0., 0.5, 0.25, 0.], [1., 0.5, 0.25, 1.]]).float()
labels = torch.tensor([1., 1.]).float().unsqueeze(1)

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Train Model
model = LSTM_Scratch()
train_model(model, dataloader)

# Predictions
print("\nPredictions:")
print("Company A pred_value:", model(torch.tensor([[0.0, 0.5, 0.25, 1.0]])).detach().item())
print("Company B pred_value:", model(torch.tensor([[1.0, 0.5, 0.25, 1.0]])).detach().item())

Epoch [1/500], Loss: 0.8386
Epoch [101/500], Loss: 0.0357
Epoch [201/500], Loss: 0.0054
Epoch [301/500], Loss: 0.0026
Epoch [401/500], Loss: 0.0016
Training complete!

Predictions:
Company A pred_value: 0.9819430708885193
Company B pred_value: 0.9841746091842651
