In [None]:
#| default_exp datasets.abstract_exp

In [None]:
#| export
import torch
from fastai.vision.all import DataLoaders
from torch.utils.data import TensorDataset, DataLoader as DataLoader_torch   
import sklearn
import numpy as np
from air.losses_metrics import get_mig

In [None]:
from nbdev import nbdev_export
nbdev_export()  

# Datasets

## Normal dataset 

In [None]:
#| export
def random_functions(x, m, seed = 0):
    '''
    Generate m random nonlinear functions of input x using:
    
    f_i(x) = sin(w_i^T x + b_i)
    
    where w_i and b_i are randomly sampled weights and biases.

    Args:
        x: Input tensor of shape (n_samples, n_features)
        m: Number of random functions to generate
        seed: Random seed for reproducibility
    Returns:
        Tensor of shape (n_samples, m) with the outputs of the m random functions
    '''

    torch.random.manual_seed(seed)  # For reproducibility    
    k = x.shape[-1]
    # Create random weights and biases for m functions, each depending on all k inputs
    weights = torch.randn(m, k)
    biases = torch.randn(m)
    # Apply a nonlinear function (e.g., sin) to a weighted sum of inputs
    return torch.sin(x @ weights.t() + biases)

In [None]:
#| export
def dataset_abstract(N, num_h = 4, dim_x = None, dim_y = 10, size_train = 0.8, BS = 100,
                    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                    seed_funcs = 0
                    ):
    """
    Create a synthetic dataset for the abstract experiment.
    
    Parameters:
    - N: Number of samples
    - num_h: Number of hidden factors
    - dim_y: Dimensionality of the output
    - size_train: Proportion of data to use for training
    - BS: Batch size
    - device: Device to which tensors will be moved (CPU or GPU)
    
    Returns:
    - DataLoaders for training and testing datasets
    """

    torch.random.manual_seed(0)  # For reproducibility
    h = torch.rand((N, num_h))

    y1 = random_functions(h[:,:2], dim_y, seed_funcs)
    y2 = random_functions(h[:,1:], dim_y, seed_funcs)

    if dim_x is None:
        x = h.clone()
    else:
        x = random_functions(h.clone(), dim_x, seed_funcs)     

    # Resetting torch seed
    torch.seed()


    a1 = torch.tensor([0,1]).repeat((N,1))
    a2 = torch.tensor([1,0]).repeat((N,1))

    data1 = torch.hstack((x, a1))
    data2 = torch.hstack((x, a2))
    inputs = torch.vstack((data1, data2))
    outputs = torch.vstack((y1, y2))

    dataset_size = inputs.shape[0]

    # Create dataset as inputs (all data) and outputs (only trajectory, no action))
    dataset = TensorDataset(inputs[:int(size_train*dataset_size)].to(device),
                            outputs[:int(size_train*dataset_size)].to(device))    

    # Same for test set
    dataset_test = TensorDataset(inputs[int(size_train*dataset_size):].to(device),
                                outputs[int(size_train*dataset_size):].to(device))  

    # Now define the dataloaders
    loader = DataLoader_torch(dataset, batch_size = BS, shuffle = True)
    loader_test = DataLoader_torch(dataset_test, batch_size = BS, shuffle = True)

    return DataLoaders(loader, loader_test), loader_test, loader

In [None]:
a = dataset_abstract(2000)

## No actions dataset

In [None]:
#| export
def dataset_abstract_no_actions(N, num_h = 4, dim_x = 10, size_train = 0.8, BS = 100,
                    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                    seed_funcs = 0, torch_seed = None
                    ):
    """
    Create a synthetic dataset for the abstract experiment. In this case input is same as output as there are no 
    actions.
    
    Parameters:
    - N: Number of samples
    - num_h: Number of hidden factors
    - size_train: Proportion of data to use for training
    - BS: Batch size
    
    Returns:
    - DataLoaders for training and testing datasets
    """

    if torch_seed is None:
        torch.seed()  # For reproducibility
    else:
        torch.random.manual_seed(torch_seed)  # For reproducibility
        
    h = torch.rand((N, num_h))

    x = random_functions(h.clone(), dim_x, seed_funcs) 

    # Resetting torch seed
    torch.seed()

    dataset_size = x.shape[0]
    
    # Create dataset as inputs (all data) and outputs (only trajectory, no action))
    dataset = TensorDataset(x[:int(size_train*dataset_size)].to(device),
                            x[:int(size_train*dataset_size)].to(device))    

    # Same for test set
    dataset_test = TensorDataset(x[int(size_train*dataset_size):].to(device),
                                 x[int(size_train*dataset_size):].to(device))  

    # Now define the dataloaders
    loader = DataLoader_torch(dataset, batch_size = BS, shuffle = True)
    loader_test = DataLoader_torch(dataset_test, batch_size = BS, shuffle = True)

    return DataLoaders(loader, loader_test), loader_test, loader

In [None]:
a,b,c = dataset_abstract_no_actions(10)

# MIG abstract dataset

## Data generation for MIG computation

