# Task clustering

# Imports

In [8]:
import numpy as np
from scipy.optimize import minimize
from scipy.stats import multivariate_normal
from scipy.special import logsumexp, softmax
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from scipy.linalg import solve_triangular, cholesky
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

In [9]:
np.random.seed(10)

# Neural Network Model

In [10]:
class MultiTaskNN:
    def __init__(self, n_input, n_hidden, n_tasks, activation='tanh'):
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_tasks = n_tasks
        self.activation = activation
        
        # Initialize with larger scale for better convergence
        self.W = np.random.randn(n_hidden, n_input + 1) * 0.5
        self.m = np.random.randn(n_hidden + 1) * 0.5
        self.Sigma = np.eye(n_hidden + 1) * 0.5
        self.sigma = 1.0
        self.A_map = [np.zeros(n_hidden + 1) for _ in range(n_tasks)]
    
    def _activate(self, x):
        if self.activation == 'tanh':
            return np.tanh(x)
        elif self.activation == 'linear':
            return x
        else:
            raise ValueError("Activation must be 'tanh' or 'linear'")
    
    def compute_hidden_activations(self, X):
        X_bias = np.hstack([X, np.ones((X.shape[0], 1))])
        H = self._activate(np.dot(X_bias, self.W.T))
        return np.hstack([H, np.ones((H.shape[0], 1))])
    
    def predict(self, X, task_idx):
        H_bias = self.compute_hidden_activations(X)
        return np.dot(H_bias, self.A_map[task_idx])
    
    def compute_sufficient_statistics(self, X, y):
        H_bias = self.compute_hidden_activations(X)
        return {
            'sum_hhT': np.dot(H_bias.T, H_bias),
            'sum_hy': np.dot(H_bias.T, y),
            'sum_yy': np.dot(y, y),
            'n_samples': X.shape[0]
        }
    
    def log_likelihood(self, params, all_stats):
        try:
            # Unpack parameters
            param_idx = 0
            
            # W
            W_size = self.n_hidden * (self.n_input + 1)
            W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1)
            param_idx += W_size
            
            # m
            m_size = self.n_hidden + 1
            m = params[param_idx:param_idx + m_size]
            param_idx += m_size
            
            # Sigma (Cholesky decomposition)
            L = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
            tril_indices = np.tril_indices(self.n_hidden + 1)
            L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])]
            param_idx += len(tril_indices[0])
            
            # sigma (log scale)
            log_sigma = params[param_idx]
            sigma = np.exp(log_sigma)
            
            total_log_lik = 0.0
            self.A_map = []
            
            # Add regularization to Sigma
            Sigma = np.dot(L, L.T) + 1e-6 * np.eye(self.n_hidden + 1)
            
            # Precompute Sigma inverse using Cholesky
            try:
                L_sigma = cholesky(Sigma, lower=True)
                Sigma_inv = solve_triangular(L_sigma, np.eye(self.n_hidden + 1), lower=True)
                Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
            except np.linalg.LinAlgError:
                return -np.inf
            
            for stats in all_stats:
                sum_hhT = stats['sum_hhT']
                sum_hy = stats['sum_hy']
                sum_yy = stats['sum_yy']
                n_samples = stats['n_samples']
                
                # Add small constant to avoid division by zero
                sigma_sq = max(sigma**2, 1e-8)
                
                # Compute Q_i with regularization
                Q_i = (1.0 / sigma_sq) * sum_hhT + Sigma_inv
                
                try:
                    L_Q = cholesky(Q_i + 1e-6*np.eye(self.n_hidden + 1), lower=True)
                    Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True)
                    Q_inv = np.dot(Q_inv.T, Q_inv)
                except np.linalg.LinAlgError:
                    return -np.inf
                
                R_i = (1.0 / sigma_sq) * sum_hy + np.dot(Sigma_inv, m)
                
                # Compute MAP estimate with regularization
                A_i = np.linalg.solve(Q_i + 1e-6*np.eye(self.n_hidden + 1), R_i)
                self.A_map.append(A_i)
                
                # Compute log determinants
                logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q)))
                logdet_Sigma = 2 * np.sum(np.log(np.diag(L_sigma)))
                
                # Compute log likelihood terms
                term1 = -0.5 * (logdet_Sigma + n_samples * 2 * log_sigma + logdet_Q_i)
                term2 = 0.5 * (np.dot(R_i, np.dot(Q_inv, R_i)) - (1.0 / sigma_sq) * sum_yy - np.dot(m, np.dot(Sigma_inv, m)))
                
                if not np.isfinite(term1 + term2):
                    return -np.inf
                
                total_log_lik += term1 + term2
            
            return total_log_lik if np.isfinite(total_log_lik) else -np.inf
        
        except:
            return -np.inf
    
    def fit(self, X_list, y_list, max_iter=100):
        # Normalize data
        X_list = [(X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-8) for X in X_list]
        y_list = [(y - np.mean(y)) / (np.std(y) + 1e-8) for y in y_list]
        
        # Compute sufficient statistics
        all_stats = [self.compute_sufficient_statistics(X, y) for X, y in zip(X_list, y_list)]
        
        # Initial parameters with better scaling
        initial_params = []
        initial_params.extend(self.W.flatten())
        initial_params.extend(self.m)
        
        # Initialize Sigma with Cholesky decomposition
        L = np.linalg.cholesky(self.Sigma + 1e-6 * np.eye(self.n_hidden + 1))
        tril_indices = np.tril_indices(self.n_hidden + 1)
        initial_params.extend(L[tril_indices])
        
        initial_params.append(np.log(self.sigma))
        
        # Optimize with bounds for stability
        bounds = []
        bounds.extend([(None, None)] * (self.n_hidden * (self.n_input + 1)))  # W
        bounds.extend([(None, None)] * (self.n_hidden + 1))  # m
        
        # L - diagonal elements must be positive
        for i in range(len(tril_indices[0])):
            if tril_indices[0][i] == tril_indices[1][i]:  # diagonal
                bounds.append((1e-8, None))
            else:
                bounds.append((None, None))
                
        bounds.append((np.log(1e-8), None))  # log_sigma
        
        # Optimization with error handling
        try:
            result = minimize(
                lambda p: -self.log_likelihood(p, all_stats),
                initial_params,
                method='L-BFGS-B',
                bounds=bounds,
                options={
                    'maxiter': max_iter,
                    'disp': True,
                    'maxfun': 15000,  # Increased function evaluations
                    'maxls': 50  # Increased line searches
                }
            )
            
            # Store optimized parameters
            self._unpack_parameters(result.x)
            
            # Recompute MAP estimates
            _ = self.log_likelihood(result.x, all_stats)
            
            return result
        
        except Exception as e:
            print(f"Optimization failed: {str(e)}")
            return None
    
    def _unpack_parameters(self, params):
        """Helper to unpack optimized parameters"""
        param_idx = 0
        
        # W
        W_size = self.n_hidden * (self.n_input + 1)
        self.W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1)
        param_idx += W_size
        
        # m
        m_size = self.n_hidden + 1
        self.m = params[param_idx:param_idx + m_size]
        param_idx += m_size
        
        # Sigma (Cholesky)
        L = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
        tril_indices = np.tril_indices(self.n_hidden + 1)
        L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])]
        param_idx += len(tril_indices[0])
        self.Sigma = np.dot(L, L.T) + 1e-6 * np.eye(self.n_hidden + 1)
        
        # sigma
        self.sigma = max(np.exp(params[param_idx]), 1e-8)

