In [2]:
import torch
import scipy.signal as signal
import numpy as np
from matplotlib import pyplot as plt
import time
from torch2trt import TRTModule
from torch2trt import torch2trt
from models.CNN import ErnNet
from utils import get_pytorch_model

In [3]:
def get_start_points_x(width, slice_width, overlap_x):
    x_points = [0]
    stride = int(slice_width * (1 - overlap_x))
    counter = 1
    while True:
        pt = stride * counter
        if pt + slice_width >= width:
            x_points.append(width - slice_width)
            break
        else:
            x_points.append(pt)
        counter += 1
    return x_points

def get_total_inference_time(self, model, num_loops):
    """Returns the total inference time on all the loops"""
    # Wait for all kernels in all streams on the CUDA device to complete.
    torch.cuda.current_stream().synchronize()

    # GPU warmup
    for _ in range(10):
        _ = model(self.input_data_batch)

    t0 = time.time()
    for _ in range(num_loops):
        _ = model(self.input_data_batch)
        torch.cuda.current_stream().synchronize()
    t1 = time.time()

    return t1 - t0

In [6]:
# Load model
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('trt_models/model_trt_int8.pth'))

<All keys matched successfully>

In [8]:
torch.cuda.current_stream().synchronize()
num_imgs = 0
t0 = time.time()
for _ in range(1000):
    fs = 714 # Sampling frequency
    nperseg = 128 # Length of each segment
    noverlap = 64 # Number of overlapping points between segments
    nfft = 128 # Length of FFT

    x = np.random.random(400000) + np.random.random(400000) * 1j

    f, t, Zxx = signal.stft(x, fs=fs, window='hamming', nperseg=nperseg, noverlap=noverlap, nfft=nfft, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=- 1)
    # plt.pcolormesh(t, f, np.abs(Zxx))

    fft_dB =20*np.log10(np.abs(Zxx))
    # fft_dB.shape

    width = fft_dB.shape[1]
    slice_width = 45
    overlap = 0.5

    start_points = get_start_points_x(width, slice_width, overlap)

    segments = torch.from_numpy(np.array([fft_dB[:,start_point:start_point+slice_width] for start_point in start_points], dtype=np.float32))
    segments_shape = segments.shape

    reshaped_segments = torch.reshape(segments, (segments_shape[0],1, segments_shape[1], segments_shape[2])).cuda()
    model_trt(reshaped_segments)
    num_imgs += len(reshaped_segments)

t1 = time.time()



In [20]:
throughput = num_imgs/(t1-t0)
point_throughput = throughput * 45
supported_range_bins = point_throughput/fs
supported_range_bins*5

1969.864781229934