In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from federated_inference.simulations.isolated.configs.data_config import DataConfiguration
from federated_inference.simulations.isolated.configs.transform_config import DataTransformConfiguration
from federated_inference.simulations.isolated.configs.model_config import ModelConfiguration
from federated_inference.simulations.simulation import Simulation
from federated_inference.simulations.isolated.models.IsolatedMnistModel import IsolatedMNISTModel
from federated_inference.simulations.isolated.models.IsolatedFmnistModel import IsolatedFMNISTModel
from federated_inference.simulations.isolated.simulation import IsolatedVerticalSimulation

In [None]:
from federated_inference.simulations.naive.configs.data_config import DataConfiguration
from federated_inference.simulations.naive.configs.transform_config import DataTransformConfiguration
from federated_inference.simulations.naive.configs.model_config import ModelConfiguration
from federated_inference.simulations.naive.models.NaiveMnistModel import NaiveMNISTModel
from federated_inference.simulations.naive.models.NaiveFmnistModel import NaiveFMNISTModel
from federated_inference.simulations.naive.simulation import NaiveVerticalSimulation

In [None]:
import torch 
import random 
import numpy as np

    
def set_seed(seed=42):
    torch.manual_seed(seed)                # CPU
    torch.cuda.manual_seed(seed)           # Current GPU
    torch.cuda.manual_seed_all(seed)       # All GPUs
    np.random.seed(seed)                   # NumPy
    random.seed(seed)                      # Python random
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    


In [None]:
naive_results = []
isolated_results = []
if __name__ == "__main__":
    seeds = [3,4,5]
    for seed in seeds:
        set_seed(seed)
        from federated_inference.simulations.isolated.configs.data_config import DataConfiguration
        data_config = DataConfiguration('FMNIST')
        transform_config = DataTransformConfiguration()
        simulation = NaiveVerticalSimulation(seed, data_config, transform_config, NaiveFMNISTModel, exist = False)
        simulation.train()
        simulation.test_inference()
        result = simulation.collect_results(seed, save = True)
        naive_results.append(result)
        from federated_inference.simulations.naive.configs.data_config import DataConfiguration
        data_config = DataConfiguration('FMNIST')
        simulation = IsolatedVerticalSimulation(data_config, transform_config, seed, IsolatedMNISTModel, exist=False)
        simulation.train()
        simulation.test()
        result = simulation.collect_results(seed, save = True, figures=True)
        isolated_results.append(result)


