In [None]:
import sys
import os

# Dynamically find the path to `src/`
notebook_dir = os.getcwd()  # Get current working directory
project_root = os.path.abspath(os.path.join(notebook_dir, ".."))  # Go up one level
src_path = os.path.join(project_root, "src")

# Add to sys.path if not already present
if src_path not in sys.path:
    sys.path.append(src_path)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch

# local modules
from src.settings import DIR_OUT, DATAGEN_GLOBALS
from src.data_tools import data_train_test_split_linear, data_train_test_split_clusters, data_train_test_split_manifold

### Case A - Linear problem ($d$-dim subspaces of $\mathbb{R}^n$)

### Case B - Clustering (Gaussian mixture models)

### Case C - Nonlinear manifolds ($d$-spheres in $\mathbb{R}^n$)

In all cases, consider the problem of being given $L \gg d$ tokens $x_1, ... , x_{L}$ from $p_X(x)$ and corrupting the final token via Gaussian noise: $x_\text{query} = x_{L} + \eta$.

The objective is to produce a MSE-minimizing estimate the target $x_{L+1}$ from the corrupted token $x_{\text{query}}$ given the context. 

In [None]:
sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

NB_OUTPUT = DIR_OUT  # alias

if not os.path.exists(NB_OUTPUT):
    os.makedirs(NB_OUTPUT)

# Generic utils functions

In [None]:
colors_blue_context = '#27AAE1'
colors_white_ = '#D8E1F3'

def plot_dataset_2d_or_3d(X, target=None, pred=None, emphasize_final_token=True, arrow_origin_to_point=None, cluster_ids=None, sns_style=True, title=None):
    L, dim_n = X.shape
    
    if not sns_style:
        plt.style.use('default')
    
    if cluster_ids is None:
        #scatter_colors = ['blue'] * L
        scatter_colors = colors_blue_context
    else:
        # specify default cmap for scatter points; if cluster_id is None then the first color is used
        cmap = plt.cm.get_cmap('Set2')                                       # get a specific colormap
        cmaplist = cmap.colors                                               # extract all colors
        scatter_colors = [cmaplist[i % len(cmaplist)] for i in cluster_ids]  # map cluster ID to color
        
    
    if dim_n not in [2,3]:
        print("plot_dataset_2d_or_3d(...) - can only plot 2D or 3D data, but dim_n=%d" % dim_n)
    else:
        print("plotting call - plot_dataset_2d_or_3d(...) dim_n=%d" % dim_n)
        fig = plt.figure()
        if dim_n == 2:
            ax = fig.add_subplot()
            ax.scatter(X[:, 0], X[:, 1], c=scatter_colors)
            ax.set_xlabel('X'); ax.set_ylabel('Y')
        else:
            ax = fig.add_subplot(111, projection='3d')
            for idx, row in enumerate(X):
                ax.scatter(row[0], row[1], row[2], c=scatter_colors)
            ax.scatter(X[:, 0], X[:, 1], X[:, 2],  c=scatter_colors)
            ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
        
        if emphasize_final_token:
            ax.scatter(*X[-1, :], color='red', edgecolor='k', linewidth=1, s=50, zorder=10, label='corrupted')
        
        if target is not None:
            # plot line from target to the corrupted token
            ax.plot(*[a for a in zip(target, X[-1,:])], '--or', zorder=8)
            # scatterplot point: target
            ax.scatter(*target, c=colors_blue_context, edgecolor='k', marker='s', s=75, zorder=10, label='target')
            plt.legend()
        
        if pred is not None:
            # plot line from target to the corrupted token
            ax.plot(*[a for a in zip(pred, X[-1,:])], '--ok', zorder=8)
            # scatterplot point: target
            ax.scatter(*pred, c='mediumpurple', edgecolor='k', marker='o', s=50, zorder=10, label='pred')
            plt.legend()
        
        if arrow_origin_to_point is not None:
            # plot line from origin to the corrupted token
            ax.plot(*[a for a in zip(np.zeros(dim_n), arrow_origin_to_point)], '--k', zorder=8, linewidth=0.5)
            
        if title is not None:
            ax.set_title(title)
        ax.set_aspect('equal')
        
        #plt.show()
        return ax

# Use datagen tools from data_tools.py 
### - build example train/test dataset and visualize it here

