In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Embedding Decoder Network
class EmbeddingDecoder(nn.Module):
    def __init__(self, embedding_dim=64, action_dim=None, hidden_dim=128):
        """
        Network that decodes embeddings back to actions
        
        Args:
            embedding_dim: Dimension of the embedding vectors
            action_dim: Dimension of the action space
            hidden_dim: Dimension of hidden layers
        """
        super(EmbeddingDecoder, self).__init__()
        
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, embedding):
        """
        Decode an embedding into an action
        
        Args:
            embedding: The embedding vector to decode [batch_size, embedding_dim]
            
        Returns:
            action: The decoded action [batch_size, action_dim]
        """
        return self.decoder(embedding)


# Diffusion Model Components
class Diffusion:
    def __init__(self, embedding_dim, timesteps=1000, beta_start=1e-4, beta_end=0.02):
        """
        Diffusion process for generating embeddings
        
        Args:
            embedding_dim: Dimension of embeddings
            timesteps: Number of diffusion timesteps
            beta_start: Starting noise schedule
            beta_end: Ending noise schedule
        """
        self.embedding_dim = embedding_dim
        self.timesteps = timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end
        
        # Define noise schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        
        # Calculations for diffusion q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        
        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
    
    def q_sample(self, x_0, t, noise=None):
        """
        Forward diffusion process: q(x_t | x_0)
        
        Args:
            x_0: Initial embedding [batch_size, embedding_dim]
            t: Timestep [batch_size]
            noise: Optional pre-generated noise [batch_size, embedding_dim]
            
        Returns:
            x_t: Noisy embedding at timestep t
        """
        if noise is None:
            noise = torch.randn_like(x_0)
            
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].reshape(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1)
        
        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_sample(self, model, x_t, t, state, t_index):
        """
        Sample from p(x_{t-1} | x_t) using model prediction
        
        Args:
            model: Denoising model
            x_t: Embedding at timestep t [batch_size, embedding_dim]
            t: Current timestep [batch_size]
            state: Current state input [batch_size, state_dim]
            t_index: Index of current timestep (integer)
            
        Returns:
            x_{t-1}: Denoised embedding at timestep t-1
        """
        with torch.no_grad():
            betas_t = self.betas[t].reshape(-1, 1)
            sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1)
            sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].reshape(-1, 1)
            
            # Predict noise
            predicted_noise = model(x_t, t, state)
            
            # Compute mean for p(x_{t-1} | x_t)
            model_mean = sqrt_recip_alphas_t * (
                x_t - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
            )
            
            if t_index == 0:
                return model_mean
            else:
                posterior_variance_t = self.posterior_variance[t].reshape(-1, 1)
                noise = torch.randn_like(x_t)
                return model_mean + torch.sqrt(posterior_variance_t) * noise
    
    def p_sample_loop(self, model, shape, state):
        """
        Generate samples using the diffusion model
        
        Args:
            model: Denoising model
            shape: Shape of embeddings to generate [batch_size, embedding_dim]
            state: Current state input [batch_size, state_dim]
            
        Returns:
            embeddings: Generated embeddings [batch_size, embedding_dim]
        """
        batch_size = shape[0]
        
        # Start from pure noise
        embeddings = torch.randn(shape).to(device)
        
        # Iteratively denoise
        for i in tqdm(reversed(range(0, self.timesteps)), desc='Sampling', total=self.timesteps):
            t = torch.full((batch_size,), i, device=device, dtype=torch.long)
            embeddings = self.p_sample(model, embeddings, t, state, i)
            
        return embeddings
    
    def train_loss(self, model, x_0, state):
        """
        Compute training loss for diffusion model
        
        Args:
            model: Denoising model
            x_0: Original embeddings [batch_size, embedding_dim]
            state: Current state input [batch_size, state_dim]
            
        Returns:
            loss: Mean squared error between predicted and actual noise
        """
        batch_size = x_0.shape[0]
        t = torch.randint(0, self.timesteps, (batch_size,), device=device).long()
        
        # Sample noise
        noise = torch.randn_like(x_0)
        x_t = self.q_sample(x_0, t, noise)
        
        # Predict noise
        predicted_noise = model(x_t, t, state)
        
        # Compute loss
        return F.mse_loss(predicted_noise, noise)