In [None]:
import plotly.graph_objects as go
class NaiveIsolatedExperiment():
    def __init__(self):
        self.isolated_results = []
        self.naive_results = []
    def _simulation_results(self, name = "MNIST", models =[NaiveMNISTModel, IsolatedMNISTModel]):  
        seeds = [1,2,3,4,5]
        self.name = name
        for seed in seeds:
            set_seed(seed)
            from federated_inference.simulations.naive.configs.data_config import DataConfiguration
            from federated_inference.simulations.naive.configs.transform_config import DataTransformConfiguration
            data_config = DataConfiguration(name)
            transform_config = DataTransformConfiguration()
            simulation = NaiveVerticalSimulation(seed, data_config, transform_config, models[0], exist = True)
            #simulation.train()
            simulation.test_inference()
            result = simulation.collect_results(seed, save = False)
            self.naive_results.append(result)
            from federated_inference.simulations.isolated.configs.data_config import DataConfiguration
            from federated_inference.simulations.isolated.configs.transform_config import DataTransformConfiguration
            data_config = DataConfiguration(name)
            simulation = IsolatedVerticalSimulation(data_config, transform_config, seed, models[1], exist=True)
            #simulation.train()
            simulation.test()
            result = simulation.collect_results(seed, save = False, figures=False)
            self.isolated_results.append(result)

    def _precision_recall(self, idx):
        import numpy as np
        values = [r['client'][0]['precision'][idx] for r in self.naive_results]
        mean =  np.mean(values)
        var = np.var(values)
        var1 = (mean, var)
        var2 = []
        for i in range(len(self.isolated_results[0]['clients'])):
            values = [r['clients'][i]['precision'][idx] for r in self.isolated_results]
            mean =  np.mean(values)
            var = np.var(values)
            var2.append((mean, var))
        var4 = []
        values = [r['client'][0]['recall'][idx] for r in self.naive_results]
        mean =  np.mean(values)
        var = np.var(values)
        var3 = (mean, var)
        for i in range(len(self.isolated_results[0]['clients'])):
            values = [r['clients'][i]['recall'][idx] for r in self.isolated_results]
            mean =  np.mean(values)
            var = np.var(values)
            var4.append((mean, var))
        return var1, var2, var3, var4

    def __precision_recall_fig(self, name, idx):
        import plotly.graph_objects as go
        var1, var2, var3, var4 = self._precision_recall(idx)
        # Labels for x-axis
        labels = ["Naive Server"] + [f"Isolated Client {i}" for i in range(len(var2))]
        
        # Build traces
        fig = go.Figure()
        
        # Precision trace
        fig.add_trace(go.Bar(
            x=labels,
            y=[var1[0]] + [v[0] for v in var2],
            name='Precision',
            marker_color='steelblue',
            error_y=dict(
                type='data',
                array=[var1[1]**0.5] + [v[1]**0.5 for v in var2],
                visible=True
            )
        ))
        
        # Recall trace
        fig.add_trace(go.Bar(
            x=labels,
            y=[var3[0]] + [v[0] for v in var4],
            name='Recall',
            marker_color='darkorange',
            error_y=dict(
                type='data',
                array=[var3[1]**0.5] + [v[1]**0.5 for v in var4],
                visible=True
            )
        ))
        
        # Layout
        fig.update_layout(
            title=f"{name} Precision and Recall of Class {idx}",
            xaxis_title='Experiment',
            yaxis_title='Score',
            yaxis=dict(range=[0.2, 1.0]),
            barmode='group',
            template='plotly_white'
        
        )
        
        fig.show()

    def show_recall_figs(self):
        for i in range(10):
            self.__precision_recall_fig(self.name,i)

    
    def _accuracy(self):
        import numpy as np
        
        values = [r['client'][0]['accuracy'] for r in self.naive_results]
        mean =  np.mean(values)
        var = np.var(values)
        var1 = (mean, var)
        
        var2 = []
        for i in range(len(self.isolated_results[0]['clients'])):
            values = [r['clients'][i]['accuracy'] for r in self.isolated_results]
            mean =  np.mean(values)
            var = np.var(values)
            var2.append((mean, var))
        return var1, var2
    
    def _accuracy_comparison(): 
        var1, var2 = self._accuracy()
        
        import plotly.graph_objects as go
        results = [("Naive Server", var1[0], var1[1])] + [(f"Isolated Client {i}", acc, var) for i, (acc, var) in enumerate(var2)]
        
        labels = [r[0] for r in results]
        accuracies = [r[1] for r in results]
        variances = [r[2] for r in results]
        
        fig = go.Figure()
        
        fig.add_trace(go.Bar(
            x=labels,
            y=accuracies,
            name='Accuracy',
            marker_color='steelblue',
            error_y=dict(
                type='data',
                array=[v ** 0.5 for v in variances],  # Use stddev for error bars
                visible=True
            )
        ))
        s
        # Customize layout
        fig.update_layout(
            title=f"{self.name} Accuracy",
            xaxis_title='Experiment',
            yaxis_title='Accuracy',
            yaxis=dict(range=[0.7, 1.0]),
            template='plotly_white'
        )
        
        fig.show()

ex = NaiveIsolatedExperiment()
ex._simulation_results()
ex.show_recall_figs()

In [None]:
ex._simulation_results("FMNIST", [NaiveFMNISTModel, IsolatedFMNISTModel])
ex.show_recall_figs()

In [None]:
import torch.nn.functional as F
def pred(self):
    predictions = []
    confidences = []
    confidence_corrects = []
    testset = self.data.test_dataset
    testloader = self._pred_loader(testset, self.model_config.BATCH_SIZE_TEST, self.model_config.TEST_SHUFFLE)
    self.model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in testloader :
            data = data.to(self.model_config.DEVICE).float()
            target = target.to(self.model_config.DEVICE).long()
            output = self.model(data)
            probs = F.softmax(output, dim=1)
            confidence, _ = probs.max(dim=1, keepdim=True)
            confidence_true = probs[torch.arange(len(target)), target]
            confidences += confidence.squeeze().tolist()
            confidence_corrects += confidence_true.squeeze().tolist()
            test_loss += self.criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            predictions = predictions + pred.squeeze().tolist()

    return confidences, confidence_corrects, predictions


for client in simulation.clients:
    
    confidences, confidences_trues, predictions = pred(client)
    true_labels = client.data.test_dataset.targets.squeeze().tolist()
    # Create a DataFrame
    df = pd.DataFrame({
        'target': list(true_labels),
        'confidence': confidences,
        'confidence_correct': confidences_trues,
        'predictions': predictions,
    })

    df = df[df['target'] == df['predictions']]
    
    # Group by target and calculate mean and variance
    summary = df.groupby('target')['confidence'].agg(['mean', 'var']).reset_index()
    
    # Rename columns for clarity (optional)
    summary.columns = ['target', 'avg_max_confidence', 'variance']
    
    print(summary)

    summary = df.groupby('target')['confidence_correct'].agg(['mean', 'var']).reset_index()
    print(summary)

    print("_"*40)
    # Create a DataFrame
    df = pd.DataFrame({
        'target': list(true_labels),
        'confidence': confidences,
        'confidence_correct': confidences_trues,
        'predictions': predictions,
    })

    df = df[df['target'] != df['predictions']]

    # Group by target and calculate mean and variance
    summary = df.groupby('target')['confidence'].agg(['mean', 'var']).reset_index()
    
    # Rename columns for clarity (optional)
    summary.columns = ['target', 'avg_max_confidence', 'variance']
    
    print(summary)

    summary = df.groupby('target')['confidence_correct'].agg(['mean', 'var']).reset_index()
    print(summary)
    print("_"*40)
    print("_"*40)


