In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pickle
from tqdm import tqdm
import sys
import os

# Add parent directory to path
sys.path.append('..')
from launch_experiment import initialize as init_agent
from hydra import initialize, compose
import rlkit.torch.pytorch_util as ptu
from scripts.mine.mine import get_estimator

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ptu.set_gpu_mode(torch.cuda.is_available())
print(f"Using device: {device}")


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  from distutils.dep_util import newer, newer_group


import module: half_cheetah_vel
import module: humanoid_dir
import module: ant_dir
import module: hopper_params
import module: ant_goal
import module: walker_params
import module: point_robot
import module: half_cheetah_dir
import module: wrappers




Using device: cuda


In [2]:
def load_agent(env, algo='offlinepearl'):
    """Load trained agent and config"""
    with initialize(version_base="1.3", config_path="../cfgs"):
        cfg = compose('experiment', overrides=[f'+env={env}', f'+algo={algo}'])
    agent = init_agent(cfg)
    return agent, cfg


In [10]:
# Main execution
# Load agent
env_name = 'cheetah-vel'
agent, cfg = load_agent(env=env_name, algo='offlinepearl')

print(f"Environment: {env_name}")
print(f"Training tasks: {agent.train_tasks}")
print(f"Eval tasks: {agent.eval_tasks}")

# Get dimensions from agent's environment
state_dim = int(np.prod(agent.env.observation_space.shape))  # observation_dim
action_dim = int(np.prod(agent.env.action_space.shape))
reward_dim = 1
latent_dim = agent.latent_dim  # This is stored in the algorithm object

print(f"\nDimensions:")
print(f"  State (observation): {state_dim}")
print(f"  Action: {action_dim}")
print(f"  Latent: {latent_dim}")