## Example usage with synthetic data

In [11]:
def generate_synthetic_data(n_tasks=20, n_samples_train=10, n_samples_test=300, 
                           n_input=10, n_hidden=1, activation='tanh'):
    # True parameters
    true_W = np.random.randn(n_hidden, n_input + 1)

    true_Sigma = np.eye(n_hidden + 1) * 0.5
    
    # Two cluster means
    true_m = np.array(
        [1.5, 0.5])
    
    
    # Generate data for each task
    train_data = []
    test_data = []
    
    for i in range(n_tasks):
        # Generate covariates
        X_train = np.random.randn(n_samples_train, n_input)
        X_test = np.random.randn(n_samples_test, n_input)
        
        # Scale per task to zero mean and unit variance
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)
        
        # Compute hidden activations
        X_train_bias = np.hstack([X_train, np.ones((n_samples_train, 1))])
        X_test_bias = np.hstack([X_test, np.ones((n_samples_test, 1))])
        
        if activation == 'tanh':
            h_train = np.tanh(np.dot(X_train_bias, true_W.T))
            h_test = np.tanh(np.dot(X_test_bias, true_W.T))
        else:
            h_train = np.dot(X_train_bias, true_W.T)
            h_test = np.dot(X_test_bias, true_W.T)
        
        # Add bias term
        h_train = np.hstack([h_train, np.ones((n_samples_train, 1))])
        h_test = np.hstack([h_test, np.ones((n_samples_test, 1))])

        # Generate task-specific weights from true distribution
        A = np.random.multivariate_normal(true_m, true_Sigma)
        
        # Generate responses with noise
        y_train = np.dot(h_train, A) + np.random.randn(n_samples_train)*0.1
        y_test = np.dot(h_test, A) + np.random.randn(n_samples_test)*0.1
        
        train_data.append((X_train, y_train))
        test_data.append((X_test, y_test))
    
    return train_data, test_data

In [12]:
def evaluate_model(model, test_data):
    # 1. Test MSE
    mse = 0
    for i, (X_test, y_test) in enumerate(test_data):
        y_pred = model.predict(X_test, i)
        mse += np.mean((y_test - y_pred)**2)
    mse /= len(test_data)
    
    return {
        'test_mse': mse,
    }

In [13]:
def run_simulations(n_simulations=10, activation='tanh'):
    results = []
    successful_simulations = 0
    
    for sim in range(n_simulations):
        print(f"\nSimulation {sim+1}/{n_simulations}")
        
        try:
            # Generate data
            train_data, test_data = generate_synthetic_data(activation=activation)
            
            # Initialize and fit model
            model = MultiTaskNN(
                n_input=10,
                n_hidden=1,
                n_tasks=20,
                activation=activation
            )

            X = []
            y = []
            for i, (X_i, y_i) in enumerate(train_data):
                X.append(X_i)
                y.append(y_i)
            
            # Fit model with error handling
            fit_result = model.fit(X, y, max_iter=100)
            if fit_result is None:
                print("Skipping simulation due to fitting error")
                continue
                
            # Evaluate
            metrics = evaluate_model(model, test_data)
            results.append(metrics)
            successful_simulations += 1
            
            print(f"Test MSE: {metrics['test_mse']:.4f}")
            
        except Exception as e:
            print(f"Error in simulation {sim+1}: {str(e)}")
            continue
    
    if successful_simulations == 0:
        print("Warning: All simulations failed")
        return None
    
    # Aggregate results
    avg_results = {
        'avg_test_mse': np.mean([r['test_mse'] for r in results]),
        'success_rate': successful_simulations / n_simulations
    }
    
    return avg_results

In [26]:
print("Testing with tanh activation:")
tanh_results = run_simulations(activation='tanh')

print("\nTesting with linear activation:")
linear_results = run_simulations(activation='linear')

print("\nFinal Results:")
print("Tanh activation:", tanh_results)
print("Linear activation:", linear_results)

Testing with tanh activation:

Simulation 1/10
Test MSE: 2.9261

Simulation 2/10
Test MSE: 2.6003

Simulation 3/10
Test MSE: 2.8722

Simulation 4/10
Test MSE: 2.8375

Simulation 5/10
Test MSE: 1.6876

Simulation 6/10
Test MSE: 3.7151

Simulation 7/10
Test MSE: 4.1093

Simulation 8/10
Test MSE: 3.7436

Simulation 9/10
Test MSE: 3.1874

Simulation 10/10
Test MSE: 3.2853

Testing with linear activation:

Simulation 1/10
Test MSE: 51.3946

Simulation 2/10
Test MSE: 11.9994

Simulation 3/10
Test MSE: 28.1998

Simulation 4/10
Test MSE: 13.0579

Simulation 5/10
Test MSE: 68.7731

Simulation 6/10
Test MSE: 48.0725

Simulation 7/10
Test MSE: 45.6132

Simulation 8/10
Test MSE: 26.2639

Simulation 9/10
Test MSE: 83.5740

Simulation 10/10
Test MSE: 22.4269

Final Results:
Tanh activation: {'avg_test_mse': 3.0964356174038157, 'success_rate': 1.0}
Linear activation: {'avg_test_mse': 39.93753318607524, 'success_rate': 1.0}


# Task-dependent Prior Mean

