## CPU or GPU

In [1]:
import torch
USE_GPU = True
device = torch.device('cuda') if USE_GPU and torch.cuda.is_available() else torch.device('cpu')
print('using device:', device)

using device: cpu


## utils.py

In [5]:
import torch
from torch.utils import data
import numpy as np
import sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score

def set_seed(seed=42, device=torch.device("cpu")):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if device.type == "cuda":
        torch.cuda.manual_seed(seed)

def force_float(X_numpy):
    return torch.from_numpy(X_numpy.astype(np.float32))

def convert_to_torch_loaders(Xd, Zd, Yd, batch_size):
    if type(Xd) != dict and type(Zd) != dict and type(Yd) != dict:
        Xd = {"train": Xd}
        Zd = {"train": Zd}
        Yd = {"train": Yd}

    data_loaders = {}
    for k in Xd:
        X_inputs = force_float(Xd[k])
        Z_inputs = force_float(Zd[k])
        targets = force_float(Yd[k])
        dataset = data.TensorDataset(X_inputs, Z_inputs, targets)
        data_loaders[k] = data.DataLoader(dataset, batch_size, shuffle=(k == "train"))

    return data_loaders

def get_continuous_cols(X, unique_threshold = 10):
    continuous_cols = []
    for i in range(X.shape[1]):
        if np.issubdtype(X[:, i].dtype, np.number):
            unique_values = np.unique(X[:, i][~np.isnan(X[:, i])])
            if len(unique_values) > unique_threshold:
                continuous_cols.append(i)
    return continuous_cols

def standardize_data(data, continuous_cols):
    if continuous_cols:
        scaler = StandardScaler() # StandardScaler objects
        scaler.fit(data["train"][:, continuous_cols])
        for k in data:
            data[k][:, continuous_cols] = scaler.transform(data[k][:, continuous_cols])
    return data


def preprocess_data(
    X,
    Z,
    Y,
    valid_size=500,
    test_size=500,
    std_scale=False,
    unique_threshold = 10, # used to identify discrete variables
    batch_size=100,
):

    n = X.shape[0]

    ## Make dataset splits
    ntrain, nval, ntest = n - valid_size - test_size, valid_size, test_size

    Xd = {
        "train": X[:ntrain],
        "val": X[ntrain : ntrain + nval],
        "test": X[ntrain + nval : ntrain + nval + ntest],
    }
    Zd = {
        "train": Z[:ntrain],
        "val": Z[ntrain : ntrain + nval],
        "test": Z[ntrain + nval : ntrain + nval + ntest],
    }

    Yd = {
        "train": np.expand_dims(Y[:ntrain], axis=1),
        "val": np.expand_dims(Y[ntrain : ntrain + nval], axis=1),
        "test": np.expand_dims(Y[ntrain + nval : ntrain + nval + ntest], axis=1),
    }

    # If the std_scale is TRUE, find continuous columns to standardize
    if std_scale:
        X_continuous_cols = get_continuous_cols(X, unique_threshold)
        Z_continuous_cols = get_continuous_cols(Z, unique_threshold)
        Y_continuous_cols = [0] if np.issubdtype(Y.dtype, np.number) and len(np.unique(Y)) > unique_threshold else []
        Xd = standardize_data(Xd, X_continuous_cols)
        Zd = standardize_data(Zd, Z_continuous_cols)
        Yd = standardize_data(Yd, Y_continuous_cols)

    return convert_to_torch_loaders(Xd, Zd, Yd, batch_size)


def get_auc(interactions, ground_truth):
    strengths = []
    gt_binary_list = []
    for inter, strength in interactions:
        strengths.append(strength)
        if any(inter == gt for gt in ground_truth):
            gt_binary_list.append(1)
        else:
            gt_binary_list.append(0)

    auc = roc_auc_score(gt_binary_list, strengths)
    return auc

def print_rankings(interactions, top_k=10, spacing=14):
    print(
        justify(["Pairwise interactions"], spacing)
    )
    for i in range(top_k):
        p_inter, p_strength = interactions[i]
        print(
            justify(
                [
                    p_inter,
                    "{0:.4f}".format(p_strength),
                    ""
                ],
                spacing,
            )
        )