# Denoising Model (UNet-like architecture for diffusion)
class ConditionalDenoisingModel(nn.Module):
    def __init__(self, embedding_dim, state_dim, time_dim=128, hidden_dim=256):
        """
        Denoising model for diffusion process
        
        Args:
            embedding_dim: Dimension of the embedding vectors
            state_dim: Dimension of the state space
            time_dim: Dimension of time embeddings
            hidden_dim: Dimension of hidden layers
        """
        super(ConditionalDenoisingModel, self).__init__()
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.ReLU(),
            nn.Linear(time_dim, time_dim)
        )
        
        # State encoder
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Conditional UNet-like architecture
        # Down blocks
        self.down1 = nn.Sequential(
            nn.Linear(embedding_dim + hidden_dim + time_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.down2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, hidden_dim//2)
        )
        
        # Middle block
        self.mid = nn.Sequential(
            nn.Linear(hidden_dim//2, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, hidden_dim//2)
        )
        
        # Up blocks with skip connections
        self.up1 = nn.Sequential(
            nn.Linear(hidden_dim//2 + hidden_dim//2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.up2 = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim)
        )
        
    def forward(self, x, t, state):
        """
        Predict noise in embedding at timestep t
        
        Args:
            x: Noisy embedding at timestep t [batch_size, embedding_dim]
            t: Current timestep [batch_size]
            state: Current state input [batch_size, state_dim]
            
        Returns:
            predicted_noise: Predicted noise in the embedding [batch_size, embedding_dim]
        """
        # Encode time step
        t_emb = self.time_mlp(t)
        
        # Encode state
        s_emb = self.state_encoder(state)
        
        # Concatenate inputs
        x_input = torch.cat([x, s_emb, t_emb], dim=1)
        
        # Down path
        down1 = self.down1(x_input)
        down2 = self.down2(down1)
        
        # Middle
        mid = self.mid(down2)
        
        # Up path with skip connections
        up1 = self.up1(torch.cat([mid, down2], dim=1))
        up2 = self.up2(torch.cat([up1, down1], dim=1))
        
        return up2


# Sinusoidal position embeddings for diffusion timesteps
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


# DiffusionPolicy - Combined model for policy generation
class DiffusionPolicy:
    def __init__(self, state_dim, action_dim, embedding_dim=64, hidden_dim=256, 
                 diffusion_steps=100, device=None):
        """
        Policy that generates actions using diffusion model and decoder
        
        Args:
            state_dim: Dimension of the state space
            action_dim: Dimension of the action space
            embedding_dim: Dimension of the embedding space
            hidden_dim: Dimension of hidden layers
            diffusion_steps: Number of diffusion timesteps
            device: Device to run computations on
        """
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.embedding_dim = embedding_dim
        
        # Create model components
        self.denoising_model = ConditionalDenoisingModel(
            embedding_dim=embedding_dim,
            state_dim=state_dim,
            hidden_dim=hidden_dim
        ).to(self.device)
        
        self.decoder = EmbeddingDecoder(
            embedding_dim=embedding_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim
        ).to(self.device)
        
        self.diffusion = Diffusion(
            embedding_dim=embedding_dim,
            timesteps=diffusion_steps
        )
        
        # Optimizers
        self.denoising_optimizer = optim.Adam(self.denoising_model.parameters(), lr=1e-4)
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=1e-4)
        
    def train_diffusion(self, embeddings, states, batch_size=64, epochs=100):
        """
        Train the diffusion model
        
        Args:
            embeddings: State-action embeddings [num_samples, embedding_dim]
            states: Corresponding states [num_samples, state_dim]
            batch_size: Training batch size
            epochs: Number of training epochs
            
        Returns:
            losses: List of training losses
        """
        dataset = torch.utils.data.TensorDataset(
            torch.FloatTensor(embeddings).to(self.device), 
            torch.FloatTensor(states).to(self.device)
        )
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        losses = []
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            for batch_idx, (batch_embeddings, batch_states) in enumerate(dataloader):
                self.denoising_optimizer.zero_grad()
                
                loss = self.diffusion.train_loss(self.denoising_model, batch_embeddings, batch_states)
                loss.backward()
                
                self.denoising_optimizer.step()
                epoch_loss += loss.item()
                
            avg_loss = epoch_loss / len(dataloader)
            losses.append(avg_loss)
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
                
        return losses
    
    def train_decoder(self, embeddings, actions, batch_size=64, epochs=100):
        """
        Train the embedding decoder
        
        Args:
            embeddings: State-action embeddings [num_samples, embedding_dim]
            actions: Corresponding actions [num_samples, action_dim]
            batch_size: Training batch size
            epochs: Number of training epochs
            
        Returns:
            losses: List of training losses
        """
        dataset = torch.utils.data.TensorDataset(
            torch.FloatTensor(embeddings).to(self.device), 
            torch.FloatTensor(actions).to(self.device)
        )
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        losses = []
        
        for epoch in range(epochs):
            epoch_loss = 0.0
            for batch_idx, (batch_embeddings, batch_actions) in enumerate(dataloader):
                self.decoder_optimizer.zero_grad()
                
                predicted_actions = self.decoder(batch_embeddings)
                loss = F.mse_loss(predicted_actions, batch_actions)
                loss.backward()
                
                self.decoder_optimizer.step()
                epoch_loss += loss.item()
                
            avg_loss = epoch_loss / len(dataloader)
            losses.append(avg_loss)
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
                
        return losses
        
    def select_action(self, state, num_samples=1, fast_sampling=False):
        """
        Generate action for a given state using the diffusion policy
        
        Args:
            state: Current state [state_dim]
            num_samples: Number of action samples to generate
            fast_sampling: If True, use fewer diffusion steps for faster inference
            
        Returns:
            action: Generated action [action_dim]
        """
        # Prepare state input
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state).to(self.device)
            
        if state.dim() == 1:
            state = state.unsqueeze(0)
            
        # Repeat state for multiple samples
        state = state.repeat(num_samples, 1)
        
        # Generate embeddings through diffusion
        shape = (num_samples, self.embedding_dim)
        
        if fast_sampling:
            # Use fewer steps for faster inference
            timesteps = min(10, self.diffusion.timesteps)
            
            # Start from pure noise
            embeddings = torch.randn(shape).to(self.device)
            
            # Iteratively denoise with fewer steps
            for i in reversed(range(0, timesteps)):
                t = torch.full((num_samples,), i, device=self.device, dtype=torch.long)
                t_index = i * self.diffusion.timesteps // timesteps
                embeddings = self.diffusion.p_sample(self.denoising_model, embeddings, t, state, t_index)
        else:
            # Use full diffusion process
            embeddings = self.diffusion.p_sample_loop(self.denoising_model, shape, state)
        
        # Decode embeddings to actions
        with torch.no_grad():
            actions = self.decoder(embeddings)
            
        if num_samples == 1:
            return actions[0].cpu().numpy()
        else:
            return actions.cpu().numpy()
    
    def save(self, path):
        """Save the model components to the specified path"""
        torch.save({
            'denoising_model': self.denoising_model.state_dict(),
            'decoder': self.decoder.state_dict(),
        }, path)
    
    def load(self, path):
        """Load the model components from the specified path"""
        checkpoint = torch.load(path)
        self.denoising_model.load_state_dict(checkpoint['denoising_model'])
        self.decoder.load_state_dict(checkpoint['decoder'])


# Example usage
def example_usage(state_dim, action_dim, num_samples=1000):
    # Create dummy data
    states = np.random.rand(num_samples, state_dim)
    actions = np.random.rand(num_samples, action_dim)
    
    # Create embedding network (from your implementation)
    from action_embedding_net import StateActionEmbedding
    
    embedding_net = StateActionEmbedding(state_dim, action_dim)
    
    # Convert to PyTorch tensors
    states_tensor = torch.FloatTensor(states)
    actions_tensor = torch.FloatTensor(actions)
    
    # Generate embeddings
    with torch.no_grad():
        embeddings = embedding_net(states_tensor, actions_tensor).numpy()
    
    # Create diffusion policy
    policy = DiffusionPolicy(state_dim, action_dim)
    
    # Train diffusion model
    policy.train_diffusion(embeddings, states, epochs=50)
    
    # Train decoder
    policy.train_decoder(embeddings, actions, epochs=50)
    
    # Generate action
    test_state = np.random.rand(state_dim)
    action = policy.select_action(test_state)
    
    print(f"Generated action: {action}")
    
    # Save model
    policy.save("diffusion_policy.pt")


if __name__ == "__main__":
    example_usage(state_dim=4, action_dim=2)