In [14]:
class MultiTaskNNDependentMean:
    def __init__(self, n_input, n_hidden, n_tasks, n_features, activation='tanh'):
        """
        Initialize the multi-task neural network model.
        
        Args:
            n_input: Number of input features
            n_hidden: Number of hidden units
            n_tasks: Number of tasks
            n_features: Number of task features
            activation: Activation function ('tanh' or 'linear')
        """
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_tasks = n_tasks
        self.n_features = n_features
        self.activation = activation
        
        # Initialize shared parameters with larger scale
        self.W = np.random.randn(n_hidden, n_input + 1) * 0.5
        
        # Initialize hyperparameters
        self.M = np.random.randn(n_hidden + 1, n_features) * 0.1
        self.Sigma = np.eye(n_hidden + 1)
        self.sigma = 1.0
        
        # Store MAP estimates
        self.A_map = [np.zeros(n_hidden + 1) for _ in range(n_tasks)]
    
    def _activate(self, x):
        """Apply activation function to hidden units"""
        if self.activation == 'tanh':
            return np.tanh(x)
        elif self.activation == 'linear':
            return x
        else:
            raise ValueError("Activation must be 'tanh' or 'linear'")
    
    def compute_hidden_activations(self, X):
        """Compute hidden unit activations"""
        X_bias = np.hstack([X, np.ones((X.shape[0], 1))])
        H = self._activate(np.dot(X_bias, self.W.T))
        return np.hstack([H, np.ones((H.shape[0], 1))])
    
    def predict(self, X, task_idx, task_features):
        """Make predictions for a specific task"""
        H_bias = self.compute_hidden_activations(X)
        return np.dot(H_bias, self.A_map[task_idx])
    
    def compute_sufficient_statistics(self, X, y):
        """Compute sufficient statistics for a single task"""
        H_bias = self.compute_hidden_activations(X)
        return {
            'sum_hhT': np.dot(H_bias.T, H_bias),
            'sum_hy': np.dot(H_bias.T, y),
            'sum_yy': np.dot(y, y),
            'n_samples': X.shape[0]
        }
    
    def log_likelihood(self, params, all_stats, all_task_features):
        """Compute the log likelihood with numerical stability improvements"""
        # Unpack parameters
        param_idx = 0
        
        # W
        W_size = self.n_hidden * (self.n_input + 1)
        W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1)
        param_idx += W_size

        # M: (n_hidden + 1 x n_features)
        M_size = (self.n_hidden + 1) * self.n_features
        M = params[param_idx:param_idx + M_size].reshape(self.n_hidden + 1, self.n_features)
        param_idx += M_size
        
        
        # Sigma (Cholesky decomposition)
        L = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
        tril_indices = np.tril_indices(self.n_hidden + 1)
        L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])]
        param_idx += len(tril_indices[0])
        
        # sigma (log scale)
        log_sigma = params[param_idx]
        sigma = np.exp(log_sigma)
        
        total_log_lik = 0.0
        self.A_map = []
        
        # Precompute Sigma inverse using Cholesky
        try:
            Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
            Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
        except np.linalg.LinAlgError:
            return -np.inf  # Invalid covariance matrix
            
        for stats in all_stats:
            sum_hhT = stats['sum_hhT']
            sum_hy = stats['sum_hy']
            sum_yy = stats['sum_yy']
            n_samples = stats['n_samples']

            # Compute task-dependent prior mean
            m_i = np.dot(M, task_features)
            
            # Compute Q_i using Cholesky for stability
            Q_i = (1.0 / (sigma**2)) * sum_hhT + Sigma_inv
            
            try:
                L_Q = np.linalg.cholesky(Q_i)
                Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True)
                Q_inv = np.dot(Q_inv.T, Q_inv)
            except np.linalg.LinAlgError:
                return -np.inf

            R_i = (1.0 / (sigma**2)) * sum_hy + np.dot(Sigma_inv, m_i)
            
            # Compute MAP estimate
            A_i = np.dot(Q_inv, R_i)
            self.A_map.append(A_i)
            
            # Compute log determinants efficiently
            logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q)))
            logdet_Sigma = 2 * np.sum(np.log(np.diag(L)))
            
            # Compute log likelihood terms
            term1 = -0.5 * (logdet_Sigma + n_samples * 2 * log_sigma + logdet_Q_i)
            term2 = 0.5 * (np.dot(R_i, np.dot(Q_inv, R_i)) - (1.0 / (sigma**2)) * sum_yy - np.dot(m_i, np.dot(Sigma_inv, m_i)))
            
            total_log_lik += term1 + term2
        
        return total_log_lik
    
    def fit(self, X_list, y_list, task_features_list, max_iter=100):
        """Fit the model with improved optimization"""
        # Normalize data
        X_list = [(X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-8) for X in X_list]
        y_list = [(y - np.mean(y)) / (np.std(y) + 1e-8) for y in y_list]
        
        # Compute sufficient statistics
        all_stats = [self.compute_sufficient_statistics(X, y) for X, y in zip(X_list, y_list)]
        
        # Initial parameters with better scaling
        initial_params = []
        initial_params.extend(self.W.flatten())
        initial_params.extend(self.M.flatten())
        
        L = np.linalg.cholesky(self.Sigma + 1e-6 * np.eye(self.n_hidden + 1))
        tril_indices = np.tril_indices(self.n_hidden + 1)
        initial_params.extend(L[tril_indices])
        
        initial_params.append(np.log(self.sigma))
        
        # Optimize with bounds for stability
        bounds = []
        
        # W - no bounds
        bounds.extend([(None, None)] * (self.n_hidden * (self.n_input + 1)))
        
        # M - no bounds
        bounds.extend([(None, None)] * ((self.n_hidden + 1) * self.n_features))
        
        # L - diagonal elements must be positive
        for i in range(len(tril_indices[0])):
            if tril_indices[0][i] == tril_indices[1][i]:  # diagonal
                bounds.append((1e-8, None))
            else:
                bounds.append((None, None))
                
        # log_sigma must be > log(1e-8)
        bounds.append((np.log(1e-8), None))
        
        # Optimize
        result = minimize(
            lambda p: -self.log_likelihood(p, all_stats, task_features_list),
            initial_params,
            method='L-BFGS-B',
            bounds=bounds,
            options={'maxiter': max_iter, 'disp': True}
        )
        
        # Store optimized parameters
        self._unpack_parameters(result.x)
        
        # Recompute MAP estimates
        _ = self.log_likelihood(result.x, all_stats, task_features_list)
        
        return result
    
    def _unpack_parameters(self, params):
        """Helper to unpack optimized parameters"""
        param_idx = 0
        
        # W
        W_size = self.n_hidden * (self.n_input + 1)
        self.W = params[param_idx:param_idx + W_size].reshape(self.n_hidden, self.n_input + 1)
        param_idx += W_size
        
        # M: (n_hidden + 1 x n_features)
        M_size = (self.n_hidden + 1) * self.n_features
        M = params[param_idx:param_idx + M_size].reshape(self.n_hidden + 1, self.n_features)
        param_idx += M_size
        
        # Sigma (Cholesky)
        L = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
        tril_indices = np.tril_indices(self.n_hidden + 1)
        L[tril_indices] = params[param_idx:param_idx + len(tril_indices[0])]
        param_idx += len(tril_indices[0])
        self.Sigma = np.dot(L, L.T)
        
        # sigma
        self.sigma = np.exp(params[param_idx])

## Test correctness

In [15]:
# Generate synthetic data with more meaningful relationships
np.random.seed(42)
n_input = 5
n_hidden = 4
n_tasks = 10
n_features = 3
n_samples = 200

