# STAT 6385: Final Project

## Importing Libraries

In [1]:
import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.inspection import permutation_importance
from sklearn.metrics import roc_auc_score, pairwise_distances
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from scipy.stats import ks_2samp, chi2_contingency
from statsmodels.stats.multitest import multipletests
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple, Union
import os
import sys
import math
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Union
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from statsmodels.stats.multitest import multipletests
from sklearn.utils import shuffle
from sklearn.preprocessing import QuantileTransformer, StandardScaler
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer

## Knockoff Functions

### Deep Knockoffs

In [2]:
# code downloaded from https://github.com/msesia/deepknockoffs/tree/master/DeepKnockoffs

import torch
import torch.nn.functional as F

min_var_est = 1e-8

def linear_mmd2(f_of_X, f_of_Y):
    loss = 0.0
    delta = f_of_X - f_of_Y
    loss = torch.mean((delta[:-1] * delta[1:]).sum(1))
    return loss

def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0):
    K_XX = (alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c)
    K_XX_mean = torch.mean(K_XX.pow(d))

    K_YY = (alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c)
    K_YY_mean = torch.mean(K_YY.pow(d))

    K_XY = (alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c)
    K_XY_mean = torch.mean(K_XY.pow(d))

    K_YX = (alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c)
    K_YX_mean = torch.mean(K_YX.pow(d))

    return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean


def _mix_rbf_kernel(X, Y, sigma_list):
    assert(X.size(0) == Y.size(0))
    m = X.size(0)

    Z = torch.cat((X, Y), 0)
    ZZT = torch.mm(Z, Z.t())
    diag_ZZT = torch.diag(ZZT).unsqueeze(1)
    Z_norm_sqr = diag_ZZT.expand_as(ZZT)
    exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()

    K = 0.0
    for sigma in sigma_list:
        gamma = 1.0 / (2 * sigma**2)
        K += torch.exp(-gamma * exponent)

    return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)

def _mix_imq_kernel(X,
               Y,
               sigma_list):

    assert(X.size(0) == Y.size(0))
    m = X.size(0)
    h_dim = X.size(1)

    Z = torch.cat((X, Y), 0)
    ZZT = torch.mm(Z, Z.t())
    diag_ZZT = torch.diag(ZZT).unsqueeze(1)
    Z_norm_sqr = diag_ZZT.expand_as(ZZT)

    exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()

    K = 0.0
    for sigma in sigma_list:
        gamma = 2 * h_dim * 1.0 * sigma**2
        K += gamma / (gamma + exponent)

    return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)

def mix_imq_mmd2(X, Y, sigma_list, biased=True):
    K_XX, K_XY, K_YY, d = _mix_imq_kernel(X, Y, sigma_list)
    # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
    return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)

def mix_rbf_mmd2(X, Y, sigma_list, biased=True):
    K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
    # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
    return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)

def mix_rbf_mmd2_loss(X, Y, sigma_list, biased=True):

    mmd_dist_ref = mix_rbf_mmd2(X, Y, sigma_list, biased=True)
    return torch.sqrt(F.relu(mmd_dist_ref))

def mix_rbf_mmd2_unbiased_loss(X, Y, sigma_list):

    K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
    mmd_dist_ref = _mmd2_ignore_diagonals(K_XX, K_XY, K_YY, const_diagonal=False, biased=False)
    return torch.sqrt(F.relu(torch.abs(mmd_dist_ref)))

def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True):
    K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
    # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
    return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)


################################################################################
# Helper functions to compute variances based on kernel matrices
################################################################################


def _mmd2_ignore_diagonals(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
    m = K_XX.size(0)    # assume X, Y are same shape

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if const_diagonal is not False:
        diag_X = diag_Y = diag_XY = const_diagonal
        sum_diag_X = sum_diag_Y = sum_diag_XY = m * const_diagonal
    else:
        diag_X = torch.diag(K_XX)            # (m,)
        diag_Y = torch.diag(K_YY)            # (m,)
        diag_XY = torch.diag(K_XY)           # (m,)
        sum_diag_X = torch.sum(diag_X)
        sum_diag_Y = torch.sum(diag_Y)
        sum_diag_XY = torch.sum(diag_XY)


    Kt_XX_sums = K_XX.sum(dim=1) - diag_X             # \tilde{K}_XX * e = K_XX * e - diag_X
    Kt_YY_sums = K_YY.sum(dim=1) - diag_Y             # \tilde{K}_YY * e = K_YY * e - diag_Y
    K_XY_sums_0 = K_XY.sum(dim=0) - diag_XY                     # K_{XY}^T * e

    Kt_XX_sum = Kt_XX_sums.sum()                       # e^T * \tilde{K}_XX * e
    Kt_YY_sum = Kt_YY_sums.sum()                       # e^T * \tilde{K}_YY * e
    K_XY_sum = K_XY_sums_0.sum()                       # e^T * K_{XY} * e

    if biased:
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
            + (Kt_YY_sum + sum_diag_Y) / (m * m)
            - 2.0 * (K_XY_sum + sum_diag_XY) / (m * m))
    else:
        mmd2 = ((Kt_XX_sum ) / (m * (m-1))
            + (Kt_YY_sum) / (m * (m-1))
            - 2.0 * (K_XY_sum) / (m * (m-1)))

    return mmd2


def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
    m = K_XX.size(0)    # assume X, Y are same shape

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if const_diagonal is not False:
        diag_X = diag_Y = const_diagonal
        sum_diag_X = sum_diag_Y = m * const_diagonal
    else:
        diag_X = torch.diag(K_XX)                       # (m,)
        diag_Y = torch.diag(K_YY)                       # (m,)
        sum_diag_X = torch.sum(diag_X)
        sum_diag_Y = torch.sum(diag_Y)

    Kt_XX_sums = K_XX.sum(dim=1) - diag_X             # \tilde{K}_XX * e = K_XX * e - diag_X
    Kt_YY_sums = K_YY.sum(dim=1) - diag_Y             # \tilde{K}_YY * e = K_YY * e - diag_Y
    K_XY_sums_0 = K_XY.sum(dim=0)                     # K_{XY}^T * e

    Kt_XX_sum = Kt_XX_sums.sum()                       # e^T * \tilde{K}_XX * e
    Kt_YY_sum = Kt_YY_sums.sum()                       # e^T * \tilde{K}_YY * e
    K_XY_sum = K_XY_sums_0.sum()                       # e^T * K_{XY} * e

    if biased:
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
            + (Kt_YY_sum + sum_diag_Y) / (m * m)
            - 2.0 * K_XY_sum / (m * m))
    else:
        mmd2 = (Kt_XX_sum / (m * (m - 1))
            + Kt_YY_sum / (m * (m - 1))
            - 2.0 * K_XY_sum / (m * m))

    return mmd2


def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
    mmd2, var_est = _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased)
    loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est))
    return loss, mmd2, var_est


