# VAE training and processing

Sample code to train a new VAE and run the CSI processing.

In [1]:
import os
import math
import string

import pickle
import numpy as np
import scipy.io as sio
import tensorflow as tf

In [17]:
ANTENNAS = 1
BATCH_SIZE = 25

antenna = 0  # if ANTENNAS==1, this value selects the antenna ID (from 0 to 3)
latent_dim = 2
num_activities = 12
if ANTENNAS == 1:
    vae_name = f'vae_s1a_a{antenna}_ls{latent_dim}'
    folder_name = f'VAE_models_12activities/{vae_name}'
elif ANTENNAS == 4:
    vae_name = f'vae_s1a_f_ls{latent_dim}'
    folder_name = f'VAE_models_12activities/{vae_name}'
else:
    print('Invalid number of antennas')

## Download the dataset

The original dataset is hosted on Zenodo. Download the dataset and place it in `dataset` directory.

In [None]:
!wget https://zenodo.org/record/7732595/files/S1.zip
!rm -rf dataset
!unzip S1.zip
!mv S1 dataset
!rm S1.zip

## CSI data generator

In [3]:
class CsiDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, files, num_samples=12000, window_size=450, antennas=1, batch_size=25, antenna_select=0):
        if antennas == 1:
            self.csi = tf.zeros([0, 2048], dtype=tf.float32)
        else:
            self.csi = tf.zeros([0, 2048, antennas], dtype=tf.float32)

        self.labels = tf.zeros([0], dtype=tf.int32)
        self.indices = tf.zeros([0], dtype=tf.int32)
        self.window_size = window_size
        self.batch_size = batch_size
        self.antennas = antennas

        for file in files:
            # Load CSI data from MATLAB file
            mat = sio.loadmat(file)      # WARNING This code does not handle exceptions for simplicity...
            data = np.array(mat['csi'])  # ...exceptions would require keeping track of indices
            if self.antennas == 1:
                data = data[range(num_samples), ..., int(antenna_select)]
            data = np.round(np.abs(data))

            index_offset = self.csi.shape[0]
            activity_label = files.index(file)  # Labels depend on file index

            # Cast CSI data into temporary TF tensors for building the dataset
            csi = tf.convert_to_tensor(data, dtype=tf.float32)
            labels = tf.convert_to_tensor(activity_label * np.ones(num_samples - window_size), dtype=tf.int32)
            indices = tf.convert_to_tensor(tf.range(index_offset, index_offset + num_samples - window_size),
                                           dtype=tf.int32)

            # Concatenate to the previous tensors
            self.csi = tf.concat([self.csi, csi], axis=0)
            self.labels = tf.concat([self.labels, labels], axis=0)
            self.indices = tf.concat([self.indices, indices], axis=0)

        # Normalize the CSI dataset
        if self.antennas == 1:
            self.csi = tf.math.divide(self.csi, tf.math.reduce_max(self.csi, axis=(0, 1)))
        else:
            self.csi = tf.math.divide(self.csi, tf.math.reduce_max(self.csi, axis=(0, 1, 2)))

    def __len__(self):
        return int(np.ceil(self.indices.shape[-1] / float(self.batch_size)))

    def __getitem__(self, batch_idx):
        first_idx = batch_idx * self.batch_size
        last_idx = (batch_idx + 1) * self.batch_size

        data_batch = [self.csi[x:x + self.window_size, ...] for x in range(first_idx, last_idx)]
        labels_batch = np.transpose([self.labels[first_idx:last_idx]])

        data_batch = tf.convert_to_tensor(data_batch)
        labels_batch = tf.convert_to_tensor(labels_batch)

        if self.antennas == 1:
            data_batch = tf.expand_dims(data_batch, 3)
            labels_batch = tf.expand_dims(labels_batch, 2)

        return data_batch, labels_batch

In [4]:
file_list = [f'./dataset/S1a_{x}.mat' for x in string.ascii_uppercase[:num_activities]]
print(file_list)

csi_generator = CsiDataGenerator(file_list, batch_size=BATCH_SIZE, antenna_select=antenna)

['./dataset/S1a_A.mat', './dataset/S1a_B.mat', './dataset/S1a_C.mat', './dataset/S1a_D.mat', './dataset/S1a_E.mat', './dataset/S1a_F.mat', './dataset/S1a_G.mat', './dataset/S1a_H.mat', './dataset/S1a_I.mat', './dataset/S1a_J.mat', './dataset/S1a_K.mat', './dataset/S1a_L.mat']


