# Custom Model Architectures

This notebook shows how to create and use custom autoencoder models with seisCAE.

In [None]:
import torch
import torch.nn as nn
import numpy as np
from seiscae.models.base import BaseAutoencoder
from seiscae.training import AutoencoderTrainer
from seiscae.utils import get_device

## Define Custom Autoencoder

In [None]:
class CustomAutoencoder(BaseAutoencoder):
    """
    Custom autoencoder with different architecture.
    """
    
    def __init__(self, latent_dim=32):
        super(CustomAutoencoder, self).__init__()
        self.latent_dim = latent_dim
        
        # Encoder with more layers
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        
        # Calculate flattened size (depends on input size)
        # For (256, 40): after 3x stride-2 conv -> (32, 5)
        self.fc_encoder = nn.Sequential(
            nn.Linear(64 * 32 * 5, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim),
        )
        
        # Decoder
        self.fc_decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64 * 32 * 5),
            nn.ReLU(),
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),  # Output in [0, 1]
        )
    
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        z = self.fc_encoder(x)
        return z
    
    def decode(self, z):
        x = self.fc_decoder(z)
        x = x.view(-1, 64, 32, 5)
        x = self.decoder(x)
        return x
    
    def forward(self, x):
        z = self.encode(x)
        x_recon = self.decode(z)
        return z, x_recon
    
    def get_latent_dim(self):
        return self.latent_dim

## Train Custom Model

In [None]:
# Load spectrograms (assuming you have them)
spectrograms = np.load('../../results/spectrograms.npy')
print(f"Loaded {len(spectrograms)} spectrograms")

# Initialize custom model
device = get_device(0)
custom_model = CustomAutoencoder(latent_dim=32)

# Train
trainer = AutoencoderTrainer(custom_model, device, learning_rate=1e-4)

history = trainer.train(
    spectrograms=spectrograms,
    epochs=200,
    batch_size=128,
    save_dir='../../results/custom_model',
)

print(f"Training complete! Final val loss: {history['val_losses'][-1]:.6f}")

In [None]:
# Extract features with custom model
features = trainer.extract_features(spectrograms)
print(f"Features shape: {features.shape}")