# Install libraries 


In [15]:
import sys, os, json
import mne, sklearn, wandb
import numpy as np
import pandas as pd

from scipy.interpolate import interp1d
from nilearn import datasets, image, masking, plotting
from nilearn.input_data import NiftiLabelsMasker


# animation part
from IPython.display import HTML
import matplotlib
import matplotlib.pyplot as plt
# from celluloid import Camera   # it is convinient method to animate
from matplotlib import animation, rc
from matplotlib.animation import FuncAnimation


## torch libraries 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader, Subset

from pytorch_model_summary import summary

%load_ext autoreload
%autoreload 2
sys.path.insert(1, os.path.realpath(os.path.pardir))

from utils import get_datasets
from utils import preproc
from utils import torch_dataset
from utils import train_utils
from utils import inference
from utils.models_arch import autoencoder_new_Artur

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Set all hyperparameters
- Cuda and GPU.
- Parameters of dataset. 
- random seed( if necessary). 


In [2]:
# import random

# torch.manual_seed(0)
# random.seed(0)  # python operation seed
# np.random.seed(0)

# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

print(torch.cuda.is_available(), torch.cuda.device_count())
torch.cuda.set_device(0)

True 1


In [29]:
config = dict(  
                dataset_name = 'CWL_raw', # CWL
                patients = ['trio1'], # work only with one patient
                fps = 1000,
                new_fps = 100, 
                crop_start = 5,
                freqs = [-1], 
    
                n_channels = 30, # 63 
                n_roi = 8,
                
                bold_delay = 6,
                to_many = True,
                random_subsample = True,
                sample_per_epoch = 512, 
                WINDOW_SIZE = 2048,
                    
                optimizer='adamW',
                lr=1e-5,  # 5e-5 is too big LMAO
                weight_decay=1e-4, 
                batch_size=16, 
                    
                mse_weight = 1.0,
                corr_weight = 0.0,
                
                preproc_type = 'dB_log',
                loss_function = 'corr', 
                model_type = 'Best_AE_Artur_Multi_Head'
                )


hp_autoencoder = dict(n_electrodes=config['n_channels'],
                      n_freqs = len(config['freqs']),
                      n_channels_out = config['n_roi'],

                     channels = [128, 128, 128, 128], 
                     kernel_sizes=[5, 5, 3],
                     strides=[8, 8, 4], 
                     dilation=[1, 1, 1], 
                     decoder_reduce=4, 
                     hidden_channels = 16,
                     )


config = {**hp_autoencoder, **config}

params_train = {'batch_size': config['batch_size'],
                'shuffle': True,
                'num_workers': 0}

params_val = {'batch_size': config['batch_size'],
              'shuffle': False}

# Upload preprocessed dataset from np files. 
It should accelerate speed of experiments.

In [4]:
# with open("../data/interim/labels_roi_17.json", 'r') as f:
#     labels_roi_17 = json.load(f)
    
labels_roi = ['Left Pallidum',
                 'Left Caudate',
                 'Left Putamen',
                 'Left Accumbens',

                 'Right Pallidum',
                 'Right Caudate',
                 'Right Putamen',
                 'Right Accumbens']

## Download CWL dataset

You can save it into numpy files and the train only on such files. 

In [30]:
path_to_dataset = '../data/eyes_open_closed_dataset/'
dataset_name = 'CWL'
remove_confounds = True

for patient in config['patients']:
    df_eeg_cwl_raw, df_fmri_cwl_raw, labels_roi_17 = get_datasets.download_cwl_dataset(patient, path_to_dataset, 
                                                                        remove_confounds=remove_confounds,
                                                                        verbose=True)
    df_eeg_cwl, df_fmri_cwl, fps = get_datasets.interpolate_df_eeg_fmri(df_eeg_cwl_raw, df_fmri_cwl_raw)
    
    
    config['fps'] = fps
    print('Original FPS', config['fps']) 
    
    # delete time columns. 
    # reshape [time, ch] -> [ch, time]
    eeg_np = df_eeg_cwl.drop(['time'], axis=1).to_numpy().T
    fmri_np = df_fmri_cwl.drop(['time'], axis=1).to_numpy().T
    
    
    eeg_np = preproc.low_level_preproc_eeg(eeg_np, fps)
    
    data = {'eeg': eeg_np, 
            'fmri': fmri_np}
    
    # np.savez(f'../data/preproc/{config["dataset_name"]}/{patient}_{config["fps"]}_filtered_data', 
    #          eeg=eeg_np, fmri=fmri_np)