# True parameters with stronger relationships
true_W = np.random.randn(n_hidden, n_input + 1) * 0.7
true_M = np.random.randn(n_hidden + 1, n_features) * 0.7

X_list = []
y_list = []
task_features_list = []

for i in range(n_tasks):
    # Generate meaningful task features
    task_features = np.random.randn(n_features) * 2
    task_features_list.append(task_features)
    
    X = np.random.randn(n_samples, n_input)
    X_bias = np.hstack([X, np.ones((n_samples, 1))])
    H = np.tanh(np.dot(X_bias, true_W.T))
    H_bias = np.hstack([H, np.ones((n_samples, 1))])
    
    # Compute task mean from features
    task_mean = np.dot(true_M, task_features)
    
    # Generate task weights
    true_A = np.random.multivariate_normal(task_mean, np.eye(n_hidden+1)*0.1)
    y = np.dot(H_bias, true_A) + np.random.randn(n_samples)*0.1
    
    X_list.append(X)
    y_list.append(y)

In [16]:
# Create and fit model
model = MultiTaskNNDependentMean(n_input=n_input, n_hidden=n_hidden,
                                     n_tasks=n_tasks, n_features=n_features,
                                     activation='tanh')
result = model.fit(X_list, y_list, task_features_list, max_iter=200)

print("Optimization success:", result.success)
print("Final sigma:", model.sigma)

# Evaluate
X_test = np.random.randn(10, n_input)
y_pred = model.predict(X_test, task_idx=0, task_features=task_features_list[0])
print("Sample predictions:", y_pred[:5])

Optimization success: True
Final sigma: 0.7124352023358816
Sample predictions: [-0.47234402 -0.32747134  0.19923604  0.87527359 -0.43738474]


  df = fun(x1) - f0


# Clustering of Tasks

In [17]:
class MultiTaskNNClustering:
    def __init__(self, n_input, n_hidden, n_tasks, n_clusters, activation='tanh'):
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_tasks = n_tasks
        self.n_clusters = n_clusters
        self.activation = activation
        
        # Initialize with larger scale and better conditioning
        self.W = np.random.randn(n_hidden, n_input + 1) * 0.5
        self.q = np.ones(n_clusters) / n_clusters
        self.m = np.random.randn(n_clusters, n_hidden + 1) * 0.5
        
        # Initialize Sigma with larger diagonal for numerical stability
        self.Sigma = np.array([np.eye(n_hidden + 1) * 0.5 for _ in range(n_clusters)])
        self.sigma = 1.0
        
        self.A = np.zeros((n_tasks, n_hidden + 1))
        self.z = np.zeros((n_tasks, n_clusters))
    
    def _activate(self, x):
        if self.activation == 'tanh':
            return np.tanh(x)
        elif self.activation == 'linear':
            return x
        else:
            raise ValueError("Unknown activation function")
    
    def compute_hidden(self, X):
        X_bias = np.hstack([X, np.ones((X.shape[0], 1))])
        h = self._activate(np.dot(X_bias, self.W.T))
        return np.hstack([h, np.ones((h.shape[0], 1))])
    
    def predict(self, X, task_idx):
        h = self.compute_hidden(X)
        return np.dot(h, self.A[task_idx])
    
    def _compute_task_log_likelihood(self, X_i, y_i, cluster_idx):
        n_i = len(y_i)
        h_i = self.compute_hidden(X_i)
        
        # Add small constant to avoid division by zero
        sigma_sq = max(self.sigma**2, 1e-8)
        
        try:
            # Use Cholesky decomposition for numerical stability
            L = cholesky(self.Sigma[cluster_idx], lower=True)
            Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
            Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
            
            Q_i = (1/sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv
            L_Q = cholesky(Q_i, lower=True)
            Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden + 1), lower=True)
            Q_inv = np.dot(Q_inv.T, Q_inv)
            
            R_i = (1/sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx])
            
            # Compute log determinants efficiently
            logdet_Sigma = 2 * np.sum(np.log(np.diag(L)))
            logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q)))
            
            term1 = -0.5 * (logdet_Sigma + n_i * np.log(sigma_sq) + logdet_Q_i)
            term2 = 0.5 * (np.dot(R_i.T, np.dot(Q_inv, R_i)) - (1/(2*sigma_sq)) * np.sum(y_i**2) - np.dot(self.m[cluster_idx].T, np.dot(Sigma_inv, self.m[cluster_idx])))
            
            return term1 + term2
            
        except np.linalg.LinAlgError:
            # Return -inf if matrix is not positive definite
            return -np.inf
    
    def e_step(self, data):
        log_responsibilities = np.zeros((self.n_tasks, self.n_clusters))
        
        for i, (X_i, y_i) in enumerate(data):
            for alpha in range(self.n_clusters):
                log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
                log_responsibilities[i, alpha] = np.log(self.q[alpha] + 1e-8) + log_lik
            
            # Normalize using logsumexp for numerical stability
            log_responsibilities[i] -= logsumexp(log_responsibilities[i])
        
        self.z = np.exp(log_responsibilities)
    
    def m_step(self, data):
        def objective(params):
            W = params[:self.n_hidden * (self.n_input + 1)].reshape(self.n_hidden, self.n_input + 1)
            log_sigma = params[-1]
            sigma = np.exp(log_sigma)
            
            self.W = W
            self.sigma = max(sigma, 1e-8)  # Prevent sigma from becoming too small
            
            total_log_lik = 0.0
            for i, (X_i, y_i) in enumerate(data):
                for alpha in range(self.n_clusters):
                    log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
                    total_log_lik += self.z[i, alpha] * log_lik
            
            return -total_log_lik if np.isfinite(total_log_lik) else np.inf
        
        # Initial parameters with bounds
        initial_params = np.concatenate([
            self.W.flatten(),
            [np.log(self.sigma)]
        ])
        
        # Add bounds for sigma (log_sigma > log(1e-8))
        bounds = [(None, None)] * len(initial_params)
        bounds[-1] = (np.log(1e-8), None)
        
        result = minimize(
            objective,
            initial_params,
            method='L-BFGS-B',
            bounds=bounds,
            options={'maxiter': 50, 'disp': True}
        )
        
        opt_params = result.x
        W_size = self.n_hidden * (self.n_input + 1)
        self.W = opt_params[:W_size].reshape(self.n_hidden, self.n_input + 1)
        self.sigma = max(np.exp(opt_params[-1]), 1e-8)
        
        # Update cluster parameters with regularization
        for alpha in range(self.n_clusters):
            self.q[alpha] = max(np.sum(self.z[:, alpha]) / self.n_tasks, 1e-8)
            
            sum_z = np.sum(self.z[:, alpha])
            if sum_z > 1e-8:
                weighted_R = np.zeros(self.n_hidden + 1)
                weighted_Q = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
                
                for i, (X_i, y_i) in enumerate(data):
                    h_i = self.compute_hidden(X_i)
                    L = cholesky(self.Sigma[alpha] + 1e-6*np.eye(self.n_hidden + 1), lower=True)
                    Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
                    Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
                    
                    Q_i = (1/max(self.sigma**2, 1e-8)) * np.dot(h_i.T, h_i) + Sigma_inv
                    R_i = (1/max(self.sigma**2, 1e-8)) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[alpha])
                    
                    weighted_R += self.z[i, alpha] * R_i
                    weighted_Q += self.z[i, alpha] * Q_i
                
                try:
                    self.m[alpha] = np.linalg.solve(weighted_Q + 1e-6*np.eye(self.n_hidden + 1), weighted_R)
                except:
                    pass
                
                # Update covariance with regularization
                weighted_cov = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
                for i, (X_i, y_i) in enumerate(data):
                    h_i = self.compute_hidden(X_i)
                    A_i = self._compute_map_estimate(X_i, y_i, alpha)
                    diff = A_i - self.m[alpha]
                    weighted_cov += self.z[i, alpha] * np.outer(diff, diff)
                
                self.Sigma[alpha] = weighted_cov / sum_z + 1e-6 * np.eye(self.n_hidden + 1)
    
    def _compute_map_estimate(self, X_i, y_i, cluster_idx):
        h_i = self.compute_hidden(X_i)
        sigma_sq = max(self.sigma**2, 1e-8)
        
        L = cholesky(self.Sigma[cluster_idx] + 1e-6*np.eye(self.n_hidden + 1), lower=True)
        Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
        Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
        
        Q_i = (1/sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv
        R_i = (1/sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx])
        
        return np.linalg.solve(Q_i + 1e-6*np.eye(self.n_hidden + 1), R_i)
    
    def fit(self, data, max_iter=100, tol=1e-4):
        prev_log_lik = -np.inf
        
        for iteration in tqdm((range(max_iter))):
            self.e_step(data)
            self.m_step(data)
            
            # Compute current log likelihood
            current_log_lik = 0.0
            for i, (X_i, y_i) in enumerate(data):
                cluster_log_liks = []
                for alpha in range(self.n_clusters):
                    log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
                    cluster_log_liks.append(np.log(self.q[alpha] + 1e-8) + log_lik)
                current_log_lik += logsumexp(cluster_log_liks)
            
            if np.isnan(current_log_lik):
                print("Warning: log likelihood is nan, stopping early")
                break
                
            if iteration > 0 and np.abs(current_log_lik - prev_log_lik) < tol:
                print(f"Converged at iteration {iteration}")
                break
                
            prev_log_lik = current_log_lik
            #if iteration % 10 == 0:
            #    print(f"Iteration {iteration}, log likelihood: {current_log_lik}")
        
        self._compute_final_weights(data)
    
    def _compute_final_weights(self, data):
        for i, (X_i, y_i) in enumerate(data):
            most_likely_cluster = np.argmax(self.z[i])
            self.A[i] = self._compute_map_estimate(X_i, y_i, most_likely_cluster)
    
    def get_cluster_assignments(self):
        return np.argmax(self.z, axis=1)
    
    def get_task_similarity(self):
        assignments = self.get_cluster_assignments()
        return np.array([[1.0 if a == b else 0.0 for b in assignments] for a in assignments])