## Variational Auto-Encoder

In [5]:
class Sampling(tf.keras.layers.Layer):
    """Takes a couple (z_mean, z_log_var) to draw a sample z from the latent space."""
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
def create_csi_encoder(input_shape, latent_dim):
    encoder_inputs = tf.keras.Input(shape=input_shape)
    x = tf.keras.layers.Conv2D(32, (5, 8), activation='relu', strides=(5, 8), padding='valid')(encoder_inputs)
    x = tf.keras.layers.Conv2D(32, (5, 8), activation='relu', strides=(5, 8), padding='valid')(x)
    x = tf.keras.layers.Conv2D(32, (2, 4), activation='relu', strides=(2, 4), padding='valid')(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(16, activation='relu')(x)

    z_mean = tf.keras.layers.Dense(latent_dim, name='z_mean')(x)
    z_log_var = tf.keras.layers.Dense(latent_dim, name='z_log_var')(x)
    z = Sampling()([z_mean, z_log_var])

    return tf.keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')


def create_csi_decoder(input_shape, latent_dim, out_filter):
    decoder_inputs = tf.keras.Input(shape=(latent_dim,))
    x = tf.keras.layers.Dense(math.prod(input_shape), activation='relu')(decoder_inputs)
    x = tf.keras.layers.Reshape(input_shape)(x)
    x = tf.keras.layers.Conv2DTranspose(32, (2, 4), activation='relu', strides=(2, 4), padding='same')(x)
    x = tf.keras.layers.Conv2DTranspose(32, (5, 8), activation='relu', strides=(5, 8), padding='same')(x)
    x = tf.keras.layers.Conv2DTranspose(32, (5, 8), activation='relu', strides=(5, 8), padding='same')(x)
    decoder_outputs = tf.keras.layers.Conv2DTranspose(out_filter, out_filter, activation='sigmoid', padding='same')(x)

    return tf.keras.Model(decoder_inputs, decoder_outputs, name='decoder')

In [6]:
class VAE(tf.keras.Model):

    def __init__(self, enc_input_shape=(450, 2048, 1), dec_input_shape=(9, 8, 32), latent_dim=2, **kwargs):
        super().__init__(**kwargs)
        self.encoder = create_csi_encoder(enc_input_shape, latent_dim)
        self.decoder = create_csi_decoder(dec_input_shape, latent_dim, enc_input_shape[-1])
        self.total_loss_tracker = tf.keras.metrics.Mean(name='total_loss')
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name='reconstruction_loss')
        self.kl_loss_tracker = tf.keras.metrics.Mean(name='kl_loss')

        self.encoder.summary()
        self.decoder.summary()

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data[0])
            reconstruction = self.decoder(z)

            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(data[0], reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            'loss': self.total_loss_tracker.result(),
            'reconstruction_loss': self.reconstruction_loss_tracker.result(),
            'kl_loss': self.kl_loss_tracker.result(),
        }

    def call(self, inputs, training=None, mask=None):
        pass


In [7]:

checkpoint_path = f'./{folder_name}/' + 'cp-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, verbose=1, save_weights_only=True)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
csv_logger_cb = tf.keras.callbacks.CSVLogger(f'./{folder_name}/model_history_log.csv', append=True)

Training the VAEs from scratch can take a very long time, so it is advised to load the pre-trained models.

In [11]:
load_pretrained_models = True

if load_pretrained_models:
#     !wget https://zenodo.org/record/8239343/files/VAE_models_12activities.zip
    !unzip -o VAE_models_12activities.zip
    !rm VAE_models_12activities.zip
else:
    # Train from scratch
    !mkdir {folder_name}
    vae = VAE()
    vae.compile(optimizer=tf.keras.optimizers.Adam())
    vae.save_weights(checkpoint_path.format(epoch=0))
    vae.fit(csi_generator, epochs=20, shuffle=True,
            callbacks=[checkpoint_cb, early_stopping_cb, csv_logger_cb])
    vae.save_weights(f'./{folder_name}/weights_vae')