ALL path:  ../data/eyes_open_closed_dataset/trio1/CWL_Data/eeg/in-scan/trio1_mrcorrected_eoec_in-scan_hpump-off.set ../data/eyes_open_closed_dataset/trio1/CWL_Data/mri/epi_normalized/rwatrio1_eoec_in-scan_hpump-off.nii ../data/eyes_open_closed_dataset/trio1/CWL_Data/mri/epi_motionparams/rp_atrio1_eoec_in-scan_hpump-off.txt


  raw = mne.io.read_raw_eeglab(eeg_path_set_file)
  raw = mne.io.read_raw_eeglab(eeg_path_set_file)
  motion_confound = pd.read_csv(motion_params_path, sep = '  ', header=None)


Dimension of our EEG data:  (303601, 31)
Dimension of our fMRi data:  (61, 72, 61, 146)
Dimension of our fMRi Roi data:  (143, 18)
fMRI info :  1.95
RoI:  ['Left Lateral Ventricle', 'Left Thalamus', 'Left Caudate', 'Left Putamen', 'Left Pallidum', 'Brain-Stem', 'Left Hippocampus', 'Left Amygdala', 'Left Accumbens', 'Right Lateral Ventricle', 'Right Thalamus', 'Right Caudate', 'Right Putamen', 'Right Pallidum', 'Right Hippocampus', 'Right Amygdala', 'Right Accumbens', 'time']
Original FPS 1000
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 6601 samples (6.601 sec)



  return df_eeg, df_fmri, df_fmri.drop(['time'], 1).columns.to_list()


Setting up band-pass filter from 1 - 1e+02 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 100.00 Hz
- Upper transition bandwidth: 25.00 Hz (-6 dB cutoff frequency: 112.50 Hz)
- Filter length: 3301 samples (3.301 sec)



In [4]:
df_eeg_cwl_raw['time'], df_fmri_cwl_raw['time']

NameError: name 'df_eeg_cwl_raw' is not defined

## Preprocessing of datasets

- get only useful fmri ROI 
- Cut starting points 
- Normalize 
- Splitting on train/test
- downsampling to new_fps 
- Shifting between EEG and FMRI 
- Dataset and DataLoader 

In [3]:
with open('../labels_roi_17.json', 'r') as file:
    labels_roi_17 = json.load(file)
    
# Define the file path using config and patient variables
file_path = f"../data/preproc/{config['dataset_name']}/{config['patients'][0]}_{config['fps']}_filtered_data.npz"



# Load the data from the .npz file
data = np.load(file_path)

eeg, fmri = data['eeg'], data['fmri']
df = pd.DataFrame(data = fmri.T, columns=labels_roi_17)
df_filter = df[labels_roi]
fmri = df_filter.to_numpy().T

# crop start
train_crop = config['crop_start']*config['fps']
eeg, fmri = eeg[..., train_crop:], fmri[..., train_crop:]

# normalize 
eeg = eeg / np.std(eeg)
fmri, fmri_means_stds = preproc.normalize_data(fmri)

# train/test split
test_time = int(60*config['fps'])
train_dataset_prep = (eeg[..., :-test_time], fmri[..., :-test_time])
test_dataset_prep = (eeg[..., -test_time:], fmri[..., -test_time:])


ds_factor = config['fps']/config['new_fps']
train_dataset_prep = preproc.downsample_dataset(train_dataset_prep, factor = ds_factor)
test_dataset_prep = preproc.downsample_dataset(test_dataset_prep, factor = ds_factor)