## Example usage with synthetic data

In [36]:
def generate_synthetic_data(n_tasks=20, n_samples_train=10, n_samples_test=300, 
                           n_input=10, n_hidden=1, n_clusters=2, activation='tanh'):
    # True parameters
    true_W = np.random.randn(n_hidden, n_input + 1)
    
    # Two cluster means
    true_m = np.array([
        [1.5, 0.5],  # Cluster 0 (bias term added)
        [-1.5, -0.5]  # Cluster 1 (bias term added)
    ])
    
    # Generate tasks - assign half to each cluster
    cluster_assignments = np.zeros(n_tasks)
    cluster_assignments[n_tasks//2:] = 1
    
    # Generate data for each task
    train_data = []
    test_data = []
    
    for i in range(n_tasks):
        # Generate covariates
        X_train = np.random.randn(n_samples_train, n_input)
        X_test = np.random.randn(n_samples_test, n_input)
        
        # Scale per task to zero mean and unit variance
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)
        
        # Compute hidden activations
        X_train_bias = np.hstack([X_train, np.ones((n_samples_train, 1))])
        X_test_bias = np.hstack([X_test, np.ones((n_samples_test, 1))])
        
        if activation == 'tanh':
            h_train = np.tanh(np.dot(X_train_bias, true_W.T))
            h_test = np.tanh(np.dot(X_test_bias, true_W.T))
        else:
            h_train = np.dot(X_train_bias, true_W.T)
            h_test = np.dot(X_test_bias, true_W.T)
        
        # Add bias term
        h_train = np.hstack([h_train, np.ones((n_samples_train, 1))])
        h_test = np.hstack([h_test, np.ones((n_samples_test, 1))])
        
        # Get cluster for this task
        cluster = int(cluster_assignments[i])
        A_i = np.random.multivariate_normal(true_m[cluster], np.eye(2)*0.1)
        
        # Generate responses with noise
        y_train = np.dot(h_train, A_i) + np.random.randn(n_samples_train)*0.1
        y_test = np.dot(h_test, A_i) + np.random.randn(n_samples_test)*0.1
        
        train_data.append((X_train, y_train))
        test_data.append((X_test, y_test))
    
    return train_data, test_data, cluster_assignments

In [37]:
def evaluate_model(model, test_data, true_clusters):
    # 1. Cluster assignment accuracy
    pred_clusters = model.get_cluster_assignments()
    cluster_acc = np.mean(pred_clusters == true_clusters)
    
    # 2. Test MSE
    mse = 0
    for i, (X_test, y_test) in enumerate(test_data):
        y_pred = model.predict(X_test, i)
        mse += np.mean((y_test - y_pred)**2)
    mse /= len(test_data)
    
    # 3. Task similarity matrix accuracy
    true_similarity = np.array([[1.0 if a == b else 0.0 for b in true_clusters] 
                               for a in true_clusters])
    pred_similarity = model.get_task_similarity()
    similarity_acc = np.mean(true_similarity == pred_similarity)
    
    return {
        'cluster_accuracy': cluster_acc,
        'test_mse': mse,
        'similarity_accuracy': similarity_acc
    }