In [None]:
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support

# Create full DataFrame
df = pd.DataFrame({
    'target': list(true_labels),
    'confidence': confidences,
    'confidence_correct': confidences_trues,
    'predictions': predictions,
})

# Split correct and incorrect
df_correct = df[df['target'] == df['predictions']]
df_wrong = df[df['target'] != df['predictions']]

# Function to summarize a group
def summarize_group(df_sub, label):
    conf_summary = df_sub.groupby('target')['confidence'].agg(['mean', 'var']).rename(columns={'mean': f'{label}_avg_max_conf', 'var': f'{label}_var_max_conf'})
    conf_true_summary = df_sub.groupby('target')['confidence_correct'].agg(['mean', 'var']).rename(columns={'mean': f'{label}_avg_true_conf', 'var': f'{label}_var_true_conf'})
    return conf_summary.join(conf_true_summary, how='outer')

# Summarize correct and wrong separately
summary_correct = summarize_group(df_correct, 'correct')
summary_wrong = summarize_group(df_wrong, 'wrong')

# Compute precision and recall
precision, recall, _, _ = precision_recall_fscore_support(true_labels, predictions, labels=sorted(set(true_labels)), zero_division=0)
metrics_df = pd.DataFrame({
    'target': sorted(set(true_labels)),
    'precision': precision,
    'recall': recall
})

# Combine everything
final_summary = metrics_df.set_index('target') \
    .join(summary_correct, how='left') \
    .join(summary_wrong, how='left') \
    .reset_index()

# Fill NaNs for classes that may have no correct/wrong predictions
final_summary = final_summary.fillna(0)

# Display result
print(final_summary)

In [None]:
    # Stack batches into full dataset arrays
    features_tensor = torch.cat(all_features)  # shape: [N, 784]
    labels_tensor = torch.cat(all_labels)


    # Convert tensors to NumPy arrays
    features_np = features_tensor.numpy()
    labels_np = labels_tensor.numpy()
    
    # Filter for two specific digits (e.g., 3 and 8)
    target_digits = [1, 0]
    mask = np.isin(labels_np, target_digits)
    features_np = features_np[mask]
    labels_np = labels_np[mask]

    sample_size = 5000
    # If dataset is larger than sample_size, randomly sample without replacement for faster PCA
    if features_np.shape[0] > sample_size:
        rng = np.random.default_rng(seed=42)  # Use a fixed seed for reproducibility
        idx = rng.choice(features_np.shape[0], size=sample_size, replace=False)
        features_np = features_np[idx]
        labels_np = labels_np[idx]

    # Apply PCA to reduce features to 2 dimensions
    pca = PCA(n_components=2)
    features_pca = pca.fit_transform(features_np)

    # Plot the 2D PCA embedding, color-coded by digit label
    plt.figure(figsize=(10, 8))
    for digit in range(10):
        idxs = labels_np == digit
        plt.scatter(features_pca[idxs, 0], features_pca[idxs, 1], label=str(digit), alpha=0.5, s=20)

    plt.title('PCA of CNN Features for All MNIST Digits')
    plt.xlabel('Principal Component 1')
    plt.ylabel('Principal Component 2')
    plt.legend(title='Digit')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    client = simulation.clients[0] 
    client.model.eval()
    all_features = []
    all_labels = []

    with torch.no_grad():
        for images, labels in client.trainloader:
            images = images.to(model_config.DEVICE)
            features = client.model.features(images)
            features = torch.flatten(features, start_dim=1) 
            all_features.append(features.cpu())
            all_labels.append(labels)


    # Stack batches into full dataset arrays
    features_tensor = torch.cat(all_features)  # shape: [N, 784]
    labels_tensor = torch.cat(all_labels)


    # Convert tensors to NumPy arrays
    features_np = features_tensor.numpy()
    labels_np = labels_tensor.numpy()

    sample_size = 5000
    # If dataset is larger than sample_size, randomly sample without replacement for faster PCA
    if features_np.shape[0] > sample_size:
        rng = np.random.default_rng(seed=42)  # Use a fixed seed for reproducibility
        idx = rng.choice(features_np.shape[0], size=sample_size, replace=False)
        features_np = features_np[idx]
        labels_np = labels_np[idx]

    # Apply PCA to reduce features to 2 dimensions
    pca = PCA(n_components=3)
    features_pca = pca.fit_transform(features_np)

    # Plot the 2D PCA embedding, color-coded by digit label
    plt.figure(figsize=(10, 8))
    for digit in target_digits:
        idxs = labels_np == digit
        plt.scatter(features_pca[idxs, 0], features_pca[idxs, 1], label=str(digit), alpha=0.5, s=20)


