In [None]:
import numpy as np
np.set_printoptions(precision=16)
import pickle

import tensorflow as tf
from tensorflow.keras.optimizers import Adam

from qkeras.utils import model_save_quantized_weights

from tf_data_pipeline.data import WaveToWaveData
from qkeras_version.qkeras_model import create_dilated_model, masked_mse

from fxpmath_version.fxpmath_model import FxpModel

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

## train qkeras model

In [None]:
from fxpmath_version.util import FxpUtil

fxp = FxpUtil()
fxp.single_width(-1.0), fxp.bits(fxp.single_width(-1.0))

In [None]:
class Opts(object):
    learning_rate = 1e-3
    num_train_egs = 20000
    num_validate_egs = 100
    num_test_egs = 100
    epochs = 5
    
opts = Opts()


In [None]:
IN_OUT_D = 4
NUM_LAYERS = 3
# WIP filter size; final will be 8
FILTER_SIZE = 4

# note: kernel size and implied dilation rate always assumed 4

RECEPTIVE_FIELD_SIZE = 4**NUM_LAYERS
TEST_SEQ_LEN = RECEPTIVE_FIELD_SIZE
TRAIN_SEQ_LEN = RECEPTIVE_FIELD_SIZE * 5
print("RECEPTIVE_FIELD_SIZE", RECEPTIVE_FIELD_SIZE)
print("TRAIN_SEQ_LEN", TRAIN_SEQ_LEN)
print("TEST_SEQ_LEN", TEST_SEQ_LEN)

In [None]:
# make tf datasets

# recall WaveToWaveData
# x -> (tri,0,0,0)
# y -> (tri,square,zigzag,0)

data = WaveToWaveData()

train_ds = data.tf_dataset_for_split('train', TRAIN_SEQ_LEN, opts.num_train_egs)
validate_ds = data.tf_dataset_for_split('validate', TRAIN_SEQ_LEN, opts.num_validate_egs)

In [None]:
# make model
train_model = create_dilated_model(TRAIN_SEQ_LEN,
        in_out_d=IN_OUT_D, num_layers=NUM_LAYERS, filter_size=FILTER_SIZE,
        all_outputs=False)
print(train_model.summary())

In [None]:
# compile loss that only considers column 1 of output ( i.e. square wave )
train_model.compile(Adam(opts.learning_rate),
                    loss=masked_mse(RECEPTIVE_FIELD_SIZE, filter_column_idx=1))

In [None]:
train_model.fit(train_ds,
                validation_data=validate_ds,
                epochs=20)