def justify(row, spacing=14):
    return "".join(str(item).ljust(spacing) for item in row)


## TwinterNet

In [6]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import copy
import re
from sklearn.metrics import accuracy_score, roc_auc_score


class Twinter_Net(nn.Module):
    def __init__(
        self,
        X_num_features,
        Z_num_features,
        X_hidden_units,
        Z_hidden_units,
        X_Z_pairs_repeats,
        X_Z_hidden_units,
        X_Z_pairwise = True,
        X_Z_parallel = True,
        X_allZ_layer = True,
        Z_allX_layer = True,
        task_type = "regression"
    ):
        super(Twinter_Net, self).__init__()

        if not X_Z_pairwise and not X_allZ_layer and not Z_allX_layer:
            raise ValueError("When X_Z_pairwise is False, at least one of X_allZ_layer or Z_allX_layer must be True. "
                              "You have three options:"
                              "\nOption 1: Set both X_allZ_layer and Z_allX_layer to True;"
                              "\nOption 2: Set X_allZ_layer to True and Z_allX_layer to False;"
                              "\nOption 3: Set X_allZ_layer to False and Z_allX_layer to True.")

        self.X_num_features = X_num_features
        self.Z_num_features = Z_num_features
        self.X_Z_pairwise = X_Z_pairwise
        self.X_Z_parallel = X_Z_parallel
        self.X_allZ_layer = X_allZ_layer
        self.Z_allX_layer = Z_allX_layer
        self.X_Z_pairs_repeats = X_Z_pairs_repeats
        self.task_type = task_type

        # create the X net
        self.X_mlp = create_mlp([X_num_features] + X_hidden_units + [1])
        # create the Z net
        self.Z_mlp = create_mlp([Z_num_features] + Z_hidden_units + [1])
        # create the X_Z net
        X_Z_pairs = (X_num_features * Z_num_features if X_Z_pairwise else
                   X_num_features + Z_num_features if X_allZ_layer and Z_allX_layer else
                   X_num_features if X_allZ_layer else
                   Z_num_features if Z_allX_layer else 0)
        X_Z_layer_units = X_Z_pairs * X_Z_pairs_repeats
        self.X_Z_layer = nn.Linear(X_num_features + Z_num_features, X_Z_layer_units)
        self.X_Z_mask = self.create_mask(X_num_features, Z_num_features)
        with torch.no_grad():
            self.X_Z_layer.weight.mul_(self.X_Z_mask)
        self.X_Z_relu = nn.ReLU()
        if X_Z_parallel:
            self.X_Z_parallel_mlp = self.create_X_Z_nets(X_Z_pairs, X_Z_hidden_units)
        else:
            self.X_Z_mlp = create_mlp([X_Z_layer_units] + X_Z_hidden_units + [1])

    def forward(self, x, z):
        output_X = self.X_mlp(x)
        output_Z = self.Z_mlp(z)
        x_z = torch.cat((x, z), dim=1)
        x_z_layer = self.X_Z_layer(x_z)
        x_z_layer = self.X_Z_relu(self.X_Z_layer(x_z))
        if self.X_Z_parallel:
            output_X_Z = self.forward_X_Z_nets(x_z_layer, self.X_Z_parallel_mlp)
        else:
            output_X_Z = self.X_Z_mlp(x_z_layer)
        output_sum = output_X + output_Z + output_X_Z
        if self.task_type == "regression":
            return output_sum
        elif self.task_type == "classification":
            return torch.sigmoid(output_sum)

    def create_X_Z_nets(self, X_Z_pairs, X_Z_hidden_units):
        x_z_mlp_list = [
            create_mlp([self.X_Z_pairs_repeats] + X_Z_hidden_units + [1], out_bias=False)
            for i in range(X_Z_pairs)
        ]

        if self.X_Z_pairwise:
            for j in range(self.X_num_features):
                for k in range(self.Z_num_features):
                    setattr(self, "X" + str(j) + "_Z" + str(k) + "_mlp", x_z_mlp_list[j * self.Z_num_features + k])
        else:
            if self.X_allZ_layer and self.Z_allX_layer:
                for j in range(self.X_num_features):
                    setattr(self, "X" + str(j) + "_Z" + "_mlp", x_z_mlp_list[j])
                for k in range(self.Z_num_features):
                    setattr(self, "X" + "_Z" + str(k) + "_mlp", x_z_mlp_list[self.X_num_features + k])
            elif self.X_allZ_layer:
                for j in range(self.X_num_features):
                    setattr(self, "X" + str(j) + "_Z" + "_mlp", x_z_mlp_list[j])
            elif self.Z_allX_layer:
                for k in range(self.Z_num_features):
                    setattr(self, "X" + "_Z" + str(k) + "_mlp", x_z_mlp_list[k])

        return x_z_mlp_list

    def forward_X_Z_nets(self, x_z_layer, mlps):
        forwarded_x_z_mlps = []
        for i, mlp in enumerate(mlps):
            x_z_layer_selected_columns = slice(i*self.X_Z_pairs_repeats, (i+1)*self.X_Z_pairs_repeats)
            forwarded_x_z_mlps.append(mlp(x_z_layer[:, x_z_layer_selected_columns]))
        forwarded_x_z_mlp = sum(forwarded_x_z_mlps)
        return forwarded_x_z_mlp

    def create_mask(self, p, q):
        if self.X_Z_pairwise:
            vec, identity = np.ones(q), np.eye(q)
            mask_x = np.zeros((p*q, p))
            for i in range(p):
                mask_x[i*q:(i+1)*q, i] = vec
            mask_z = np.vstack([identity] * p)
            mask = np.append(mask_x, mask_z, axis=1)
        else:
            if self.X_allZ_layer and self.Z_allX_layer:
                mask = np.block([[np.eye(p), np.ones((p, q))], [np.ones((q, p)), np.eye(q)]])
            elif self.X_allZ_layer:
                mask = np.block([np.eye(p), np.ones((p, q))])
            elif self.Z_allX_layer:
                mask = np.block([np.ones((q, p)), np.eye(q)])
        mask_repeated = np.repeat(mask, repeats=self.X_Z_pairs_repeats, axis=0)

        return torch.tensor(mask_repeated, dtype=torch.float32)


