In [None]:
import torch
import numpy as np
import os
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, MaxPooling2D, UpSampling2D
from tensorflow.keras.models import Model
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
import cv2
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential


In [None]:
#load dataset
#data already prepared before using in this notebook
dir = 'path to main directory'
data_train = torch.load(os.path.join(dir, 'eeg train dataset file'))
data_test = torch.load(os.path.join(dir, 'eeg test dataset file'))
ds_train = data_train['dataset']
ds_test = data_test['dataset']

In [None]:
#load bilstm encoder model
bilstmencoder = load_model('path to encoder model file')
bilstmencoder = Model(inputs=bilstmencoder.inputs, outputs=bilstmencoder.layers[-2].output)
bilstmencoder.summary()

In [None]:
#preprocessing steps

X_train = []
X_test = []
Y_train = []
Y_test = []

def preprocess_X(X, ds)->None:
    for i in ds:
        X.append(i['eeg'].numpy())

preprocess_X(X_train, ds_train)
preprocess_X(X_test, ds_test)

def trim(X, max_cols)->None:
    X_trimmed = [arr[:, :max_cols] for arr in X if arr.shape[1] >= max_cols]
    return np.array(X_trimmed)

X_train = trim(X_train, 480)
X_test = trim(X_test, 480)
    
image_ids_train = ds_train['images']
image_ids_test = ds_test['images']

def resize_image(image, target_size=(64, 64)):
    return cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), target_size)

for i in ds_train['dataset']:
    Y_train.append(np.array(resize_image(cv2.imread('path to dir with images that are part of the eeg-image pairs dataset'+image_ids_train[i['image']]+'.JPEG')))/255)
Y_train = np.array(Y_train)

for i in ds_test['dataset']:
    Y_test.append(np.array(resize_image(cv2.imread('path to dir with images that are part of the eeg-image pairs dataset'+image_ids_test[i['image']]+'.JPEG')))/255)
Y_test = np.array(Y_test)

In [None]:
#load the pretrained vae

input_shape = (64, 64, 3) 
encoder_input = Input(shape=input_shape, name='encoder_input')
x = Conv2D(64, 3, strides=1, activation = 'relu', padding = 'same')(encoder_input)
x = MaxPooling2D()(x)
x = Conv2D(64, 5, strides=2, activation = 'relu', padding = 'same')(x)
x = Conv2D(128, 3, strides=1, activation = 'relu')(x)
x = Conv2D(128, 5, strides=1, activation = 'relu')(x)
x = Conv2D(256, 3, strides=1, activation = 'relu')(x)
x = Conv2D(512, 3, strides=2, activation = 'relu')(x)
x = Conv2D(64, 3, strides=2, activation = 'relu', padding = 'same')(x)
x = Flatten()(x)

latent_dim = 2*2*64
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=1.)
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

z = tf.keras.layers.Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

encoder = Model(encoder_input, [z_mean, z_log_var, z], name='encoder')
encoder.summary()

latent_inputs = Input(shape=(latent_dim,), name='latent_inputs')
x = Reshape((2, 2, 64))(latent_inputs)
x = Conv2DTranspose(512, 3, strides=2, activation='relu')(x)
x = Conv2DTranspose(256, 5, strides=1, activation='relu')(x)
x = Conv2DTranspose(256, 3, strides=1, activation='relu')(x)
x = Conv2DTranspose(256, 5, strides=1, activation='relu', padding = 'same')(x)
x = Conv2DTranspose(128, 3, strides=1, activation='relu')(x)
x = Conv2DTranspose(128, 4, strides=1, activation='relu')(x)
x = UpSampling2D()(x)
x = Conv2DTranspose(64, 3, strides=2, padding = 'same', activation='relu')(x)
decoder_output = Conv2DTranspose(3, 3, strides=1,padding = 'same', activation='sigmoid')(x)



decoder = Model(latent_inputs, decoder_output, name='decoder')
decoder.summary()

vae_outputs = decoder(encoder(encoder_input)[2])
vae = Model(encoder_input, vae_outputs, name='vae')

def custom_loss(y_true, y_pred):
    loss1 = K.mean(K.square(y_true - y_pred))
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    kl_loss = K.mean(kl_loss)
    return (loss1) + (kl_loss)

vae.compile(optimizer='adam', loss = custom_loss)
vae.summary()
vae.load_weights('path to vae weigths')


In [None]:
#add dense layer to encoder and pretrained vae decoder to create combined model

decoder_input = vae.get_layer('decoder').input 
decoder_output = vae.get_layer('decoder').output 

decoder_model = Model(inputs=decoder_input, outputs=decoder_output, name='decoder')
# for layer in decoder_model.layers:
#     layer.trainable = True
decoder_model.summary()

bilstmencoder_input = bilstmencoder.input
bilstmencoder_output = bilstmencoder.output
x = Dense(256, activation='relu', name = 'hehe')(bilstmencoder_output)

modified_bilstmencoder = Model(inputs=bilstmencoder_input, outputs=x)

for layer in modified_bilstmencoder.layers[:-1]:
    layer.trainable = False
    
combined_model = Sequential([modified_bilstmencoder, decoder_model])
combined_model.summary()
combined_model.compile(optimizer='adam', loss='mae')


In [None]:
#training loop
history = combined_model.fit(
    X_train, 
    Y_train, 
    epochs = 100, 
    validation_split = 0.05,
    batch_size=32,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=8,
            restore_best_weights=True
        )
    ]
)

In [None]:
#loss curves
fig = plt.figure()
plt.plot(history.history['loss'], color='teal', label='loss')
plt.plot(history.history['val_loss'], color='orange', label='val_loss')
fig.suptitle('Loss', fontsize=20)
plt.legend(loc="upper left")
plt.show()

In [None]:
# load test images to visualize ground truths
Y_visualize = []
for i in ds_test['dataset']:
    m = cv2.imread('path to dir with images that are part of the eeg-image pairs dataset'+image_ids_test[i['image']]+'.JPEG')
    m = cv2.cvtColor(m, cv2.COLOR_BGR2RGB)
    Y_visualize.append(m)

In [None]:
#show generated images
generated_images = combined_model.predict(X_test)
num_images_to_display = 50  

plt.figure(figsize=(6, 2*num_images_to_display))  

for i in range(num_images_to_display):
    plt.subplot(num_images_to_display, 2, 2*i + 1)  
    plt.imshow(generated_images[i])
    plt.axis('off')


for i in range(num_images_to_display):
    plt.subplot(num_images_to_display, 2, 2*i + 2)  
    plt.imshow(Y_visualize[i])
    plt.axis('off')

plt.show()