In [1]:
import math
import torch
import numpy as np
import pandas as pd
import tensorflow as tf
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn import preprocessing
from Training import model, utils, dataset, train
from sklearn.model_selection import train_test_split


import keras.backend as K

In [2]:
def check_acc(y_hat,y,margin=0.05):
    a_err = (np.abs(y_hat - y)) # get normalized error 
    err = np.divide(a_err, y, where=y!=0)
    assert(err.shape == y.shape)
    num_correct = 0
    for row in err:
        num_in_row = len(np.where(row < margin)[0]) # margin * 100 because 
        if num_in_row == len(row):
            num_correct += 1

    num_samples = y.shape[0]
    correct_idx = np.where(err < margin)
    num_part_correct = len(correct_idx[0])
    num_part_samples = y.shape[0] * y.shape[1]
    print(f"Correct = {num_correct} / {num_samples}")
    return (num_correct/num_samples)
from matplotlib.patches import Ellipse
def multivariate_gaussian_nll(ypreds, ytrue, var):
    
    diag = torch.exp(var[:,:2]) # convert log-scale var to
    n = ypreds.shape[1] #number of parameters ie number of means (2 gain and bandwidth)
    B = ypreds.shape[0] #Batch size
    
    z = torch.zeros(B)
    o = torch.ones(B)
    D = torch.stack((diag[:,0],z,z,diag[:,1]),dim=1).reshape(B,2,2) # form Diagnol matrix D for LDLT
    L = torch.stack((o,z,var[:,2],o),dim=1).reshape(B,2,2) # form L matrix 
    LT = torch.stack((o,var[:,2],z,o),dim=1).reshape(B,2,2) # form LT matrix (transpose of L)

    sigma = L @ D @ LT   # form sigma inv from LDLT decomp
    ximu =(ytrue-ypreds).reshape(B,2,1)  #true- minus 
    ximuT =(ytrue-ypreds).reshape(B,1,2) # true- minus  transpose

    loss = 0.5*torch.mean(ximuT@sigma@ximu + ((n/2)*(-torch.sum(var[:,:2],axis=1).reshape(B,1))))
    return loss
def formCovMatrix(var):
    diag = np.exp(var[:2])
    z = np.zeros(2)
    o = np.ones(2)
    D = np.array([diag[0],0,0,diag[1]]).reshape(2,2)
    L = np.array([1,0,var[2],1]).reshape(2,2)
    LT = np.array([1,var[2],0,1]).reshape(2,2)
    sigma = L @ D @ LT
    return np.linalg.pinv(sigma)

In [4]:
data_full_raw = utils.parseGainAndBWCsv2("Data/BW-3000.csv").astype(float)
full_data = preprocessing.MinMaxScaler((0,1)).fit_transform(data_full_raw)
full_X = full_data[:,:2]
full_Y = full_data[:,2:]

In [5]:
data_full_raw.shape

(3022, 4)

