In [None]:
# Embedding Generation Notebook

# Import necessary libraries
import torch
from src.data_utils import load_mnist_data
from src.embedding_models import BasicAutoencoder, IntermediateAutoencoder, AdvancedAutoencoder, EnhancedAutoencoder

# Load the dataset
data_loader = load_mnist_data(fraction=0.5, batch_size=64, shuffle=True)

# Define model parameters
code_dim = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize and train Basic Autoencoder
basic_autoencoder = BasicAutoencoder(code_dim=code_dim).to(device)

# Define training parameters
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(basic_autoencoder.parameters(), lr=1e-3)

# Training loop for Basic Autoencoder
num_epochs = 10
basic_autoencoder.train()
for epoch in range(num_epochs):
    total_loss = 0
    for images, _ in data_loader:
        images = images.to(device).float()

        # Forward pass
        encoded, decoded = basic_autoencoder(images)
        loss = criterion(decoded, images)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(data_loader):.4f}")

# Save the embeddings
basic_autoencoder.eval()
with torch.no_grad():
    all_embeddings = []
    all_labels = []
    for images, labels in data_loader:
        images = images.to(device).float()
        encoded, _ = basic_autoencoder(images)
        all_embeddings.append(encoded.cpu())
        all_labels.append(labels.cpu())

all_embeddings = torch.cat(all_embeddings, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Save embeddings to file
embedding_path = "data/embeddings/basic_autoencoder_embeddings.pt"
torch.save({"embeddings": all_embeddings, "labels": all_labels}, embedding_path)
print(f"Embeddings saved to {embedding_path}")