def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
    m = K_XX.size(0)    # assume X, Y are same shape

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if const_diagonal is not False:
        diag_X = diag_Y = const_diagonal
        sum_diag_X = sum_diag_Y = m * const_diagonal
        sum_diag2_X = sum_diag2_Y = m * const_diagonal**2
    else:
        diag_X = torch.diag(K_XX)                       # (m,)
        diag_Y = torch.diag(K_YY)                       # (m,)
        sum_diag_X = torch.sum(diag_X)
        sum_diag_Y = torch.sum(diag_Y)
        sum_diag2_X = diag_X.dot(diag_X)
        sum_diag2_Y = diag_Y.dot(diag_Y)

    Kt_XX_sums = K_XX.sum(dim=1) - diag_X             # \tilde{K}_XX * e = K_XX * e - diag_X
    Kt_YY_sums = K_YY.sum(dim=1) - diag_Y             # \tilde{K}_YY * e = K_YY * e - diag_Y
    K_XY_sums_0 = K_XY.sum(dim=0)                     # K_{XY}^T * e
    K_XY_sums_1 = K_XY.sum(dim=1)                     # K_{XY} * e

    Kt_XX_sum = Kt_XX_sums.sum()                       # e^T * \tilde{K}_XX * e
    Kt_YY_sum = Kt_YY_sums.sum()                       # e^T * \tilde{K}_YY * e
    K_XY_sum = K_XY_sums_0.sum()                       # e^T * K_{XY} * e

    Kt_XX_2_sum = (K_XX ** 2).sum() - sum_diag2_X      # \| \tilde{K}_XX \|_F^2
    Kt_YY_2_sum = (K_YY ** 2).sum() - sum_diag2_Y      # \| \tilde{K}_YY \|_F^2
    K_XY_2_sum  = (K_XY ** 2).sum()                    # \| K_{XY} \|_F^2

    if biased:
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
            + (Kt_YY_sum + sum_diag_Y) / (m * m)
            - 2.0 * K_XY_sum / (m * m))
    else:
        mmd2 = (Kt_XX_sum / (m * (m - 1))
            + Kt_YY_sum / (m * (m - 1))
            - 2.0 * K_XY_sum / (m * m))

    var_est = (
        2.0 / (m**2 * (m - 1.0)**2) * (2 * Kt_XX_sums.dot(Kt_XX_sums) - Kt_XX_2_sum + 2 * Kt_YY_sums.dot(Kt_YY_sums) - Kt_YY_2_sum)
        - (4.0*m - 6.0) / (m**3 * (m - 1.0)**3) * (Kt_XX_sum**2 + Kt_YY_sum**2)
        + 4.0*(m - 2.0) / (m**3 * (m - 1.0)**2) * (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0))
        - 4.0*(m - 3.0) / (m**3 * (m - 1.0)**2) * (K_XY_2_sum) - (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2
        + 8.0 / (m**3 * (m - 1.0)) * (
            1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
            - Kt_XX_sums.dot(K_XY_sums_1)
            - Kt_YY_sums.dot(K_XY_sums_0))
        )
    return mmd2, var_est

In [3]:
# code downloaded from https://github.com/msesia/deepknockoffs/tree/master/DeepKnockoffs

def covariance_diff_biased(X, Xk, SigmaHat, Mask, scale=1.0):
    """ Second-order loss function, as described in deep knockoffs manuscript
    :param X: input data
    :param Xk: generated knockoffs
    :param SigmaHat: target covariance matrix
    :param Mask: masking the diagonal of Cov(X,Xk)
    :param scale: scaling the loss function
    :return: second-order loss function
    """

    # Center X,Xk
    mX  = X  - torch.mean(X,0,keepdim=True)
    mXk = Xk - torch.mean(Xk,0,keepdim=True)
    # Compute covariance matrices
    SXkXk = torch.mm(torch.t(mXk),mXk)/mXk.shape[0]
    SXXk  = torch.mm(torch.t(mX),mXk)/mXk.shape[0]

    # Compute loss
    T  = (SigmaHat-SXkXk).pow(2).sum() / scale
    T += (Mask*(SigmaHat-SXXk)).pow(2).sum() / scale
    return T

def create_checkpoint_name(pars):
    """ Defines the filename of the network
    :param pars: training hyper-parameters
    :return: filename composed of the hyper-parameters
    """

    checkpoint_name = 'net'
    for key, value in pars.items():
        checkpoint_name += '_' + key
        if key == 'alphas':
            for i in range(len(pars['alphas'])):
                checkpoint_name += '_' + str(pars['alphas'][i])
        else:
            checkpoint_name += '_' + str(value)
    return checkpoint_name

def save_checkpoint(state, filename):
    """ Saves the most updatated network to filename and store the previous
    machine in filename + _prev.pth.tar' file
    :param state: training state of the machine
    :filename: filename to save the current machine
    """

    # keep the previous model
    if os.path.isfile(filename):
        os.rename(filename, filename + '_prev.pth.tar')
    # save new model
    torch.save(state, filename)

def gen_batches(n_samples, batch_size, n_reps):
    """ Divide input data into batches.
    :param data: input data
    :param batch_size: size of each batch
    :return: data divided into batches
    """
    batches = []
    for rep_id in range(n_reps):
        idx = np.random.permutation(n_samples)
        for i in range(0, math.floor(n_samples/batch_size)*batch_size, batch_size):
            window = np.arange(i,i+batch_size)
            new_batch = idx[window]
            batches += [new_batch]
    return(batches)

class Net(nn.Module):
    """ Deep knockoff network
    """
    def __init__(self, p, dim_h, family="continuous"):
        """ Constructor
        :param p: dimensions of data
        :param dim_h: width of the network (~6 layers are fixed)
        :param family: data type, either "continuous" or "binary"
        """
        super(Net, self).__init__()

        self.p = p
        self.dim_h = dim_h
        if (family=="continuous"):
            self.main = nn.Sequential(
                nn.Linear(2*self.p, self.dim_h, bias=False),
                nn.BatchNorm1d(self.dim_h),
                nn.PReLU(),
                nn.Linear(self.dim_h, self.dim_h, bias=False),
                nn.BatchNorm1d(self.dim_h),
                nn.PReLU(),
                nn.Linear(self.dim_h, self.dim_h, bias=False),
                nn.BatchNorm1d(self.dim_h),
                nn.PReLU(),
                # nn.Linear(self.dim_h, self.dim_h, bias=False),
                # nn.BatchNorm1d(self.dim_h),
                # nn.PReLU(),
                # nn.Linear(self.dim_h, self.dim_h, bias=False),
                # nn.BatchNorm1d(self.dim_h),
                # nn.PReLU(),
                # nn.Linear(self.dim_h, self.dim_h, bias=False),
                # nn.BatchNorm1d(self.dim_h),
                # nn.PReLU(),
                nn.Linear(self.dim_h, self.p),
            )
        elif (family=="binary"):
            self.main = nn.Sequential(
                nn.Linear(2*self.p, self.dim_h, bias=False),
                nn.BatchNorm1d(self.dim_h, eps=1e-02),
                nn.PReLU(),
                nn.Linear(self.dim_h, self.dim_h, bias=False),
                nn.BatchNorm1d(self.dim_h, eps=1e-02),
                nn.PReLU(),
                nn.Linear(self.dim_h, self.dim_h, bias=False),
                nn.BatchNorm1d(self.dim_h, eps=1e-02),
                nn.PReLU(),
                nn.Linear(self.dim_h, self.dim_h, bias=False),
                nn.BatchNorm1d(self.dim_h, eps=1e-02),
                nn.PReLU(),
                nn.Linear(self.dim_h, self.dim_h, bias=False),
                nn.BatchNorm1d(self.dim_h, eps=1e-02),
                nn.PReLU(),
                nn.Linear(self.dim_h, self.dim_h, bias=False),
                nn.BatchNorm1d(self.dim_h, eps=1e-02),
                nn.PReLU(),
                nn.Linear(self.dim_h, self.p),
                nn.Sigmoid(),
                nn.BatchNorm1d(self.p, eps=1e-02),
            )
        else:
            sys.exit("Error: unknown family");

    def forward(self, x, noise):
        """ Sample knockoff copies of the data
        :param x: input data
        :param noise: random noise seed
        :returns the constructed knockoffs
        """
        x_cat = torch.cat((x,noise),1)
        x_cat[:,0::2] = x
        x_cat[:,1::2] = noise
        return self.main(x_cat)

def norm(X, p=2):
    if(p==np.inf):
        return(torch.max(torch.abs(X)))
    else:
        return(torch.norm(X,p))

