# Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
import tensorflow as tf
import tensorflow.keras.backend as K
layers = keras.layers
tf.compat.v1.disable_eager_execution()  # gp loss won't work with eager
from functools import partial
from NuRadioReco.utilities import fft
from NuRadioReco.utilities import units
from NuRadioReco.framework import base_trace
import sys

In [None]:
sys.path.insert(1, '/lustre/fs22/group/radio/dhjelm/')
import data_preprocessing

# Data

## Load data from file and shuffle 

In [None]:
dataset = np.load('/lustre/fs22/group/radio/dhjelm/data.npy')
np.random.shuffle(dataset)

## Preprocess data

In [None]:
# RMS
rms = data_preprocessing.rms_preprocessing(dataset)

# L1
l1 = data_preprocessing.l1_preprocessing(rms)

# Remove DC-offset
fft_traces = fft.time2freq(l1, 3.2*units.GHz)
no_offset = fft_traces[:,1:len(fft_traces[0])]
data_no_offset = fft.freq2time(no_offset, 3.2*units.GHz)

# Shorten the trace
short = data_no_offset[:,0:256]

# Normalization
normalize = data_preprocessing.normalize(short)

# Set data to the normailzed data
data = normalize


In [None]:
# Calculate the length of the trace 
trace_length = len(data[0])

## Train on FFT data

In [None]:
# data = abs(fft.time2freq(data, 3.2*units.GHz))
# data = data[:,1:len(data[0])]
# trace_length = len(data[0])

## Plot a couple of traces to check that everything looks okey

In [None]:
# Time domain
for i in range(5):
    plt.plot(data[i])
plt.title("Time domain")
plt.show()

# Frequency domain
for i in range(5):
    plt.plot(abs(fft.time2freq(data[i], 3.2*units.GHz)))
plt.title("Frequency domain")
plt.show()


# GAN implementation

## Generator

In [None]:
# Generator architecture
def generator_model(latent_size, number_of_outputs):
    """ Generator network """
    model = keras.models.Sequential(name="Generator")

    model.add(layers.Dense(latent_size, input_dim=latent_size, kernel_initializer='he_uniform'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.1)) 

    model.add(layers.Dense(latent_size*2, kernel_initializer='he_uniform'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.1)) 

    model.add(layers.Dense(latent_size*4, kernel_initializer='he_uniform'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.1))


    model.add(layers.Dense(number_of_outputs))
    return model

In [None]:
# Create the generator
latent_size = 128
g = generator_model(latent_size, trace_length)
g.summary()

## Critic (Discriminator)

In [None]:
# Critic architecture
def critic_model(trace_length):
    model = keras.models.Sequential(name="Critic")
    
    model.add(layers.Dense(128, kernel_initializer='he_uniform',
  input_dim=trace_length))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.2)) 
    
    model.add(layers.Dense(64, kernel_initializer='he_uniform'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.15)) 
    
    model.add(layers.Dense(32, kernel_initializer='he_uniform'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.1)) 
    
    model.add(layers.Dense(16, kernel_initializer='he_uniform'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.05)) 
    
    model.add(layers.Dense(1)) # No activation!
    
    return model

In [None]:
# Create the critic
critic = critic_model(trace_length)
critic.summary()

## Training pipelines

Pipelines between the networks to be able to construct and train the GAN

In [None]:
def make_trainable(model, trainable):
    ''' Freezes/unfreezes the weights in the given model '''
    for layer in model.layers:

        if type(layer) is layers.BatchNormalization:
            layer.trainable = True
        else:
            layer.trainable = trainable

In [None]:
# Freeze the critic during the generator training and unfreeze the generator during the generator training
make_trainable(critic, False) 
make_trainable(g, True)

In [None]:
# Stack the generator o top of the critic and finiliaze the training pipeline of the generator
gen_input = g.inputs
generator_training = keras.models.Model(gen_input, critic(g(gen_input)))
generator_training.summary()
keras.utils.plot_model(generator_training, show_shapes=True)

In [None]:
def wasserstein_loss(y_true, y_pred):
    """Calculates the Wasserstein loss - critic maximises the distance between its output for real and generated samples.
    To achieve this generated samples have the label -1 and real samples the label 1. Multiplying the outputs by the labels results to the wasserstein loss via the Kantorovich-Rubinstein duality"""
    return K.mean(y_true * y_pred)

