In [None]:
import tensorflow as tf

import numpy as np
from tensorflow.keras import layers
import tensorflow.keras.backend as K

import ssl
from tensorflow.keras import datasets, layers, models

In [None]:
NUM_EPOCH = 200

# training SNR
channel_snr = 15 # in dB
channel_type = 'awgn' # "fading" or "awgn"
power_normalization = 'average' # " average" or "max" power constraint

# ##### KSG MI estimator #######################
# settings = {'kraskov_k': 3}
# gpu_est = est.OpenCLKraskovMI(settings = settings)
# # usage: MI = estimator.estimate(var1, var2)

##### setting for communication channel ######

def getnoisevariance(SNRdB,P=1):
    snr = 10.0**(SNRdB/10.0)
    N0 = P/snr
    return (N0/2)

noise_var = getnoisevariance(channel_snr,P=1)


In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = (x_train / 255.0)
x_test = (x_test / 255.0)

x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)

Y_train = tf.keras.utils.to_categorical(y_train, 10)
Y_test = tf.keras.utils.to_categorical(y_test, 10)


In [None]:
# label selection for phase 1
# only half of the label will be selected 
##############################################

label_permutation = np.arange(10) # fixed the chosen label
# label_permutation = np.random.permutation(10) # random permutation
LABEL_FIRST_HALF = label_permutation[:5]
label2_list = np.zeros((6,5))
for i in range(6):
    label2_list[i][:] = label_permutation[i:i+5]

train_filter = np.where(np.in1d(y_train, LABEL_FIRST_HALF))
test_filter =  np.where(np.in1d(y_test, LABEL_FIRST_HALF))

x_1train = x_train[train_filter]
Y_1train = Y_train[train_filter]
x_1test = x_test[test_filter]
Y_1test = Y_test[test_filter]

In [None]:
### normalized Input before 
# class GaussianNoiseLayer(layers.Layer):
#     def __init__(self, stddev, **kwargs):
#         super(GaussianNoiseLayer, self).__init__(**kwargs)
#         self.stddev = stddev
    
#     def call(self, inputs, training=None):
#         if training:
#             noise = tf.random.normal(shape=tf.shape(inputs), mean=0.0, stddev=self.stddev)
#             return inputs + noise
#         else:
#             noise = tf.random.normal(shape=tf.shape(inputs), mean=0.0, stddev=self.stddev)
#             return inputs + noise
    
#     def get_config(self):
#         config = super(GaussianNoiseLayer, self).get_config()
#         config.update({'noise_var': self.stddev})
#         return config

In [None]:
## fading channel

def fading(x, stddev, h=None):
    """Implements the fading channel with multiplicative fading and
    additive white gaussian noise.
    Args:
        x: channel input symbols
        stddev: standard deviation of noise
    Returns:
        y: noisy channel output symbols
    """
    # channel gain
    if h is None:
        h = tf.complex(
            tf.random.normal([tf.shape(x)[0], 1], 0, 1 / np.sqrt(2)),
            tf.random.normal([tf.shape(x)[0], 1], 0, 1 / np.sqrt(2)),
        )

    # additive white gaussian noise
    awgn = tf.complex(
        tf.random.normal(tf.shape(x), 0, 1 / np.sqrt(2)),
        tf.random.normal(tf.shape(x), 0, 1 / np.sqrt(2)),
    )

    return (h * x + stddev * awgn), h

class Channel(layers.Layer):

    def __init__(self, channel_type, channel_snr, power_normalization="average", name="channel", **kwargs):
        super(Channel, self).__init__(name=name, **kwargs)
        self.channel_type = channel_type
        self.channel_snr = channel_snr
        self.power_normalization = power_normalization

    def call(self, inputs):
        (encoded_img, prev_h) = inputs
        inter_shape = tf.shape(encoded_img)
        # reshape array to [-1, dim_z]
        z = layers.Flatten()(encoded_img)
        # convert from snr to std
        # print("channel_snr: {}".format(self.channel_snr))
        noise_stddev = np.sqrt(10 ** (-self.channel_snr / 10))

        # Add channel noise
        if self.channel_type == "awgn":
            dim_z = tf.shape(z)[1]
            # normalize latent vector so that the average power is 1
            if self.power_normalization == "max":
                z_in = tf.tanh(z)
            elif self.power_normalization == "average":
                z_in = tf.sqrt(tf.cast(dim_z, dtype=tf.float32)) * tf.nn.l2_normalize(
                    z, axis=1)
            z_out = real_awgn(z_in, noise_stddev)
            h = tf.ones_like(z_in)  # h just makes sense on fading channels

        elif self.channel_type == "fading":
            dim_z = tf.shape(z)[1] // 2
            # convert z to complex representation
            z_in = tf.complex(z[:, :dim_z], z[:, dim_z:])
            # normalize the latent vector so that the average power is 1
            z_norm = tf.reduce_sum(
                tf.math.real(z_in * tf.math.conj(z_in)), axis=1, keepdims=True
            )
            z_in = z_in * tf.complex(
                tf.sqrt(tf.cast(dim_z, dtype=tf.float32) / z_norm), 0.0
            )
            z_out, h = fading(z_in, noise_stddev, prev_h)
            # convert back to real
            z_out = tf.concat([tf.math.real(z_out), tf.math.imag(z_out)], 1)

        # convert signal back to intermediate shape
        z_out = tf.reshape(z_out, inter_shape)

        return z_out, h