# apply time dealy corrected
train_dataset_prep = preproc.bold_time_delay_align(train_dataset_prep, 
                                                   config['new_fps'],
                                                   config['bold_delay'])
test_dataset_prep = preproc.bold_time_delay_align(test_dataset_prep, 
                                                  config['new_fps'],
                                                  config['bold_delay'])


print('Size of train dataset:', train_dataset_prep[0].shape, train_dataset_prep[1].shape)
print('Size of test dataset:', test_dataset_prep[0].shape, test_dataset_prep[1].shape)

# torch dataset creation 
torch_dataset_train = torch_dataset.CreateDataset_eeg_fmri(train_dataset_prep, 
                                                            random_sample=config['random_subsample'], 
                                                            sample_per_epoch=config['sample_per_epoch'], 
                                                            to_many=config['to_many'], 
                                                            window_size = config['WINDOW_SIZE'])

torch_dataset_test = torch_dataset.CreateDataset_eeg_fmri(test_dataset_prep, 
                                                            random_sample=False, 
                                                            sample_per_epoch=None, 
                                                            to_many=config['to_many'], 
                                                            window_size = config['WINDOW_SIZE'])
print('Size of test dataset:', len(torch_dataset_test))
# because you do not have strid for val data. 
torch_dataset_test = Subset(torch_dataset_test, np.arange(len(torch_dataset_test))[::100])

# init dataloaders for training
train_loader = torch.utils.data.DataLoader(torch_dataset_train, **params_train)
val_loader = torch.utils.data.DataLoader(torch_dataset_test, **params_val)




NameError: name 'json' is not defined

# Init Model, Loss, optimizers

In [10]:
model = autoencoder_new_Artur.AutoEncoder1D_Artur_MultiHead(hp_autoencoder)

print(summary(model, torch.zeros(4, config['n_channels'],
                                 config['WINDOW_SIZE']), show_input=False))


-----------------------------------------------------------------------------
            Layer (type)        Output Shape         Param #     Tr. Param #
   AutoEncoder1D_Artur-1        [4, 1, 2048]         245,057         245,057
   AutoEncoder1D_Artur-2        [4, 1, 2048]         245,057         245,057
   AutoEncoder1D_Artur-3        [4, 1, 2048]         245,057         245,057
   AutoEncoder1D_Artur-4        [4, 1, 2048]         245,057         245,057
   AutoEncoder1D_Artur-5        [4, 1, 2048]         245,057         245,057
   AutoEncoder1D_Artur-6        [4, 1, 2048]         245,057         245,057
   AutoEncoder1D_Artur-7        [4, 1, 2048]         245,057         245,057
   AutoEncoder1D_Artur-8        [4, 1, 2048]         245,057         245,057
Total params: 1,960,456
Trainable params: 1,960,456
Non-trainable params: 0
-----------------------------------------------------------------------------


# Model training

In [23]:
n_runs = 1

for i in range(n_runs):
    
    model = autoencoder_new_Artur.AutoEncoder1D_Artur_MultiHead(hp_autoencoder)

    loss_func = train_utils.make_complex_loss_function(mse_weight = config['mse_weight'], 
                                                       corr_weight = config['corr_weight'],
                                                       manifold_weight = 0,
                                                       bound=1)
    
    train_step = train_utils.train_step

    optimizer = optim.AdamW(model.parameters(), 
                       lr=config['lr'], 
                       weight_decay=config['weight_decay'])
    
    
    parameters = {
        'EPOCHS': 1500,
        'model': model, 
        'train_loader': train_loader, 
        'val_loader': val_loader, 
        'loss_function': loss_func,
        'train_step': train_step,
        'optimizer': optimizer, 
        'device': 'cuda', 
        'raw_test_data': test_dataset_prep,
        'show_info': 30, 
        'num_losses': 5,
        'labels': labels_roi,
        'inference_function': inference.model_inference_function, 
        'to_many': config['to_many']
    }

    path_to_save_wandb = '/home/lyz6/Documents'

    with wandb.init(project="eeg_fmri", config=config, save_code=True):
        wandb.define_metric("val/corr_mean", summary="max")

        if i == 0: 
            exp_name = wandb.run.name
        
        wandb.run.name = exp_name +'_run_' + str(i)
        
        print(config)
        print(parameters['model'])
        print(summary(model, torch.zeros(4, config['n_channels'], config['WINDOW_SIZE']), show_input=False))
        
        model = train_utils.wanb_train_regression(**parameters)
        