def create_mlp(layer_sizes, out_bias=True):
    ls = list(layer_sizes)
    layers = nn.ModuleList()
    for i in range(1, len(ls) - 1):
        layers.append(nn.Linear(int(ls[i - 1]), int(ls[i])))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(int(ls[-2]), int(ls[-1]), bias=out_bias))
    return nn.Sequential(*layers)


In [7]:
def train(
    net,
    data_loaders,
    nepochs=100,
    verbose=False,
    early_stopping=True,
    patience=5,
    l1_const=5e-5,
    l2_const=0,
    learning_rate=1e-2,
    penalize_MMLP = False,
    opt_func=optim.Adam,
    device=torch.device("cpu"),
):

    X_num_features = net.X_num_features
    Z_num_features = net.Z_num_features
    X_Z_pairwise = net.X_Z_pairwise
    X_Z_parallel = net.X_Z_parallel
    X_allZ_layer = net.X_allZ_layer
    Z_allX_layer = net.Z_allX_layer
    mask = net.X_Z_mask.to(device)
    task_type = net.task_type

    optimizer = opt_func(net.parameters(), lr=learning_rate, weight_decay=l2_const)
    criterion = nn.MSELoss(reduction="mean") if task_type == "regression" else nn.BCELoss(reduction="mean")

    def evaluate_loss(net, data_loader, criterion, device):
        losses = []
        for X_inputs, Z_inputs, targets in data_loader:
            X_inputs = X_inputs.to(device)
            Z_inputs = Z_inputs.to(device)
            targets = targets.to(device)
            loss = criterion(net(X_inputs, Z_inputs), targets).cpu().data
            losses.append(loss)
        return torch.stack(losses).mean()

    def evaluate_accu(net, data_loader, device):
        accus = []
        for X_inputs, Z_inputs, targets in data_loader:
            X_inputs = X_inputs.to(device)
            Z_inputs = Z_inputs.to(device)
            targets = targets.to(device)
            outputs = net(X_inputs, Z_inputs)
            accu = accuracy_score(targets.squeeze(1).detach().numpy(), outputs.squeeze(1).detach().numpy().round())
            accus.append(accu)
        return np.mean(accus)

    def evaluate_auc(net, data_loader, device):
        aucs = []
        for X_inputs, Z_inputs, targets in data_loader:
            X_inputs = X_inputs.to(device)
            Z_inputs = Z_inputs.to(device)
            targets = targets.to(device)
            outputs = net(X_inputs, Z_inputs)
            auc = roc_auc_score(targets.squeeze(1).detach().cpu().numpy(), outputs.squeeze(1).detach().numpy())
            aucs.append(auc)
        return np.mean(aucs)

    best_loss = float("inf")
    best_net = None

    if "val" not in data_loaders:
        early_stopping = False

    patience_counter = 0

    if verbose:
        print("starting to train")
        if early_stopping:
            print("early stopping enabled")

    for epoch in range(nepochs):
        running_loss = 0.0
        run_count = 0
        for i, data in enumerate(data_loaders["train"], 0):
            X_inputs, Z_inputs, targets = data
            X_inputs = X_inputs.to(device)
            Z_inputs = Z_inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = net(X_inputs, Z_inputs)
            loss = criterion(outputs, targets).mean()

            reg_loss = 0
            if not X_Z_pairwise or not X_Z_parallel or penalize_MMLP:
                for name, param in net.named_parameters():
                    if name == "X_Z_layer.weight" and not X_Z_pairwise:
                        if X_allZ_layer and Z_allX_layer:
                            reg_loss += (torch.sum(torch.abs(param[:X_num_features, X_num_features:])) + torch.sum(torch.abs(param[X_num_features:, :X_num_features])))
                        elif X_allZ_layer:
                            reg_loss += torch.sum(torch.abs(param[:, X_num_features:]))
                        elif Z_allX_layer:
                            reg_loss += torch.sum(torch.abs(param[:, :X_num_features]))
                    if ("X_Z_mlp" in name and "weight" in name) and not X_Z_parallel:
                        reg_loss += torch.sum(torch.abs(param))
                    if (re.match(r"^X_mlp\.\d+\.weight$", name) or re.match(r"^Z_mlp\.\d+\.weight$", name)) and penalize_MMLP:
                        reg_loss += torch.sum(torch.abs(param))

            (loss + reg_loss * l1_const).backward()
            # mask gradients for the X_Z_layer
            with torch.no_grad():
                net.X_Z_layer.weight.grad.mul_(mask)
            optimizer.step()
            running_loss += loss.item()
            run_count += 1

        if epoch % 1 == 0:
            key = "val" if "val" in data_loaders else "train"
            val_loss = evaluate_loss(net, data_loaders[key], criterion, device)

            if epoch % 2 == 0:
                if verbose:
                    print(
                        "[epoch %d, total %d] train loss: %.4f, val loss: %.4f"
                        % (epoch + 1, nepochs, running_loss / run_count, val_loss)
                    )
            if early_stopping:
                if val_loss < best_loss:
                    best_loss = val_loss
                    best_net = copy.deepcopy(net)
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter > patience:
                        net = best_net
                        val_loss = best_loss
                        if verbose:
                            print("early stopping!")
                        break

            prev_loss = running_loss
            running_loss = 0.0

    if "test" in data_loaders:
        key = "test"
    elif "val" in data_loaders:
        key = "val"
    else:
        key = "train"

    if task_type == "regression":
        test_loss = evaluate_loss(net, data_loaders[key], criterion, device).item()
        output = (net, test_loss)
        if verbose:
            print("Finished Training. Test loss: ", test_loss)
    elif task_type == "classification":
        test_loss = evaluate_loss(net, data_loaders[key], criterion, device).item()
        test_accu = evaluate_accu(net, data_loaders[key], device)
        test_auc = evaluate_auc(net, data_loaders[key], device)
        output = (net, test_loss, test_accu, test_auc)
        if verbose:
            print("Finished Training. Test loss: %.4f, Test accuracy: %.4f, Test auc: %.4f" % (test_loss, test_accu, test_auc))

    return output