In [None]:
### custom callback function for MI calculation
import tensorflow.keras.backend as K

def normalize_data_new(data,C):
    data = np.reshape(data, (data.shape[0],int(data.size/data.shape[0])))
    means =  np.tile(np.mean(data,axis=0),(data.shape[0],1))
    data = data - means
    sqz = data**2
    norm =  np.sqrt(np.mean(np.sum(sqz,axis=1)))
    normalized_data = C*data / norm
    
    return normalized_data,norm**2
class getMIOutput(tf.keras.callbacks.Callback):
    def __init__(self, trn, tst, snr, num_selection, embedding_network, do_save_func=None, *kargs, **kwargs):
        super(getMIOutput, self).__init__(*kargs, **kwargs)
        self.layer_values = []
        self.trn = trn
        self.tst = tst
        self.snr = snr
        self.datalist = []
        self.num_selection = num_selection
        self.embedding_network = embedding_network
        self.do_save_func = do_save_func
        
    def on_train_begin(self, logs={}):
        self.layer_values = []
        self.layerixs = []
        self.layerfuncs = []

        # Assuming the embedding layer index is 3, change it based on the actual index
        embedding_layer_index = 3

        for lndx, l in enumerate(self.embedding_network.layers):
            self.layerixs.append(lndx)
            self.layer_values.append(lndx)
            self.layerfuncs.append(K.function(self.embedding_network.inputs, [l.output,]))

    def on_epoch_end(self, epoch, logs={}):
        if self.do_save_func is not None and not self.do_save_func(epoch):
            return
        
        data = {
            'activity_tst': []    # Activity in each layer for the test set
        }

        for lndx, layerix in enumerate(self.layerixs):
            #if (lndx == 4):  # Assuming you want to access a specific layer
            data['activity_tst'].append(self.layerfuncs[lndx]([self.trn[:self.num_selection],])[0])

        save_dic = {'epoch': epoch, 'data': data}
        self.datalist.append(save_dic)



In [None]:

def do_report(epoch):
#     if epoch < 20:       # Log for all first 20 epochs
#         return True
    if epoch < 100:    # Then for every 5th epoch
        return (epoch % 5 == 0)
    elif epoch < 200:    # Then every 10th
        return (epoch % 10 == 0)
    else:                # Then every 100th
        return (epoch % 100 == 0)
    
    