In [38]:
def run_simulations(n_simulations=10, activation='tanh'):
    results = []
    
    for sim in range(n_simulations):
        print(f"\nSimulation {sim+1}/{n_simulations}")
        
        # Generate data
        train_data, test_data, true_clusters = generate_synthetic_data(activation=activation)
        
        # Initialize and fit model
        model = MultiTaskNNClustering(
            n_input=10,
            n_hidden=1,
            n_tasks=20,
            n_clusters=2,
            activation=activation
        )
        
        model.fit(train_data, max_iter=100)
        print('yes')
        
        # Evaluate
        metrics = evaluate_model(model, test_data, true_clusters)
        print('yes')
        results.append(metrics)
        
        print(f"Cluster accuracy: {metrics['cluster_accuracy']:.3f}")
        print(f"Test MSE: {metrics['test_mse']:.4f}")
        print(f"Similarity accuracy: {metrics['similarity_accuracy']:.3f}")
    
    # Aggregate results
    avg_results = {
        'avg_cluster_accuracy': np.mean([r['cluster_accuracy'] for r in results]),
        'avg_test_mse': np.mean([r['test_mse'] for r in results]),
        'avg_similarity_accuracy': np.mean([r['similarity_accuracy'] for r in results])
    }
    
    return avg_results

In [39]:
print("Testing with tanh activation:")
tanh_results = run_simulations(activation='tanh')

print("\nTesting with linear activation:")
linear_results = run_simulations(activation='linear')

print("\nFinal Results:")
print("Tanh activation:", tanh_results)
print("Linear activation:", linear_results)

Testing with tanh activation:

Simulation 1/10


  sigma = np.exp(log_sigma)
  df = fun(x1) - f0
100%|██████████| 100/100 [01:08<00:00,  1.46it/s]


yes
yes
Cluster accuracy: 0.450
Test MSE: 2.2244
Similarity accuracy: 0.505

Simulation 2/10


 35%|███▌      | 35/100 [00:31<00:59,  1.10it/s]


Converged at iteration 35
yes
yes
Cluster accuracy: 0.950
Test MSE: 0.0128
Similarity accuracy: 0.905

Simulation 3/10


  sigma = np.exp(log_sigma)
  df = fun(x1) - f0
 23%|██▎       | 23/100 [00:41<02:18,  1.80s/it]


Converged at iteration 23
yes
yes
Cluster accuracy: 0.000
Test MSE: 0.0127
Similarity accuracy: 1.000

Simulation 4/10


 13%|█▎        | 13/100 [00:34<03:53,  2.69s/it]


Converged at iteration 13
yes
yes
Cluster accuracy: 1.000
Test MSE: 0.0136
Similarity accuracy: 1.000

Simulation 5/10


 61%|██████    | 61/100 [00:34<00:22,  1.74it/s]


Converged at iteration 61
yes
yes
Cluster accuracy: 0.550
Test MSE: 0.0130
Similarity accuracy: 0.505

Simulation 6/10


  sigma = np.exp(log_sigma)
  df = fun(x1) - f0
100%|██████████| 100/100 [00:44<00:00,  2.22it/s]


yes
yes
Cluster accuracy: 0.600
Test MSE: 1.0080
Similarity accuracy: 0.520

Simulation 7/10


100%|██████████| 100/100 [00:39<00:00,  2.51it/s]


yes
yes
Cluster accuracy: 1.000
Test MSE: 0.6215
Similarity accuracy: 1.000

Simulation 8/10


 28%|██▊       | 28/100 [00:20<00:53,  1.36it/s]


Converged at iteration 28
yes
yes
Cluster accuracy: 0.950
Test MSE: 0.0141
Similarity accuracy: 0.905

Simulation 9/10


 11%|█         | 11/100 [00:16<02:12,  1.48s/it]


Converged at iteration 11
yes
yes
Cluster accuracy: 1.000
Test MSE: 0.0150
Similarity accuracy: 1.000

Simulation 10/10


  sigma = np.exp(log_sigma)
  df = fun(x1) - f0
 34%|███▍      | 34/100 [00:23<00:45,  1.45it/s]


Converged at iteration 34
yes
yes
Cluster accuracy: 0.850
Test MSE: 0.0143
Similarity accuracy: 0.745

Testing with linear activation:

Simulation 1/10


 11%|█         | 11/100 [00:13<01:52,  1.26s/it]


Converged at iteration 11
yes
yes
Cluster accuracy: 1.000
Test MSE: 0.0148
Similarity accuracy: 1.000

Simulation 2/10


  sigma = np.exp(log_sigma)
  df = fun(x1) - f0
  8%|▊         | 8/100 [00:07<01:30,  1.01it/s]


Converged at iteration 8
yes
yes
Cluster accuracy: 1.000
Test MSE: 0.0139
Similarity accuracy: 1.000

Simulation 3/10


  sigma = np.exp(log_sigma)
  df = fun(x1) - f0
 11%|█         | 11/100 [00:07<01:00,  1.47it/s]


Converged at iteration 11
yes
yes
Cluster accuracy: 1.000
Test MSE: 0.0151
Similarity accuracy: 1.000

Simulation 4/10


 45%|████▌     | 45/100 [00:29<00:36,  1.51it/s]


Converged at iteration 45
yes
yes
Cluster accuracy: 0.150
Test MSE: 0.0147
Similarity accuracy: 0.745

Simulation 5/10


 24%|██▍       | 24/100 [00:17<00:56,  1.34it/s]


Converged at iteration 24
yes
yes
Cluster accuracy: 0.700
Test MSE: 0.0128
Similarity accuracy: 0.580

Simulation 6/10


  9%|▉         | 9/100 [00:09<01:31,  1.01s/it]


Converged at iteration 9
yes
yes
Cluster accuracy: 0.000
Test MSE: 0.0128
Similarity accuracy: 1.000

Simulation 7/10


  sigma = np.exp(log_sigma)
  df = fun(x1) - f0
  8%|▊         | 8/100 [00:06<01:18,  1.17it/s]


Converged at iteration 8
yes
yes
Cluster accuracy: 0.000
Test MSE: 0.0140
Similarity accuracy: 1.000

Simulation 8/10


 13%|█▎        | 13/100 [00:11<01:14,  1.17it/s]


Converged at iteration 13
yes
yes
Cluster accuracy: 0.000
Test MSE: 0.0145
Similarity accuracy: 1.000

Simulation 9/10


 38%|███▊      | 38/100 [00:30<00:48,  1.27it/s]


Converged at iteration 38
yes
yes
Cluster accuracy: 0.100
Test MSE: 0.0129
Similarity accuracy: 0.820

Simulation 10/10


 20%|██        | 20/100 [00:18<01:13,  1.09it/s]

Converged at iteration 20
yes
yes
Cluster accuracy: 1.000
Test MSE: 0.0123
Similarity accuracy: 1.000

Final Results:
Tanh activation: {'avg_cluster_accuracy': 0.735, 'avg_test_mse': 0.39494072882681475, 'avg_similarity_accuracy': 0.8084999999999999}
Linear activation: {'avg_cluster_accuracy': 0.49499999999999994, 'avg_test_mse': 0.013777777418109297, 'avg_similarity_accuracy': 0.9145}





# Gating of Tasks

