# Train FFNN decocoder to predict location of animal from CA1 Hippocampus data

## Set up environment paths

In [None]:
import setup
setup.main()

%load_ext autoreload
%autoreload 2

## Imports

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import matplotlib.pyplot as plt
from lovely_numpy import lo
import neuralgeom.datasets.experimental as experimental

## Load neural activity & labels

In [None]:
expt_id = "34"
timestep_microsec = int(2e5)
vel_threshold = 5

neural_activity, labels = experimental.load_neural_activity(expt_id=expt_id, vel_threshold= vel_threshold, timestep_microsec=timestep_microsec)

times_in_seconds = labels["times"]*1e-6
angles = labels["angles"]

print(f"There are {neural_activity.shape[1]} neurons binned over {neural_activity.shape[0]} timesteps")

In [None]:
# copy archtecture from VAE encoder


# Set a fixed random seed for reproducibility
torch.manual_seed(0)


input_dim = neural_activity.shape[1]

# Define a simple feedforward neural network
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 50).float()  # First hidden layer
        self.fc2 = nn.Linear(50, 20).float()  # Second hidden layer
        self.fc3 = nn.Linear(20, 2).float()  # Output layer

    def forward(self, x):
        x = x.float()
        x = F.relu(self.fc1(x))  # Apply ReLU activation function after first layer
        x = F.relu(self.fc2(x))  # Apply ReLU activation function after second layer
        x = self.fc3(x)  # No activation function after final layer (for regression task)
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Create an instance of the network
decoder = Decoder().to(device)

# Use mean squared error loss for regression
criterion = nn.MSELoss()

# use adam optimizer
optimizer = optim.Adam(decoder.parameters(), lr=0.01)

In [None]:
cos_sin_angles = np.vstack((np.cos(angles_radians),np.sin(angles_radians))).T

In [None]:
# Assume we have some training data in input_data (100-dimensional input) and target_data (2-dimensional output)
input_data = torch.from_numpy(neural_activity).float().to(device)
target_data = torch.from_numpy(cos_sin_angles).float().to(device)

# Split data into training and validation sets (80-20 split)
train_size = int(0.8 * len(input_data))
val_size = len(input_data) - train_size

train_data = TensorDataset(input_data[:train_size], target_data[:train_size])
val_data = TensorDataset(input_data[train_size:], target_data[train_size:])

# Create DataLoaders from your datasets
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=True)

In [None]:
# Arrays to keep track of losses
train_losses = []
val_losses = []


# Train the network
for epoch in range(10):  # Loop over the dataset multiple times
    # Training phase
    decoder.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = decoder(inputs)

        # Compute loss
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f"Epoch {epoch+1}, Training Loss: {epoch_loss}")

    # Validation phase
    decoder.eval()
    with torch.no_grad():
        running_loss = 0.0
        for inputs, targets in val_loader:
            outputs = decoder(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item()

        epoch_loss = running_loss / len(val_loader)
        val_losses.append(epoch_loss)
        print(f"Epoch {epoch+1}, Validation Loss: {epoch_loss}")

print('Finished Training')

In [None]:
# Plot training and validation losses
plt.plot(train_losses, label='Training loss')
plt.plot(val_losses, label='Validation loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()