Archive:  VAE_models_12activities.zip
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0001.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0019.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0005.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0005.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0014.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0009.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0008.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0011.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/checkpoint  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0013.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0017.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls3/cp-0020.ckpt.data-00000-of-00001  
  inflatin

  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0017.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0020.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/weights_vae.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0000.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0016.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0007.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0012.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0016.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0004.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0000.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0013.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a2_ls5/cp-0002.ckpt.index  
  inflating: VAE_models_12a

  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0018.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0010.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0014.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0001.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0018.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/weights_vae.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0004.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0002.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0015.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0006.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/cp-0009.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/model_history_log.csv  
  inflating: VAE_models_12activities/vae_s1a_a3_ls2/

  inflating: VAE_models_12activities/vae_s1a_a0_ls4/cp-0003.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a0_ls4/cp-0006.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a0_ls4/cp-0019.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a0_ls4/cp-0015.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a0_ls4/cp-0011.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a0_ls4/cp-0017.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a0_ls4/cp-0020.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a0_ls3/cp-0001.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a0_ls3/cp-0019.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a0_ls3/cp-0005.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a0_ls3/cp-0005.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a0_ls3/cp-0014.ckpt.index  
  inflating: VAE_models_12activities/vae

  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0020.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/weights_vae.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0000.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0016.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0007.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0012.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0016.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0004.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0000.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0013.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0002.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a2_ls4/cp-0008.ckpt.data-00000-of-00001  
  inflating: VAE_models_12a

  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0018.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0010.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0014.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0001.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0018.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/weights_vae.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0004.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0002.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0015.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0006.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/cp-0009.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/model_history_log.csv  
  inflating: VAE_models_12activities/vae_s1a_a1_ls2/

  inflating: VAE_models_12activities/vae_s1a_a0_ls5/cp-0019.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a0_ls5/cp-0015.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a0_ls5/cp-0011.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a0_ls5/cp-0017.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a0_ls5/cp-0020.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls4/cp-0001.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a3_ls4/cp-0019.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls4/cp-0005.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a3_ls4/cp-0005.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls4/cp-0014.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_a3_ls4/cp-0009.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_a3_ls4/cp-0008.ckpt.index  
  inflating: VAE_models_12activities/vae

  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0017.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0020.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/weights_vae.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0000.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0016.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0007.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0012.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0016.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0004.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0000.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0013.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls4/cp-0002.ckpt.index  
  inflating: VAE_models_12activities/va

  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0014.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0001.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0018.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/weights_vae.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0004.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0002.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0015.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0006.ckpt.data-00000-of-00001  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0009.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/model_history_log.csv  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0012.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0003.ckpt.index  
  inflating: VAE_models_12activities/vae_s1a_f_ls2/cp-0007.ckpt.data-00000-of-00001  
  infl

## Use the VAE to process CSI data

In [14]:
z_data = np.zeros([0, 4])
z_labels = np.zeros([0])

vae = VAE(enc_input_shape=(450, 2048, ANTENNAS))
vae.compile(optimizer=tf.keras.optimizers.Adam())
vae.load_weights(f'./{folder_name}/weights_vae').expect_partial()

for (data, labels) in csi_generator:
    labels = tf.squeeze(labels)
    z_mean, z_log_var, _ = vae.encoder.predict(data, verbose=0)
    z_tmp = np.concatenate([z_mean, z_log_var], axis=1)
    z_data = np.concatenate([z_data, z_tmp], axis=0)
    z_labels = np.concatenate([z_labels, labels], axis=0)

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 450, 2048,   0           []                               
                                1)]                                                               
                                                                                                  
 conv2d_3 (Conv2D)              (None, 90, 256, 32)  1312        ['input_3[0][0]']                
                                                                                                  
 conv2d_4 (Conv2D)              (None, 18, 32, 32)   40992       ['conv2d_3[0][0]']               
                                                                                                  
 conv2d_5 (Conv2D)              (None, 9, 8, 32)     8224        ['conv2d_4[0][0]']         

2023-08-11 21:32:26.013852: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [19]:
# Store the latent space representation of CSI data to file.
!mkdir latent_space_dataset_12activities
with open(f'./latent_space_dataset_12activities/{vae_name}.pkl', 'wb') as f:
    pickle.dump([z_data, z_labels], f)

mkdir: latent_space_dataset_12activities: File exists