In [None]:
#| export
def create_data_mig(N, num_h, seed_funcs, bins, dim_x):
    
    '''
    Creates the data needed to compute the Mutual Information Gap (MIG) for the abstract experiment 

    Parameters
    ----------
    - N: Number of samples
    - num_h: Number of hidden factors
    - seed_funcs: Seed for random functions
    - bins: Bins for discretization
    - dim_x: Dimension of the observation space
    
    Returns
    -------
    - true_factors: The true hidden factors
    - data: The generated dataset with observations and actions
    - entropy: Entropy of each hidden factor
    '''
    
    true_factors = torch.rand(N, num_h)

    entropy = np.zeros((num_h))
    for idx in range(num_h):
        cj = true_factors[:, idx].numpy()
        cj = np.digitize(cj, np.histogram(cj, bins = bins)[1][:-1])
        entropy[idx] = sklearn.metrics.normalized_mutual_info_score(cj, cj)

    assert (entropy == 1).all()
        
    observations = random_functions(true_factors.clone(), dim_x, seed_funcs)
    
    a1 = torch.tensor([0,1]).repeat((N,1))
    a2 = torch.tensor([1,0]).repeat((N,1))
    
    data1 = torch.hstack((observations, a1))
    data2 = torch.hstack((observations, a2))
    
    data = torch.vstack((data1, data2))

    return true_factors, data, entropy

In [None]:
true_factors, data, _ = create_data_mig(N = 10, num_h = 4, seed_funcs = 3, bins = 20, dim_x = 15)

## Class mig calculator

In [None]:
#| export
class mig_calc_abs_data():

    '''
    Class to compute the Mutual Information Gap (MIG) for the abstract experiment dataset.

    Parameters
    ----------
    - N: Number of samples
    - num_h: Number of hidden factors
    - dim_x: Dimension of the observation space
    - seed_funcs: Seed for random functions
    - torch_seed: Seed for torch random number generator
    - device: Device to which tensors will be moved (CPU or GPU)
    - action: Action to be used in the dataset (None, False, or specific action)
    - bins: Number of bins for discretization
    - normalized_MI: Whether to use normalized mutual information
    '''

    def __init__(self, N, num_h, dim_x, seed_funcs, 
                 torch_seed = None, device = 'cuda' if torch.cuda.is_available() else 'cpu', 
                 action = None, bins = 20,
                 normalized_MI = True
                ):

        

        self.N = N
        self.dim_x = dim_x
        self.num_h = num_h
        self.action = action
        self.seed_funcs = seed_funcs
        self.torch_seed = torch_seed 
        self.bins = bins
        self.normalized_MI = normalized_MI
        self.device = device

        self.reset_dataset()
    
    def reset_dataset(self):
        
        if self.torch_seed is None:
            torch.seed()
        else:
            torch.random.manual_seed(self.torch_seed)

        self.true_factors = torch.rand(self.N, self.num_h)

        self.entropy = np.zeros((self.num_h))
        for idx in range(self.num_h):
            cj = self.true_factors[:, idx].numpy()
            cj = np.digitize(cj, np.histogram(cj, bins = self.bins)[1][:-1])
            self.entropy[idx] = sklearn.metrics.normalized_mutual_info_score(cj, cj)
        
        observations = random_functions(self.true_factors.clone(), self.dim_x, self.seed_funcs)

        # Reset torch seed
        torch.seed()        

        if self.action is None:                
            data1 = torch.hstack((observations, torch.tensor([0,1]).repeat((self.N,1))))
            data2 = torch.hstack((observations, torch.tensor([1,0]).repeat((self.N,1))))            
            self.data = torch.vstack((data1, data2)).to(self.device)
        elif self.action is False:
            self.data = observations.to(self.device)        
        else:
            self.data = torch.hstack((observations.repeat(2,1), torch.tensor(self.action).repeat((2*self.N,1)))).to(self.device)
        

    def compute_mig(self, 
                    model, # must 3 outputs, where 2nd and 3rd are mu and logvar
                    bins = 20,
                    normalized_MI = True,
                    only_mig = False,
                    disentangled_var = 1,
                    compute_with_mu = True
                   ):

        with torch.no_grad():
            _, mu, logvars = model(self.data)
        z = reparametrize(mu, logvars).cpu().detach().numpy()

        input_mi = mu.detach().cpu().numpy() if compute_with_mu else z

        if self.action == False:
            mi, mig = get_mig(input_mi, self.true_factors, self.num_h, bins = self.bins, normalized_MI = self.normalized_MI)            
        else:
            mi, mig = get_mig(input_mi, self.true_factors.repeat((2,1)), self.num_h, bins = self.bins, normalized_MI = self.normalized_MI)

        if only_mig:
            return mig
        
        mi = np.sort(mi, 1)[:, ::-1]

        # Expected disentangled neuron
        dis_mig = (mi[disentangled_var,0] - mi[disentangled_var,1])/self.entropy[1]

        # Rest of neurons
        ent_mi = np.delete(mi.copy(), disentangled_var, axis = 0)        
        ent_mig = ((ent_mi[:,0]-ent_mi[:,1])/self.entropy[np.delete(np.arange(self.num_h), disentangled_var)]).mean()

        return mi, mig, dis_mig, ent_mig

In [None]:
mig_calc = mig_calc_abs_data(N = 20, num_h = 3, dim_x = 16, seed_funcs = 123, action=False)