Problem at: /tmp/tmp.9tdT0GhyR8/ipykernel_3837830/1333472953.py 38 <module>


KeyboardInterrupt: 

In [20]:
# DEEPONET
import json

with open('../labels_roi_17.json', 'r') as file:
    labels_roi_17 = json.load(file)
    
# Define the file path using config and patient variables
file_path = f"../data/preproc/{config['dataset_name']}/{config['patients'][0]}_{config['fps']}_filtered_data.npz"

# Load the data from the .npz file
data = np.load(file_path)

eeg, fmri = data['eeg'], data['fmri']
df = pd.DataFrame(data = fmri.T, columns=labels_roi_17)
df_filter = df[labels_roi]
fmri = df_filter.to_numpy().T

# crop start
train_crop = config['crop_start']*config['fps']
eeg, fmri = eeg[..., train_crop:], fmri[..., train_crop:]

# normalize 
eeg = eeg / np.std(eeg)
fmri, fmri_means_stds = preproc.normalize_data(fmri)

# train/test split
test_time = int(60*config['fps'])
train_dataset_prep = (eeg[..., :-test_time], fmri[..., :-test_time])
test_dataset_prep = (eeg[..., -test_time:], fmri[..., -test_time:])


ds_factor = config['fps']/config['new_fps']
train_dataset_prep = preproc.downsample_dataset(train_dataset_prep, factor = ds_factor)
test_dataset_prep = preproc.downsample_dataset(test_dataset_prep, factor = ds_factor)

# apply time dealy corrected
train_dataset_prep = preproc.bold_time_delay_align(train_dataset_prep, 
                                                   config['new_fps'],
                                                   config['bold_delay'])
test_dataset_prep = preproc.bold_time_delay_align(test_dataset_prep, 
                                                  config['new_fps'],
                                                  config['bold_delay'])


print('Size of train dataset:', train_dataset_prep[0].shape, train_dataset_prep[1].shape)
print('Size of test dataset:', test_dataset_prep[0].shape, test_dataset_prep[1].shape)

# torch dataset creation 
torch_dataset_train = torch_dataset.CreateDeepONetDataset(train_dataset_prep, 
                                                            random_sample=config['random_subsample'], 
                                                            sample_per_epoch=config['sample_per_epoch'], 
                                                            to_many=config['to_many'], 
                                                            window_size = config['WINDOW_SIZE'])

torch_dataset_test = torch_dataset.CreateDeepONetDataset(test_dataset_prep, 
                                                            random_sample=False, 
                                                            sample_per_epoch=None, 
                                                            to_many=config['to_many'], 
                                                            window_size = config['WINDOW_SIZE'])
print('Size of test dataset:', len(torch_dataset_test))
# because you do not have strid for val data. 
torch_dataset_test = Subset(torch_dataset_test, np.arange(len(torch_dataset_test))[::100])

# init dataloaders for training
train_loader = torch.utils.data.DataLoader(torch_dataset_train, **params_train)
val_loader = torch.utils.data.DataLoader(torch_dataset_test, **params_val)

Size of train dataset: (30, 20590) (8, 20590)
Size of test dataset: (30, 5400) (8, 5400)
Size of test dataset: 3351


In [30]:


# ========== MODEL DEFINITION ==========
import torch
import torch.nn as nn
import torch.nn.functional as F