class KnockoffMachine:
    """ Deep Knockoff machine
    """
    def __init__(self, pars, checkpoint_name=None, logs_name=None):
        """ Constructor
        :param pars: dictionary containing the following keys
                'family': data type, either "continuous" or "binary"
                'p': dimensions of data
                'epochs': number of training epochs
                'epoch_length': number of iterations over the full data per epoch
                'batch_size': batch size
                'test_size': size of test set
                'lr': learning rate for main training loop
                'lr_milestones': when to decrease learning rate, unused when equals to number of epochs
                'dim_h': width of the network
                'target_corr': target correlation between variables and knockoffs
                'LAMBDA': penalty encouraging second-order knockoffs
                'DELTA': decorrelation penalty hyper-parameter
                'GAMMA': penalty for MMD distance
                'alphas': kernel widths for the MMD measure (uniform weights)
        :param checkpoint_name: location to save the machine
        :param logs_name: location to save the logfile
        """
        # architecture parameters
        self.p = pars['p']
        self.dim_h = pars['dim_h']
        self.family = pars['family']

        # optimization parameters
        self.epochs = pars['epochs']
        self.epoch_length = pars['epoch_length']
        self.batch_size = pars['batch_size']
        self.test_size = pars['test_size']
        self.lr = pars['lr']
        self.lr_milestones = pars['lr_milestones']

        # loss function parameters
        self.alphas = pars['alphas']
        self.target_corr = torch.from_numpy(pars['target_corr']).float()
        self.DELTA = pars['DELTA']
        self.GAMMA = pars['GAMMA']
        self.LAMBDA = pars['LAMBDA']

        # noise seed
        self.noise_std = 1.0
        self.dim_noise = self.p

        # higher-order discrepency function
        self.matching_loss = mix_rbf_mmd2_loss
        self.matching_param = self.alphas

        # Normalize learning rate to avoid numerical issues
        self.lr = self.lr / np.max([self.DELTA, self.GAMMA, self.GAMMA, self.LAMBDA, 1.0])

        self.pars = pars
        if checkpoint_name == None:
            self.checkpoint_name = None
            self.best_checkpoint_name = None
        else:
            self.checkpoint_name = checkpoint_name + "_checkpoint.pth.tar"
            self.best_checkpoint_name = checkpoint_name + "_best.pth.tar"

        if logs_name == None:
            self.logs_name = None
        else:
            self.logs_name = logs_name

        self.resume_epoch = 0

        # init the network
        self.net = Net(self.p, self.dim_h, family=self.family)

    def compute_diagnostics(self, X, Xk, noise, test=False):
        """ Evaluates the different components of the loss function
        :param X: input data
        :param Xk: knockoffs of X
        :param noise: allocated tensor that is used to sample the noise seed
        :param test: compute the components of the loss on train (False) or test (True)
        :return diagnostics: a dictionary containing the following keys:
                 'Mean' : distance between the means of X and Xk
                 'Corr-Diag': correlation between X and Xk
                 'Corr-Full: ||Cov(X,X) - Cov(Xk,Xk)||_F^2 / ||Cov(X,X)||_F^2
                 'Corr-Swap: ||M(Cov(X,X) - Cov(Xk,Xk))||_F^2 / ||Cov(X,X)||_F^2
                             where M is a mask that excludes the diagonal
                 'Loss': the value of the loss function
                 'MMD-Full': discrepancy between (X',Xk') and (Xk'',X'')
                 'MMD-Swap': discrepancy between (X',Xk') and (X'',Xk'')_swap(s)
        """
        # Initialize dictionary of diagnostics
        diagnostics = dict()
        if test:
            diagnostics["Data"] = "test"
        else:
            diagnostics["Data"] = "train"

        ##############################
        # Second-order moments
        ##############################

        # Difference in means
        D_mean = X.mean(0) - Xk.mean(0)
        D_mean = (D_mean*D_mean).mean()
        diagnostics["Mean"] = D_mean.data.cpu().item()

        # Center and scale X, Xk
        mX = X - torch.mean(X,0,keepdim=True)
        mXk = Xk - torch.mean(Xk,0,keepdim=True)
        scaleX  = (mX*mX).mean(0,keepdim=True)
        scaleXk = (mXk*mXk).mean(0,keepdim=True)

        # Correlation between X and Xk
        scaleX[scaleX==0] = 1.0   # Prevent division by 0
        scaleXk[scaleXk==0] = 1.0 # Prevent division by 0
        mXs  = mX  / torch.sqrt(scaleX)
        mXks = mXk / torch.sqrt(scaleXk)
        corr = (mXs*mXks).mean()
        diagnostics["Corr-Diag"] = corr.data.cpu().item()

        # Cov(Xk,Xk)
        Sigma = torch.mm(torch.t(mXs),mXs)/mXs.shape[0]
        Sigma_ko = torch.mm(torch.t(mXks),mXks)/mXk.shape[0]
        DK_2 = norm(Sigma_ko-Sigma) / norm(Sigma)
        diagnostics["Corr-Full"] = DK_2.data.cpu().item()

        # Cov(Xk,X) excluding the diagonal elements
        SigIntra_est = torch.mm(torch.t(mXks),mXs)/mXk.shape[0]
        DS_2 = norm(self.Mask*(SigIntra_est-Sigma)) / norm(Sigma)
        diagnostics["Corr-Swap"] = DS_2.data.cpu().item()

        ##############################
        # Loss function
        ##############################
        _, loss_display, mmd_full, mmd_swap = self.loss(X[:noise.shape[0]], Xk[:noise.shape[0]], test=True)
        diagnostics["Loss"]  = loss_display.data.cpu().item()
        diagnostics["MMD-Full"] = mmd_full.data.cpu().item()
        diagnostics["MMD-Swap"] = mmd_swap.data.cpu().item()

        # Return dictionary of diagnostics
        return diagnostics

    def loss(self, X, Xk, test=False):
        """ Evaluates the loss function
        :param X: input data
        :param Xk: knockoffs of X
        :param test: evaluate the MMD, regardless the value of GAMMA
        :return loss: the value of the effective loss function
                loss_display: a copy of the loss variable that will be used for display
                mmd_full: discrepancy between (X',Xk') and (Xk'',X'')
                mmd_swap: discrepancy between (X',Xk') and (X'',Xk'')_swap(s)
        """

        # Divide the observations into two disjoint batches
        n = int(X.shape[0]/2)
        X1,Xk1 = X[:n], Xk[:n]
        X2,Xk2 = X[n:(2*n)], Xk[n:(2*n)]

        # Joint variables
        Z1 = torch.cat((X1,Xk1),1)
        Z2 = torch.cat((Xk2,X2),1)
        Z3 = torch.cat((X2,Xk2),1).clone()
        swap_inds = np.where(np.random.binomial(1,0.5,size=self.p))[0]
        Z3[:,swap_inds] = Xk2[:,swap_inds]
        Z3[:,swap_inds+self.p] = X2[:,swap_inds]

        # Compute the discrepancy between (X,Xk) and (Xk,X)
        mmd_full = 0.0
        # Compute the discrepancy between (X,Xk) and (X,Xk)_s
        mmd_swap = 0.0
        if(self.GAMMA>0 or test):
            # Evaluate the MMD by following section 4.3 in
            # Li et al. "Generative Moment Matching Networks". Link to
            # the manuscript -- https://arxiv.org/pdf/1502.02761.pdf
            mmd_full = self.matching_loss(Z1, Z2, self.matching_param)
            mmd_swap = self.matching_loss(Z1, Z3, self.matching_param)

        # Match first two moments
        loss_moments = 0.0
        if self.LAMBDA>0:
            # First moment
            D_mean = X.mean(0) - Xk.mean(0)
            loss_1m = D_mean.pow(2).sum()
            # Second moments
            loss_2m = covariance_diff_biased(X, Xk, self.SigmaHat, self.Mask, scale=self.Sigma_norm)
            # Combine moments
            loss_moments = loss_1m + loss_2m

        # Penalize correlations between variables and knockoffs
        loss_corr = 0.0
        if self.DELTA>0:
            # Center X and Xk
            mX  = X  - torch.mean(X,0,keepdim=True)
            mXk = Xk - torch.mean(Xk,0,keepdim=True)
            # Correlation between X and Xk
            eps = 1e-3
            scaleX  = mX.pow(2).mean(0,keepdim=True)
            scaleXk = mXk.pow(2).mean(0,keepdim=True)
            mXs  = mX / (eps+torch.sqrt(scaleX))
            mXks = mXk / (eps+torch.sqrt(scaleXk))
            corr_XXk = (mXs*mXks).mean(0)
            loss_corr = (corr_XXk-self.target_corr).pow(2).mean()

        # Combine the loss functions
        loss = self.GAMMA*mmd_full + self.GAMMA*mmd_swap + self.LAMBDA*loss_moments + self.DELTA*loss_corr
        loss_display = loss
        return loss, loss_display, mmd_full, mmd_swap


    def train(self, X_in, resume = False):
        """ Fit the machine to the training data
        :param X_in: input data
        :param resume: proceed the training by loading the last checkpoint
        """

        # Divide data into training/test set
        X = torch.from_numpy(X_in[self.test_size:]).float()
        if(self.test_size>0):
            X_test = torch.from_numpy(X_in[:self.test_size]).float()
        else:
            X_test = torch.zeros(0, self.p)

        # used to compute statistics and diagnostics
        self.SigmaHat = np.cov(X,rowvar=False)
        self.SigmaHat = torch.from_numpy(self.SigmaHat).float()
        self.Mask = torch.ones(self.p, self.p) - torch.eye(self.p)

        # allocate a matrix for the noise realization
        noise = torch.zeros(self.batch_size,self.dim_noise)
        noise_test = torch.zeros(X_test.shape[0],self.dim_noise)
        use_cuda = torch.cuda.is_available()

        if resume == True:  # load the last checkpoint
            self.load(self.checkpoint_name)
            self.net.train()
        else:  # start learning from scratch
            self.net.train()
            # Define the optimization method
            self.net_optim = optim.SGD(self.net.parameters(), lr = self.lr, momentum=0.9)
            # Define the scheduler
            self.net_sched = optim.lr_scheduler.MultiStepLR(self.net_optim, gamma=0.1,
                                                            milestones=self.lr_milestones)

        # bandwidth parameters of the Gaussian kernel
        self.matching_param = self.alphas

        # move data to GPU if available
        if use_cuda:
            self.SigmaHat = self.SigmaHat.cuda()
            self.Mask = self.Mask.cuda()
            self.net = self.net.cuda()
            X = X.cuda()
            X_test = X_test.cuda()
            noise = noise.cuda()
            noise_test = noise_test.cuda()
            self.target_corr = self.target_corr.cuda()

        Xk = 0*X
        self.Sigma_norm = self.SigmaHat.pow(2).sum()
        self.Sigma_norm_cross = (self.Mask*self.SigmaHat).pow(2).sum()

        # Store diagnostics
        diagnostics = pd.DataFrame()
        losses_test = []

        # main training loop
        for epoch in range(self.resume_epoch, self.epochs):
            # prepare for training phase
            self.net.train()
            # update the learning rate scheduler
            self.net_sched.step()
            # divide the data into batches
            batches = gen_batches(X.size(0), self.batch_size, self.epoch_length)

            losses = []
            losses_dist_swap = []
            losses_dist_full = []

            for batch in batches:
                # Extract data for this batch
                X_batch  = X[batch,:]

                self.net_optim.zero_grad()

                # Run the network
                Xk_batch = self.net(X_batch, self.noise_std*noise.normal_())

                # Compute the loss function
                loss, loss_display, mmd_full, mmd_swap = self.loss(X_batch, Xk_batch)

                # Compute the gradient
                loss.backward()

                # Take a gradient step
                self.net_optim.step()

                # Save history
                losses.append(loss_display.data.cpu().item())
                if self.GAMMA>0:
                    losses_dist_swap.append(mmd_swap.data.cpu().item())
                    losses_dist_full.append(mmd_full.data.cpu().item())

                # Save the knockoffs
                Xk[batch, :] = Xk_batch.data

            ##############################
            # Compute diagnostics
            ##############################

            # Prepare for testing phase
            self.net.eval()

            # Evaluate the diagnostics on the training data, the following
            # function recomputes the loss on the training data
            diagnostics_train = self.compute_diagnostics(X, Xk, noise, test=False)
            diagnostics_train["Loss"] = np.mean(losses)
            if(self.GAMMA>0 and self.GAMMA>0):
                diagnostics_train["MMD-Full"] = np.mean(losses_dist_full)
                diagnostics_train["MMD-Swap"] = np.mean(losses_dist_swap)
            diagnostics_train["Epoch"] = epoch
            diagnostics = pd.concat([diagnostics, pd.DataFrame([diagnostics_train])], ignore_index=True)

            # Evaluate the diagnostics on the test data if available
            if(self.test_size>0):
                Xk_test = self.net(X_test, self.noise_std*noise_test.normal_())
                diagnostics_test = self.compute_diagnostics(X_test, Xk_test, noise_test, test=True)
            else:
                diagnostics_test = {key:np.nan for key in diagnostics_train.keys()}
            diagnostics_test["Epoch"] = epoch
            diagnostics = pd.concat([diagnostics, pd.DataFrame([diagnostics_test])], ignore_index=True)

            # If the test loss is at a minimum, save the machine to
            # the location pointed by best_checkpoint_name
            losses_test.append(diagnostics_test["Loss"])
            if((self.test_size>0) and (diagnostics_test["Loss"] == np.min(losses_test)) and \
               (self.best_checkpoint_name is not None)):
                best_machine = True
                save_checkpoint({
                    'epochs': epoch+1,
                    'pars'  : self.pars,
                    'state_dict': self.net.state_dict(),
                    'optimizer' : self.net_optim.state_dict(),
                    'scheduler' : self.net_sched.state_dict(),
                }, self.best_checkpoint_name)
            else:
                best_machine = False

            ##############################
            # Print progress
            ##############################
            if(self.test_size>0):
                print("[%4d/%4d], Loss: (%.4f, %.4f)" %
                      (epoch + 1, self.epochs, diagnostics_train["Loss"], diagnostics_test["Loss"]), end=", ")
                print("MMD: (%.4f,%.4f)" %
                      (diagnostics_train["MMD-Full"]+diagnostics_train["MMD-Swap"], 
                       diagnostics_test["MMD-Full"]+diagnostics_test["MMD-Swap"]), end=", ")
                print("Cov: (%.3f,%.3f)" %
                      (diagnostics_train["Corr-Full"]+diagnostics_train["Corr-Swap"], 
                       diagnostics_test["Corr-Full"]+diagnostics_test["Corr-Swap"]), end=", ")
                print("Decorr: (%.3f,%.3f)" %
                      (diagnostics_train["Corr-Diag"], diagnostics_test["Corr-Diag"]), end="")
                if best_machine:
                    print(" *", end="")
            else:
                print("[%4d/%4d], Loss: %.4f" %
                      (epoch + 1, self.epochs, diagnostics_train["Loss"]), end=", ")
                print("MMD: %.4f" %
                      (diagnostics_train["MMD-Full"] + diagnostics_train["MMD-Swap"]), end=", ")
                print("Cov: %.3f" %
                      (diagnostics_train["Corr-Full"] + diagnostics_train["Corr-Swap"]), end=", ")
                print("Decorr: %.3f" %
                      (diagnostics_train["Corr-Diag"]), end="")
                
            print("")
            sys.stdout.flush()

            # Save diagnostics to logfile
            if self.logs_name is not None:
                diagnostics.to_csv(self.logs_name, sep=" ", index=False)

            # Save the current machine to location checkpoint_name
            if self.checkpoint_name is not None:
                save_checkpoint({
                    'epochs': epoch+1,
                    'pars'  : self.pars,
                    'state_dict': self.net.state_dict(),
                    'optimizer' : self.net_optim.state_dict(),
                    'scheduler' : self.net_sched.state_dict(),
                }, self.checkpoint_name)

    def load(self, checkpoint_name):
        """ Load a machine from a stored checkpoint
        :param checkpoint_name: checkpoint name of a trained machine
        """
        filename = checkpoint_name + "_checkpoint.pth.tar"

        flag = 1
        if os.path.isfile(filename):
            print("=> loading checkpoint '{}'".format(filename))
            sys.stdout.flush()
            try:
                checkpoint = torch.load(filename, map_location='cpu')
            except:
                print("error loading saved model, trying the previous version")
                sys.stdout.flush()
                flag = 0

            if flag == 0:
                try:
                    checkpoint = torch.load(filename + '_prev.pth.tar', map_location='cpu')
                    flag = 1
                except:
                    print("error loading prev model, starting from scratch")
                    sys.stdout.flush()
                    flag = 0
        else:
            print("=> no checkpoint found at '{}'".format(filename))
            sys.stdout.flush()
            flag = 0

        if flag == 1:
                self.net.load_state_dict(checkpoint['state_dict'])
                if torch.cuda.is_available():
                    self.net = self.net.cuda()

                self.net_optim = optim.SGD(self.net.parameters(), lr = self.lr, momentum=0.9)
                self.net_optim.load_state_dict(checkpoint['optimizer'])
                self.net_sched = optim.lr_scheduler.MultiStepLR(self.net_optim, gamma=0.1,
                                                                milestones=self.lr_milestones)
                self.resume_epoch = checkpoint['epochs']

                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(filename, checkpoint['epochs']))
                sys.stdout.flush()
        else:
            self.net.train()
            self.net_optim = optim.SGD(self.net.parameters(), lr = self.lr, momentum=0.9)
            self.net_sched = optim.lr_scheduler.MultiStepLR(self.net_optim, gamma=0.1,
                                                            milestones=self.lr_milestones)

            self.resume_epoch = 0

    def generate(self, X_in):
        """ Generate knockoff copies
        :param X_in: data samples
        :return Xk: knockoff copy per each sample in X
        """

        X = torch.from_numpy(X_in).float()
        self.net = self.net.cpu()
        self.net.eval()

        # Run the network in evaluation mode
        Xk = self.net(X, self.noise_std*torch.randn(X.size(0),self.dim_noise))
        Xk = Xk.data.cpu().numpy()

        return Xk