In [19]:
class MultiTaskNNGating:
    def __init__(self, n_input, n_hidden, n_tasks, n_clusters, n_features, activation='tanh'):
        self.n_input = n_input
        self.n_hidden = n_hidden
        self.n_tasks = n_tasks
        self.n_clusters = n_clusters
        self.n_features = n_features
        self.activation = activation
        
        # Initialize with larger scale for better convergence
        self.W = np.random.randn(n_hidden, n_input + 1) * 0.5
        self.U = np.random.randn(n_clusters, n_features) * 0.5
        self.m = np.random.randn(n_clusters, n_hidden + 1) * 0.5
        
        # Initialize covariance matrices with larger diagonal
        self.Sigma = np.array([np.eye(n_hidden + 1) * 0.5 for _ in range(n_clusters)])
        self.sigma = 1.0
        
        self.A = np.zeros((n_tasks, n_hidden + 1))
        self.z = np.zeros((n_tasks, n_clusters))
    
    def _activate(self, x):
        if self.activation == 'tanh':
            return np.tanh(x)
        elif self.activation == 'linear'
            return x
        else:
            raise ValueError("Unknown activation function")
    
    def compute_hidden(self, X):
        X_bias = np.hstack([X, np.ones((X.shape[0], 1))])
        h = self._activate(np.dot(X_bias, self.W.T))
        return np.hstack([h, np.ones((h.shape[0], 1))])
    
    def compute_gating_probabilities(self, F):
        """Compute task-cluster assignment probabilities with numerical stability"""
        # Ensure F is 2D array
        F = np.atleast_2d(F)
        if F.shape[0] == 1 and self.n_tasks > 1:
            F = np.repeat(F, self.n_tasks, axis=0)
            
        logits = np.dot(F, self.U.T)
        return softmax(logits, axis=1)
    
    def predict(self, X, task_idx, task_features=None):
        h = self.compute_hidden(X)
        if task_features is not None:
            # If task features provided, use gating to determine cluster
            probs = self.compute_gating_probabilities(task_features)
            cluster = np.argmax(probs)
            return np.dot(h, self._compute_map_estimate(X, np.zeros(len(X)), cluster))
        return np.dot(h, self.A[task_idx])
    
    def _compute_task_log_likelihood(self, X_i, y_i, cluster_idx):
        n_i = len(y_i)
        h_i = self.compute_hidden(X_i)
        
        # Add small constant to avoid division by zero
        sigma_sq = max(self.sigma**2, 1e-8)
        
        try:
            # Use Cholesky decomposition for numerical stability
            L = cholesky(self.Sigma[cluster_idx] + 1e-6*np.eye(self.n_hidden+1), lower=True)
            Sigma_inv = solve_triangular(L, np.eye(self.n_hidden+1), lower=True)
            Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
            
            Q_i = (1/sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv
            L_Q = cholesky(Q_i, lower=True)
            Q_inv = solve_triangular(L_Q, np.eye(self.n_hidden+1), lower=True)
            Q_inv = np.dot(Q_inv.T, Q_inv)
            
            R_i = (1/sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx])
            
            # Compute log determinants
            logdet_Sigma = 2 * np.sum(np.log(np.diag(L)))
            logdet_Q_i = 2 * np.sum(np.log(np.diag(L_Q)))
            
            term1 = -0.5 * (logdet_Sigma + n_i * np.log(sigma_sq) + logdet_Q_i)
            term2 = 0.5 * (np.dot(R_i.T, np.dot(Q_inv, R_i)) - (1/(2*sigma_sq)) * np.sum(y_i**2) - np.dot(self.m[cluster_idx].T, np.dot(Sigma_inv, self.m[cluster_idx])))
            
            return term1 + term2
            
        except np.linalg.LinAlgError:
            return -np.inf
    
    def e_step(self, data, task_features):
        """Expectation step with improved numerical stability"""
        # Ensure task_features is 2D array
        task_features = np.atleast_2d(task_features)
        if task_features.shape[0] == 1 and self.n_tasks > 1:
            task_features = np.repeat(task_features, self.n_tasks, axis=0)
        
        q = self.compute_gating_probabilities(task_features)
        log_responsibilities = np.zeros((self.n_tasks, self.n_clusters))
        
        for i, (X_i, y_i) in enumerate(data):
            for alpha in range(self.n_clusters):
                log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
                log_responsibilities[i, alpha] = np.log(q[i, alpha] + 1e-8) + log_lik
            
            # Normalize using logsumexp
            log_responsibilities[i] -= logsumexp(log_responsibilities[i])
        
        self.z = np.exp(log_responsibilities)
    
    def m_step(self, data, task_features):
        """Maximization step with regularization"""
        # Optimize W and sigma
        def objective(params):
            W = params[:self.n_hidden*(self.n_input+1)].reshape(self.n_hidden, self.n_input+1)
            log_sigma = params[-1]
            sigma = np.exp(log_sigma)
            
            self.W = W
            self.sigma = max(sigma, 1e-8)
            
            total_log_lik = 0.0
            for i, (X_i, y_i) in enumerate(data):
                for alpha in range(self.n_clusters):
                    log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
                    total_log_lik += self.z[i, alpha] * log_lik
            
            return -total_log_lik if np.isfinite(total_log_lik) else np.inf
        
        # Initial parameters with bounds
        initial_params = np.concatenate([
            self.W.flatten(),
            [np.log(self.sigma)]
        ])
        
        bounds = [(None, None)] * len(initial_params)
        bounds[-1] = (np.log(1e-8), None)  # sigma > 1e-8
        
        result = minimize(
            objective,
            initial_params,
            method='L-BFGS-B',
            bounds=bounds,
            options={'maxiter': 50, 'disp': True}
        )
        
        # Update parameters
        opt_params = result.x
        W_size = self.n_hidden * (self.n_input + 1)
        self.W = opt_params[:W_size].reshape(self.n_hidden, self.n_input + 1)
        self.sigma = max(np.exp(opt_params[-1]), 1e-8)
        
        # Update cluster parameters with regularization
        for alpha in range(self.n_clusters):
            sum_z = np.sum(self.z[:, alpha])
            if sum_z > 1e-8:
                # Update m_α
                weighted_R = np.zeros(self.n_hidden + 1)
                weighted_Q = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
                
                for i, (X_i, y_i) in enumerate(data):
                    h_i = self.compute_hidden(X_i)
                    L = cholesky(self.Sigma[alpha] + 1e-6*np.eye(self.n_hidden+1), lower=True)
                    Sigma_inv = solve_triangular(L, np.eye(self.n_hidden+1), lower=True)
                    Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
                    
                    Q_i = (1/max(self.sigma**2, 1e-8)) * np.dot(h_i.T, h_i) + Sigma_inv
                    R_i = (1/max(self.sigma**2, 1e-8)) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[alpha])
                    
                    weighted_R += self.z[i, alpha] * R_i
                    weighted_Q += self.z[i, alpha] * Q_i
                
                try:
                    self.m[alpha] = np.linalg.solve(weighted_Q + 1e-6*np.eye(self.n_hidden+1), weighted_R)
                except:
                    pass
                
                # Update Σ_α with regularization
                weighted_cov = np.zeros((self.n_hidden + 1, self.n_hidden + 1))
                for i, (X_i, y_i) in enumerate(data):
                    h_i = self.compute_hidden(X_i)
                    A_i = self._compute_map_estimate(X_i, y_i, alpha)
                    diff = A_i - self.m[alpha]
                    weighted_cov += self.z[i, alpha] * np.outer(diff, diff)
                
                self.Sigma[alpha] = weighted_cov / sum_z + 1e-6 * np.eye(self.n_hidden + 1)
        
        # Update gating parameters U
        if self.n_clusters > 1:
            task_features = np.atleast_2d(task_features)
            if task_features.shape[0] == 1 and self.n_tasks > 1:
                task_features = np.repeat(task_features, self.n_tasks, axis=0)
                
            lr = LogisticRegression(
                multi_class='multinomial',
                solver='lbfgs',
                fit_intercept=False,
                max_iter=100,
                penalty='l2',
                C=1.0
            )
            try:
                lr.fit(task_features, self.get_cluster_assignments(), sample_weight=np.max(self.z, axis=1))
                self.U = lr.coef_
            except:
                pass

    def _compute_map_estimate(self, X_i, y_i, cluster_idx):
        h_i = self.compute_hidden(X_i)
        sigma_sq = max(self.sigma**2, 1e-8)
        
        L = cholesky(self.Sigma[cluster_idx] + 1e-6*np.eye(self.n_hidden + 1), lower=True)
        Sigma_inv = solve_triangular(L, np.eye(self.n_hidden + 1), lower=True)
        Sigma_inv = np.dot(Sigma_inv.T, Sigma_inv)
        
        Q_i = (1/sigma_sq) * np.dot(h_i.T, h_i) + Sigma_inv
        R_i = (1/sigma_sq) * np.dot(h_i.T, y_i) + np.dot(Sigma_inv, self.m[cluster_idx])
        
        return np.linalg.solve(Q_i + 1e-6*np.eye(self.n_hidden + 1), R_i)
    
    def fit(self, data, task_features, max_iter=100, tol=1e-4):
        """Improved fitting with better initialization and checks"""
        prev_log_lik = -np.inf
        
        # Normalize task features
        task_features = np.atleast_2d(task_features)
        if task_features.shape[0] == 1 and self.n_tasks > 1:
            task_features = np.repeat(task_features, self.n_tasks, axis=0)
        
        self.task_feature_mean = np.mean(task_features, axis=0)
        self.task_feature_std = np.std(task_features, axis=0) + 1e-8
        task_features = (task_features - self.task_feature_mean) / self.task_feature_std
        
        for iteration in range(max_iter):
            try:
                self.e_step(data, task_features)
                self.m_step(data, task_features)
                
                # Compute current log likelihood
                current_log_lik = 0.0
                q = self.compute_gating_probabilities(task_features)
                
                for i, (X_i, y_i) in enumerate(data):
                    cluster_log_liks = []
                    for alpha in range(self.n_clusters):
                        log_lik = self._compute_task_log_likelihood(X_i, y_i, alpha)
                        cluster_log_liks.append(np.log(q[i, alpha] + 1e-8) + log_lik)
                    current_log_lik += logsumexp(cluster_log_liks)
                
                if np.isnan(current_log_lik):
                    print("Warning: log likelihood is nan, stopping early")
                    break
                    
                if iteration > 0 and abs(current_log_lik - prev_log_lik) < tol:
                    print(f"Converged at iteration {iteration}")
                    break
                    
                prev_log_lik = current_log_lik
                print(f"Iteration {iteration}, log likelihood: {current_log_lik}")
                
            except Exception as e:
                print(f"Error at iteration {iteration}: {str(e)}")
                break
        
        self._compute_final_weights(data)
        return self
    
    def _compute_final_weights(self, data):
        for i, (X_i, y_i) in enumerate(data):
            most_likely_cluster = np.argmax(self.z[i])
            self.A[i] = self._compute_map_estimate(X_i, y_i, most_likely_cluster)
    
    def get_cluster_assignments(self):
        return np.argmax(self.z, axis=1)
    
    def get_task_similarity(self):
        assignments = self.get_cluster_assignments()
        return np.array([[1.0 if a == b else 0.0 for b in assignments] for a in assignments])

