In [None]:
import torch
import scipy.signal as signal
import numpy as np
from matplotlib import pyplot as plt
import time
from torch2trt import TRTModule
from torch2trt import torch2trt
from models.CNN import Model4
from scipy import signal
import matplotlib.pyplot as plt

In [None]:


fs = 10000/14  # Sample frequency (Hz)

f0 = 0.5 # Frequency to be removed from signal (Hz)

Q = 0.3  # Quality factor

# Design notch filter

b, a = signal.iirnotch(f0, Q, fs)

# Frequency response

freq, h = signal.freqz(b, a, fs=fs)

# Plot

fig, ax = plt.subplots(2, 1, figsize=(8, 6))

ax[0].plot(freq, 20*np.log10(abs(h)), color='blue')

ax[0].set_title("Frequency Response")

ax[0].set_ylabel("Amplitude (dB)", color='blue')

ax[0].set_xlim([0, 400])

ax[0].set_ylim([-25, 10])

ax[0].grid()

ax[1].plot(freq, np.unwrap(np.angle(h))*180/np.pi, color='green')

ax[1].set_ylabel("Angle (degrees)", color='green')

ax[1].set_xlabel("Frequency (Hz)")

ax[1].set_xlim([0, 100])

ax[1].set_yticks([-90, -60, -30, 0, 30, 60, 90])

ax[1].set_ylim([-90, 90])

ax[1].grid()

plt.show()

In [None]:
# Load model
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('trt_models/model_trt_int8.pth'))

In [None]:
num_range_bins = 36000
fs = 714 # Sampling frequency
nperseg = 128 # Length of each segment
noverlap = 64 # Number of overlapping points between segments
nfft = 128 # Length of FFT
dwell_time = 4
slice_width = 45
overlap = 0.5
num_points = np.int64(np.floor(dwell_time* fs))
classified_imgs = 0
num_loops = 10
times = []
num_range_bins_array = np.arange(3_000, 21_000, 3_000)

torch.cuda.current_stream().synchronize()
# GPU warmup
for _ in range(10):
    _ = model_trt(torch.randn((1, 1, 128, 45)).cuda())

for num_range_bins in num_range_bins_array:
    # Start counter
    range_bin_times = []
    t0 = time.time()

    print(f'looking at {num_range_bins} range bins ......')

    for _ in range(100):
        # Creating the samples
        t2 = time.time()
        x = torch.randn(num_range_bins, num_points, dtype=torch.cfloat)
        filtered_x = signal.filtfilt(b, a, x)
        # Get the short-term fourier transform of the signal

        f, t, Zxx = signal.stft(filtered_x, fs=fs, window='hamming', nperseg=nperseg, noverlap=noverlap, nfft=nfft, return_onesided=False, padded=False, axis=-1)
        Zxx_tensor = torch.from_numpy(Zxx)
        # Zxx_tensor = torch.randn(num_range_bins, 128,45)

        # Converting to dB
        fft_dB = 20*torch.log10(torch.abs(Zxx_tensor))
        _ = model_trt(fft_dB.cuda())

        t3 = time.time()

        range_bin_times.append (t3-t2)

    times.append((range_bin_times))

In [None]:
num_range_bins_array = np.arange(3_000, 21_000, 3000)


In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10,10))
ax.grid()
ax.boxplot(times)

ax.set_xticklabels(num_range_bins_array)
ax.set_ylim([0,6])
# ax.set_title('Plot of number of ranges bins vs processing time (10 times for each number of range bins)')
ax.set_xlabel('Number of Range Bins')
ax.set_ylabel('Processing Time (s)')