In [8]:
import bisect
import operator
import numpy as np
import torch
from torch.utils import data
import re
from collections import defaultdict

def get_weights(model):

    X_Z_mlp_weights = defaultdict(list)

    for name, param in model.named_parameters():
        if "weight" not in name:
            continue
        if "X_Z_layer" in name:
            X_Z_layer_weights = param.cpu().detach().numpy()
        else:
            # when X_Z_parallel = False
            match_X_Z = re.search(r"X_Z_mlp", name)
            # when X_Z_parallel = True
            match_Xj_Zk = re.search(r"X\d+_Z\d+_mlp", name) # X_Z_pairwise = True
            match_Xj_Z = re.search(r"X\d+_Z+_mlp", name) # X_Z_pairwise = False, X_allZ_layer = True, Z_allX_layer = False
            match_X_Zk = re.search(r"X+_Z\d+_mlp", name) # X_Z_pairwise = False, X_allZ_layer = False, Z_allX_layer = True
            if match_Xj_Zk:
                X_Z_mlp_weights[match_Xj_Zk.group(0)].append(param.cpu().detach().numpy())
            elif match_Xj_Z:
                X_Z_mlp_weights[match_Xj_Z.group(0)].append(param.cpu().detach().numpy())
            elif match_X_Zk:
                X_Z_mlp_weights[match_X_Zk.group(0)].append(param.cpu().detach().numpy())
            elif match_X_Z:
                X_Z_mlp_weights[match_X_Z.group(0)].append(param.cpu().detach().numpy())

    return X_Z_layer_weights, X_Z_mlp_weights