In [None]:
weights = model_save_quantized_weights(train_model)
from qkeras.utils import model_save_quantized_weights
with open('qkeras_weights.pkl', 'wb') as f:
    pickle.dump(model_save_quantized_weights(train_model),
                f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open('qkeras_weights.pkl', 'rb') as f:
    weights = pickle.load(f)

In [None]:
weights

## load weights into fxp model

In [None]:
# make model
inference_model = create_dilated_model(TEST_SEQ_LEN,
        in_out_d=IN_OUT_D, num_layers=NUM_LAYERS, filter_size=FILTER_SIZE,
        all_outputs=False)
inference_model.set_weights(train_model.get_weights())
print(inference_model.summary())


## run fxp_model from weights files

see `python3 -m fxpmath_version.run_fxpmath_model`

In [None]:
with open("/tmp/test_x.hex", "w") as f:
    for i in range(len(x)):
        next_x = x[i].numpy()[0]
        fp_x = fxp_model.fxp.single_width(next_x)
        print(i, next_x, fp_x, fp_x.hex(), file=f)
        

In [None]:
eval('0xFFFD')

In [None]:
0.25+0.125

In [None]:

# a0=1, a1=a2=a3=0

# 
# veri [00000000.000011001100000000000000,  11111111.011001110000000000000000,  00000000.101100001100000000000000,  111111111.10000110100000000000000]
# fxp  [00000000.000011001100000000000000', 11111111.011001110000000000000000', 00000000.101100001100000000000000', '11111111.110000110100000000000000']
# after bias add...
# veri [00000000.011100010110000000000000,  11111111.010100010010000000000000, 00000000.001100000110000000000000, 11111111.001111011010000000000000]
# fxp  [00000000.011100010110000000000000', 11111111.010100010010000000000000, 00000000.001100000110000000000000, 11111111.001111011010000000000000']
# LGTM !

# a0=a1=1, a2=a3=0

# after bias add
# veri [00000000.101011000000000000000000, 11111110.011111011000000000000000, 11111111.110001000010000000000000, 11111111.000100110010000000000000]
# fxp  [00000000.101011000000000000000000, 11111110.011111011000000000000000, 11111111.110001000010000000000000, 11111111.000100110010000000000000']
# LGTM !

# a0=a1=a2=a3 = 1

# after bias add
# veri  [11111111.010100110110000000000000, 11111111.100111111000000000000000, 00000001.011001010100000000000000, 11111110.010110010100000000000000]
# fxp   [11111111.010100110110000000000000, 11111111.100111111000000000000000, 00000001.011001010100000000000000, 11111110.010110010100000000000000']
# LGTM

# a0=a1=a2=a3 = -1
# veri  [00000001.011101011110000000000000, 00000000.001101001100000000000000, 11111101.100110100000000000000000, 00000000.100110111000000000000000]
# fxp   [00000001.011101011110000000000000, 00000000.001101001100000000000000, 11111101.100110100000000000000000, 00000000.100110111000000000000000']
# LGTM

#     dut.a0.value = [
#        0x0400,   # 0000.010000000000 0.25
#        0xFDFC,   # 1111.110111111100 -0.1259765625
#        0x0506,   # 0000.010100000110 0.31396484375
#        0xF000    # 1111.000000000000 -1.0
# a0=a1=a2=a3

# veri [11111111.111001111110011101011100, 11111111.011100010110011101111100, 00000000.001100100110110101000100, 00000000.111001100001010110000000]
# fxp  [11111111.111001111110011101011100, 11111111.011100010110011101111100, 00000000.001100100110110101000100, 00000000.111001100001010110000000']

## plot values from verilog version

```
cat sverilog_version/tests/network/net.out \
 | grep ^OUT\ 1 \
 | cut -b26-41 \
 | python3 single_width_bin_to_decimal.py \
 > y_pred.sverilog.txt
```

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

y_pred_sverilog = open('y_pred.sverilog.txt').readlines()
y_pred_sverilog = list(map(float, y_pred_sverilog))

df = pd.DataFrame()
df['y_pred'] = y_pred_sverilog
df['n'] = range(len(y_pred_sverilog))

sns.lineplot(df, x='n', y='y_pred')

double check.. what should be the triangle wave input

```
cat sverilog_version/tests/network/net.out | grep ^next_x | uniq | cut -d' ' -f2 > test_x.txt
```

In [None]:
test_x_sverilog = open('test_x.txt').readlines()
test_x_sverilog = list(map(float, test_x_sverilog))

df = pd.DataFrame()
df['test_x'] = test_x_sverilog
df['n'] = range(len(test_x_sverilog))

sns.lineplot(df, x='n', y='test_x')

## receptive field for 192kHz

given the 192kHz resampled data, how much does a receptive field cover?


In [None]:
#fname = '/data2/cached_dilated_causal_convolutions/2d_embed/96kHz/tri_squ_zigzag.ssv'
#fname = '/data2/cached_dilated_causal_convolutions/2d_embed/192kHz_resampled/tri_squ_zigzag.ssv'
fname = '/data2/cached_dilated_causal_convolutions/2d_embed/192kHz_resampled/tri_squ_zigzag.ssv'

df = pd.read_csv(fname, sep=' ', names=['tri', 'squ', 'zigzag'])
df['n'] = range(len(df))

df.head()

In [None]:
first_512_df = df[:600]
wide_df = first_512_df.melt(id_vars='n', value_vars=['tri', 'squ', 'zigzag'])

sns.set(rc={"figure.figsize": (12, 4)})
sns.lineplot(wide_df, x='n', y='value', hue='variable')