In [6]:
percentages = [0.1, 0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
parameter_accuracy = []
performace_accuracy = []

In [None]:
for percentage in percentages:
    data_raw = data_full_raw[np.random.choice(data_full_raw.shape[0], int(percentage * data_full_raw.shape[0]), replace = False),:]
    print(data_raw.shape)
    data = preprocessing.MinMaxScaler((0,1)).fit_transform(data_raw)
    scaler = preprocessing.MinMaxScaler((0,1))
    data2 = scaler.fit_transform(data_raw)
    assert(np.allclose(data,data2))
    X = data[:,:2]
    Y = data[:,2:]
    test_model = model.DistModelBatchNorm(2,5)
    optimizer = optim.Adagrad(test_model.parameters(),lr=0.001)
    loss_fn = multivariate_gaussian_nll

    dataset1 = dataset.CircuitSynthesisGainAndBandwidthManually(Y, X)
    train_dataset, val_dataset = utils.splitDataset(dataset1, 0.95)

    train_data = DataLoader(train_dataset,batch_size = 500)
    validation_data = DataLoader(val_dataset, batch_size = 500)
    epochs = 2000
    loss_list, val_loss_list = train.trainProbModel(test_model, train_data, loss_fn, optimizer, num_epochs=epochs, print_every=10, validation_data=validation_data)
    
    x_preds = test_model(torch.Tensor(full_Y))
    mock_simulator = tf.keras.models.load_model('mock_simulator2.0')
    print(x_preds.shape)
    means = x_preds.detach().numpy()[:,:2]
    final_preds = mock_simulator(means).numpy()    
    # for i,d in enumerate(final_preds):
    #     print(Y[i],d)
    print(final_preds)
    parameter_acc = check_acc(full_X,means,margin=.05)
    performance_acc = check_acc(full_Y,final_preds,margin=.05)
    parameter_accuracy.append(parameter_acc)
    performace_accuracy.append(performance_acc)

(302, 4)
t = 10, loss = -0.3865, val loss = -0.3865
t = 20, loss = -0.7696, val loss = -0.7696
t = 30, loss = -0.9797, val loss = -0.9797
t = 40, loss = -1.1797, val loss = -1.1797
t = 50, loss = -1.3174, val loss = -1.3174
t = 60, loss = -1.4689, val loss = -1.4689
t = 70, loss = -1.6038, val loss = -1.6038
t = 80, loss = -1.7054, val loss = -1.7054
t = 90, loss = -1.7949, val loss = -1.7949
t = 100, loss = -1.8711, val loss = -1.8711
t = 110, loss = -1.9927, val loss = -1.9927
t = 120, loss = -2.0271, val loss = -2.0271
t = 130, loss = -2.1644, val loss = -2.1644
t = 140, loss = -2.2238, val loss = -2.2238
t = 150, loss = -2.2727, val loss = -2.2727
t = 160, loss = -2.3882, val loss = -2.3882
t = 170, loss = -2.3807, val loss = -2.3807
t = 180, loss = -2.5299, val loss = -2.5299
t = 190, loss = -2.5766, val loss = -2.5766
t = 200, loss = -2.6110, val loss = -2.6110
t = 210, loss = -2.6902, val loss = -2.6902
t = 220, loss = -2.7643, val loss = -2.7643
t = 230, loss = -2.7198, val los

t = 1860, loss = -6.6326, val loss = -6.6326
t = 1870, loss = -6.2544, val loss = -6.2544
t = 1880, loss = -6.3928, val loss = -6.3928
t = 1890, loss = -6.6283, val loss = -6.6283
t = 1900, loss = -6.2235, val loss = -6.2235
t = 1910, loss = -6.3556, val loss = -6.3556
t = 1920, loss = -6.8483, val loss = -6.8483
t = 1930, loss = -6.5242, val loss = -6.5242
t = 1940, loss = -5.6786, val loss = -5.6786
t = 1950, loss = -6.4818, val loss = -6.4818
t = 1960, loss = -6.6025, val loss = -6.6025
t = 1970, loss = -6.8151, val loss = -6.8151
t = 1980, loss = -6.8082, val loss = -6.8082
t = 1990, loss = -6.7940, val loss = -6.7940
t = 2000, loss = -6.5596, val loss = -6.5596
torch.Size([3022, 5])
[[0.39352566 0.95009834]
 [0.48219824 0.7727458 ]
 [0.5592391  0.6276409 ]
 ...
 [0.2281191  0.55199766]
 [0.14832    0.67931163]
 [0.5322176  0.24223928]]
Correct = 1033 / 3022
Correct = 2290 / 3022
(604, 4)
t = 10, loss = -0.3888, val loss = -0.4315
t = 20, loss = -0.7983, val loss = -0.8396
t = 30, 

t = 1660, loss = -5.7467, val loss = -5.9371
t = 1670, loss = -7.2528, val loss = -7.4892
t = 1680, loss = -7.5587, val loss = -7.7615
t = 1690, loss = -7.5247, val loss = -7.7031
t = 1700, loss = -7.0605, val loss = -7.2537
t = 1710, loss = -6.3715, val loss = -7.0215
t = 1720, loss = -7.1467, val loss = -7.3118
t = 1730, loss = -7.1986, val loss = -7.4277
t = 1740, loss = -7.4618, val loss = -7.7867
t = 1750, loss = -7.3257, val loss = -7.5194
t = 1760, loss = -6.6979, val loss = -6.9264
t = 1770, loss = -6.5929, val loss = -6.9578
t = 1780, loss = -7.4389, val loss = -7.7058
t = 1790, loss = -7.5414, val loss = -7.8277
t = 1800, loss = -6.6877, val loss = -6.9480
t = 1810, loss = -7.2436, val loss = -7.4821
t = 1820, loss = -7.4554, val loss = -7.7159
t = 1830, loss = -6.9802, val loss = -7.4609
t = 1840, loss = -7.4876, val loss = -7.7205
t = 1850, loss = -5.9674, val loss = -6.4135
t = 1860, loss = -7.2186, val loss = -7.5893
t = 1870, loss = -7.3855, val loss = -7.6105
t = 1880, 

In [None]:
print(parameter_accuracy)
print(performace_accuracy)