In [None]:
## Fit and analyze autoencoders

In [None]:
import os
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from behavenet import get_user_dir, make_dir_if_not_exists
from behavenet.fitting.utils import get_expt_dir
from behavenet.fitting.utils import get_session_dir
from behavenet.fitting.utils import get_best_model_version
from behavenet.fitting.utils import get_lab_example

%load_ext autoreload
%autoreload 2

save_outputs = True  # true to save figures/movies to user's figure directory
format = 'png'  # figure format ('png' | 'jpeg' | 'pdf'); movies saved as mp4

In [None]:
from behavenet.models import load_data as ld
from torch.utils.data import Dataset, DataLoader

train_dataset = ld.ParquetDataset(get_user_dir('data'), data_type = "image", split="train")
val_dataset = ld.ParquetDataset(get_user_dir('data'), data_type ="image", split="val")
test_dataset = ld.ParquetDataset(get_user_dir('data'), data_type="image", split="test")


# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f'Data in train/validation/test: {len(train_dataset)}/{len(val_dataset)}/{len(test_dataset)}')

In [None]:
import torch
from pytorch_lightning import Trainer
from behavenet.models import lightning_ae as ae


# Initialize autoencoder with correct input shape
autoencoder = ae.LightningAutoencoder(
    input_channels=1,
    input_height=140,
    input_width=170,
    latent_dim=9, 
    learning_rate=1e-4
)

# Initialize Trainer
trainer = Trainer(
    max_epochs=100,  # Adjust based on convergence
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    log_every_n_steps=10
)

# Train model
trainer.fit(autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader)

# Save model
#model_save_path = os.path.join(get_user_dir('models'), 'ae')
#make_dir_if_not_exists(model_save_path)
#autoencoder.save_checkpoint(model_save_path)

In [None]:
#plot loss 
import matplotlib.pyplot as plt
train_losses = autoencoder.train_losses
val_losses = autoencoder.val_losses
val_losses = val_losses[:-1]

# Plot across epochs
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", marker="o")
plt.plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss", marker="s")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Over Epochs")
plt.legend()
plt.grid()
plt.show()

## Check reconstructions on test data

In [None]:
# Get a batch of test images
test_batch = next(iter(test_loader))
test_images = test_batch[0].to(autoencoder.device)

# Forward pass through the autoencoder
with torch.no_grad():
    reconstructed_images = autoencoder(test_images)

# Convert tensors to numpy for visualization
test_images_np = test_images.cpu().numpy()
reconstructed_images_np = reconstructed_images.cpu().numpy()

# Plot original and reconstructed images side by side
n_images = 5  # Number of images to visualize
fig, axes = plt.subplots(2, n_images, figsize=(15, 5))

for i in range(n_images):
    axes[0, i].imshow(test_images_np[i, 0], cmap="gray")  # Original
    axes[0, i].axis("off")
    axes[1, i].imshow(reconstructed_images_np[i, 0], cmap="gray")  # Reconstructed
    axes[1, i].axis("off")

axes[0, 0].set_title("Original Images")
axes[1, 0].set_title("Reconstructed Images")
plt.show()

In [None]:
import torch.nn.functional as F

mse_loss = F.mse_loss(reconstructed_images, test_images)
print(f"Mean Squared Error (MSE): {mse_loss.item()}")

In [None]:
from skimage.metrics import structural_similarity as ssim
import numpy as np

# Compute SSIM for a batch
ssim_values = []
for i in range(test_images_np.shape[0]):
    ssim_value = ssim(test_images_np[i, 0], reconstructed_images_np[i, 0], data_range=1.0)
    ssim_values.append(ssim_value)

print(f"Average SSIM: {np.mean(ssim_values):.4f}")