def initialize_weights(m):
    if isinstance(m, (nn.Linear, nn.Conv1d)):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

class BranchNet(nn.Module):
    def __init__(self, input_channels, window_size, output_dim=8):  # Reduced output_dim
        super(BranchNet, self).__init__()
        self.input_channels = input_channels
        self.window_size = window_size
        
        self.conv1 = nn.Sequential(
            nn.Conv1d(input_channels, 128, kernel_size=5, stride=2, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.conv2 = nn.Conv1d(128, 64, kernel_size=5, stride=2, padding=3)
        
        self.pool = nn.AdaptiveAvgPool1d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim)
        )

        self.apply(initialize_weights)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        
        x = self.pool(x).squeeze(-1)
        x = self.fc(x)
        return x

class TrunkNet(nn.Module):
    def __init__(self, input_dim, output_dim=8):  # Keep same output_dim
        super(TrunkNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Dropout(0.1),  # Dropout added
            nn.Linear(16, output_dim)
        )

        self.apply(initialize_weights)

    def forward(self, x):
        return self.fc(x)

class DeepONet(nn.Module):
    def __init__(self, 
                 input_channels, 
                 window_size, 
                 n_roi, 
                 branch_output_dim=128,   
                 trunk_output_dim=128):    
        super(DeepONet, self).__init__()
        
        self.branch = BranchNet(
            input_channels=input_channels, 
            window_size=window_size, 
            output_dim=branch_output_dim
        )
        
        self.trunk = TrunkNet(
            input_dim=window_size * n_roi, 
            output_dim=trunk_output_dim
        )
        
        self.window_size = window_size
        self.n_roi = n_roi

    def forward(self, inputs):
        eeg, trunk_inputs = inputs

        branch_out = self.branch(eeg)
        trunk_out = self.trunk(trunk_inputs)

        output = torch.einsum('bi,bj->bij', branch_out, trunk_out)
        output = output.view(-1, self.n_roi, self.window_size)
        
        return output

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_step(x_batch, y_batch, model, optimizer, loss_function):
    optimizer.zero_grad()
    
    # Gradient clipping
    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    outputs = model(x_batch)
    losses = loss_function(outputs, y_batch)
    
    #losses = [loss + 1e-8 if torch.any(loss == 0) else loss for loss in losses]
    #print(losses)

    losses[0].backward()
    optimizer.step()
    
    return losses

n_runs = 1

for i in range(n_runs):
    # Model initialization
    model = DeepONet(
        input_channels=config['n_channels'], 
        window_size=config['WINDOW_SIZE'], 
        n_roi=config['n_roi']
    )

    loss_func = train_utils.make_new_complex_loss_function(mse_weight = config['mse_weight'], 
                                                       corr_weight = config['corr_weight'])
    
    train_step = train_step

    optimizer = optim.AdamW(model.parameters(), 
                       lr=config['lr'], 
                       weight_decay=config['weight_decay'])
    
    
    parameters = {
        'EPOCHS': 4500,
        'model': model, 
        'train_loader': train_loader, 
        'val_loader': val_loader, 
        'loss_function': loss_func,
        'train_step': train_step,
        'optimizer': optimizer, 
        'device': 'cuda', 
        'raw_test_data': test_dataset_prep,
        'show_info': 50, 
        'num_losses': 5,
        'labels': labels_roi,
        'inference_function': inference.new_model_inference_function, 
        'to_many': config['to_many']
    }

    path_to_save_wandb = '/home/lyz6/Documents'

    with wandb.init(project="eeg_fmri", config=config, save_code=True):
        wandb.define_metric("val/corr_mean", summary="max")

        if i == 0: 
            exp_name = wandb.run.name
        
        wandb.run.name = exp_name +'_run_' + str(i)
        
        print(config)
        print(parameters['model'])
        
        # Create dummy inputs for the model: EEG signals and time indices
        dummy_eeg = torch.zeros(4, config['n_channels'], config['WINDOW_SIZE'])  # Batch of 4 EEG samples
        dummy_time = torch.zeros(4, config['WINDOW_SIZE'] * 8)  # Batch of 4 normalized time indices
        print(dummy_eeg.shape)
        print(dummy_time.shape)
        
        # Call the summary function with both inputs
        print(summary(model, (dummy_eeg, dummy_time), show_input=False))

        model = train_utils.new_train_regression(**parameters)
        

