In [368]:
import h5py
import numpy as np
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

gpu_idx = 4
device = torch.device(f"cuda:{gpu_idx}")
print(device)

cuda:4


In [369]:
def set_seed(seed: int):

    random.seed(seed)

    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_value = 42
set_seed(seed_value)

In [370]:
neg_rng, pos_rng = -0.3, 0.3
t_min, t_max = -0.6, 0.6
neg_obs, pos_obs = -0.9, 0.9
a_min, a_max = 0.5, 10
n_samples = 11
snr_dB = 15
t_rng = torch.linspace(neg_obs, pos_obs, n_samples).to(device)
scale = 5

In [371]:
class FRIModel(nn.Module):
    def __init__(self, n_inp, n_out, neg_rng, pos_rng, scale):
        super(FRIModel, self).__init__()

        self.n_inp = n_inp
        self.n_brd = scale * n_inp
        self.neg_rng = neg_rng
        self.pos_rng = pos_rng

        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * n_inp, 128)
        self.fc2 = nn.Linear(128, n_out)
        self.brd_pts = torch.linspace(neg_rng, pos_rng, scale * n_inp).to(device)

        ## Gaussian
        # self.coeffs = torch.exp(-(self.brd_pts**2)/(2 * (0.18 ** 2)))
        
        ## Gaussian pair
        # self.coeffs = 1.4*(torch.exp(-(self.brd_pts + 0.2)**2 / (2 * (0.045 ** 2))) + torch.exp(-(self.brd_pts - 0.19)**2 / (2 * (0.045 ** 2))))
        # self.coeffs = 1.25 * torch.exp(-(self.brd_pts + 0.1)**2 / (2 * (0.06 ** 2))) + 0.7 * torch.exp(-(self.brd_pts - 0.16)**2 / (2 * (0.07 ** 2)))

        ## Learnable Kernel
        self.coeffs = nn.Parameter(torch.exp(-(self.brd_pts**2)/(2 * (self.pos_rng ** 2))))

    def forward(self, x):
        x = x.unsqueeze(1)
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = F.gelu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

    def get_ker(self):
        return self.coeffs

    def get_brd_pts(self):
        return self.brd_pts


    def fn_val(self, c1, c2, strt, t_val, sep):

        ret_val = (c1 + ((c2 - c1) * (t_val - strt)) / sep).sum(dim=1)
        return ret_val


    def get_spikes(self, aks, tks):
        ker = self.get_ker().unsqueeze(0).unsqueeze(1)
        a = aks.unsqueeze(2) * ker
        t = tks.unsqueeze(2) + self.brd_pts

        return a, t

    def get_sig(self, t_samps, a, t):

        diffs = torch.abs(t.unsqueeze(-1) - t_samps.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(t.shape[0], t.shape[1], t.shape[2], -1))
        min_ids = torch.argmin(diffs, dim=-2)

        t_fs = torch.gather(t, 2, min_ids)
        t_samps_exp = t_samps.unsqueeze(0).unsqueeze(0).expand(t_fs.shape[0], t_fs.shape[1], -1)

        cond_1 = (t_fs < t_samps_exp) & (diffs[:, :, self.n_brd // 2, :] < self.pos_rng)
        cond_2 = (t_fs >= t_samps_exp) & (diffs[:, :, self.n_brd // 2, :] < self.pos_rng)
        cond_3 = (min_ids == 0) & (diffs[:,:,self.n_brd // 2, :] > self.pos_rng)
        cond_4 = (min_ids == self.n_brd - 1) & (diffs[:,:,self.n_brd // 2, :] > self.pos_rng)
        cond_5 = (min_ids == 0) & (diffs[:,:,self.n_brd // 2, :] < self.pos_rng)
        cond_6 = (min_ids == self.n_brd - 1) & (diffs[:,:,self.n_brd // 2, :] < self.pos_rng)

        min_ids[cond_3] = 1
        min_ids[cond_5] = 1
        min_ids[cond_4] = self.n_brd - 2
        min_ids[cond_6] = self.n_brd - 2

        strt_pts = torch.where(cond_3 | cond_4, -100, torch.where(cond_1, torch.where(cond_5, torch.gather(t, 2, min_ids - 1), torch.gather(t, 2, min_ids)), torch.where(cond_6, torch.gather(t, 2, min_ids), torch.gather(t, 2, min_ids - 1))))
        c1_vals = torch.where(cond_3 | cond_4, 0, torch.where(cond_1, torch.where(cond_5, torch.gather(a, 2, min_ids - 1), torch.gather(a, 2, min_ids)), torch.where(cond_6, torch.gather(a, 2, min_ids), torch.gather(a, 2, min_ids - 1))))
        c2_vals = torch.where(cond_3 | cond_4, 0, torch.where(cond_2, torch.where(cond_6, torch.gather(a, 2, min_ids + 1), torch.gather(a, 2, min_ids)), torch.where(cond_5, torch.gather(a, 2, min_ids), torch.gather(a, 2, min_ids + 1))))

        return self.fn_val(c1_vals, c2_vals, strt_pts, t_samps, self.brd_pts[1] - self.brd_pts[0])

    def get_signal(self, aks, tks, t_rng):
        a_, t_ = self.get_spikes(aks, tks)
        sig = self.get_sig(t_rng, a_, t_)
        return sig

In [372]:
def add_noise(signals, snr_dB):

    signal_power = torch.mean(signals ** 2, dim=1, keepdim=True)
    snr_linear = 10 ** (snr_dB / 10)
    noise_power = signal_power / snr_linear
    noise = torch.sqrt(noise_power) * torch.randn_like(signals)
    return signals + noise

def mse_db(pred, actual):
    mse = np.mean((pred - actual)**2, axis=1)
    signal_power = np.mean(actual**2, axis=1)

    return 10 * np.log10(mse / signal_power)

In [373]:
def get_batch(data, model, batch_size, idx, t_rng):

    amp_batch = data["amps"][idx * batch_size : (idx + 1) * batch_size]
    loc_batch = data["locs"][idx * batch_size : (idx + 1) * batch_size] 

    return model.get_signal(amp_batch, loc_batch, t_rng), loc_batch

In [374]:
model = FRIModel(n_samples, 3, neg_rng, pos_rng, scale).to(device)
model.load_state_dict(torch.load("l1_loss/kernel_locs3_samples11_15dB.pt", map_location=device, weights_only=True))
model.eval()

FRIModel(
  (conv1): Conv1d(1, 16, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (fc1): Linear(in_features=704, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=3, bias=True)
)

In [375]:
savepath = 'test_40dB_3loc.h5'
with h5py.File(savepath, 'r') as f:

    amp3_test = f['/a_k'][:]
    loc3_test = f['/t_k'][:]
    
n_exps = 1000

tks_test = []
aks_test = []

for i in range(n_exps):

    a_values = amp3_test[i]
    t_values = loc3_test[i]

    tks_test.append(t_values)
    aks_test.append(a_values)

tks_test = torch.tensor(np.array(tks_test)).float().to(device)
aks_test = torch.tensor(np.array(aks_test)).float().to(device)


data_test = {"amps": aks_test, "locs": tks_test}

In [376]:
sig_test, loc_test = get_batch(data_test, model, len(aks_test), 0, t_rng)
sig_test = add_noise(sig_test, snr_dB)
model.eval()
tks_pred = model(sig_test).squeeze(0).cpu().detach().numpy()
tks_pred = torch.tensor(np.array(tks_pred)).float().to(device)

In [377]:
for param in model.parameters():
    param.requires_grad = False

In [378]:
aks_pred = 10 * (1 * torch.randn_like(aks_test) + 1) / 2
optimizer = optim.Adam([aks_pred], lr=0.05)
train = TensorDataset(tks_pred, sig_test)
train = DataLoader(train, batch_size=64, shuffle=False)

aks_test = aks_test.cpu().detach().numpy()
tks_pred = tks_pred.cpu().detach().numpy()
tks_test = tks_test.cpu().detach().numpy()

In [379]:
def mse_fit(coeffs, dataloader, loss_fn, epochs=30):
    
    coeffs.requires_grad_(True)
    
    for i in range(epochs):
        total_loss = 0
        for idx, batch in enumerate(dataloader):
            locs, sig_true = batch
            locs, sig_true = locs.to(device), sig_true.to(device)
            
            start_idx = idx*locs.shape[0]
            end_idx = min((idx + 1)*locs.shape[0], coeffs.shape[0])
            
            amp_batch = coeffs[start_idx:end_idx, :]
            
            sig_pred = model.get_signal(amp_batch, locs, t_rng)
            loss = loss_fn(sig_pred, sig_true)
            optimizer.zero_grad()

            loss.backward(retain_graph=True)
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {i+1}/{epochs}, Loss: {total_loss/len(dataloader)}")
        
    coeffs.requires_grad_(False)
        
    return coeffs

In [380]:
aks_pred = mse_fit(aks_pred, train, F.mse_loss, epochs=300)
aks_pred = aks_pred.cpu().detach().numpy()
del train, sig_test

Epoch 1/300, Loss: 12.426310420036316


Epoch 2/300, Loss: 11.64056533575058
Epoch 3/300, Loss: 10.95687660574913
Epoch 4/300, Loss: 10.32086643576622
Epoch 5/300, Loss: 9.7272769510746
Epoch 6/300, Loss: 9.173227220773697
Epoch 7/300, Loss: 8.656035244464874
Epoch 8/300, Loss: 8.173083245754242
Epoch 9/300, Loss: 7.721881479024887
Epoch 10/300, Loss: 7.300097733736038
Epoch 11/300, Loss: 6.905582576990128
Epoch 12/300, Loss: 6.536349564790726
Epoch 13/300, Loss: 6.190564259886742
Epoch 14/300, Loss: 5.8665522038936615
Epoch 15/300, Loss: 5.56278158724308
Epoch 16/300, Loss: 5.277851298451424
Epoch 17/300, Loss: 5.010464504361153
Epoch 18/300, Loss: 4.75939616560936
Epoch 19/300, Loss: 4.523555517196655
Epoch 20/300, Loss: 4.301922723650932
Epoch 21/300, Loss: 4.093535169959068
Epoch 22/300, Loss: 3.8975168466567993
Epoch 23/300, Loss: 3.7130445688962936
Epoch 24/300, Loss: 3.5393883883953094
Epoch 25/300, Loss: 3.3758675158023834
Epoch 26/300, Loss: 3.2217941284179688
Epoch 27/300, Loss: 3.076582193374634
Epoch 28/300, Loss

In [381]:
idx = np.random.randint(0, len(aks_test))
print(aks_test[idx])
print(aks_pred[idx])

[3.5837219 3.1392334 6.524935 ]
[3.0671372 3.145503  6.706693 ]


In [382]:
tks_mse, aks_mse = mse_db(tks_pred, tks_test), mse_db(aks_pred, aks_test)
print("Location prediction MSE (dB) :", np.round(np.mean(tks_mse), 2))
print("Amplitude prediction MSE (dB) :", np.round(np.mean(aks_mse), 2))

Location prediction MSE (dB) : -20.22
Amplitude prediction MSE (dB) : -12.94