# Compile generator for training using the Wasserstein loss as loss function
generator_training.compile(keras.optimizers.Adam(
    0.0001, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[wasserstein_loss])


## Gradient penalty

To obtain the Wasserstein distance, we have to use the gradient penalty to enforce the Lipschitz constraint.
Therefore, we need to design a layer that samples on straight lines between reals and fakes samples 

In [None]:
# Size of the batches used in training
BATCH_SIZE = 64

class UniformLineSampler(tf.keras.layers.Layer):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def call(self, inputs, **kwargs):
        weights = K.random_uniform((self.batch_size, 1))
        return(weights * inputs[0]) + ((1 - weights) * inputs[1])

    def compute_output_shape(self, input_shape):
        return input_shape[0]

In [None]:
make_trainable(critic, True)  # unfreeze the critic during the critic training
make_trainable(g, False)  # freeze the generator during the critic training

g_out = g(g.inputs)
critic_out_fake_samples = critic(g_out)
critic_out_data_samples = critic(critic.inputs)
averaged_batch = UniformLineSampler(BATCH_SIZE)([g_out, critic.inputs[0]])
averaged_batch_out = critic(averaged_batch)

critic_training = keras.models.Model(inputs=[g.inputs, critic.inputs], outputs=[critic_out_fake_samples, critic_out_data_samples, averaged_batch_out])

In [None]:
critic_training.summary()
keras.utils.plot_model(critic_training, show_shapes=True)

### Gradient penalty loss

In [None]:
def gradient_penalty_loss(y_true, y_pred, averaged_batch, penalty_weight):
    """Calculates the gradient penalty.
    The 1-Lipschitz constraint of improved WGANs is enforced by adding a term that penalizes a gradient norm in the critic unequal to 1."""
    gradients = K.gradients(y_pred, averaged_batch)
    gradients_sqr_sum = K.sum(K.square(gradients)[0], axis=(1))
    gradient_penalty = penalty_weight * K.square(1 - K.sqrt(gradients_sqr_sum))
    return K.mean(gradient_penalty)

# Construct the gradient penalty
gradient_penalty_weight = 5
gradient_penalty = partial(gradient_penalty_loss, averaged_batch=averaged_batch, penalty_weight=gradient_penalty_weight)  
gradient_penalty.__name__ = 'gradient_penalty'

In [None]:
# Compile critic
critic_training.compile(keras.optimizers.Adam(0.00005, beta_1=0.5, beta_2=0.9, decay=0.0), loss=[wasserstein_loss, wasserstein_loss, gradient_penalty])

In [None]:
# Keras throws an error when calculating a loss without having a label -> needed for using the gradient penalty loss
positive_y = np.ones(BATCH_SIZE)
negative_y = -positive_y
dummy = np.zeros(BATCH_SIZE) 


# Training

In [None]:
# Create arrays for generator and critic loss
generator_loss = []
critic_loss = []

In [None]:
# Training parameters
EPOCHS = 100
nsamples = len(data)
critic_iterations = 5
iterations_per_epoch = nsamples*4//(BATCH_SIZE*critic_iterations)
iters = 0
print(iterations_per_epoch)

In [None]:

for epoch in range(EPOCHS):
    
    print("Epoch: ", epoch)

    for iteration in range(iterations_per_epoch):
       
        for j in range(critic_iterations):
            
            # Pick data in batches and generate noise
            bunch=data[BATCH_SIZE*(j+iteration):BATCH_SIZE*(j++iteration+1)]
            noise_batch = np.random.randn(BATCH_SIZE, latent_size)
       
            
            # Train critic
            critic_loss.append(critic_training.train_on_batch([noise_batch, bunch], [negative_y, positive_y, dummy]))
        

        # Generate noise batch for generator
        noise_batch = np.random.randn(BATCH_SIZE, latent_size)
        
        # Train the generator
        generator_loss.append(generator_training.train_on_batch([noise_batch], [positive_y]))  
        iters+=1
        
        # Printing errors and plotting example traces
        if iters % 300 == 1:
            print("Iteration", iters)
            print("Critic loss:", critic_loss[-1])
            print("Generator loss:", generator_loss[-1])
            
            # Generate signals
            generated_signals = g.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size))
            print(np.shape(generated_signals[0]))
            print(len(abs(fft.time2freq(generated_signals[0], 3.2*units.GHz))))
            
            
            # Plot data
            
            fig, (ax1, ax2) = plt.subplots(1, 2)
            fig.set_size_inches(18.5, 10.5, forward=True)
            ax1.title.set_text('Time domain')
            ax2.title.set_text('Frequency domain')
            
            

            ax1.plot(data[0], label = "Example trace")