### KNN Knockoffs

In [4]:
def construct_knockoffs_knn(X: pd.DataFrame,
                            types: List[str],
                            topk_features: int = 10,
                            nbrs_per_sample: int = 5,
                            add_noise_scale: float = 1e-6,
                            random_state: int = 0) -> pd.DataFrame:
    np.random.seed(random_state)
    X_np = X.values.copy()
    n, p = X_np.shape
    X_tilde = np.zeros_like(X_np)

    # compute Pearson correlation (handle constant columns)
    corr = np.corrcoef(np.nan_to_num(X_np.T))
    corr = np.nan_to_num(corr)  # replace nan

    for j in range(p):
        col_j = X_np[:, j]
        # get absolute correlations with other features
        abs_corrs = np.abs(corr[j, :])
        # exclude self
        abs_corrs[j] = -np.inf
        # indices of top-k correlated features
        topk_idx = np.argsort(abs_corrs)[::-1][:topk_features]
        # weights for distances: proportional to correlation strength
        w = abs_corrs[topk_idx].astype(float)
        if np.sum(w) <= 0:
            # fallback: equal weights
            w = np.ones_like(w)
        w = w / np.max(w)  # normalize so max is 1

        # build the feature-subspace matrix for nearest neighbor distances
        # If topk_features == 0 (no other features), use only j itself (this degenerates)
        if len(topk_idx) == 0:
            # degenerate: use the feature itself (will make samples identical)
            subspace = col_j.reshape(-1, 1)
            weights = np.array([1.0])
        else:
            subspace = X_np[:, topk_idx]  # shape (n, k)
            weights = w  # shape (k,)

        # compute pairwise distances with weights: d_ij^2 = sum_k w_k * (x_ik - x_jk)^2
        # scale columns by sqrt(weights) then use Euclidean distance.
        sqrt_w = np.sqrt(weights)
        # handle columns alignment: if subspace is 1D, make it 2D
        if subspace.ndim == 1:
            subspace = subspace.reshape(-1, 1)
        scaled_subspace = subspace * sqrt_w 

        nbrs = NearestNeighbors(n_neighbors=min(n, nbrs_per_sample + 1), algorithm='auto').fit(scaled_subspace)
        distances, indices = nbrs.kneighbors(scaled_subspace, return_distance=True)  # includes self at pos 0

        # For each sample i, pick neighbor indices (excluding itself) and sample knockoff value
        for i in range(n):
            neighs = indices[i, :]
            # exclude self if present (distance zero)
            if neighs[0] == i:
                neighs = neighs[1: nbrs_per_sample + 1]
            else:
                neighs = neighs[:nbrs_per_sample]
            neighs = neighs[neighs != i]
            if len(neighs) == 0:
                if types[j] == "cont":
                    std_j = np.nanstd(col_j) if np.nanstd(col_j) > 0 else 1.0
                    X_tilde[i, j] = col_j[i] + np.random.normal(scale=add_noise_scale * std_j)
                else:  # binary
                    X_tilde[i, j] = col_j[i]
                continue

            neigh_vals = col_j[neighs]
            # if target is binary -> sample Bernoulli from neighbor frequency
            if types[j] == "bin" or is_binary_col(col_j):
                # neighbors may have NaNs: drop
                neigh_vals_clean = neigh_vals[~np.isnan(neigh_vals)]
                if len(neigh_vals_clean) == 0:
                    p_hat = 0.5
                else:
                    # p_hat = weighted mean by distance (closer neighbors contribute more)
                    # weights by inverse distance (avoid div by zero)
                    d = distances[i, 1: 1 + len(neigh_vals_clean)]
                    d = np.maximum(d, 1e-8)
                    invd = 1 / d
                    p_hat = np.average(neigh_vals_clean, weights=invd)
                X_tilde[i, j] = np.random.binomial(1, p_hat)
            else:
                # continuous: sample a neighbor's value with probability proportional to 1/distance
                d = distances[i, :len(neigh_vals)]
                d = np.maximum(d, 1e-8)  # avoid zero
                probs = 1.0 / d
                probs = probs / probs.sum()
                chosen_idx = np.random.choice(len(neigh_vals), p=probs)
                chosen_val = neigh_vals[chosen_idx]
                # optionally add small Gaussian jitter proportional to std of j
                std_j = np.nanstd(col_j) if np.nanstd(col_j) > 0 else 1.0
                X_tilde[i, j] = chosen_val + np.random.normal(scale=add_noise_scale * std_j)

    X_tilde_df = pd.DataFrame(X_tilde, columns=[f"{c}_knock" for c in X.columns], index=X.index)
    return X_tilde_df

### Gaussian Knockoff

In [5]:
def gaussian_knockoff(X):
    mean = np.mean(X, axis=0)
    cov = np.cov(X, rowvar=False)
    X_tilde = np.random.multivariate_normal(mean, cov, size=X.shape[0])
    return X_tilde

### Defining Binary vs. Continuous Columns

In [6]:
def is_binary_col(col: np.ndarray, tol=1e-8):
    uniques = np.unique(col[~np.isnan(col)])
    return set(uniques).issubset({0, 1}) and len(uniques) <= 2

## Permutation Functions for Obtaining P Values

### Method with Knockoffs

In [18]:
def permutation_pvalues(
    X,
    X_knockoffs,
    y,
    task="classification",
    ntree=500,
    mtry=None,
    B=50,
    random_state=123,
):
    np.random.seed(random_state)
    X = pd.DataFrame(X).reset_index(drop=True)
    X_knockoffs = pd.DataFrame(X_knockoffs).reset_index(drop=True)

    n, p = X.shape

    # Combine X and knockoffs
    combined_data = pd.concat([X, X_knockoffs], axis=1)
    combined_data.columns = (
        [f"{col}_orig" for col in X.columns]
        + [f"{col}_knockoff" for col in X.columns]
    )

    # Choose model type
    Model = RandomForestClassifier if task == "classification" else RandomForestRegressor
    model_kwargs = dict(
        n_estimators=ntree,
        random_state=random_state,
        n_jobs=-1,
        max_features=mtry if mtry is not None else "sqrt",
    )

    rf = Model(**model_kwargs)
    rf.fit(combined_data, y)
    imp_scores = rf.feature_importances_

    orig_imp = imp_scores[:p]
    knock_imp = imp_scores[p:]
    W_obs = orig_imp - knock_imp

    W_perm = np.zeros((p, B))

    for b in range(B):
        y_perm = np.random.permutation(y)
        rf_perm = Model(**model_kwargs)
        rf_perm.fit(combined_data, y_perm)
        imp_perm = rf_perm.feature_importances_

        W_perm[:, b] = imp_perm[:p] - imp_perm[p:]

        if (b + 1) % max(1, B // 10) == 0:
            print(f"Permutation {b + 1}/{B} done")

    pvals = np.zeros(p)
    for j in range(p):
        w_obs = W_obs[j]
        w_null = W_perm[j, :]
        if w_obs >= 0:
            pvals[j] = (np.sum(w_null >= w_obs) + 1) / (B + 1)
        else:
            pvals[j] = (np.sum(w_null <= w_obs) + 1) / (B + 1)


    return pd.DataFrame({
        "variable": X.columns,
        "importance_diff": W_obs,
        "p_value": pvals
    })


### Method without Knockoffs

In [8]:
def permutation_mda_pvalues(
    X,
    y,
    task="classification",
    ntree=500,
    mtry=None,
    B=50,
    random_state=123
):
    np.random.seed(random_state)
    X = pd.DataFrame(X).reset_index(drop=True)
    n, p = X.shape

    Model = RandomForestClassifier if task=="classification" else RandomForestRegressor
    model_kwargs = dict(
        n_estimators=ntree,
        max_features=mtry if mtry is not None else "sqrt",
        random_state=random_state,
        n_jobs=-1
    )

    rf = Model(**model_kwargs)
    rf.fit(X, y)

    if task == "classification":
        from sklearn.metrics import accuracy_score
        baseline_score = accuracy_score(y, rf.predict(X))
        score_fn = accuracy_score
    else:
        from sklearn.metrics import r2_score
        baseline_score = r2_score(y, rf.predict(X))
        score_fn = r2_score

    def mda_feature(j, Xdata, y_true, baseline):
        X_permuted = Xdata.copy()
        X_permuted.iloc[:, j] = shuffle(X_permuted.iloc[:, j], random_state=random_state)
        score = score_fn(y_true, rf.predict(X_permuted))
        return baseline - score

    # observed MDA importances
    obs_imp = np.array([mda_feature(j, X, y, baseline_score) for j in range(p)])

    # permutation nulls
    null_imp = np.zeros((p, B))
    for b in range(B):
        y_perm = np.random.permutation(y)
        rf.fit(X, y_perm)
        baseline_perm = score_fn(y_perm, rf.predict(X))
        null_imp[:, b] = np.array([mda_feature(j, X, y_perm, baseline_perm) for j in range(p)])
        if (b+1) % 5 == 0 or b == 0:
            print(f"Permutation {b+1}/{B} done")

    # empirical p-values

    pvals = np.zeros(p)

    for j in range(p):
        if obs_imp[j] >= 0:
            pvals[j] = (np.sum(null_imp[j, :] >= obs_imp[j]) + 1) / (B + 1)
        else:
            pvals[j] = (np.sum(null_imp[j, :] <= obs_imp[j]) + 1) / (B + 1)

    return pd.DataFrame({
        "variable": X.columns,
        "mda_importance": obs_imp,
        "p_value": pvals
    })


## Simulation 1

### Generate Data

In [None]:
import numpy as np
import pandas as pd
from scipy.special import expit  # sigmoid

np.random.seed(47965636)
n = 1000  # number of samples

# --- 3 non-normal variables ---
X1 = np.random.uniform(-2, 2, n)       # uniform
X2 = np.random.exponential(1, n)       # exponential
X3 = np.random.binomial(1, 0.4, n)     # binary

# --- 5 normal variables ---
X4 = np.random.normal(0, 1, n)
X5 = np.random.normal(0, 1, n)
X6 = np.clip(np.random.poisson(4, n), 0, 5)
X6 = (X6 - np.mean(X6)) / np.std(X6)
X7 = np.random.normal(0, 1, n)
X8 = np.random.normal(0, 1, n)

# --- heavy-tailed variables ---
X9  = np.random.gamma(shape=2.0, scale=1.5, size=n)      # right-skewed (Gamma)
X10 = np.random.lognormal(mean=0.0, sigma=1.0, size=n)   # strong right skew (Log-normal)
X11 = np.random.pareto(a=2.5, size=n) + 1                # heavy tail (Pareto)

# --- Combine into a matrix ---
X = np.column_stack([X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11])

# --- Define base effects (linear ones) ---
beta = np.array([0.8, 0.8, 0.0, 0.8, 0.8, 0.8, 0.0, 0.0, 0.8, 0.8, 0.0]) # 1,2,4,5,6,9,10 all have effects
#beta = np.array([0.3, 0.3, 0.0, 0.3, 0.3, 0.3, 0.0, 0.0, 0.3, 0.3, 0.0]) # 1,2,4,5,6,9,10 all have effects

# --- Add nonlinear transformations ---
# X1 has sinusoidal effect
# X9 has log effect (heavy-tailed variable with diminishing effect)
lin_pred = (
    1.5 * np.sin(np.pi * X1 / 2) +           # nonlinear in X1
    0.8 * np.log1p(X9) +                     # nonlinear in X9
    X @ beta                                 # rest linear
)

# --- Generate binary outcome via logistic link ---
p = expit(lin_pred)
Y = np.random.binomial(1, p)

# --- Put into DataFrame ---
cols = [f"X{i+1}" for i in range(11)]
df = pd.DataFrame(X, columns=cols)
df["Y"] = Y

print(df.head())

# --- Determine variable types automatically ---
def is_binary_col(col):
    return np.isin(np.unique(col), [0, 1]).all()

types = ["bin" if is_binary_col(df[c].values) else "cont" for c in cols]

# Expected logistic slope
avg_slope = np.mean(p * (1 - p))

# Compute variance contribution of each term
var_contrib = np.zeros(X.shape[1])

for j in range(X.shape[1]):
    if j == 0:
        term = 1.5 * np.sin(np.pi * X1 / 2)
    elif j == 8:
        term = 0.8 * np.log1p(X9)
    else:
        term = X[:, j] * beta[j]
    var_contrib[j] = np.var(term) * avg_slope

df_contrib = pd.DataFrame({
    "variable": [f"X{i+1}" for i in range(11)],
    "var_contrib": var_contrib
})

df_contrib = df_contrib.sort_values("var_contrib", ascending=False).reset_index(drop=True)
print(df_contrib)
X = pd.DataFrame(X, columns=cols)


### KNN Knockoff P Values

In [None]:
knockoff_knn = construct_knockoffs_knn(X, types,nbrs_per_sample=3)

In [None]:
perm_knn = permutation_pvalues(X, knockoff_knn, Y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_knn['p_value'], method='fdr_bh')

# add the adjusted p-values to your DF
perm_knn['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_knn_sorted = perm_knn.sort_values('p_value_bh').reset_index(drop=True)

perm_knn_sorted.head(20)

### Gaussian Knockoff P Values

In [None]:
knockoff_gauss = gaussian_knockoff(X)

In [None]:
perm_gauss = permutation_pvalues(X, knockoff_gauss, Y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_gauss['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_gauss['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_gauss_sorted = perm_gauss.sort_values('p_value_bh').reset_index(drop=True)

perm_gauss_sorted.head(20)

### Deep Knockoff P Values

In [None]:
p = X.shape[1]
qt = QuantileTransformer(output_distribution='normal')
X_scaled = qt.fit_transform(X)
X_scaled = np.array(X_scaled)
SigmaHat = np.cov(X_scaled, rowvar=False)
second_order = GaussianKnockoffs(SigmaHat, mu=np.mean(X_scaled,0), method="sdp")
corr_g = (np.diag(SigmaHat) - np.diag(second_order.Ds)) / np.diag(SigmaHat)

# Set the parameters for training deep knockoffs
pars = dict()
# Number of epochs
pars['epochs'] = 100
# Number of iterations over the full data per epoch
pars['epoch_length'] = 50
# Data type, either "continuous" or "binary"
pars['family'] = "continuous"
# Dimensions of the data
pars['p'] = p
# Size of the test set
pars['test_size']  = int(0.1*n)
# Batch size
pars['batch_size'] = int(0.45*n)
# Learning rate
pars['lr'] = 0.003
# When to decrease learning rate (unused when equal to number of epochs)
pars['lr_milestones'] = [pars['epochs']]
# Width of the network (number of layers is fixed to 6)
pars['dim_h'] = int(10*p)
# Penalty for the MMD distance
pars['GAMMA'] = 0.5
# Penalty encouraging second-order knockoffs
pars['LAMBDA'] = 0.5
# Decorrelation penalty hyperparameter
pars['DELTA'] = 0.5
# Target pairwise correlations between variables and knockoffs
pars['target_corr'] = corr_g
# Kernel widths for the MMD measure (uniform weights)
pars['alphas'] = [1.,2.,4.,8.,16.,32.,64.,128.]

# Initialize the machine
machine = KnockoffMachine(pars)

In [None]:
print("Fitting the knockoff machine...")
machine.train(X_scaled)
X_deepknockoff = machine.generate(X_scaled)
print("Size of the deep knockoff dataset: %d x %d." %(X_deepknockoff.shape))

In [None]:
perm_deep = permutation_pvalues(X_scaled, X_deepknockoff, Y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_deep['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_deep['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_deep_sorted = perm_deep.sort_values('p_value_bh').reset_index(drop=True)

perm_deep_sorted.head(20)

### Permutation Only - No Knockoff

In [None]:
perm_only = permutation_mda_pvalues(X,Y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_only['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_only['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_only_sorted = perm_only.sort_values('p_value_bh').reset_index(drop=True)

perm_only_sorted.head(20)

## Simulation 2

### Generate Data

In [None]:
import numpy as np
import pandas as pd
from scipy.special import expit
from scipy.stats import norm
from scipy.stats import poisson


np.random.seed(47965636)
n = 1000
p = 11 

# --- correlation strengths ---
rho_block_strong = 0.3   # within strong block (X4–X6 + X9–X10)
rho_pair = 0.15          # X1,X2 pair
rho_cross = 0.05         # weak across blocks

# initialize covariance matrix
cov = np.full((p, p), rho_cross)
np.fill_diagonal(cov, 1.0)

# define correlated groups
pair_idx = [0, 1]                # X1, X2
strong_block_idx = [3, 4, 5, 8, 9]  # X4,X5,X6,X9,X10 (strong block)
independent_idx = [2, 6, 7, 10]     # weakly correlated (X3,X7,X8,X11)

# apply pair correlation
for i in pair_idx:
    for j in pair_idx:
        if i != j:
            cov[i, j] = rho_pair

# apply strong block correlation
for i in strong_block_idx:
    for j in strong_block_idx:
        if i != j:
            cov[i, j] = rho_block_strong

# symmetry and PSD correction
cov = (cov + cov.T) / 2.0
eigvals = np.linalg.eigvalsh(cov)
if eigvals.min() < -1e-8:
    jitter = abs(eigvals.min()) + 1e-8
    cov += np.eye(p) * jitter

# --- latent correlated normals ---
mean = np.zeros(p)
Z = np.random.multivariate_normal(mean, cov, size=n)

# --- transform to desired marginals ---
X1 = 4 * (Z[:, 0] - Z[:, 0].min()) / (Z[:, 0].max() - Z[:, 0].min()) - 2   # uniform(-2,2)
X2 = np.exp(Z[:, 1])            # exponential-like
X3 = (Z[:, 2] > 0).astype(int)  # binary
X4 = Z[:, 3]                    # normal
X5 = Z[:, 4]
U6 = norm.cdf(Z[:, 5])                       # map to uniform(0,1)
X6 = poisson.ppf(U6, mu=4).clip(0, 5)        # Poisson quantile
X6 = (X6 - X6.mean()) / X6.std()             # normalize to mean 0, sd 1
X7 = Z[:, 6]
X8 = Z[:, 7]
X9  = np.random.gamma(shape=2.0, scale=np.exp(Z[:, 8])/3, size=n)      # gamma, correlated through Z
X10 = np.exp(Z[:, 9])                                                # log-normal, correlated through Z
X11 = (np.random.pareto(a=2.5, size=n) + 1) * (1 + 0.2 * Z[:, 10])    # weakly correlated Pareto

# combine into matrix
X = np.column_stack([X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11])

# --- true effects (X9 & X10 associated with Y) ---
beta = np.array([0.8, 0.8, 0.0, 0.8, 0.8, 0.8, 0.0, 0.0, 0.8, 0.8, 0.0]) #3, 7,8,11 no effect

# X1 has sinusoidal effect
# X9 has log effect (heavy-tailed variable with diminishing effect)
lin_pred = (
    1.5 * np.sin(np.pi * X1 / 2) +           # nonlinear in X1
    0.8 * np.log1p(X9) +                     # nonlinear in X9
    X @ beta                                 # rest linear
)
p_prob = expit(lin_pred)
Y = np.random.binomial(1, p_prob)

# --- wrap up in DataFrame ---
cols = [f"X{i+1}" for i in range(X.shape[1])]
df = pd.DataFrame(X, columns=cols)
df["Y"] = Y

# identify binary vs continuous columns
def is_binary_col(col):
    u = np.unique(col)
    return set(u).issubset({0, 1})

types = ["bin" if is_binary_col(df[c]) else "cont" for c in cols]

# --- summaries ---
print("Correlation matrix (strong block):")
print(df[["X4", "X5", "X6", "X9", "X10"]].corr().round(2))

print("\nSkewed variable summaries (X9–X11):")
print(df[["X9", "X10", "X11"]].describe().round(2))

print("\nFirst 5 rows:")
print(df.head())

# Expected logistic slope
avg_slope = np.mean(p * (1 - p))

# Compute variance contribution of each term
var_contrib = np.zeros(X.shape[1])

for j in range(X.shape[1]):
    if j == 0:
        term = 1.5 * np.sin(np.pi * X1 / 2)
    elif j == 8:
        term = 0.8 * np.log1p(X9)
    else:
        term = X[:, j] * beta[j]
    var_contrib[j] = np.var(term) * avg_slope

df_contrib = pd.DataFrame({
    "variable": [f"X{i+1}" for i in range(11)],
    "var_contrib": var_contrib
})

df_contrib = df_contrib.sort_values("var_contrib", ascending=True).reset_index(drop=True)
print(df_contrib)

X = pd.DataFrame(X, columns=cols)


### KNN Knockoff P Values

In [None]:
knockoff_knn = construct_knockoffs_knn(X, types, nbrs_per_sample=3)

In [None]:
perm_knn = permutation_pvalues(X, knockoff_knn, Y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_knn['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_knn['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_knn_sorted = perm_knn.sort_values('p_value_bh').reset_index(drop=True)

perm_knn_sorted.head(20)

### Gaussian Knockoff P Values

In [None]:
knockoff_gauss = gaussian_knockoff(X)

In [None]:
perm_gauss = permutation_pvalues(X, knockoff_gauss, Y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_gauss['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_gauss['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_gauss_sorted = perm_gauss.sort_values('p_value_bh').reset_index(drop=True)

perm_gauss_sorted.head(20)

### Deep Knockoff P Values

In [None]:
p = X.shape[1]
qt = QuantileTransformer(output_distribution='normal')
X_scaled = qt.fit_transform(X)
X_scaled = np.array(X_scaled)
SigmaHat = np.cov(X_scaled, rowvar=False)
second_order = GaussianKnockoffs(SigmaHat, mu=np.mean(X_scaled,0), method="sdp")
corr_g = (np.diag(SigmaHat) - np.diag(second_order.Ds)) / np.diag(SigmaHat)

# Set the parameters for training deep knockoffs
pars = dict()
# Number of epochs
pars['epochs'] = 100
# Number of iterations over the full data per epoch
pars['epoch_length'] = 50
# Data type, either "continuous" or "binary"
pars['family'] = "continuous"
# Dimensions of the data
pars['p'] = p
# Size of the test set
pars['test_size']  = int(0.1*n)
# Batch size
pars['batch_size'] = int(0.45*n)
# Learning rate
pars['lr'] = 0.003
# When to decrease learning rate (unused when equal to number of epochs)
pars['lr_milestones'] = [pars['epochs']]
# Width of the network (number of layers is fixed to 6)
pars['dim_h'] = int(10*p)
# Penalty for the MMD distance
pars['GAMMA'] = 0.5
# Penalty encouraging second-order knockoffs
pars['LAMBDA'] = 0.5
# Decorrelation penalty hyperparameter
pars['DELTA'] = 0.5
# Target pairwise correlations between variables and knockoffs
pars['target_corr'] = corr_g
# Kernel widths for the MMD measure (uniform weights)
pars['alphas'] = [1.,2.,4.,8.,16.,32.,64.,128.]

# Initialize the machine
machine = KnockoffMachine(pars)

In [None]:
print("Fitting the knockoff machine...")
machine.train(X_scaled)
X_deepknockoff = machine.generate(X_scaled)
print("Size of the deep knockoff dataset: %d x %d." %(X_deepknockoff.shape))

In [None]:
perm_deep = permutation_pvalues(X_scaled, X_deepknockoff, Y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_deep['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_deep['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_deep_sorted = perm_deep.sort_values('p_value_bh').reset_index(drop=True)

perm_deep_sorted.head(20)

### Permutation Only - No Knockoffs

In [None]:
perm_only = permutation_mda_pvalues(X,Y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_only['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_only['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_only_sorted = perm_only.sort_values('p_value_bh').reset_index(drop=True)

perm_only_sorted.head(20)

## Concussion Data Application

### Load Data

In [9]:
# Read the CSV
conc_data = pd.read_csv("ContexData_TimeToClinicLessThan2Weeks.csv")

# Function to detect binary
def is_binary_col(col: np.ndarray, tol=1e-8):
    uniques = np.unique(col[~np.isnan(col)])
    return set(uniques).issubset({0, 1}) and len(uniques) <= 2

# Make a copy
conc_data_imputed = conc_data.copy()
conc_data_imputed = conc_data_imputed.iloc[:,2:]

imputer = IterativeImputer(random_state=0, max_iter=10, sample_posterior=True)

# Fit and transform the data
conc_data_imputed[:] = imputer.fit_transform(conc_data_imputed)

# Check missingness
print(conc_data_imputed.isna().sum())

record_id                     0
age                           0
sex                           0
race                          0
ethnicity                     0
                             ..
gad_7_7                       0
gad7_total_score              0
gad7_difficult_to_function    0
psqi                          0
time_since_injury             0
Length: 77, dtype: int64


In [10]:
X = conc_data_imputed
X = X.loc[:, X.nunique() > 1]
# Automatically detect types
types = []
for col_name in X.columns:
    col = X[col_name].values
    if is_binary_col(col):
        types.append("bin")
    else:
        types.append("cont")
y = (conc_data.iloc[:, 1] <= 3).astype(int)

In [12]:
cols=X.columns

### KNN Knockoff P Values

In [15]:
cont_cols = [c for c, t in zip(X.columns, types) if t == "cont"]
bin_cols  = [c for c, t in zip(X.columns, types) if t == "bin"]

X_scaled = X.copy()

qt = QuantileTransformer(output_distribution='normal', random_state=0)
if cont_cols:  # avoid empty slice
    X_scaled[cont_cols] = qt.fit_transform(X_scaled[cont_cols])

# Ensure binary columns remain as 0/1 integers
X_scaled[bin_cols] = X_scaled[bin_cols].round().astype(int)
X_scaled = pd.DataFrame(X_scaled, columns=cols)
knockoff_knn = construct_knockoffs_knn(X_scaled, types)

In [16]:
perm_knn = permutation_pvalues(X, knockoff_knn, y, B=300)

NameError: name 'permutation_pvalues' is not defined

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_knn['p_value'], method='fdr_bh')

# add the adjusted p-values DF
perm_knn['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_knn_sorted = perm_knn.sort_values('p_value_bh').reset_index(drop=True)

perm_knn_sorted.head(20)

### Gaussian Knockoff P Values

In [None]:
knockoff_gauss = gaussian_knockoff(X)

In [None]:
perm_gauss = permutation_pvalues(X, knockoff_gauss, y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_gauss['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_gauss['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_gauss_sorted = perm_gauss.sort_values('p_value_bh').reset_index(drop=True)

perm_gauss_sorted.head(20)

### Deep Knockoff P Values

In [None]:
p = X.shape[1]

qt = QuantileTransformer(output_distribution='normal')
X_scaled = qt.fit_transform(X)
X_scaled = np.array(X_scaled)
SigmaHat = np.cov(X_scaled, rowvar=False)
SigmaHat += 1e-1 * np.eye(SigmaHat.shape[0])  # a bit stronger ridge for stability
def make_posdef(Sigma, eps=1e-3):
    # Symmetrize
    Sigma = (Sigma + Sigma.T) / 2
    # Eigen-decompose
    eigvals, eigvecs = np.linalg.eigh(Sigma)
    # Clip eigenvalues to be at least eps
    eigvals_clipped = np.clip(eigvals, eps, None)
    return eigvecs @ np.diag(eigvals_clipped) @ eigvecs.T

SigmaHat = make_posdef(SigmaHat, eps=1e-3)
second_order = GaussianKnockoffs(SigmaHat, mu=np.mean(X_scaled,0), method="equi")
# Compute correlation targets safely
corr_g = (np.diag(SigmaHat) - np.diag(second_order.Ds)) / (np.diag(SigmaHat) + 1e-8)
corr_g = np.clip(corr_g, 0.1, 0.4)  # a little easier target
corr_g += np.random.uniform(-0.05, 0.05, size=corr_g.shape)
# Set the parameters for training deep knockoffs
pars = dict()
# Number of epochs
pars['epochs'] = 100
# Number of iterations over the full data per epoch
pars['epoch_length'] = 50
# Data type, either "continuous" or "binary"
pars['family'] = "continuous"
# Dimensions of the data
pars['p'] = p
# Size of the test set
pars['test_size']  = int(0.1*n)
# Batch size
pars['batch_size'] = int(0.3*n)
# Learning rate
pars['lr'] = 0.003
# When to decrease learning rate (unused when equal to number of epochs)
pars['lr_milestones'] = [pars['epochs']]
# Width of the network (number of layers is fixed to 6)
pars['dim_h'] = int(3*p)
# Penalty for the MMD distance
pars['GAMMA'] = 0.5
# Penalty encouraging second-order knockoffs
pars['LAMBDA'] = 1
# Decorrelation penalty hyperparameter
pars['DELTA'] = 1
# Target pairwise correlations between variables and knockoffs
pars['target_corr'] = corr_g
# Kernel widths for the MMD measure (uniform weights)
pars['alphas'] = [1.,2.,4.,8.,16.,32.,64.,128.]

from sklearn.preprocessing import QuantileTransformer, StandardScaler


# Initialize the machine
machine = KnockoffMachine(pars)

In [None]:
print("corr_g min:", np.min(corr_g))
print("corr_g max:", np.max(corr_g))
print("Any NaNs:", np.any(~np.isfinite(corr_g)))

In [None]:
print("Fitting the knockoff machine...")
machine.train(X_scaled)
X_deepknockoff = machine.generate(X_scaled)
print("Size of the deep knockoff dataset: %d x %d." %(X_deepknockoff.shape))

In [None]:
X_scaled = pd.DataFrame(X_scaled, columns=cols)
X_deepknockoff = pd.DataFrame(X_deepknockoff, columns=cols)

In [None]:
perm_deep = permutation_pvalues(X, X_deepknockoff, y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_deep['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_deep['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_deep_sorted = perm_deep.sort_values('p_value_bh').reset_index(drop=True)
#perm_deep_sorted['variable_name'] = perm_deep_sorted['variable'].apply(lambda x: cols[x])

perm_deep_sorted.head(20)

### Permutation Only - No Knockoffs

In [None]:
perm_only = permutation_mda_pvalues(X,y, B=300)

In [None]:
reject, pvals_corrected, _, _ = multipletests(perm_only['p_value'], method='fdr_bh')

# add the adjusted p-values to DF
perm_only['p_value_bh'] = pvals_corrected

# sort by the adjusted p-values
perm_only_sorted = perm_only.sort_values('p_value_bh').reset_index(drop=True)

perm_only_sorted.head(20)

## Sources

Sesia, M. (n.d.). DeepKnockoffs [Python package]. GitHub. https://github.com/msesia/deepknockoffs/tree/master/DeepKnockoffs

Romano, Y., Sesia, M., & Candès, E. (2020). Deep knockoffs. Journal of the American Statistical Association, 115(532), 1861-1872. https://doi.org/10.1080/01621459.2019.1660174