## Test correctness

In [21]:
# Generate synthetic data with clear cluster structure
n_input = 5
n_hidden = 4  # Increased hidden units
n_tasks = 10
n_clusters = 2
n_features = 3
n_samples = 200

# Create task features with clear separation
task_features = np.zeros((n_tasks, n_features))
for i in range(n_tasks):
    if i < n_tasks // 2:
        task_features[i] = [1.0, -0.5, 0.7] + np.random.randn(n_features)*0.1
    else:
        task_features[i] = [-0.8, 1.2, -0.3] + np.random.randn(n_features)*0.1

data = []
for i in range(n_tasks):
    X = np.random.randn(n_samples, n_input)
    if i < n_tasks // 2:  # First cluster
        y = 2.5 * X[:, 0] - 1.5 * X[:, 1] + np.random.randn(n_samples) * 0.1
    else:  # Second cluster
        y = -2.0 * X[:, 0] + 3.0 * X[:, 2] + np.random.randn(n_samples) * 0.1
    data.append((X, y))

# Initialize and fit model
model = MultiTaskNNGating(
    n_input=n_input,
    n_hidden=n_hidden,
    n_tasks=n_tasks,
    n_clusters=n_clusters,
    n_features=n_features,
    activation='tanh'
)
model.fit(data, task_features, max_iter=100)

# Check results
assignments = model.get_cluster_assignments()
print("Cluster assignments:", assignments)
print("True clusters:", [0 if i < n_tasks//2 else 1 for i in range(n_tasks)])

# Make predictions
X_test = np.random.randn(5, n_input)
for task_idx in [0, n_tasks//2]:
    preds = model.predict(X_test, task_idx)
    print(f"Task {task_idx} (cluster {assignments[task_idx]}) predictions:", preds[:3])

  sigma = np.exp(log_sigma)
  df = fun(x1) - f0


Error at iteration 0: index 1 is out of bounds for axis 1 with size 1
Cluster assignments: [0 0 0 0 0 1 1 1 1 1]
True clusters: [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
Task 0 (cluster 0) predictions: [2.14593525 2.08323812 0.43084465]
Task 5 (cluster 1) predictions: [-3.72567776  1.19625966  3.69211654]