In [None]:
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    client = simulation.clients[0] 
    client.model.eval()
    all_features = []
    all_labels = []

    with torch.no_grad():
        for images, labels in client.trainloader:
            images = images.to(model_config.DEVICE)
            features = client.model.features(images)
            features = torch.flatten(features, start_dim=1) 
            all_features.append(features.cpu())
            all_labels.append(labels)


    # Stack batches into full dataset arrays
    features_tensor = torch.cat(all_features)  # shape: [N, 784]
    labels_tensor = torch.cat(all_labels)


    # Convert tensors to NumPy arrays
    features_np = features_tensor.numpy()
    labels_np = labels_tensor.numpy()
    
    # Filter for two specific digits (e.g., 3 and 8)
    target_digits = [0,1]
    mask = np.isin(labels_np, target_digits)
    features_np = features_np[mask]
    labels_np = labels_np[mask]

    sample_size = 5000
    # If dataset is larger than sample_size, randomly sample without replacement for faster PCA
    if features_np.shape[0] > sample_size:
        rng = np.random.default_rng(seed=42)  # Use a fixed seed for reproducibility
        idx = rng.choice(features_np.shape[0], size=sample_size, replace=False)
        features_np = features_np[idx]
        labels_np = labels_np[idx]

    # Apply PCA to reduce features to 2 dimensions
    pca = PCA(n_components=3)
    features_pca = pca.fit_transform(features_np)


In [None]:
import plotly.subplots as sp
import plotly.graph_objects as go
import torch
import numpy as np

def tensor_to_numpy_image(tensor_img):
    """
    Convert a torch Tensor image to numpy array suitable for plotly.
    Assumes image shape is [C, H, W], returns [H, W, C].
    """
    if isinstance(tensor_img, torch.Tensor):
        img = tensor_img.detach().cpu().numpy()
        if img.ndim == 3:  # C,H,W -> H,W,C
            img = np.transpose(img, (1, 2, 0))
        # If grayscale (1 channel), squeeze channel dim
        if img.shape[-1] == 1:
            img = img.squeeze(-1)
        # Normalize to [0,255] uint8 if float
        if img.dtype != np.uint8:
            img = (img * 255).clip(0, 255).astype(np.uint8)
        return img
    else:
        # Assume already numpy image
        return tensor_img

def create_clients_image_subplots(datasets, n_images=3):
    """
    Create a Plotly figure with one row per client dataset,
    showing `n_images` images per client side-by-side.

    Parameters:
    - datasets: list of torch.utils.data.Dataset (one per client)
    - n_images: how many images to show per client (default 3)
    """
    n_clients = len(datasets)
    fig = sp.make_subplots(rows=n_clients, cols=n_images,
                           subplot_titles=[f"Img {i+1}" for i in range(n_images)],
                           vertical_spacing=0.05,
                           horizontal_spacing=0.01)

    for r, dataset in enumerate(datasets, start=1):
        # Pick n_images images from dataset
        for c in range(n_images):
            if c >= len(dataset):
                break
            img, label = dataset[c]
            img_np = tensor_to_numpy_image(img)

            fig.add_trace(go.Image(z=img_np), row=r, col=c+1)

        # Set y-axis title as Client number (only once per row)
        fig.update_yaxes(title_text=f"Client {r}", row=r, col=1)

    fig.update_layout(height=250 * n_clients, width=250 * n_images,
                      showlegend=False, title_text="Clients Dataset Images")
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)

    return fig


In [None]:
import plotly.express as px
import pandas as pd

# Create a DataFrame for Plotly
df = pd.DataFrame({
    'PC1': features_pca[:, 0],
    'PC2': features_pca[:, 1],
    'PC3': features_pca[:, 2],
    'Digit': labels_np.astype(str)  # convert labels to strings for grouping
})

# Create interactive 3D scatter plot
fig = px.scatter_3d(
    df,
    x='PC1', y='PC2', z='PC3',
    color='Digit',
    title='3D PCA of CNN Features for Digits 1, 2, and 3',
    opacity=0.7,
    symbol='Digit'
)

fig.update_layout(
    legend_title_text='Digit',
    margin=dict(l=0, r=0, b=0, t=40),
    scene=dict(
        xaxis_title='PC 1',
        yaxis_title='PC 2',
        zaxis_title='PC 3'
    )
)

fig.show()