def preprocess_weights(weights):
    X_Z_input_weights, X_Z_later_weights = weights
    w_input = np.abs(X_Z_input_weights)
    w_later = {}
    for name in X_Z_later_weights:
        mlp_weights = X_Z_later_weights[name]
        later_weights = np.abs(mlp_weights[-1])
        for i in range(len(mlp_weights) - 2, -1, -1):
            later_weights = np.matmul(later_weights, np.abs(mlp_weights[i]))
        w_later[name] = later_weights

    return w_input, w_later

def interpret_interactions(w_input, w_later, X_num_features, Z_num_features, X_Z_incoming, X_allZ_layer, Z_allX_layer, X_Z_pairs_repeats):

    w_later_list = []
    for name, value in w_later.items():
        if len(value.shape) == 2:
            value = value.flatten()
        w_later_list.extend(value.tolist())

    X_Z_pairwise = len(w_later_list) == X_num_features * Z_num_features * X_Z_pairs_repeats
    X_w_input, Z_w_input = w_input[:, :X_num_features], w_input[:, X_num_features:]
    x_w_index, z_w_index = np.arange(X_num_features), np.arange(Z_num_features)

    if X_Z_pairwise:
        X_index, Z_index = np.repeat(x_w_index, Z_num_features * X_Z_pairs_repeats), np.tile(np.repeat(z_w_index, X_Z_pairs_repeats), X_num_features)
        row_index = np.arange(w_input.shape[0])
    else:
        X_index_part1, X_index_part2 = np.repeat(x_w_index, Z_num_features * X_Z_pairs_repeats), np.tile(x_w_index, Z_num_features * X_Z_pairs_repeats)
        Z_index_part1, Z_index_part2 = np.tile(z_w_index, X_num_features * X_Z_pairs_repeats), np.repeat(z_w_index, X_num_features * X_Z_pairs_repeats)
        row_index_part1, row_index_part2 = np.repeat(np.arange(X_num_features * X_Z_pairs_repeats), Z_num_features), np.repeat(np.arange(Z_num_features * X_Z_pairs_repeats), X_num_features)

        if X_allZ_layer and Z_allX_layer:
            X_index = np.concatenate((X_index_part1, X_index_part2))
            Z_index = np.concatenate((Z_index_part1, Z_index_part2))
            row_index = np.concatenate((row_index_part1, row_index_part2 + X_num_features * X_Z_pairs_repeats))
            w_later_list1, w_later_list2 = np.repeat(w_later_list[:X_num_features*X_Z_pairs_repeats], Z_num_features), np.repeat(w_later_list[X_num_features*X_Z_pairs_repeats:], X_num_features)
            w_later_list = np.concatenate((w_later_list1, w_later_list2))
        elif X_allZ_layer:
            X_index, Z_index, row_index = X_index_part1, Z_index_part1, row_index_part1
            w_later_list = np.repeat(w_later_list, Z_num_features)
        elif Z_allX_layer:
            X_index, Z_index, row_index = X_index_part2, Z_index_part2, row_index_part2
            w_later_list = np.repeat(w_later_list, X_num_features)

    if X_Z_incoming == "mean":
        strength = np.mean([X_w_input[row_index, X_index], Z_w_input[row_index, Z_index]], axis=0) * w_later_list
    elif X_Z_incoming == "min":
        strength = np.min([X_w_input[row_index, X_index], Z_w_input[row_index, Z_index]], axis=0) * w_later_list
    interaction_strength = list(zip(zip(X_index, Z_index), strength))

    if X_Z_pairwise:
        interaction_ranking = interaction_strength
    else:
        interaction_ranking = defaultdict(int)
        for i in range(len(interaction_strength)):
            name = (X_index[i], Z_index[i])
            value = strength[i]
            interaction_ranking[name] += value
        interaction_ranking = [(name, strength) for name, strength in interaction_ranking.items()]

    interaction_ranking.sort(key=lambda x: x[1], reverse=True)

    return interaction_ranking