In [None]:
sns_style = False

for datagen_seed in range(3):
    for idx, datagen_choice in enumerate([0, 1, 2]):
        
        context_len = 100
        dim_n = 2
        test_ratio = 0.2
        
        
        #datagen_seed = 15  # None  |  15, 4
        
        ###datagen_choice = 1   # {0, 1, 2} -> {linear, clusters, manifold}
        datagen_label = ['linear', 'clusters', 'manifold'][datagen_choice]
        
        sigma2_corruption = 0.1
        
        base_kwargs = dict(
            context_len=context_len,
            dim_n=dim_n,
            num_W_in_dataset=1000,
            context_examples_per_W=1,
            samples_per_context_example=1,
            test_ratio=test_ratio,
            verbose=True,  
            as_torch=False,  
            savez_fname=None,  
            seed=datagen_seed,  
            style_subspace_dimensions=DATAGEN_GLOBALS[datagen_choice]['style_subspace_dimensions'],
            style_origin_subspace=DATAGEN_GLOBALS[datagen_choice]['style_origin_subspace'],
            style_corruption_orthog=DATAGEN_GLOBALS[datagen_choice]['style_corruption_orthog'],
            sigma2_corruption=sigma2_corruption,
        )
        
        print('='*20)
        ################################################################################
        # Build data
        ################################################################################
        if datagen_choice == 0:
            x_train, y_train, x_test, y_test, train_data_subspaces, test_data_subspaces = data_train_test_split_linear(
                **base_kwargs,
                sigma2_pure_context=DATAGEN_GLOBALS[datagen_choice]['sigma2_pure_context'],
            )
        elif datagen_choice == 1:
            x_train, y_train, x_test, y_test, train_data_subspaces, test_data_subspaces = data_train_test_split_clusters(
                **base_kwargs,
                style_cluster_mus=DATAGEN_GLOBALS[datagen_choice]['style_cluster_mus'],
                style_cluster_vars=DATAGEN_GLOBALS[datagen_choice]['style_cluster_vars'],
                num_cluster=4,
                cluster_var=0.01,
            )
        else:
            assert datagen_choice == 2
            x_train, y_train, x_test, y_test, train_data_subspaces, test_data_subspaces = data_train_test_split_manifold(
                **base_kwargs,
                radius_sphere=DATAGEN_GLOBALS[datagen_choice]['radius_sphere'],
            )
        
        print('x_train.shape', x_train.shape)
        
        
        cmap = plt.get_cmap("tab20")
        print('TODO samples_per_context parameter (currently 1)')
        
        example_id = 12
        context_example = x_train[example_id, :, :]
        target_example = y_train[example_id, :]
        
        #title_example = 'example input sequence, one batch (n=%d)' % (dim_n)
        title_example = 'example input sequence and target (dim n=%d)' % (dim_n)
        
        # Clustering case annotation
        if datagen_choice == 1:
            
            num_k = train_data_subspaces[example_id]['num_cluster']
            print('For example %d, there were %d clusters' % (example_id, num_k))
            cluster_centers = train_data_subspaces[example_id]['sampled_mus']
            cluster_membership = train_data_subspaces[example_id]['sampled_cluster_id']
            
            # colormap scatter based on cluster ID
            ax = plot_dataset_2d_or_3d(context_example.T, target=target_example, 
                               emphasize_final_token=True, 
                               cluster_ids=cluster_membership,
                               title=title_example, sns_style=sns_style)
            
            # Decorate the ax object (Clustering case)
            # ========================================
            from matplotlib.patches import Circle
                
            cmap = plt.cm.get_cmap('Set2')                                       # get a specific colormap
            cmaplist = cmap.colors                                               # extract all colors
            scatter_colors = [cmaplist[i % len(cmaplist)] for i in range(num_k)]  # map cluster ID to color
                
            ax.scatter(*cluster_centers, edgecolor='k', marker='s', s=100, zorder=10, c=scatter_colors)
            
            if dim_n == 2:
                origin_circle = Circle(np.zeros(dim_n), 1.0, facecolor='blue', fill=None, edgecolor='k', zorder=1)
                ax.add_patch(origin_circle)
            # ========================================
        else:
            
            prediction = None
            arrow_origin_to_point = None
            baseline_pred_A = None
            baseline_pred_B = None
            
            if datagen_choice == 2:
                
                from data_tools import proj_subsphere_estimator
                
                X_seq = context_example[:, :-1]
                x_corrupt = context_example[:, -1]
                baseline_pred_A = proj_subsphere_estimator(X_seq, x_corrupt, sigma2_corruption, shrunken=False,
                                                           eval_cutoff=1e-6, verbose_vis=True, style_origin_subspace=True, style_corruption_orthog=False)
                baseline_pred_B = proj_subsphere_estimator(X_seq, x_corrupt, sigma2_corruption, shrunken=True,
                                                           eval_cutoff=1e-6, verbose_vis=True, style_origin_subspace=True, style_corruption_orthog=False)
                
                
                guess_most_recent = context_example[:, -1]
                
                print('='*20)
                print('baseline_pred_A', baseline_pred_A)
                print('norm baseline_pred_A', np.linalg.norm(baseline_pred_A), '(should be %.2f)' % DATAGEN_GLOBALS[datagen_choice]['radius_sphere'])
                print('='*20)
                print('baseline_pred_B', baseline_pred_B)
                print('norm baseline_pred_B', np.linalg.norm(baseline_pred_B))
                print('='*20)
                print('guess_most_recent', guess_most_recent)
                print('norm guess_most_recent', np.linalg.norm(guess_most_recent))
                print('='*20)
                
                # normalize the baseline_pred so itts radius matches the sphere
                point_on_sphere = baseline_pred_B / np.linalg.norm(baseline_pred_B) * DATAGEN_GLOBALS[datagen_choice]['radius_sphere']
                arrow_origin_to_point = point_on_sphere
                
        
            ax = plot_dataset_2d_or_3d(context_example.T, target=target_example, pred=baseline_pred_A,
                                       arrow_origin_to_point=arrow_origin_to_point,
                                       emphasize_final_token=True, 
                                       cluster_ids=None,
                                       title=title_example, sns_style=sns_style)
        
            ax = plot_dataset_2d_or_3d(context_example.T, target=target_example, pred=baseline_pred_B,
                                       arrow_origin_to_point=arrow_origin_to_point,
                                       emphasize_final_token=True, 
                                       cluster_ids=None,
                                       title=title_example, sns_style=sns_style)
        
        
            ax = plot_dataset_2d_or_3d(context_example.T, target=target_example, pred=None,
                               arrow_origin_to_point=arrow_origin_to_point,
                               emphasize_final_token=True, 
                               cluster_ids=None,
                               title=title_example, sns_style=sns_style)
        
        if dim_n == 2:
            ax.axhline(0, linestyle='--', c='grey')
            ax.axvline(0, linestyle='--', c='grey')
        
        # post-process for linear case
        if datagen_choice == 0:
            ax.set_xlim(-1.5, 1.5)
            ax.set_ylim(-1.5, 1.5)
        
        # post-process for clusters case
        if datagen_choice == 1:
            ax.set_xlim(-1.75, 1.75)
            ax.set_ylim(-1.75, 1.75)
            
        # post-process for manifold case
        if datagen_choice == 2:
            if dim_n == 2:
                origin_circle = Circle(np.zeros(dim_n), 1.0, facecolor='blue', fill=None, edgecolor='k', zorder=1)
                ax.add_patch(origin_circle)
        
        """for i in range(x_train.shape[0]):
            plt.scatter(x=batch_choice[ 0, :], y=context_example[1, :], color=cmap(i % 20), alpha=1, zorder=10)
            plt.scatter(x=y_train[0], y=y_train[i, 1], color=cmap(i % 20), edgecolor='k', marker='s', s=30, zorder=11)
        plt.axhline(0, alpha=0.5)
        plt.axvline(0, alpha=0.5)
        #plt.plot([0, unique_W_V[0,0]], [0, unique_W_V[1,0]], '--', linewidth=1, color='k', zorder=15)
        plt.gca().set_aspect('equal')"""
        #plt.title('Example sequences with query at end: corruption of a point on affine subspace')
        plt.savefig(NB_OUTPUT + os.sep + 'example_denoise_%s_seed%s.pdf' % (datagen_label, datagen_seed))
        plt.savefig(NB_OUTPUT + os.sep + 'example_denoise_%s_seed%s.svg' % (datagen_label, datagen_seed))
        plt.show()