[34m[1mwandb[0m: wandb version 0.18.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


{'n_electrodes': 30, 'n_freqs': 1, 'n_channels_out': 8, 'channels': [128, 128, 128, 128], 'kernel_sizes': [5, 5, 3], 'strides': [8, 8, 4], 'dilation': [1, 1, 1], 'decoder_reduce': 4, 'hidden_channels': 16, 'dataset_name': 'CWL_raw', 'patients': ['trio1'], 'fps': 1000, 'new_fps': 100, 'crop_start': 5, 'freqs': [-1], 'n_channels': 30, 'n_roi': 8, 'bold_delay': 6, 'to_many': True, 'random_subsample': True, 'sample_per_epoch': 512, 'WINDOW_SIZE': 2048, 'optimizer': 'adamW', 'lr': 1e-05, 'weight_decay': 0.0001, 'batch_size': 16, 'mse_weight': 1.0, 'corr_weight': 0.0, 'preproc_type': 'dB_log', 'loss_function': 'corr', 'model_type': 'Best_AE_Artur_Multi_Head'}
DeepONet(
  (branch): BranchNet(
    (conv1): Sequential(
      (0): Conv1d(30, 128, kernel_size=(5,), stride=(2,), padding=(3,))
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv2): Conv1d(128, 64, kernel_size=(5,), stride=(2,), padding=(3,))
    (pool): Adapti

In [None]:

"""
class DeepONet(nn.Module):
    def __init__(self, n_channels, window_size, n_roi, latent_dim=16):
        super(DeepONet, self).__init__()
        self.branch = BranchNet(input_channels=n_channels, output_dim=latent_dim)
        self.trunk = TrunkNet(input_dim=window_size, output_dim=latent_dim)
        self.n_roi = n_roi  # Number of regions of interest (output dimension)

    def forward(self, inputs):
        eeg, trunk_inputs = inputs
        batch_size, time_steps = trunk_inputs.size(0), trunk_inputs.size(1)

        # Outputs from BranchNet and TrunkNet
        branch_out = self.branch(eeg)  # [batch_size, latent_dim]
        trunk_out = self.trunk(trunk_inputs)  # [batch_size, time_steps, latent_dim]

        # Compute the inner product across the latent_dim
        branch_out = branch_out.unsqueeze(1)  # [batch_size, 1, latent_dim]
        output = torch.sum(branch_out * trunk_out, dim=-1)  # [batch_size, time_steps]

        # Add extra dimension for ROI-specific outputs
        output = output.unsqueeze(1).expand(-1, self.n_roi, -1)  # [batch_size, n_roi, time_steps]
        return output

"""
"""
class DeepONet(nn.Module):
    def __init__(self, 
                 n_channels, 
                 window_size, 
                 n_roi,
                 branch_out_dim=16, 
                 trunk_out_dim=8):
        super(DeepONet, self).__init__()
        self.branch = BranchNet(input_channels=n_channels, output_dim=branch_out_dim)
        
        self.trunk = TrunkNet(
            input_dim=window_size,  # Entire time series as input
            output_dim=trunk_out_dim
        )
        
        # Prediction layer
        self.fc = nn.Linear(branch_out_dim + trunk_out_dim, n_roi)

    def forward(self, inputs):
        eeg, trunk_inputs = inputs  # Unpack the inputs tuple
        branch_out = self.branch(eeg)  # [batch_size, branch_out_dim]
        trunk_out = self.trunk(trunk_inputs)  # [batch_size, trunk_out_dim]
        combined = torch.cat([branch_out, trunk_out], dim=-1)
        return self.fc(combined)
"""

# 