def make_one_indexed(interaction_ranking):
    return [(tuple(np.array(i) + 1), s) for i, s in interaction_ranking]

def get_interactions(weights, X_num_features, Z_num_features, X_Z_pairs_repeats, X_Z_incoming = "min", X_allZ_layer = True, Z_allX_layer = True, one_indexed=False):

    w_input, w_later = preprocess_weights(weights)

    interaction_ranking = interpret_interactions(w_input, w_later, X_num_features, Z_num_features, X_Z_incoming, X_allZ_layer, Z_allX_layer, X_Z_pairs_repeats)

    if one_indexed:
        return make_one_indexed(interaction_ranking)
    else:
        return interaction_ranking

# Synthtic function

In [11]:
def synth_asym_func1_ver1(X, Z, task_type):

    # Extract individual features from X and Z matrices
    X1, X2, X3, X4, X5  = X[:,0], X[:,1], X[:,2], X[:,3], X[:,4]
    X6, X7, X8, X9, X10 = X[:,5], X[:,6], X[:,7], X[:,8], X[:,9]
    X11, X12, X13, X14, X15 = X[:,10], X[:,11], X[:,12], X[:,13], X[:,14]
    Z1, Z2, Z3, Z4, Z5 = Z[:,0], Z[:,1], Z[:,2], Z[:,3], Z[:,4]
    Z6, Z7, Z8, Z9, Z10 = Z[:,5], Z[:,6], Z[:,7], Z[:,8], Z[:,9]
    Z11, Z12, Z13, Z14, Z15 = Z[:,10], Z[:,11], Z[:,12], Z[:,13], Z[:,14]

    # Define two-view interactions between X and Z
    interaction1 = + X1 * Z1
    interaction2 = - X2 * Z1
    interaction3 = + X3 * Z1
    interaction4 = - X4 * Z2
    interaction5 = + X4 * Z3
    interaction6 = - X4 * Z4
    interaction7 = + X5 * Z5
    interaction8 = - X6 * Z5
    interaction9 = + X7 * Z6
    interaction10 = - X8 * Z7
    interaction11 = + X8 * Z8
    interaction12 = - X9 * Z9
    interaction13 = + X9 * Z10
    interaction14 = - X10 * Z11
    interaction15 = + X11 * Z11
    interaction16 = - X12 * Z12
    interaction17 = + X13 * Z12

    # Define within-view effects for X and Z
    X_within_effects = X1 - 2*X2 + X3 - X4 * X5 + X6 - X7 + X8 - X9 + X10 * X11 - X12 + X13 - X14 + X15 - X10 * X13 + X14 * X15
    Z_within_effects = Z1 - Z2 + Z3 - Z4 + 2*Z5 - Z6 * Z7 + Z8 - Z9 + Z10 - Z11 + Z12 - Z13 + Z14 - Z15 + Z13 * Z14 - Z3 * Z15 + Z14 * Z15

    # Genereate the linear output
    linear_output = (
        interaction1 + interaction2 + interaction3 + interaction4 + interaction5 +
        interaction6 + interaction7 + interaction8 + interaction9 + interaction10 +
        interaction11 + interaction12 + interaction13 + interaction14 + interaction15 +
        interaction16 + interaction17 + X_within_effects + Z_within_effects
    )

    # Define the ground truth for within-view interactions
    ground_truth = [ (1,1), (2,1), (3,1), (4,2), (4,3), (4,4), (5,5), (6,5), (7,6), (8,7), (8,8), (9,9), (9,10), (10,11), (11,11), (12,12), (13,12)]

    # Generate the response variable Y based on the task type (regression or classification)
    if task_type == "regression":
        Y = linear_output
    elif task_type == "classification":
        Y = generate_binary_response(linear_output)

    return Y, ground_truth


# Running Example

In [21]:
import numpy as np

# Set the number of features for X and Z
X_num_features, Z_num_features = 100, 20

# Generate random data for X and Z
n = 10000
X = np.random.rand(n, X_num_features)
Z = np.random.rand(n, Z_num_features)

# Generate the synthetic response variable Y and ground truth interactions using synth_asym_func1_ver1
task_type = "regression"
Y, ground_truth = synth_asym_func1_ver1(X, Z, task_type)

# Set the sizes for test, validation, and batch
valid_size = 1125
test_size = 2500
batch_size=100

# Preprocess the data
data_loaders = preprocess_data(X, Z, Y, valid_size=valid_size, test_size=test_size, batch_size=batch_size, std_scale=True)

# Initialize the TwinterNet model
model = Twinter_Net(
    X_num_features, Z_num_features,
    X_hidden_units = [30, 10, 5],
    Z_hidden_units = [8, 5, 3],
    X_Z_pairs_repeats=10,
    X_Z_hidden_units=[10, 10],
    X_Z_pairwise=False,
    X_Z_parallel=True,
    X_allZ_layer=True,
    Z_allX_layer=False
).to(device)

# Train the model
model, loss = train(model, data_loaders, verbose=True)

# Detect interactions from the model weights
model_weights = get_weights(model)
interactions = get_interactions(
    model_weights, X_num_features, Z_num_features, X_Z_incoming="min",
    X_Z_pairs_repeats=10, X_allZ_layer=True, Z_allX_layer=False, one_indexed=True
)

# Evaluate the model by calculating AUC
auc = get_auc(interactions, ground_truth)
print("AUC:", auc)


starting to train
early stopping enabled
[epoch 1, total 100] train loss: 0.6389, val loss: 0.0981
[epoch 3, total 100] train loss: 0.0283, val loss: 0.0347
[epoch 5, total 100] train loss: 0.0155, val loss: 0.0240
[epoch 7, total 100] train loss: 0.0111, val loss: 0.0160
[epoch 9, total 100] train loss: 0.0107, val loss: 0.0151
[epoch 11, total 100] train loss: 0.0082, val loss: 0.0148
[epoch 13, total 100] train loss: 0.0092, val loss: 0.0127
[epoch 15, total 100] train loss: 0.0086, val loss: 0.0145
[epoch 17, total 100] train loss: 0.0126, val loss: 0.0151
[epoch 19, total 100] train loss: 0.0075, val loss: 0.0107
[epoch 21, total 100] train loss: 0.0121, val loss: 0.0101
[epoch 23, total 100] train loss: 0.0111, val loss: 0.0160
[epoch 25, total 100] train loss: 0.0075, val loss: 0.0108
[epoch 27, total 100] train loss: 0.0065, val loss: 0.0182
[epoch 29, total 100] train loss: 0.0067, val loss: 0.0136
[epoch 31, total 100] train loss: 0.0080, val loss: 0.0099
[epoch 33, total 100