#             plt.plot(np.random.randn(trace_length), label = 'Random noise')
            for i in range(3):
                ax1.plot(generated_signals[i], alpha=0.5)
            
            
            # Plot frequency
            ax2.plot(abs(fft.time2freq(data[0], 3.2*units.GHz)), label = "Example trace")
            for i in range(3):
                ax2.plot(abs(fft.time2freq(generated_signals[i], 3.2*units.GHz)),alpha=0.5)
                
            ax1.legend()
            ax2.legend()
            plt.show()



        
       

        #generated_signal = g.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size))
    
    #print("Critic loss:", critic_loss[-1])
    #print("Generator loss:", generator_loss[-1])


## Loss functions

In [None]:
critic_loss = np.array(critic_loss)
plt.subplots(1, figsize=(10, 5))
plt.plot(np.arange(len(critic_loss)), critic_loss[:,0], color='red', markersize=12, label=r'Total')
plt.plot(np.arange(len(critic_loss)), critic_loss[:,1] + critic_loss[:, 2], color='green', label=r'Wasserstein', linestyle='dashed')
plt.plot(np.arange(len(critic_loss)), critic_loss[:, 3], color='royalblue', markersize=12, label=r'Gradient penalty', linestyle='dashed')
plt.legend(loc='upper right')
plt.xlabel(r'Iterations')
plt.ylabel(r'Critic Loss')
#plt.ylim(-6, 3)

generator_loss = np.array(generator_loss)

plt.subplots(1, figsize=(10, 5))
plt.plot(np.arange(len(generator_loss)), generator_loss, color='red', markersize=12, label=r'Total')
plt.legend(loc='upper right')
plt.xlabel(r'Iterations')
plt.ylabel(r'Loss')


# #save the generated networks
# g.save('generator')
# critic.save('critic')

In [None]:
generated_signals = g.predict_on_batch(np.random.randn(BATCH_SIZE, latent_size))

# Measures

In [None]:
print(f"Mean generated: {np.mean(generated_signals)}")
print(f"Mean data: {np.mean(data)}\n")

print(f"Std generated: {np.std(generated_signals)}")
print(f"Std data: {np.std(data)}\n")





# Plotting

In [None]:
plt.rcParams["figure.dpi"]= 100

plt.plot(data[0], label = "Data")
for i in range(2):
    plt.plot(generated_signals[i],alpha=0.5)
# plt.plot(np.random.randn(trace_length), label = 'Random noise')
plt.legend()


In [None]:
plt.rcParams["figure.dpi"]= 100

plt.plot(data[0], label = "Example data")
for i in range(10):
    plt.plot(generated_signals[5*i],alpha=0.2)
# plt.plot(np.random.randn(trace_length), label = 'Random noise')
plt.legend()

In [None]:
plt.plot(abs(fft.time2freq(data[1], 3.2*units.GHz)), label = "Data")
plt.plot(abs(fft.time2freq(generated_signals[0], 3.2*units.GHz)), label = "Generated")
plt.legend()

In [None]:


# Get frequencies of data
data_freq = fft.time2freq(data, 3.2*units.GHz)

print(nsamples)
# Get frequencies of generated data
generated_signals = g.predict_on_batch(np.random.randn(nsamples, latent_size))
generator_freq = fft.time2freq(generated_signals, 3.2*units.GHz)

# Get average frequencies for both
avg_freq_data = np.mean(abs(data_freq), axis=0)
avg_freq_generator = np.mean(abs(generator_freq), axis=0)



# Create dummy trace to get frequencies
dummy_trace = base_trace.BaseTrace()
dummy_trace.set_trace(np.zeros(trace_length), sampling_rate = 3.2*units.GHz)

plt.plot(dummy_trace.get_frequencies()/units.MHz,avg_freq_data, label =f"Data")
plt.plot(dummy_trace.get_frequencies()/units.MHz,avg_freq_generator, label =f"Generator")
plt.xlabel("Frequency [MHz]")
# plt.ylabel("Square root of power per MHZ")
plt.title("Average frequencies for data and the generated data")
plt.legend()

#plt.semilogy()




In [None]:
generated_signals = g.predict_on_batch(np.random.randn(nsamples, latent_size))
print(np.shape(generated_signals))
print(np.shape(dataset))