In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

In [None]:
class CliffFailureDataset(Dataset):
    """Dataset for cliff failure prediction using temporal sequences"""
    
    def __init__(self, erosion_cube, cluster_cube, sequence_length=5, prediction_horizon=3):
        """
        Args:
            erosion_cube: (time, height, width) erosion rate data
            cluster_cube: (time, height, width) cluster ID data  
            sequence_length: number of time steps to use as input
            prediction_horizon: how many steps ahead to predict failure
        """
        self.erosion_cube = erosion_cube
        self.cluster_cube = cluster_cube
        self.seq_len = sequence_length
        self.pred_horizon = prediction_horizon
        
        # Create failure labels by detecting large erosion events
        self.failure_labels = self._create_failure_labels()
        
        # Generate valid indices for sequences
        self.valid_indices = list(range(
            sequence_length, 
            len(erosion_cube) - prediction_horizon
        ))
    
    def _create_failure_labels(self):
        """Create binary failure labels based on erosion magnitude"""
        # Define failure as erosion > 95th percentile in any location
        threshold = np.nanpercentile(self.erosion_cube, 95)
        failures = (self.erosion_cube > threshold).astype(float)
        return failures
    
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, idx):
        """Get sequence of inputs and future failure label"""
        time_idx = self.valid_indices[idx]
        
        # Input sequence: erosion + cluster data
        erosion_seq = self.erosion_cube[time_idx-self.seq_len:time_idx]
        cluster_seq = self.cluster_cube[time_idx-self.seq_len:time_idx]
        
        # Stack as 2-channel input (erosion, clusters)
        input_seq = np.stack([erosion_seq, cluster_seq], axis=1)  # (seq_len, 2, H, W)
        
        # Target: failure probability map at future time
        target = self.failure_labels[time_idx + self.pred_horizon]
        
        return torch.FloatTensor(input_seq), torch.FloatTensor(target)

class ConvLSTMCell(nn.Module):
    """ConvLSTM cell for processing spatial-temporal sequences"""
    
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias
        
        self.conv = nn.Conv2d(
            in_channels=self.input_dim + self.hidden_dim,
            out_channels=4 * self.hidden_dim,
            kernel_size=self.kernel_size,
            padding=self.padding,
            bias=self.bias
        )
    
    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        
        # Concatenate along channel axis
        combined = torch.cat([input_tensor, h_cur], dim=1)
        combined_conv = self.conv(combined)
        
        # Split into gates
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        
        return h_next, c_next