---------- Networks initialized -------------
[Network] Total number of parameters : 0.235 M
-----------------------------------------------
---------- Networks initialized -------------
[Network] Total number of parameters : 0.144 M
-----------------------------------------------
---------- Networks initialized -------------
[Network] Total number of parameters : 0.144 M
-----------------------------------------------
---------- Networks initialized -------------
[Network] Total number of parameters : 0.142 M
-----------------------------------------------
---------- Networks initialized -------------
[Network] Total number of parameters : 0.144 M
-----------------------------------------------
---------- Networks initialized -------------
[Network] Total number of parameters : 0.094 M
-----------------------------------------------
Environment: cheetah-vel
Training tasks: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
Eval tasks: [5, 6, 7, 8, 9, 30, 

In [84]:
def get_batch_with_latent(agent, task_idx, include_action=True, batch_size=64, is_training=None):
    """
    Get batch from agent's buffer and compute latent representation Z
    
    Args:
        agent: The loaded agent
        task_idx: Task index (single integer or list of task indices)
        include_action: Whether to include action in the context
        batch_size: Batch size
        is_training: If True, use replay_buffer; if False, use eval_buffer; 
                     if None, automatically determine based on task_idx
    
    Returns:
        Tuple of (joint, marginal) tensors
    """
    # Handle both single task index and list of task indices
    if not hasattr(task_idx, '__iter__') or isinstance(task_idx, (str, bytes)):
        task_indices = [task_idx]
    else:
        task_indices = list(task_idx)
    
    # Auto-determine buffer if is_training is None
    if is_training is None:
        # Check if task is in train_tasks or eval_tasks
        if task_indices[0] in agent.train_tasks:
            is_training = True
        elif task_indices[0] in agent.eval_tasks:
            is_training = False
        else:
            # Default to training buffer if not found in either
            is_training = True
            print(f"Warning: task {task_indices[0]} not found in train_tasks or eval_tasks, using replay_buffer")

    # Get batch from appropriate buffer
    if is_training:
        batches = [ptu.np_to_pytorch_batch(agent.replay_buffer.random_batch(idx, batch_size=batch_size)) 
                   for idx in task_indices]
        marginal_batches1 = [ptu.np_to_pytorch_batch(agent.replay_buffer.random_batch(idx, batch_size=batch_size)) 
                   for idx in task_indices]
        marginal_batches2 = [ptu.np_to_pytorch_batch(agent.replay_buffer.random_batch(idx, batch_size=batch_size)) 
                   for idx in task_indices]
    else:
        batches = [ptu.np_to_pytorch_batch(agent.eval_buffer.random_batch(idx, batch_size=batch_size)) 
                   for idx in task_indices]
        marginal_batches1 = [ptu.np_to_pytorch_batch(agent.eval_buffer.random_batch(idx, batch_size=batch_size)) 
                   for idx in task_indices]
        marginal_batches2 = [ptu.np_to_pytorch_batch(agent.eval_buffer.random_batch(idx, batch_size=batch_size)) 
                   for idx in task_indices]
    
    # Use agent's unpack_batch method (same as sample_context in agent)
    unpacked = [agent.unpack_batch(batch, sparse_reward=False) for batch in batches]
    marginal_unpacked1 = [agent.unpack_batch(batch, sparse_reward=False) for batch in marginal_batches1]
    marginal_unpacked2 = [agent.unpack_batch(batch, sparse_reward=False) for batch in marginal_batches2]
    # unpacked is [[o1, a1, r1, no1, t1], [o2, a2, r2, no2, t2], ...]
    
    # Group like elements together: [[o1, o2, ...], [a1, a2, ...], [r1, r2, ...], ...]
    unpacked = [[x[i] for x in unpacked] for i in range(len(unpacked[0]))]
    marginal_unpacked1 = [[x[i] for x in marginal_unpacked1] for i in range(len(marginal_unpacked1[0]))]
    marginal_unpacked2 = [[x[i] for x in marginal_unpacked2] for i in range(len(marginal_unpacked2[0]))]


    # Concatenate each group: [all_o, all_a, all_r, all_no, all_t]
    unpacked = [torch.cat(x, dim=0) for x in unpacked]
    marginal_unpacked1 = [torch.cat(x, dim=0) for x in marginal_unpacked1]
    marginal_unpacked2 = [torch.cat(x, dim=0) for x in marginal_unpacked2]

    # Get context encoder
    context_encoder = agent.agent.context_encoder
    context_encoder.eval()
    
    with torch.no_grad():
        # Create context using agent's existing method
        if agent.use_next_obs_in_context:
            context_input = torch.cat(unpacked[:-1], dim=2)  # [obs, act, rewards, next_obs]
            marginal_context_input1 = torch.cat(marginal_unpacked1[:-1], dim=2)
            marginal_context_input2 = torch.cat(marginal_unpacked2[:-1], dim=2)
        else:
            context_input = torch.cat(unpacked[:-2], dim=2)  # [obs, act, rewards]
            marginal_context_input1 = torch.cat(marginal_unpacked1[:-2], dim=2)
            marginal_context_input2 = torch.cat(marginal_unpacked2[:-2], dim=2)
        
        context_input = context_input.reshape(-1, context_input.shape[2])
        marginal_context_input1 = marginal_context_input1.reshape(-1, marginal_context_input1.shape[2])
        marginal_context_input2 = marginal_context_input2.reshape(-1, marginal_context_input2.shape[2])
        # Encode to get latent representation Z
        z_raw = context_encoder(context_input)
        marginal_z_raw2 = context_encoder(marginal_context_input2)
        
        # Handle information bottleneck case (latent_dim * 2 -> mean, std)
        if z_raw.shape[1] == agent.agent.latent_dim * 2:
            z = z_raw[:, :agent.agent.latent_dim]
            marginal_z = marginal_z_raw2[:, :agent.agent.latent_dim]
        else:
            z = z_raw
            marginal_z = marginal_z_raw2
        if include_action:
            pass
        else:
            state_dim = int(np.prod(agent.env.observation_space.shape))  # observation_dim
            action_dim = int(np.prod(agent.env.action_space.shape))
            context_input = torch.cat([context_input[:, :state_dim], context_input[:, state_dim + action_dim:]], dim=1)
            marginal_context_input1 = torch.cat([marginal_context_input1[:, :state_dim], marginal_context_input1[:, state_dim + action_dim:]], dim=1)
        return torch.cat([z, context_input], dim=1), torch.cat([marginal_z, marginal_context_input1], dim=1)


### MINE with IB

In [71]:
def train_mine_model(agent, train_tasks, latent_dim, action_dim, state_dim, reward_dim,
                    epochs=100, batch_size=256, samples_per_task=1000, lr=2e-4, estimator='dv', device='cpu'):
    """
    Train MINE model to estimate I(Z; A | S, R, S')
    
    For conditional MI I(Z; A | S, R, S'), we use the chain rule:
    I(Z; A | S, R, S') = I(Z; A, S, R, S') - I(Z; S, R, S')
    
    We estimate this by training two MINE models:
    1. I(Z; A, S, R, S') = I(Z; concat(A, S, R, S'))
    2. I(Z; S, R, S') = I(Z; concat(S, R, S'))
    
    Then: I(Z; A | S, R, S') = I(Z; A, S, R, S') - I(Z; S, R, S')
    
    Args:
        agent: The loaded agent
        train_tasks: List of training task indices
        latent_dim: Dimension of latent Z
        action_dim: Dimension of action A
        state_dim: Dimension of state S
        reward_dim: Dimension of reward R (typically 1)
        epochs: Number of training epochs
        batch_size: Batch size for training
        samples_per_task: Number of samples per task per epoch
        lr: Learning rate
        estimator: MINE estimator type ('dv', 'fdiv', 'nwj')
        device: Device to use
    
    Returns:
        Dictionary with models, losses, and mi_estimates
    """
    # Create conditioning variable dimensions
    X_dim = state_dim + reward_dim + state_dim  # S + R + S'
    AX_dim = action_dim + state_dim + reward_dim + state_dim  # A + S + R + S'
    
    # Initialize MINE model 1: I(Z; A, S, R, S')
    mine_args1 = {
        'estimator': estimator,
        'est_lr': lr,
        'variant': 'unbiased',
        'device': device
    }
    mine_model1 = get_estimator(AX_dim, latent_dim, mine_args1)
    mine_model1 = mine_model1.to(device)
    mine_model1.train()
    optimizer1, _ = mine_model1._configure_optimizers()
    
    # Initialize MINE model 2: I(Z; S, R, S')
    mine_args2 = {
        'estimator': estimator,
        'est_lr': lr,
        'variant': 'unbiased',
        'device': device
    }
    mine_model2 = get_estimator(X_dim, latent_dim, mine_args2)
    mine_model2 = mine_model2.to(device)
    mine_model2.train()
    optimizer2, _ = mine_model2._configure_optimizers()
    
    # Training loop
    losses1 = []
    losses2 = []
    mi_estimates1 = []  # I(Z; A, S, R, S')
    mi_estimates2 = []  # I(Z; S, R, S')
    mi_conditional = []  # I(Z; A | S, R, S')
    
    for epoch in tqdm(range(epochs), desc="Training MINE"):
        epoch_losses1 = []
        epoch_losses2 = []
        epoch_mi1 = []
        epoch_mi2 = []
        
        # Collect samples from all training tasks
        all_data = []
        for task_idx in train_tasks:
            # Get batch from agent's buffer
            data = get_batch_with_latent(agent, task_idx, samples_per_task, is_training=True)
            all_data.append(data)
        
        # Concatenate all tasks
        states = torch.cat([d['states'] for d in all_data], dim=0).to(device)
        actions = torch.cat([d['actions'] for d in all_data], dim=0).to(device)
        rewards = torch.cat([d['rewards'] for d in all_data], dim=0).to(device)
        next_states = torch.cat([d['next_states'] for d in all_data], dim=0).to(device)
        latents = torch.cat([d['latents'] for d in all_data], dim=0).to(device)
        
        # Create conditioning variable X = [S, R, S']
        X = torch.cat([states, rewards, next_states], dim=1)
        
        # Create joint variable AX = [A, S, R, S']
        AX = torch.cat([actions, states, rewards, next_states], dim=1)
        
        # Sample batches
        n_samples = X.shape[0]
        n_batches = n_samples // batch_size
        indices = np.random.permutation(n_samples)
        
        for i in range(n_batches):
            batch_idx = indices[i * batch_size:(i + 1) * batch_size]
            
            # Model 1: I(Z; A, S, R, S')
            ax_batch = AX[batch_idx]
            z_batch = latents[batch_idx]
            shuffle_idx = np.random.permutation(len(z_batch))
            z_marginal1 = z_batch[shuffle_idx]
            
            mi1 = mine_model1.get_mi_bound(ax_batch, z_batch, z_marginal=z_marginal1, 
                                          update_ema=(epoch > 10))
            loss1 = -mi1
            
            optimizer1.zero_grad()
            loss1.backward()
            optimizer1.step()
            
            epoch_losses1.append(loss1.item())
            epoch_mi1.append(mi1.item())
            
            # Model 2: I(Z; S, R, S')
            x_batch = X[batch_idx]
            z_marginal2 = z_batch[shuffle_idx]
            
            mi2 = mine_model2.get_mi_bound(x_batch, z_batch, z_marginal=z_marginal2, 
                                          update_ema=(epoch > 10))
            loss2 = -mi2
            
            optimizer2.zero_grad()
            loss2.backward()
            optimizer2.step()
            
            epoch_losses2.append(loss2.item())
            epoch_mi2.append(mi2.item())
        
        mine_model1.step_epoch()
        mine_model2.step_epoch()
        
        avg_loss1 = np.mean(epoch_losses1)
        avg_loss2 = np.mean(epoch_losses2)
        avg_mi1 = np.mean(epoch_mi1)
        avg_mi2 = np.mean(epoch_mi2)
        avg_mi_cond = avg_mi1 - avg_mi2  # I(Z; A | S, R, S')
        
        losses1.append(avg_loss1)
        losses2.append(avg_loss2)
        mi_estimates1.append(avg_mi1)
        mi_estimates2.append(avg_mi2)
        mi_conditional.append(avg_mi_cond)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{epochs}:")
            print(f"  I(Z; A, S, R, S') = {avg_mi1:.4f}, I(Z; S, R, S') = {avg_mi2:.4f}")
            print(f"  I(Z; A | S, R, S') = {avg_mi_cond:.4f}")
    
    # Return both models and combined results
    results = {
        'model_joint': mine_model1,  # I(Z; A, S, R, S')
        'model_cond': mine_model2,   # I(Z; S, R, S')
        'losses_joint': losses1,
        'losses_cond': losses2,
        'mi_joint': mi_estimates1,
        'mi_cond_base': mi_estimates2,
        'mi_conditional': mi_conditional
    }
    
    return results


In [72]:
def evaluate_mine_model(mine_models, agent, eval_tasks, state_dim, reward_dim, 
                       samples_per_task=1000, device='cpu'):
    """
    Evaluate MINE models on eval tasks to get I(Z; A | S, R, S') estimates
    
    Args:
        mine_models: Dictionary with 'model_joint' and 'model_cond' MINE models
        agent: The loaded agent
        eval_tasks: List of eval task indices
        state_dim: Dimension of state
        reward_dim: Dimension of reward
        samples_per_task: Number of samples per task for evaluation
        device: Device to use
    
    Returns:
        List of MI estimates per task
    """
    mine_model_joint = mine_models['model_joint']
    mine_model_cond = mine_models['model_cond']
    mine_model_joint.eval()
    mine_model_cond.eval()
    
    mi_estimates = []
    
    with torch.no_grad():
        for task_idx in tqdm(eval_tasks, desc="Evaluating on eval tasks"):
            # Get batch from agent's eval buffer
            data = get_batch_with_latent(agent, task_idx, samples_per_task, is_training=False)
            
            states = data['states'].to(device)
            actions = data['actions'].to(device)
            rewards = data['rewards'].to(device)
            next_states = data['next_states'].to(device)
            latents = data['latents'].to(device)
            
            # Create conditioning variable X = [S, R, S']
            X = torch.cat([states, rewards, next_states], dim=1)
            
            # Create joint variable AX = [A, S, R, S']
            AX = torch.cat([actions, states, rewards, next_states], dim=1)
            
            # Evaluate both MI estimates
            # Use multiple samples for stable estimate
            mi_joint_vals = []
            mi_cond_vals = []
            n_eval_samples = min(1000, len(X))
            
            for _ in range(10):  # Multiple evaluations for stability
                indices = np.random.choice(len(X), n_eval_samples, replace=False)
                
                # Evaluate I(Z; A, S, R, S')
                ax_batch = AX[indices]
                z_batch = latents[indices]
                shuffle_idx = np.random.permutation(len(z_batch))
                z_marginal1 = z_batch[shuffle_idx]
                
                mi1 = mine_model_joint.get_mi_bound(ax_batch, z_batch, z_marginal=z_marginal1, update_ema=False)
                mi_joint_vals.append(mi1.item())
                
                # Evaluate I(Z; S, R, S')
                x_batch = X[indices]
                z_marginal2 = z_batch[shuffle_idx]
                
                mi2 = mine_model_cond.get_mi_bound(x_batch, z_batch, z_marginal=z_marginal2, update_ema=False)
                mi_cond_vals.append(mi2.item())
            
            # Compute conditional MI: I(Z; A | S, R, S') = I(Z; A, S, R, S') - I(Z; S, R, S')
            avg_mi_joint = np.mean(mi_joint_vals)
            avg_mi_cond_base = np.mean(mi_cond_vals)
            avg_mi_conditional = avg_mi_joint - avg_mi_cond_base
            
            mi_estimates.append({
                'task_idx': task_idx,
                'mi_joint': avg_mi_joint,
                'mi_cond_base': avg_mi_cond_base,
                'mi_estimate': avg_mi_conditional,  # I(Z; A | S, R, S')
                'mi_std': np.std([m1 - m2 for m1, m2 in zip(mi_joint_vals, mi_cond_vals)])
            })
    
    return mi_estimates


In [73]:
# # Train MINE models on training tasks data
# print("\nTraining MINE models for I(Z; A | S, R, S')...")
# train_results = train_mine_model(
#     train_buffer,
#     latent_dim=latent_dim,
#     action_dim=action_dim,
#     state_dim=state_dim,
#     reward_dim=reward_dim,
#     epochs=100,
#     batch_size=256,
#     lr=2e-4,
#     estimator='dv',
#     device=device
# )

# # Extract results
# mine_models = {
#     'model_joint': train_results['model_joint'],
#     'model_cond': train_results['model_cond']
# }
# train_losses_joint = train_results['losses_joint']
# train_losses_cond = train_results['losses_cond']
# train_mi_joint = train_results['mi_joint']
# train_mi_cond_base = train_results['mi_cond_base']
# train_mi_conditional = train_results['mi_conditional']


In [None]:
# # Plot training progress
# fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# # Loss plots
# axes[0, 0].plot(train_losses_joint, label='I(Z; A, S, R, S\')')
# axes[0, 0].plot(train_losses_cond, label='I(Z; S, R, S\')')
# axes[0, 0].set_xlabel('Epoch')
# axes[0, 0].set_ylabel('Loss (Negative MI)')
# axes[0, 0].set_title('Training Loss')
# axes[0, 0].legend()
# axes[0, 0].grid(True)

# # MI estimates
# axes[0, 1].plot(train_mi_joint, label='I(Z; A, S, R, S\')')
# axes[0, 1].plot(train_mi_cond_base, label='I(Z; S, R, S\')')
# axes[0, 1].set_xlabel('Epoch')
# axes[0, 1].set_ylabel('MI Estimate')
# axes[0, 1].set_title('MI Estimates during Training')
# axes[0, 1].legend()
# axes[0, 1].grid(True)

# # Conditional MI
# axes[1, 0].plot(train_mi_conditional, color='green', linewidth=2)
# axes[1, 0].set_xlabel('Epoch')
# axes[1, 0].set_ylabel('I(Z; A | S, R, S\')')
# axes[1, 0].set_title('Conditional MI: I(Z; A | S, R, S\')')
# axes[1, 0].grid(True)

# # Combined view
# axes[1, 1].plot(train_mi_joint, label='I(Z; A, S, R, S\')', alpha=0.7)
# axes[1, 1].plot(train_mi_cond_base, label='I(Z; S, R, S\')', alpha=0.7)
# axes[1, 1].plot(train_mi_conditional, label='I(Z; A | S, R, S\')', linewidth=2, color='green')
# axes[1, 1].set_xlabel('Epoch')
# axes[1, 1].set_ylabel('MI Estimate')
# axes[1, 1].set_title('All MI Estimates')
# axes[1, 1].legend()
# axes[1, 1].grid(True)

# plt.tight_layout()
# plt.show()

# print(f"\nFinal MI estimates on training data:")
# print(f"  I(Z; A, S, R, S') = {train_mi_joint[-1]:.4f}")
# print(f"  I(Z; S, R, S') = {train_mi_cond_base[-1]:.4f}")
# print(f"  I(Z; A | S, R, S') = {train_mi_conditional[-1]:.4f}")


In [None]:
# # Evaluate on eval tasks
# print("\nEvaluating MINE models on eval tasks...")
# eval_results = evaluate_mine_model(mine_models, eval_buffer, state_dim, reward_dim, device=device)

# # Extract MI estimates and task indices
# task_indices = [r['task_idx'] for r in eval_results]
# mi_joint_vals = [r['mi_joint'] for r in eval_results]
# mi_cond_base_vals = [r['mi_cond_base'] for r in eval_results]
# mi_values = [r['mi_estimate'] for r in eval_results]  # Conditional MI
# mi_stds = [r['mi_std'] for r in eval_results]

# print(f"\nMI Estimates for eval tasks:")
# for i, (task_idx, mi_j, mi_cb, mi, std) in enumerate(zip(task_indices, mi_joint_vals, mi_cond_base_vals, mi_values, mi_stds)):
#     print(f"  Task {task_idx}: I(Z; A, S, R, S') = {mi_j:.4f}, I(Z; S, R, S') = {mi_cb:.4f}")
#     print(f"    => I(Z; A | S, R, S') = {mi:.4f} Â± {std:.4f}")


In [None]:
# # Plot results
# # Sort by task index for better visualization
# sort_idx = np.argsort(task_indices)
# sorted_task_indices = np.array(task_indices)[sort_idx]
# sorted_mi_joint = np.array(mi_joint_vals)[sort_idx]
# sorted_mi_cond_base = np.array(mi_cond_base_vals)[sort_idx]
# sorted_mi_values = np.array(mi_values)[sort_idx]
# sorted_mi_stds = np.array(mi_stds)[sort_idx]

# # Plot 1: Conditional MI with error bars
# fig, ax = plt.subplots(1, 1, figsize=(12, 6))
# ax.errorbar(range(len(sorted_task_indices)), sorted_mi_values, yerr=sorted_mi_stds, 
#             fmt='o-', capsize=5, capthick=2, linewidth=2, markersize=8, color='green', label='I(Z; A | S, R, S\')')
# ax.set_xlabel('Eval Task Index', fontsize=12)
# ax.set_ylabel('I(Z; A | S, R, S\')', fontsize=12)
# ax.set_title(f'Conditional Mutual Information I(Z; A | S, R, S\') for Eval Tasks - {env_name}', fontsize=14)
# ax.set_xticks(range(len(sorted_task_indices)))
# ax.set_xticklabels(sorted_task_indices, rotation=45, ha='right')
# ax.grid(True, alpha=0.3)
# ax.axhline(y=train_mi_conditional[-1], color='r', linestyle='--', linewidth=2, 
#            label=f'Training: {train_mi_conditional[-1]:.4f}')
# ax.legend(fontsize=11)
# plt.tight_layout()
# plt.show()

# # Plot 2: Bar plot for conditional MI
# fig, ax = plt.subplots(1, 1, figsize=(14, 6))
# bars = ax.bar(range(len(sorted_task_indices)), sorted_mi_values, yerr=sorted_mi_stds,
#               capsize=5, alpha=0.7, color='steelblue', edgecolor='black')
# ax.set_xlabel('Eval Task Index', fontsize=12)
# ax.set_ylabel('I(Z; A | S, R, S\')', fontsize=12)
# ax.set_title(f'Conditional Mutual Information I(Z; A | S, R, S\') for Eval Tasks - {env_name}', fontsize=14)
# ax.set_xticks(range(len(sorted_task_indices)))
# ax.set_xticklabels(sorted_task_indices, rotation=45, ha='right')
# ax.grid(True, alpha=0.3, axis='y')
# ax.axhline(y=train_mi_conditional[-1], color='r', linestyle='--', linewidth=2, 
#            label=f'Training Average: {train_mi_conditional[-1]:.4f}')
# ax.legend(fontsize=11)
# plt.tight_layout()
# plt.show()

# # Plot 3: Comparison of joint, conditional base, and conditional MI
# fig, ax = plt.subplots(1, 1, figsize=(14, 6))
# x_pos = range(len(sorted_task_indices))
# width = 0.25

# ax.bar([x - width for x in x_pos], sorted_mi_joint, width, label='I(Z; A, S, R, S\')', alpha=0.7, color='blue')
# ax.bar(x_pos, sorted_mi_cond_base, width, label='I(Z; S, R, S\')', alpha=0.7, color='orange')
# ax.bar([x + width for x in x_pos], sorted_mi_values, width, label='I(Z; A | S, R, S\')', alpha=0.7, color='green', yerr=sorted_mi_stds, capsize=3)

# ax.set_xlabel('Eval Task Index', fontsize=12)
# ax.set_ylabel('MI Estimate', fontsize=12)
# ax.set_title(f'MI Estimates Comparison for Eval Tasks - {env_name}', fontsize=14)
# ax.set_xticks(x_pos)
# ax.set_xticklabels(sorted_task_indices, rotation=45, ha='right')
# ax.legend(fontsize=11)
# ax.grid(True, alpha=0.3, axis='y')
# plt.tight_layout()
# plt.show()


### MINE - Simple

In [12]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd

In [39]:
class Mine(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Mine, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        nn.init.normal_(self.fc1.weight, std=0.02)
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.normal_(self.fc2.weight, std=0.02)
        nn.init.constant_(self.fc2.bias, 0)
        nn.init.normal_(self.fc3.weight, std=0.02)
        nn.init.constant_(self.fc3.bias, 0)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = self.fc3(x)
        return output
    
    
def mutual_info(joint, marginal, mine_net):
    t = mine_net(joint)
    et = torch.exp(mine_net(marginal))
    mi = torch.mean(t) - torch.log(torch.mean(et))
    return mi, t, et


def train_mine(info, mine_net, optimizer, ma_et, ma_rate=0.01):
    joint, marginal = info
    
    # Handle both tensor and numpy array inputs
    if isinstance(joint, torch.Tensor):
        # Already a tensor, just ensure it's on the right device
        if not joint.is_cuda:
            joint = joint.cuda()
    else:
        # numpy array or list, convert to tensor
        joint = torch.autograd.Variable(torch.FloatTensor(joint)).cuda()
    
    if isinstance(marginal, torch.Tensor):
        # Already a tensor, just ensure it's on the right device
        if not marginal.is_cuda:
            marginal = marginal.cuda()
    else:
        # numpy array or list, convert to tensor
        marginal = torch.autograd.Variable(torch.FloatTensor(marginal)).cuda()
    
    # joint = joint.view(joint.size(0), -1)
    # marginal = marginal.view(marginal.size(0), -1)
    
    mi, t, et = mutual_info(joint, marginal, mine_net)
    ma_et = (1-ma_rate)*ma_et + ma_rate*torch.mean(et)
    
    loss = -(torch.mean(t) - (1/ma_et.mean()).detach()*torch.mean(et))
    optimizer.zero_grad()
    autograd.backward(loss)
    # loss.backward()
    optimizer.step()
    
    return loss.item(), mi, ma_et
    

In [85]:
def train(agent, task_idx, mine_net, optimizer, include_action=True, iter_num = int(2e+4), batch_size = 1024, log_freq = int(1e+2)):
    result = list()
    ma_et = 1.0
    # Auto-determine which buffer to use based on task_idx
    # is_training=None will auto-detect from agent.train_tasks and agent.eval_tasks
    for i in range(iter_num):
        batch = get_batch_with_latent(agent, task_idx, include_action, batch_size, is_training=False)
        loss, mi, ma_et = train_mine(batch, mine_net, optimizer, ma_et)
        # mi is already a scalar (detached), just append the value
        if isinstance(mi, torch.Tensor):
            result.append(mi.detach().cpu().item())
        else:
            result.append(float(mi))
        if i % log_freq == 0:
            print(f"Iteration {i}, loss : {loss:.4f}, MI: {mi.item() if isinstance(mi, torch.Tensor) else mi:.4f}")
    return result


In [86]:
def ma(a, window_size=100):
    return [np.mean(a[i:i+window_size]) for i in range(len(a)-window_size+1)]


mine_net_1 = Mine(input_dim=state_dim+action_dim+reward_dim+state_dim, hidden_dim=200, output_dim=1).cuda()
mine_net_2 = Mine(input_dim=state_dim+reward_dim+state_dim, hidden_dim=200, output_dim=1).cuda()

mine_opt_1 = optim.Adam(mine_net_1.parameters(), lr=1e-3)
mine_opt_2 = optim.Adam(mine_net_2.parameters(), lr=1e-3)




In [87]:
eval_tasks = agent.eval_tasks
eval_tasks

[5, 6, 7, 8, 9, 30, 31, 32, 33, 34]

In [89]:
task_idx = 5
mine_net_1 = Mine(input_dim=state_dim+action_dim+reward_dim+state_dim, hidden_dim=200, output_dim=1).cuda()
mine_net_2 = Mine(input_dim=state_dim+reward_dim+state_dim, hidden_dim=200, output_dim=1).cuda()
mine_opt_1 = optim.Adam(mine_net_1.parameters(), lr=1e-3)
mine_opt_2 = optim.Adam(mine_net_2.parameters(), lr=1e-3)
result_1 = train(agent, task_idx, mine_net_1, mine_opt_1, include_action=True, iter_num=int(2e+4), batch_size=64, log_freq=int(1e+2))
result_2 = train(agent, task_idx, mine_net_2, mine_opt_2, include_action=False, iter_num=int(2e+4), batch_size=64, log_freq=int(1e+2))

ma_result_1 = ma(result_1)
ma_result_2 = ma(result_2)

conditional_MI = ma_result_1[-1] - ma_result_2[-1]
print(conditional_MI)




Iteration 0, loss : 1.0003, MI: -0.0003
Iteration 100, loss : 1.0034, MI: -0.0042
Iteration 200, loss : 0.9995, MI: 0.0002
Iteration 300, loss : 1.0002, MI: -0.0007
Iteration 400, loss : 1.0002, MI: -0.0006
Iteration 500, loss : 0.9907, MI: 0.0080
Iteration 600, loss : 1.0073, MI: -0.0072
Iteration 700, loss : 1.0015, MI: -0.0016
Iteration 800, loss : 1.0000, MI: 0.0006
Iteration 900, loss : 1.0003, MI: 0.0002
Iteration 1000, loss : 1.0002, MI: -0.0005
Iteration 1100, loss : 1.0029, MI: -0.0023
Iteration 1200, loss : 1.0022, MI: -0.0026
Iteration 1300, loss : 1.0038, MI: -0.0034
Iteration 1400, loss : 0.9986, MI: 0.0016
Iteration 1500, loss : 0.9949, MI: 0.0057
Iteration 1600, loss : 1.0109, MI: -0.0082
Iteration 1700, loss : 0.9833, MI: 0.0138
Iteration 1800, loss : 1.0292, MI: -0.0329
Iteration 1900, loss : 0.9975, MI: 0.0105
Iteration 2000, loss : 0.8293, MI: 0.1680
Iteration 2100, loss : 0.8490, MI: 0.2093
Iteration 2200, loss : 0.6987, MI: 0.3130
Iteration 2300, loss : 0.4613, MI: