# Standalone Notebook for SynFS

This notebook is an example notebook with code copied directly rather than imported

In this notebook, we demonstrate SynFS on the Syn2 experiment from the paper.

In [1]:
import numpy as np 
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score
import torch 
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
import math
from functools import reduce


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
config = {
      "views_dims": [250, 250],
      "hidden_dims": [32, 32], 
      "s_lam": 0.1,
      "n_lam": 1.07,
      "alpha": 0.25,
      "learning_rate": 0.001, 
      "s_learning_rate":0.001,
      "batch_size": 250, 
      "weight_decay": 1e-4,
      "epochs": 90,
      "threshold": 0.7}

# DATA

## Generating Synthetic Multi-view Data

In [4]:
def generate_views_gt(views):
    """
    Ground truth informative featues used in generating the data 
    """
    
    gt_views = [np.zeros((view.shape[1])) for view in views]
    
    # Syn2 GT informative features 
    gt_views[0][0] = 1
    gt_views[0][2] = 1
    gt_views[1][1] = 1
    gt_views[1][3] = 1
    
    syn_views = [np.zeros((view.shape[1])) for view in views]
    syn_views[0][0] = 1
    syn_views[1][1] = 1
    return gt_views, syn_views 
    
def generate_multi_dataset(n, dims, seed=0):
    """
    Generate mutli-view dataset 
    """
    np.random.seed(seed)
    # x generation
    views = [np.random.randn(n, dim) for dim in dims]
    # y generation rule 
    y = np.zeros((n, 2))
    logit = np.exp(views[0][:,0]*views[1][:,1] + views[0][:,2] + views[1][:,3])
    # Compute P(Y=0|X)
    prob_0 = np.reshape((logit / (1+logit)), [n, 1])
    # Sampling process
    y[:, 0] = np.reshape(np.random.binomial(1, prob_0), [n,])
    y[:,1] = 1-y[:, 0]
    print("validate:",np.unique(y[:,1], return_counts=True))
    y = y[:,1]
    a_gt, syn_gt = generate_views_gt(views)
    return views, y, (a_gt, syn_gt)


In [5]:
# Generate data with ground truth
views, y, (a_gt, syn_gt) = generate_multi_dataset(20000, [250, 250])

data = (views, y)
ns_gt = [ai - syn for ai, syn in zip(a_gt, syn_gt)]
ns_gt = np.where(np.concatenate(ns_gt))[0]
s_gt = np.where(np.concatenate(syn_gt))[0]
ground_truth_features = [ns_gt, s_gt]

validate: (array([0., 1.]), array([ 9876, 10124]))


## Data Split

In [6]:
#
val_size, test_size, seed = 0.2, 0.2, 0
n= len(data[1])

X_set = data[0]  # Tuple (v1, v2, ... ,v)
y = data[1]

# Generate the indices for the train/test 
# split 64/16/20 train/ val/ test
train_indices, test_indices = train_test_split(np.arange(n), test_size=test_size, random_state=seed)
train_indices, val_indices = train_test_split(train_indices, test_size=val_size, random_state=seed)

# Use these indices to split each view
tr_X_set = [X[train_indices] for X in X_set]
va_X_set = [X[val_indices] for X in X_set]
te_X_set = [X[test_indices] for X in X_set]

tr_y, va_y, te_y = y[train_indices], y[val_indices], y[test_indices]

In [7]:

class SimpleDataset(Dataset) :
    def __init__(self, data_set, y, device) :
        self.data_set = [torch.tensor(data, dtype=torch.float32).to(device) for data in data_set]
        self.y = torch.tensor(y).squeeze().long().to(device)
        self.device = device
        
    def __len__(self) :
        return len(self.data_set[0])
    
    def __getitem__(self, i) :
        xs = [data[i] for data in self.data_set]
        y = self.y[i]
        return (xs, y)

In [8]:
train_data = SimpleDataset(tr_X_set, tr_y, device=device)
trainloader = DataLoader(train_data, batch_size=config['batch_size'], drop_last=True)

# Metric

In [9]:
def tpr_fdr(true_groups, predicted_groups):
    # True positive rate and false discovery rate.

    if len(true_groups) == 0:  # Ground truth not known.
        return -1, -1

    if len(predicted_groups) == 0 or all(len(sg) == 0 for sg in predicted_groups):
        return 0.0, 0.0

    predicted_features = np.unique(reduce(np.union1d, predicted_groups))
    true_features = np.unique(reduce(np.union1d, true_groups))

    overlap = np.intersect1d(predicted_features, true_features).size
    tpr = 100 * overlap / len(true_features)
    fdr = (
        100 * (len(predicted_features) - overlap) / len(predicted_features)
    )  # If len(predicted_features) != 0 else 0.0.
    return tpr, fdr


def strict_jaccard(true_groups, predicted_groups):
    # return jacard for true groups (ns, s), predicted_groups
    # predicted group will be re-sorted to ns, s (if element==1 -> ns) folloiwng comfps def
    if len(true_groups) == 0:  # i.e. we don't know the ground truth.
        return -1, len(true_groups), len(predicted_groups)

    if len(predicted_groups) == 0 or all(len(sg) == 0 for sg in predicted_groups): # we didn't find anything
        return 0, len(true_groups), len(predicted_groups)
    
    ns_true, s_true = true_groups  # non-synergic, synergic
    ns_pred, s_pred = predicted_groups

    if len(ns_pred) > 0: 
        ns_jac = np.intersect1d(ns_true, ns_pred).size / np.union1d(ns_true, ns_pred).size
    elif len(ns_pred) == len(ns_true):  # no ground truth 
        ns_jac = 1
    else:
        ns_jac = 0
        
    if len(s_pred) > 0: 
        s_jac = np.intersect1d(s_true, s_pred).size / np.union1d(s_true, s_pred).size
    elif len(s_pred) == len(s_true):  # no ground truth 
        ns_jac = 1
    else:
        s_jac =0
    return (ns_jac + s_jac) / 2, len(true_groups), len(predicted_groups)

def standard_metrics(y_train, y_test, logits, verbose=False):
    if len(y_test) != len(logits):
        y_test = y_test[:len(logits)]
    if isinstance(y_test, np.object_) :
        y_test = np.array([y.numpy() for y in y_test])
    
    def np_softmax(target, all):
        softmax = np.exp(target) / np.sum(np.exp(all) ,axis=1)
        return softmax
        
    threshold = np.sum(y_train[y_train==1])/ len(y_train)
    threshold = threshold
    y_prob = np_softmax(logits[:,1], logits)
    y_pred = np.where(y_prob > threshold, 1, 0)
    
    auroc = roc_auc_score(y_test, logits[:,1])
    auprc = average_precision_score(y_test, logits[:,1])
    accuracy = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred)
    if verbose:
        print(f"auroc | {auroc.item():.3f}, auprc | {auprc.item():.3f}, accuracy | {accuracy.item():.3f}, f1 | {f1.item():.3f}")
    return auroc, auprc, accuracy, f1 

## Modules

In [10]:
class Selector(nn.Module):
    """Stochastic gate to discover selected features 
    
    Args:
        input_dim : dimension of input features
        sigma : constant for gaussian distribution
        device : device for the tensors to be created 
    """
    def __init__(self, input_dim, sigma=0.5, mean=0.5) -> None:
        super(Selector, self).__init__()
        self.mean = mean # 0.5
        self.mu = 0.01*torch.randn(input_dim,)
        self.mu = torch.nn.Parameter(self.mu, requires_grad=True)
        self.sigma = sigma
    
    def forward(self, prev_v, X_mean) -> None: 

        self.noise = torch.randn(prev_v.size()).to(self.mu.device)
        z = self.mu + self.sigma*self.noise.normal_()*self.training # noise normal_ ~N(0,1)
        self.z = z  # save z for the same noise
        stochastic_gate = self.hard_sigmoid(self.z)
        new_v = prev_v*stochastic_gate + X_mean*(1-stochastic_gate)
        return new_v
    
    def hard_sigmoid(self, v):
        return torch.clamp(v + self.mean, 0.0, 1.0)
    
    def regularizer(self, v):
        #guassian CDF
        return 0.5*(1 + torch.erf(v/math.sqrt(2)))
    
    def get_gates(self, mode='prob') :
        if mode == 'raw':
            return self.mu.detach().cpu().numpy()
        elif mode == 'prob':
            return np.minimum(1.0, np.maximum(0.0, self.mu.detach().cpu().numpy()+self.mean))
        

In [11]:
class MLP(nn.Module):
    def __init__(self, input_dim , hidden_dims, output_dim=2, batch_norm=True, dropout=True, activation='relu'):
        super(MLP, self).__init__()
        modules = self.build_layers(input_dim, hidden_dims, output_dim, batch_norm, dropout, activation)
        self.layers = nn.Sequential(*modules)
        
    def base_layer(self, in_features, out_features , batch_norm, dropout, activation):
        modules = [nn.Linear(in_features, out_features, bias=True)]
        if batch_norm : 
            modules.append(nn.BatchNorm1d(out_features))
        if dropout : 
            modules.append(nn.Dropout(0.5, True))
        if activation :
            modules.append(nn.ReLU(True))
        layer = nn.Sequential(*modules)
        return layer
    
    def build_layers(self, input_dim, hidden_dims, output_dim, batch_norm, dropout, activation) :
        dims = [input_dim]
        dims.extend(hidden_dims)
        dims.append(output_dim)
        nr_hiddens = len(hidden_dims)
        modules = []
        for i in range(nr_hiddens) :
            layer = self.base_layer(dims[i], dims[i+1], batch_norm, dropout, activation)
            modules.append(layer)
        layer = nn.Linear(dims[-2], dims[-1], bias=True)
        modules.append(layer)
        return modules
        
    def reset_parameters(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                module.reset_parameters()
                
    def forward(self, input):
        out = self.layers(input)
        return out 

In [12]:
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.s_selectors = nn.ModuleList()
        self.views_dims = config['views_dims']
        num_views = len(self.views_dims)
        for i in range(num_views):
            v_dim = self.views_dims[i]
            s_selector = Selector(v_dim)
            self.s_selectors.append(s_selector)
            setattr(self, f's_selector_{i}', s_selector)

        self.shared_predictor = MLP(sum(config['views_dims']), config['hidden_dims'],
                                )

In [13]:
def _standard_truncnorm_sample(lower_bound, upper_bound, sample_shape=torch.Size()):
    r"""
    Weight initialization usggested in STG
    """
    x = torch.randn(sample_shape)
    done = torch.zeros(sample_shape).byte()
    while not done.all():
        proposed_x = lower_bound + torch.rand(sample_shape) * (upper_bound - lower_bound)
        if (upper_bound * lower_bound).lt(0.0):  # of opposite sign
            log_prob_accept = -0.5 * proposed_x**2
        elif upper_bound < 0.0:  # both negative
            log_prob_accept = 0.5 * (upper_bound**2 - proposed_x**2)
        else:  # both positive
            assert(lower_bound.gt(0.0))
            log_prob_accept = 0.5 * (lower_bound**2 - proposed_x**2)
        prob_accept = torch.exp(log_prob_accept).clamp_(0.0, 1.0) #inplace
        accept = torch.bernoulli(prob_accept).byte() & ~done # return the prob_accept shape matrix where done is 0 
        if accept.any():
            accept = accept.bool()
            x[accept] = proposed_x[accept]
            accept = accept.byte()
            done |= accept # |=, in-place bitwise OR operator to done 
    return x

## SynFS

In [14]:
class SynFS(object):
    def __init__(self, config):
        """
        construct 
        s_model : a set of synergistic selector + predictor
        n_model : a set of non-synergistic selecotr + predictor
        inf_model : predictor
        """
        
        self.config = config
        self.batch_size = self.config['batch_size']
        self.device = self.get_device()
        self.loss = nn.CrossEntropyLoss(reduction='none')
        self.cos_loss = nn.CosineSimilarity(dim=0)

        self.s_model = Model(config)
        self.s_model.apply(self.init_weights)
        self.s_model.to(self.device)

        self.n_model = Model(config)
        self.n_model.apply(self.init_weights)
        self.n_model.to(self.device)
        
        self.inf_model = MLP(sum(config['views_dims']), config['hidden_dims'])
        self.inf_model.apply(self.init_weights)
        self.inf_model.to(self.device)
        
        self.sp_params = list(self.s_model.shared_predictor.parameters()) 
        self.np_params = list(self.n_model.shared_predictor.parameters()) 

        self.s_params = [p for selector in self.s_model.s_selectors for p in selector.parameters()] 
        self.n_params = [p for selector in self.n_model.s_selectors for p in selector.parameters()] 
        
        self.p_opt = torch.optim.Adam(self.sp_params+self.np_params, 
                                      lr=config['learning_rate'], 
                                      weight_decay=config['weight_decay']) # 
        self.synergy_opt = torch.optim.Adam(self.s_params, lr=config['s_learning_rate'])# list of optimizers
        self.nsynergy_opt = torch.optim.Adam(self.n_params, lr=config['s_learning_rate'])
        self.inf_opt = torch.optim.Adam(self.s_params + self.n_params + list(self.inf_model.parameters()),
                                            lr=config['learning_rate'], 
                                            weight_decay=config['weight_decay'])
    def get_device(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        return device
    
    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            stddev = torch.tensor(0.1)
            shape = m.weight.shape
            m.weight = nn.Parameter(_standard_truncnorm_sample(lower_bound = -2*stddev, upper_bound = 2 * stddev, sample_shape = shape))
            torch.nn.init.zeros_(m.bias)
    
    def get_reg(self, selector):
        reg = selector.regularizer
        #reg = reg((selector.mu + selector.mean) / selector.sigma) #cdf
        reg = torch.mean(reg(selector.mu / selector.sigma))
        return reg
    
    def get_gates(self, model):
        S = [selector.hard_sigmoid(selector.mu.detach()) for selector in model.s_selectors]
        return S

    def mask_generator(self, batch_size=None):
        """
        Generate mask to use shared predictor with V marginal selectors
        """
        #generating mask for latent space
        # return a list of masks for each view 
        if batch_size is None:
            batch_size = self.config['batch_size']
        masks = []
        view_sum = sum(self.config['views_dims'])
        cumsum = np.cumsum(self.config['views_dims'])
        blank_mask = torch.zeros(batch_size, view_sum, device=self.device)
        for v in range(len(self.config['views_dims'])):
            v_mask = blank_mask.clone()
            if v == 0 :
                v_mask[:, :cumsum[v]] = 1
            else:
                v_mask[:, cumsum[v-1]:cumsum[v]] = 1
            masks.append(v_mask)
        return masks
    
    def simple_forward(self, model, S, views, X_mean_set):
        """
        simple forward to get P(Y|X_g_s) or P(Y|X_g_n)
        """
        masks = self.mask_generator(batch_size=views[0].shape[0])
        s_z = [S[i]*views[i]+(1-S[i])*X_mean_set[i] for i in range(len(views))]
        logits = []
        for v in range(len(self.config['views_dims'])):
            logit = model.shared_predictor(torch.cat(s_z, dim=1)*masks[v])
            logits.append(logit)
        bar_logit = model.shared_predictor(torch.cat(s_z, dim=1))
        logits.append(bar_logit)
        return logits 

    def train_step(self, data, X_mean_set):
        
        """"
        Train Predictor -> synergistic selector -> non-syergistic selector
        only corresponding weights specified for the loss are updated 
        
        """
        self.s_model.train()
        self.s_model.train()
        self.inf_model.train()
        
        views, y = data[:-1][0], data[-1]
        masks = self.mask_generator()
        
        # new v 
        v_s = [selector(views[i], X_mean_set[i]) for i, selector in enumerate(self.s_model.s_selectors)] 
        v_n = [selector(views[i], X_mean_set[i]) for i, selector in enumerate(self.n_model.s_selectors)]
        z_s = [selector.z for selector in self.s_model.s_selectors]
        z_n = [selector.z for selector in self.n_model.s_selectors]
        
        ###############################
        ###        Predictor       ###
        ###############################
        # with the same noise as with selectors  
        #  non-synergic
        
        n_logits, s_logits = [], []
        for v in range(len(self.config['views_dims'])):
            logit = self.n_model.shared_predictor(torch.cat(v_n, dim=1)*masks[v])
            n_logits.append(logit)
        n_v_bar_logits = self.n_model.shared_predictor(torch.cat(v_n, dim=1))
        n_logits.append(n_v_bar_logits)
        #  synergic
        for v in range(len(self.config['views_dims'])):
            logit = self.s_model.shared_predictor(torch.cat(v_s, dim=1)*masks[v])
            s_logits.append(logit)
        s_v_bar_logits = self.s_model.shared_predictor(torch.cat(v_s, dim=1))
        s_logits.append(s_v_bar_logits)
        
        n_losses = [self.loss(logit, y) for logit in n_logits]
        s_losses = [self.loss(logit, y) for logit in s_logits]
        p_loss = torch.mean(torch.sum(torch.stack(n_losses+s_losses, dim=1), dim=1))

        self.p_opt.zero_grad() 
        p_loss.backward(retain_graph=True) # s_selector, n_selector, predictors 
        self.p_opt.step()  # update predictors
        
        #####################################
        ###   Synergic Selector and Inf   ###
        #####################################
        
        n_logits, s_logits = [], []
        # informative gate (max(synergic gate , non-synergic gate))
        all_gates = [self.s_model.s_selectors[0].hard_sigmoid(torch.max(s, n)) for s, n in zip(z_s, z_n)]  # with noise
        all_v = [views[i]*gate + X_mean_set[i]*(1-gate) for i, gate in enumerate(all_gates)]
        gate_s, gate_n = [self.s_model.s_selectors[0].hard_sigmoid(s) for s in z_s], [self.s_model.s_selectors[0].hard_sigmoid(n) for n in z_n] 
        s_regs =[self.get_reg(selector) for selector in self.s_model.s_selectors]
        
        all_bar_logits = self.inf_model(torch.cat(all_v, dim=1))
        all_bar_loss = self.loss(all_bar_logits, y)

        for v in range(len(self.config['views_dims'])):
            logit = self.s_model.shared_predictor(torch.cat(v_s, dim=1)*masks[v])
            s_logits.append(logit)
        s_v_bar_logits = self.s_model.shared_predictor(torch.cat(v_s, dim=1))

        s_losses = [self.loss(logit, y) for logit in s_logits]
        s_v_bar_loss = self.loss(s_v_bar_logits, y)
        inf_loss = torch.mean(all_bar_loss) 
        
        synergy_loss = torch.mean(s_v_bar_loss - torch.sum(torch.stack(s_losses, dim=1), dim=1) \
                                + self.config['s_lam']*torch.mean(torch.stack(s_regs)) \
                                )

        self.inf_opt.zero_grad()
        self.synergy_opt.zero_grad()
        inf_loss.backward(retain_graph=True) #s selecotr, n selector, ai predictor 
        synergy_loss.backward() # s selector, predictors 

        self.inf_opt.step()  
        self.synergy_opt.step()
        
        ##########################
        ###   Non-syneristic   ###
        ##########################

        sim = torch.nn.functional.cosine_similarity(torch.cat(gate_s, dim=1), torch.cat(gate_n, dim=1), dim=1) #(batch, alldim)
        ns_regs =[self.get_reg(selector) for selector in self.n_model.s_selectors]
        
        for v in range(len(self.config['views_dims'])):
            logit = self.n_model.shared_predictor(torch.cat(v_n, dim=1)*masks[v])
            n_logits.append(logit)
        n_v_bar_logits = self.n_model.shared_predictor(torch.cat(v_n, dim=1))
        
        n_losses = [self.loss(logit, y) for logit in n_logits]
        n_v_bar_loss = self.loss(n_v_bar_logits, y)

        nsynergy_loss = torch.mean(-n_v_bar_loss + torch.sum(torch.stack(n_losses, dim=1), dim=1) 
                                   + self.config['n_lam']*torch.mean(torch.stack(ns_regs)) # same mu for all batch
                                   + self.config['alpha']*sim 
                                   )
        
        self.nsynergy_opt.zero_grad()
        nsynergy_loss.backward() # n selector, predcitor 
        self.nsynergy_opt.step()
        
        return synergy_loss, nsynergy_loss, inf_loss

    def predict(self, views, te_X_set):
        self.s_model.eval()
        self.n_model.eval()
        self.inf_model.eval()
        
        X_mean_set = [torch.mean(torch.Tensor(X).to(self.device), dim=0) for X in te_X_set] 

        with torch.no_grad():
            S = self.get_gates(self.s_model)
            NS = self.get_gates(self.s_model)
            all_mu = [torch.max(s, ns) for s, ns in zip(S, NS)]
            all_z = [all_mu[i]*views[i]+(1-all_mu[i])*X_mean_set[i] for i in range(len(views))]
            all_bar_logits = self.inf_model(torch.cat(all_z, dim=1))
        return all_bar_logits


In [15]:
synfs = SynFS(config)
X_mean_set = [torch.mean(torch.Tensor(X).to(device), dim=0) for X in tr_X_set]

for epoch in range(config['epochs']):
    syn_loss = 0
    for batch in trainloader:
        synergy_loss, nsynergy_loss, inf_loss = synfs.train_step(batch, X_mean_set)
        syn_loss +=synergy_loss
    avg_syn_loss = syn_loss/len(trainloader)
    if epoch % 30 == 0:
        print(f'=====epoch {epoch} syn_loss {avg_syn_loss.item():.3f}=====')
        print('predicted synergistic features')
        print(np.where(np.concatenate([g.cpu().numpy() for g in synfs.get_gates(synfs.s_model)])>0.7)[0])
        print('predicted non-synergistic features')
        print(np.where(np.concatenate([g.cpu().numpy() for g in synfs.get_gates(synfs.n_model)])>0.7)[0])

=====epoch 0 syn_loss -0.693=====
predicted synergistic features
[]
predicted non-synergistic features
[]
=====epoch 30 syn_loss -0.665=====
predicted synergistic features
[  0 251]
predicted non-synergistic features
[  2 253]
=====epoch 60 syn_loss -0.681=====
predicted synergistic features
[  0 251]
predicted non-synergistic features
[  2 253]


## Evaluation

### Feature Discovery

In [16]:
# Get group similarity and group structure.
s = [gate.cpu().numpy() for gate in synfs.get_gates(synfs.s_model)]
n = [gate.cpu().numpy() for gate in synfs.get_gates(synfs.n_model)]
n_predicted = np.where(np.concatenate(n)>0.7)[0]
s_predicted = np.where(np.concatenate(s)>0.7)[0]
predicted_features = [n_predicted, s_predicted]

# Get group similarity and group structure.
tpr, fdr = tpr_fdr(ground_truth_features, predicted_features)
j_index, ntrue, npredicted = strict_jaccard(ground_truth_features, predicted_features)

print("ground_truth_syn |", ground_truth_features[0], ",ground_truth_non-syn | ", ground_truth_features[1])
print("predicted_syn |", predicted_features[0], ",predicted_non-syn | ", predicted_features[1] )
print(
    "Jaccard Index: {:.3f}, True Positive Rate: {:.3f}%, False Discovery Rate: {:.3f}%".format(
        j_index, tpr, fdr
    )
)


ground_truth_syn | [  2 253] ,ground_truth_non-syn |  [  0 251]
predicted_syn | [  2 253] ,predicted_non-syn |  [  0 251]
Jaccard Index: 1.000, True Positive Rate: 100.000%, False Discovery Rate: 0.000%


### Predictive Performance

In [17]:
data = SimpleDataset(te_X_set, te_y, device=device)
testloader = DataLoader(data, batch_size=len(te_y))
           
res = []
for x, y in testloader:
    logits = synfs.predict(x, X_set)
    res.append(logits.detach().cpu().numpy())
logits = np.concatenate(res)

auroc, auprc, accuracy, f1 = standard_metrics(tr_y, te_y, logits, verbose=True)

auroc | 0.655, auprc | 0.660, accuracy | 0.604, f1 | 0.652