def train_first_stage(lambda_val):

    tf.keras.backend.clear_session()
    tf.random.set_seed(42)
    prev_chn_gain = None

        ############ Model Structure ###############################

    input_layer  = tf.keras.layers.Input((x_train.shape[1],))
    encoder_1 = tf.keras.layers.Dense(128, activation='relu')(input_layer)
    encoder_2 =  tf.keras.layers.Dense(40, activation='relu')(encoder_1)
    ## first loss: classification 
    normalized_x = tf.keras.layers.Lambda(lambda x: K.tanh(x))(encoder_2)
    fadingchannel = Channel(channel_type, channel_snr, name="channel_output")
    noise_layer, chn_gain = fadingchannel((normalized_x, prev_chn_gain))

    CE_decoder_1 = tf.keras.layers.Dense(40, activation='relu')(noise_layer)
    CE_decoder_2 = tf.keras.layers.Dense(40, activation='relu')(CE_decoder_1)
    CE_output = tf.keras.layers.Dense(10,activation = 'softmax',name='CE')(CE_decoder_2)

    ## second loss: reconstruction
    mse_decoder_1 = tf.keras.layers.Dense(256, activation='relu')(noise_layer)
    mse_output = tf.keras.layers.Dense(x_train.shape[1],activation = 'sigmoid',name='mse')(mse_decoder_1) 
    
    model = tf.keras.Model(inputs = input_layer, outputs = [CE_output,mse_output])
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=5e-2,
    decay_steps=10000,
    decay_rate=0.8)
    opt = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
    model.compile(optimizer=opt,
                  loss = {'CE' : 'categorical_crossentropy', 
                          'mse' : 'mse'},
                  metrics = {'CE' : 'accuracy', 
                             'mse': tf.keras.metrics.RootMeanSquaredError()},
                  loss_weights=[1, lambda_val])

    #     reporter = getMIOutput(trn=x_1train, 
    #                                tst=x_1test, 
    #                                snr=channel_snr,
    #                               do_save_func=do_report)

    history = model.fit(x=x_1train, y=(Y_1train,x_1train),
                        batch_size=128,
                        epochs=150,
                        verbose=0,
                        validation_data=(x_1test, (Y_1test,x_1test)))
                       #callbacks = [callback])
                        #callbacks=[reporter,])
    print("acc:", history.history['val_CE_accuracy'][-1]) 


#     ep_last = len(reporter.datalist) -1
#     # Z_data = reporter.datalist[EPOCH]['data']['activity_tst'][3][:10000]
#     ZHAT_data = reporter.datalist[ep_last]['data']['activity_tst'][4][:SAMPLE_NUM]
#     x_selected = funcs.normalize_data_new(x_1train[:SAMPLE_NUM],C=10)[0]
#     print("Current I(Z_hat; X) is: ",gpu_est.estimate(funcs.normalize_data_new(ZHAT_data,C=4)[0],x_selected))

  
    ################### secpmd stage #################
    print("stage 2!")
   
    train_filter = np.where(np.in1d(y_train, LABEL_SECOND_HALF))
    test_filter =  np.where(np.in1d(y_test, LABEL_SECOND_HALF))

    x_retrain = x_train[train_filter]
    x_retest = x_test[test_filter]
    Y_retrain = Y_train[train_filter]
    Y_retest = Y_test[test_filter]

    CE_dense_3 = layers.Dense(40, activation='relu')(noise_layer)
    CE_dense_3 = layers.Dense(40, activation='relu')(CE_dense_3)
    CE_output1 = tf.keras.layers.Dense(10,activation = 'softmax',name='CE')(CE_dense_3)

    reconstructed_model = tf.keras.models.Model(inputs = input_layer, outputs = [CE_output1,mse_output])

    reconstructed_model.layers[1].trainable = False
    reconstructed_model.layers[2].trainable = False

    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=2e-2,
    decay_steps=10000,
    decay_rate=0.8)
    opt = tf.keras.optimizers.SGD(learning_rate=lr_schedule)


    reconstructed_model.compile(optimizer=opt, 
                loss = {'CE' : 'categorical_crossentropy', 
                          'mse' : 'mse'
                         },              loss_weights=[1, 0],
                  metrics = {'CE' : 'accuracy', 
                             'mse': tf.keras.metrics.RootMeanSquaredError()
                           }
                 )

    #     reporter_2 = getMIOutput(trn=x_retrain, 
    #                                tst=Y_retrain, 
    #                                snr=channel_snr,
    #                               do_save_func=do_report)

    history1 = reconstructed_model.fit(x=x_retrain, y=(Y_retrain,x_retrain),
                        batch_size=256,
                        epochs=50,
                        verbose=0,
                        validation_data=(x_retest, (Y_retest,x_retest)))
                        # callbacks=[callback])
    results = reconstructed_model.evaluate(x = x_retest, y = (Y_retest,x_retest))
    print("test accuracy for phase 2: ", history1.history['val_CE_accuracy'][-1]) 
    return history.history['val_CE_accuracy'][-1], history1.history['val_CE_accuracy'][-1]
    

In [None]:
sd = np.sqrt(noise_var)

lambda_list = [0,1,3,10]
acc1_list = []
acc2_list = []
for i in range(6):
    print("current overlap is: ", (i) )
    LABEL_SECOND_HALF = label2_list[5-i]
    for lambda_val in lambda_list:
        print("lambda: ",lambda_val)
        acc1, acc2 = train_first_stage(lambda_val) 
        acc1_list.append(acc1) # accuracy for phase 1
        acc2_list.append(acc2) # accuracy for phase 2