class UNetConvLSTM(nn.Module):
    """U-Net with ConvLSTM for cliff failure prediction"""
    
    def __init__(self, input_channels=2, hidden_dims=[64, 128, 256], kernel_size=3):
        super().__init__()
        self.hidden_dims = hidden_dims
        
        # Encoder ConvLSTM layers
        self.encoder_convlstms = nn.ModuleList()
        self.encoder_convlstms.append(
            ConvLSTMCell(input_channels, hidden_dims[0], kernel_size)
        )
        for i in range(1, len(hidden_dims)):
            self.encoder_convlstms.append(
                ConvLSTMCell(hidden_dims[i-1], hidden_dims[i], kernel_size)
            )
        
        # Decoder layers
        self.decoder_convs = nn.ModuleList()
        for i in range(len(hidden_dims)-1, 0, -1):
            self.decoder_convs.append(
                nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i-1], 4, 2, 1)
            )
        
        # Final output layer
        self.final_conv = nn.Conv2d(hidden_dims[0], 1, 1)
        
        # Pooling for encoder
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, channels, height, width)
        Returns:
            failure_prob: (batch, 1, height, width)
        """
        batch_size, seq_len, channels, height, width = x.shape
        
        # Initialize hidden states
        encoder_states = []
        for i, hidden_dim in enumerate(self.hidden_dims):
            h_size = (height // (2**i), width // (2**i))
            h = torch.zeros(batch_size, hidden_dim, *h_size).to(x.device)
            c = torch.zeros(batch_size, hidden_dim, *h_size).to(x.device)
            encoder_states.append((h, c))
        
        # Process sequence through encoder
        skip_connections = []
        current_input = x
        
        for t in range(seq_len):
            layer_input = current_input[:, t]  # (batch, channels, H, W)
            layer_outputs = []
            
            for i, convlstm in enumerate(self.encoder_convlstms):
                if i > 0:
                    layer_input = self.pool(layer_outputs[i-1])
                
                h, c = convlstm(layer_input, encoder_states[i])
                encoder_states[i] = (h, c)
                layer_outputs.append(h)
            
            # Store skip connections from last timestep
            if t == seq_len - 1:
                skip_connections = layer_outputs[:-1]  # Exclude bottleneck
        
        # Decoder with skip connections
        x = layer_outputs[-1]  # Bottleneck features
        
        for i, decoder_conv in enumerate(self.decoder_convs):
            x = F.relu(decoder_conv(x))
            if i < len(skip_connections):
                # Add skip connection
                skip_idx = len(skip_connections) - 1 - i
                x = x + skip_connections[skip_idx]
        
        # Final prediction
        failure_prob = torch.sigmoid(self.final_conv(x))
        
        return failure_prob

def create_data_loaders(location, batch_size=8, test_size=0.2):
    """Create train/test data loaders from saved cubes"""
    
    # Load data cubes
    base_path = f"/Volumes/group/LiDAR/LidarProcessing/LidarProcessingCliffs/results/{location}/data_cubes"
    
    erosion_data = np.load(f"{base_path}/cube_ero_10cm_filled.npz")['data']
    cluster_data = np.load(f"{base_path}/cube_clusters_ero_10cm_filled.npz")['data']
    
    # Handle NaN values
    erosion_data = np.nan_to_num(erosion_data, nan=0.0)
    cluster_data = np.nan_to_num(cluster_data, nan=0.0)
    
    # Normalize erosion data
    scaler = StandardScaler()
    original_shape = erosion_data.shape
    erosion_flat = erosion_data.reshape(-1, 1)
    erosion_normalized = scaler.fit_transform(erosion_flat).reshape(original_shape)
    
    # Create dataset
    dataset = CliffFailureDataset(erosion_normalized, cluster_data)
    
    # Train/test split
    train_indices, test_indices = train_test_split(
        range(len(dataset)), test_size=test_size, random_state=42
    )
    
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader, scaler

def train_model(model, train_loader, test_loader, num_epochs=50, lr=0.001):
    """Train the U-Net ConvLSTM model"""
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
    
    train_losses = []
    test_losses = []
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            
            # Reshape for loss computation
            output = output.squeeze(1)  # Remove channel dim
            loss = criterion(output, target)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Testing
        model.eval()
        test_loss = 0.0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data).squeeze(1)
                test_loss += criterion(output, target).item()
        
        avg_train_loss = train_loss / len(train_loader)
        avg_test_loss = test_loss / len(test_loader)
        
        train_losses.append(avg_train_loss)
        test_losses.append(avg_test_loss)
        
        scheduler.step(avg_test_loss)
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}')
    
    return train_losses, test_losses

def predict_failure_risk(model, erosion_cube, cluster_cube, scaler, device='cpu'):
    """Generate failure risk predictions for entire time series"""
    
    model.eval()
    model = model.to(device)
    
    # Normalize erosion data
    original_shape = erosion_cube.shape
    erosion_flat = erosion_cube.reshape(-1, 1)
    erosion_normalized = scaler.transform(erosion_flat).reshape(original_shape)
    
    # Handle NaN values
    erosion_normalized = np.nan_to_num(erosion_normalized, nan=0.0)
    cluster_cube = np.nan_to_num(cluster_cube, nan=0.0)
    
    dataset = CliffFailureDataset(erosion_normalized, cluster_cube)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    
    predictions = []
    
    with torch.no_grad():
        for data, _ in dataloader:
            data = data.to(device)
            pred = model(data).squeeze().cpu().numpy()
            predictions.append(pred)
    
    return np.array(predictions)

# Usage example:
def run_cliff_failure_analysis(location="Delmar"):
    """Complete pipeline for cliff failure analysis"""
    
    print(f"Setting up cliff failure analysis for {location}...")
    
    # Create data loaders
    train_loader, test_loader, scaler = create_data_loaders(location)
    print(f"Created datasets with {len(train_loader.dataset)} training samples")
    
    # Initialize model
    model = UNetConvLSTM(input_channels=2, hidden_dims=[32, 64, 128])
    print("Initialized U-Net ConvLSTM model")
    
    # Train model
    print("Starting training...")
    train_losses, test_losses = train_model(model, train_loader, test_loader, num_epochs=30)
    
    # Plot training curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Progress')
    plt.savefig(f'{location}_training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Save model
    torch.save(model.state_dict(), f'{location}_cliff_failure_model.pth')
    print(f"Model saved as {location}_cliff_failure_model.pth")
    
    return model, scaler