**Table of contents**<a id='toc0_'></a>    
- [Testing compact function with and without torch.jit.script decorator](#toc1_1_)    
  - [Testing some basic functions with the torch.jit.script decorator](#toc1_2_)    
  - [Test get_PSD and get_peaks functions with torch.jit.script decorator](#toc1_3_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

In [1]:
import numpy as np
import matplotlib.pyplot as plt 
import scienceplots
from ofc_functions import *
from nn_functions import *
import os, glob

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

In [2]:
import torch.distributions.uniform as urand
from scipy.optimize import minimize
from torch.cuda.amp import autocast

In [3]:
#Trying to use CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Runnning in {device} mãe")

Runnning in cuda mãe


In [4]:
# Parameters for the OFC signal
ofc_args = parameters()
ofc_args.SpS = 64   # Samples per symbol. Determines the window size
ofc_args.Rs  = 10e9 # Symbol rate (baud rate - symbols per second)
ofc_args.Ts  = 1/ofc_args.Rs # Symbol period (s)
ofc_args.Fa  = 1/(ofc_args.Ts/ofc_args.SpS)  # Sampling frequency of the signal (samples per second)
ofc_args.Ta  = 1/ofc_args.Fa   # Sampling period of the signal (s)
ofc_args.NFFT = ofc_args.SpS   # Number of points of the FFT (multiple of SpS) 
ofc_args.t = ofc_args.Ta * torch.arange(0, ofc_args.SpS*10 + 1, device=device).unsqueeze(0) # Time vector for the signal (s)
ofc_args.Vpi = 2             # Voltage required to achieve a π phase shift (V)
#ofc_args.Vpi_DDMZM = 3.5    # Voltage required to achieve a π phase shift (V)
#ofc_args.Vpi_MZM = 5.5      # Voltage required to achieve a π phase shift (V)
#ofc_args.Vpi_PM = 3.5       # Voltage required to achieve a π phase shift (V)

ofc_args.P = 10**(0/10)/1000 # Power of the signal (W) --> 0 dBm
ofc_args.n_peaks = 9 # Number of peaks to be found in the signal

# Parameters for the modulators cascaded
mod_args = parameters()
mod_args.P_min = 1e-3*10**(-10/10) # Minimum power (W) --> -10 dBm
mod_args.P_max = 1e-3*10**(10/10)  # Maximum power (W) -->  10 dBm
mod_args.V_min = -5    # Minimum amplitude voltage (V)
mod_args.V_max = 5     # Maximum amplitude voltage (V)
mod_args.Phase_min = 0 # Minimum phase (rad)
mod_args.Phase_max = 0 # Maximum phase (rad) --> Phase is periodic with period 2*π
mod_args.Vb_min = -15  # Minimum bias voltage (V)
mod_args.Vb_max = 15   # Maximum bias voltage (V) --> Vb is periodic with period 4*Vπ

mod_args.bounds = [(mod_args.V_min,  mod_args.V_max),  (mod_args.V_min,  mod_args.V_max),  
                   (mod_args.Vb_min, mod_args.Vb_max), (mod_args.Vb_min, mod_args.Vb_max),
                   (mod_args.P_min,  mod_args.P_max)] # Bounds of the parameters


#mod_args.bounds = [(mod_args.V_min,mod_args.V_max),(mod_args.V_min,mod_args.V_max),(mod_args.V_min,mod_args.V_max), 
#        (mod_args.Phase_min, mod_args.Phase_max),(mod_args.Phase_min, mod_args.Phase_max), (mod_args.Phase_min, mod_args.Phase_max),
#        (mod_args.Vb_min,mod_args.Vb_max), (mod_args.Vb_min,mod_args.Vb_max)] # Bounds of the parameters

#mod_args.bounds = [(mod_args.V_min,mod_args.V_max),(mod_args.V_min,mod_args.V_max),(mod_args.V_min,mod_args.V_max), 
#        (mod_args.Phase_min, mod_args.Phase_max),(mod_args.Phase_min, mod_args.Phase_max), (mod_args.Phase_min, mod_args.Phase_max),
#        (mod_args.Vb_min,mod_args.Vb_max)] # Bounds of the parameters

freqs_peaks_GHz = torch.linspace(-(ofc_args.n_peaks//2), ofc_args.n_peaks//2, ofc_args.n_peaks) # Frequency range in GHz for n_peaks


## <a id='toc1_1_'></a>[Testing compact function with and without torch.jit.script decorator](#toc0_)

In [5]:
def analytical_function_compact2(params: torch.Tensor, args) -> torch.Tensor:
    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    args: parameters
        Parameters object with the arguments to generate the frequency comb signal
        Should contain the following attributes: t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks
        args.t: torch.Tensor
            Time vector
        args.Rs: float
            Symbol rate (samples per second)
        args.Vpi: float
            Half-wave voltage of the MZM (V)
        args.P: float
            Power of the frequency comb signal (W)
        args.NFFT: int
            Number of points of the FFT (multiple of SpS)
        args.Fa: float
            Sampling frequency of the signal (samples per second)
        args.SpS: int
            Samples by each Rs
        args.n_peaks: int
            Number of peaks to be found (odd number)

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal

    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    t = args.t#.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * args.Rs * t
    K = torch.pi / args.Vpi
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K, sin_u1_K = torch.cos(u1 * K), torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    frequency_comb = args.P * (cos_u1_K + 1j * sin_u1_K) * cos_u2_Vb2_K * cos_u3_Vb3_K
    
    psd = torch.fft.fft(frequency_comb, n=args.NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (args.NFFT * args.Fa)
    psd[psd == 0] = torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)
    
    n_peaks_2 = args.n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (args.NFFT // args.SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)
    return peaks

In [6]:
def analytical_function_compact3(params: torch.Tensor, t: torch.Tensor, Rs: torch.Tensor, Vpi: torch.Tensor, P: torch.Tensor, NFFT: torch.Tensor, Fa: torch.Tensor, SpS: torch.Tensor, n_peaks: torch.Tensor) -> torch.Tensor:
    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    args: parameters
        Parameters object with the arguments to generate the frequency comb signal
        Should contain the following attributes: t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks
        args.t: torch.Tensor
            Time vector
        args.Rs: float
            Symbol rate (samples per second)
        args.Vpi: float
            Half-wave voltage of the MZM (V)
        args.P: float
            Power of the frequency comb signal (W)
        args.NFFT: int
            Number of points of the FFT (multiple of SpS)
        args.Fa: float
            Sampling frequency of the signal (samples per second)
        args.SpS: int
            Samples by each Rs
        args.n_peaks: int
            Number of peaks to be found (odd number)

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal

    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    #t = t.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * Rs * t
    K = torch.pi / Vpi
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K, sin_u1_K = torch.cos(u1 * K), torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    frequency_comb = P * (cos_u1_K + 1j * sin_u1_K) * cos_u2_Vb2_K * cos_u3_Vb3_K
    
    psd = torch.fft.fft(frequency_comb, n=NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (NFFT * Fa)
    psd[psd == 0] = torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)
    
    n_peaks_2 = n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (NFFT // SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)
    return peaks.to('cuda')

In [35]:
@torch.jit.trace
def analytical_function_jit2(params: torch.Tensor, t: torch.Tensor, Rs: torch.Tensor, Vpi: torch.Tensor, P: torch.Tensor, NFFT: torch.Tensor, Fa: torch.Tensor, SpS: torch.Tensor, n_peaks: torch.Tensor) -> torch.Tensor:

    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM using JIT

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    t: torch.Tensor
        Time vector
    Rs: torch.Tensor
        Symbol rate (samples per second)
    Vpi: torch.Tensor
        Half-wave voltage of the MZM (V)
    P: torch.Tensor
        Power of the frequency comb signal (W)
    NFFT: torch.Tensor
        Number of points of the FFT
    Fa: torch.Tensor
        Sampling frequency of the signal (samples per second)
    SpS: torch.Tensor
        Samples by each Rs
    n_peaks: torch.Tensor
        Number of peaks to be found

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal
    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    #t = t.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * Rs * t
    K = torch.pi / Vpi

    #with autocast(enabled=True):
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K = torch.cos(u1 * K)
    sin_u1_K = torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    real_part = P * cos_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    imag_part = P * sin_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    frequency_comb = torch.complex(real_part, imag_part)

    psd = torch.fft.fft(frequency_comb, n=NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (NFFT * Fa)
    psd[psd == 0] = 1.2e-7 #torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)

    n_peaks_2 = n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (NFFT // SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)

    return peaks

TypeError: 'NoneType' object is not iterable

In [28]:
@torch.jit.script
#@torch.jit.trace
def analytical_function_jit3(params: torch.Tensor, t: torch.Tensor, Rs: torch.Tensor, Vpi: torch.Tensor, P: torch.Tensor, NFFT: torch.Tensor, Fa: torch.Tensor, SpS: torch.Tensor, n_peaks: torch.Tensor) -> torch.Tensor:

    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM using JIT

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    t: torch.Tensor
        Time vector
    Rs: torch.Tensor
        Symbol rate (samples per second)
    Vpi: torch.Tensor
        Half-wave voltage of the MZM (V)
    P: torch.Tensor
        Power of the frequency comb signal (W)
    NFFT: torch.Tensor
        Number of points of the FFT
    Fa: torch.Tensor
        Sampling frequency of the signal (samples per second)
    SpS: torch.Tensor
        Samples by each Rs
    n_peaks: torch.Tensor
        Number of peaks to be found

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal
    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    #t = t.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * Rs * t
    K = torch.pi / Vpi

    #with autocast(enabled=True):
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K = torch.cos(u1 * K)
    sin_u1_K = torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    real_part = P * cos_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    imag_part = P * sin_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    frequency_comb = torch.complex(real_part, imag_part)

    psd = torch.fft.fft(frequency_comb, n=NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (NFFT * Fa)
    psd[psd == 0] = 1.2e-7 #torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)

    n_peaks_2 = n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (NFFT // SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)

    return peaks.to('cuda')

In [29]:
def analytical_function_compact4(params: torch.Tensor, t: torch.Tensor, Rs: float, Vpi: float, P: float, NFFT: float, Fa: float, SpS: float, n_peaks: float) -> torch.Tensor:
    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    args: parameters
        Parameters object with the arguments to generate the frequency comb signal
        Should contain the following attributes: t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks
        args.t: torch.Tensor
            Time vector
        args.Rs: float
            Symbol rate (samples per second)
        args.Vpi: float
            Half-wave voltage of the MZM (V)
        args.P: float
            Power of the frequency comb signal (W)
        args.NFFT: int
            Number of points of the FFT (multiple of SpS)
        args.Fa: float
            Sampling frequency of the signal (samples per second)
        args.SpS: int
            Samples by each Rs
        args.n_peaks: int
            Number of peaks to be found (odd number)

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal

    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    #t = t.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * Rs * t
    K = torch.pi / Vpi
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K, sin_u1_K = torch.cos(u1 * K), torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    frequency_comb = P * (cos_u1_K + 1j * sin_u1_K) * cos_u2_Vb2_K * cos_u3_Vb3_K
    
    psd = torch.fft.fft(frequency_comb, n=NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (NFFT * Fa)
    psd[psd == 0] = torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)
    
    n_peaks_2 = n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (NFFT // SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)
    return peaks.to('cuda')

In [30]:
device = "cuda"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)

out = analytical_function_compact2(params, ofc_args)
print(out.shape, out.device)

%timeit analytical_function_compact2(params, ofc_args)


torch.Size([1, 9]) cuda:0
762 μs ± 29.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [31]:

device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32, device=device)
Rs = torch.tensor(10e9, dtype=torch.float32, device=device)
Vpi = torch.tensor(2.0, dtype=torch.float32, device=device)
NFFT = torch.tensor(64, dtype=torch.int32, device=device)
Fa = torch.tensor(640e9, dtype=torch.float32, device=device)
SpS = torch.tensor(64, dtype=torch.int32, device=device)
n_peaks = torch.tensor(9, dtype=torch.int32, device=device)

out = analytical_function_compact3(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)
print(out.shape, out.device)

%timeit analytical_function_compact3(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)

torch.Size([1, 9]) cuda:0
801 μs ± 58.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [32]:

device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = 1.0
Rs = 10e9
Vpi = 2.0
NFFT = 64
Fa = 640e9
SpS = 64
n_peaks = 9

out = analytical_function_compact4(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)
print(out.shape, out.device)

%timeit analytical_function_compact4(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)

torch.Size([1, 9]) cuda:0
760 μs ± 31.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [33]:

device = "cpu"

params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32, device=device)
Rs = torch.tensor(10e9, dtype=torch.float32, device=device)
Vpi = torch.tensor(2.0, dtype=torch.float32, device=device)
NFFT = torch.tensor(64, dtype=torch.int32, device=device)
Fa = torch.tensor(640e9, dtype=torch.float32, device=device)
SpS = torch.tensor(64, dtype=torch.int32, device=device)
n_peaks = torch.tensor(9, dtype=torch.int32, device=device)

out = analytical_function_jit3(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)
print(out.shape, out.device)

%timeit analytical_function_jit3(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)

torch.Size([1, 9]) cuda:0
523 μs ± 17.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [34]:
device = "cuda"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32, device=device)
Rs = torch.tensor(10e9, dtype=torch.float32, device=device)
Vpi = torch.tensor(2.0, dtype=torch.float32, device=device)
NFFT = torch.tensor(64, dtype=torch.int32, device=device)
Fa = torch.tensor(640e9, dtype=torch.float32, device=device)
SpS = torch.tensor(64, dtype=torch.int32, device=device)
n_peaks = torch.tensor(9, dtype=torch.int32, device=device)

out = analytical_function_jit2(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)
print(out.shape, out.device)

%timeit analytical_function_jit2(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)

torch.Size([1, 9]) cuda:0
1.08 ms ± 235 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


Using the decorator torch.jit.script with 'cpu' is faster!!!

## <a id='toc1_2_'></a>[Testing some basic functions with the torch.jit.script decorator](#toc0_)

In [15]:

@torch.jit.script
def pm_jit(Ai: torch.Tensor, u: torch.Tensor, Vpi: torch.Tensor) -> torch.Tensor:
    K = torch.pi / Vpi
    #Ao = Ai * torch.exp(1j * K * u)
    Ao = Ai * (torch.cos(K * u) + 1j * torch.sin(K * u))
    return Ao

@torch.jit.script
def mzm_jit(Ai: torch.Tensor, u: torch.Tensor, Vpi: torch.Tensor, Vb: torch.Tensor) -> torch.Tensor:

    K = torch.pi/Vpi
    Ao = Ai * torch.cos(0.5 * K * (u + Vb))

    return Ao

@torch.jit.script
def frequencyCombGenerator_PM_MZM_MZM_jit(params: torch.Tensor, Rs: torch.Tensor, t: torch.Tensor, P:torch.Tensor, Vpi:torch.Tensor) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''

    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting
    
    # Extract individual elements from params using indexing

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm_jit(frequency_comb, u1, Vpi)
    frequency_comb = mzm_jit(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm_jit(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb.to("cuda")


device = "cpu"
params = torch.tensor([1,  1,  1,  1, 1,1,1,1]).unsqueeze(0).to(device)
Rs = torch.tensor(10e9).to(device)
t = torch.arange(0, 641).unsqueeze(0).to(device)*1.5625e-12
P = torch.tensor([1], device=device)
Vpi = torch.tensor([2], device=device)

oioi1 = frequencyCombGenerator_PM_MZM_MZM_jit(params, Rs, t, P, Vpi);
oioi2 = frequencyCombGenerator_PM_MZM_MZM_jit(params, Rs, t, P, Vpi);
oioi3 = frequencyCombGenerator_PM_MZM_MZM_jit(params, Rs, t, P, Vpi);



In [16]:

def frequencyCombGenerator_PM_MZM_MZM2(params: torch.Tensor, Rs: torch.Tensor, t: torch.Tensor, P: torch.Tensor, Vpi:torch.Tensor) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm(frequency_comb, u1, Vpi)
    frequency_comb = mzm(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb


def frequencyCombGenerator_PM_MZM_MZM22(params: torch.Tensor, Rs: torch.Tensor, t: torch.Tensor, P: torch.Tensor, Vpi:torch.Tensor) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm(frequency_comb, u1, Vpi)
    frequency_comb = mzm(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb.to("cuda")


In [17]:
def frequencyCombGenerator_PM_MZM_MZM3(params: torch.Tensor, Rs: float, t: torch.Tensor, P: torch.Tensor, Vpi:float) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm(frequency_comb, u1, Vpi)
    frequency_comb = mzm(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb

def frequencyCombGenerator_PM_MZM_MZM33(params: torch.Tensor, Rs: float, t: torch.Tensor, P: torch.Tensor, Vpi:float) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm(frequency_comb, u1, Vpi)
    frequency_comb = mzm(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb.to("cuda")

In [None]:

device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = torch.tensor(10e9, dtype=torch.float32).to(device)
Vpi = torch.tensor(2.0, dtype=torch.float32).to(device)

print("Using CPU and JIT")
out = frequencyCombGenerator_PM_MZM_MZM_jit(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_PM_MZM_MZM_jit(params, Rs, t, P, Vpi)


device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = torch.tensor(10e9, dtype=torch.float32).to(device)
Vpi = torch.tensor(2.0, dtype=torch.float32).to(device)

print("\nUsing CPU, and all tensors")
out = frequencyCombGenerator_PM_MZM_MZM22(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_PM_MZM_MZM22(params, Rs, t, P, Vpi)

device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = 10e9
Vpi = 2.0

print("\nUsing CPU, and somes floats")
out = frequencyCombGenerator_PM_MZM_MZM33(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_PM_MZM_MZM33(params, Rs, t, P, Vpi)

device = "cuda"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = torch.tensor(10e9, dtype=torch.float32).to(device)
Vpi = torch.tensor(2.0, dtype=torch.float32).to(device)

print("\nUsing CUDA, and all tensors")
out = frequencyCombGenerator_PM_MZM_MZM2(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_PM_MZM_MZM2(params, Rs, t, P, Vpi)

print("\nUsing CUDA, and some floats")
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = 10e9
Vpi = 2.0

out = frequencyCombGenerator_PM_MZM_MZM3(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_PM_MZM_MZM3(params, Rs, t, P, Vpi)


Using CPU and JIT
torch.Size([1, 641]) cuda:0
309 μs ± 7.35 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Using CPU, and all tensors
torch.Size([1, 641]) cuda:0
454 μs ± 60.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Using CPU, and somes floats
torch.Size([1, 641]) cuda:0
335 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Using CUDA, and all tensors
torch.Size([1, 641]) cuda:0
501 μs ± 40.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Using CUDA, and some floats
torch.Size([1, 641]) cuda:0
339 μs ± 17.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Using device = "cpu" with @torch.jit.script is faster!!!

## <a id='toc1_3_'></a>[Test get_PSD and get_peaks functions with torch.jit.script decorator](#toc0_)

In [21]:

@torch.jit.script
def pm_jit(Ai: torch.Tensor, u: torch.Tensor, Vpi: float) -> torch.Tensor:

    K = torch.pi/Vpi
    #Ao = Ai * (torch.cos(u*K) + 1j * torch.sin(u*K))
    Ao = Ai * torch.exp(1j * K * u) 
    
    return Ao

@torch.jit.script
def mzm_jit(Ai: torch.Tensor, u: torch.Tensor, Vpi: float, Vb: torch.Tensor) -> torch.Tensor:

    K = torch.pi/Vpi
    Ao = Ai * torch.cos(0.5 * K * (u + Vb))

    return Ao

@torch.jit.script
def frequencyCombGenerator_PM_MZM_MZM_jit(params: torch.Tensor, Rs:float, t: torch.Tensor, P: torch.Tensor, Vpi: float) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting
    
    # Extract individual elements from params using indexing

    V1     = params[:,0].unsqueeze(1)
    V2     = params[:,1].unsqueeze(1)
    V3     = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2    = params[:,6].unsqueeze(1)
    Vb3    = params[:,7].unsqueeze(1)

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm_jit(frequency_comb, u1, Vpi)
    frequency_comb = mzm_jit(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm_jit(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb


@torch.jit.script
def get_psd_ByFFT_jit(signal: torch.Tensor, Fa: float, NFFT: int = 16*1024) -> tuple[torch.Tensor, torch.Tensor]:
    '''
    Calculate the Power Spectral Density (PSD) of a signal using FFT in PyTorch.

    Parameters:
    signal: torch.Tensor
        Signal to calculate the PSD
    Fa: float
        Sampling frequency of the signal (samples per second)
    NFFT: int
        Number of points of the FFT (default is 16*1024)
        
    Returns:
    psd: torch.Tensor
        Power Spectral Density of the signal
    freqs: torch.Tensor
        Frequencies of the PSD    
    '''

    fft_result = torch.fft.fftshift(torch.fft.fft(signal, n=NFFT, dim=-1), dim=-1)  # get the fft
    power_spectrum = torch.abs(fft_result) ** 2  # get the power spectrum
    psd = power_spectrum / (NFFT * Fa)  # get the power spectral density
    psd[psd == 0] = 1.2e-7 #torch.finfo(psd.dtype).eps # to avoid log(0) issues
    freqs = torch.fft.fftshift(torch.fft.fftfreq(NFFT, 1 / Fa)).to(psd.device) # get the frequencies

    return psd, freqs

@torch.jit.script
def get_indx_peaks_jit(log_Pxx: torch.Tensor, SpRs: float, n_peaks: int) -> Tuple[torch.Tensor, torch.Tensor]:
    '''
    Function to get the indexes of the peaks in the power spectrum of the frequency comb signal

    Parameters:
    log_Pxx: torch.Tensor
        Power spectrum of the frequency comb signal
    SpRs: int
        Samples by each Rs. Each peak is separated by the Rs frequency. SpRs = (NFFT/SpS)
    n_peaks: int
        Number of peaks to be found

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal
    indx: torch.Tensor
        Indexes of the peaks in the power spectrum of the frequency comb signal
    '''

    center_indx = log_Pxx.size(-1) // 2  # get the center index
    offsets = torch.arange(-(n_peaks // 2), (n_peaks // 2) + 1, device=log_Pxx.device) * SpRs  # create an array of offsets for the peaks
    offsets = offsets.to(torch.int)  # Ensure offsets are integers
    indx = center_indx + offsets  # calculate the indices of the peaks
    peaks = log_Pxx[:, indx]  # extract the peaks from log_Pxx

    return peaks, indx

In [22]:

@torch.jit.script
def pm_jit_tensors(Ai: torch.Tensor, u: torch.Tensor, Vpi: torch.Tensor) -> torch.Tensor:

    K = torch.pi/Vpi
    #Ao = Ai * (torch.cos(u*K) + 1j * torch.sin(u*K))
    Ao = Ai * torch.exp(1j * K * u) 
    
    return Ao

@torch.jit.script
def mzm_jit_tensors(Ai: torch.Tensor, u: torch.Tensor, Vpi: torch.Tensor, Vb: torch.Tensor) -> torch.Tensor:

    K = torch.pi/Vpi
    Ao = Ai * torch.cos(0.5 * K * (u + Vb))

    return Ao

@torch.jit.script
def frequencyCombGenerator_PM_MZM_MZM_jit_tensors(params: torch.Tensor, Rs:torch.Tensor, t: torch.Tensor, P: torch.Tensor, Vpi: torch.Tensor) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting
    
    # Extract individual elements from params using indexing

    V1     = params[:,0].unsqueeze(1)
    V2     = params[:,1].unsqueeze(1)
    V3     = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2    = params[:,6].unsqueeze(1)
    Vb3    = params[:,7].unsqueeze(1)

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm_jit_tensors(frequency_comb, u1, Vpi)
    frequency_comb = mzm_jit_tensors(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm_jit_tensors(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb


@torch.jit.script
def get_psd_ByFFT_jit_tensors(signal: torch.Tensor, Fa: torch.Tensor, NFFT: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    '''
    Calculate the Power Spectral Density (PSD) of a signal using FFT in PyTorch.

    Parameters:
    signal: torch.Tensor
        Signal to calculate the PSD
    Fa: float
        Sampling frequency of the signal (samples per second)
    NFFT: int
        Number of points of the FFT (default is 16*1024)
        
    Returns:
    psd: torch.Tensor
        Power Spectral Density of the signal
    freqs: torch.Tensor
        Frequencies of the PSD    
    '''

    fft_result = torch.fft.fftshift(torch.fft.fft(signal, n=NFFT, dim=-1), dim=-1)  # get the fft
    power_spectrum = torch.abs(fft_result) ** 2  # get the power spectrum
    psd = power_spectrum / (NFFT * Fa)  # get the power spectral density
    psd[psd == 0] = 1.2e-7 #torch.finfo(psd.dtype).eps # to avoid log(0) issues
    freqs = torch.fft.fftshift(torch.fft.fftfreq(NFFT, 1 / Fa)).to(psd.device) # get the frequencies

    return psd, freqs

@torch.jit.script
def get_indx_peaks_jit_tensors(log_Pxx: torch.Tensor, SpRs: torch.Tensor, n_peaks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    '''
    Function to get the indexes of the peaks in the power spectrum of the frequency comb signal

    Parameters:
    log_Pxx: torch.Tensor
        Power spectrum of the frequency comb signal
    SpRs: int
        Samples by each Rs. Each peak is separated by the Rs frequency. SpRs = (NFFT/SpS)
    n_peaks: int
        Number of peaks to be found

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal
    indx: torch.Tensor
        Indexes of the peaks in the power spectrum of the frequency comb signal
    '''

    center_indx = log_Pxx.size(-1) // 2  # get the center index
    offsets = torch.arange(-(n_peaks // 2), (n_peaks // 2) + 1, device=log_Pxx.device) * SpRs  # create an array of offsets for the peaks
    offsets = offsets.to(torch.int)  # Ensure offsets are integers
    indx = center_indx + offsets  # calculate the indices of the peaks
    peaks = log_Pxx[:, indx]  # extract the peaks from log_Pxx

    return peaks, indx

In [23]:
device = "cpu"

params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32, device=device)
Rs = torch.tensor(10e9, dtype=torch.float32, device=device)
Vpi = torch.tensor(2.0, dtype=torch.float32, device=device)
NFFT = torch.tensor(64, dtype=torch.int32, device=device)
Fa = torch.tensor(640e9, dtype=torch.float32, device=device)
SpS = torch.tensor(64, dtype=torch.int32, device=device)
n_peaks = torch.tensor(9, dtype=torch.int32, device=device)


frequency_comb = frequencyCombGenerator_PM_MZM_MZM_jit_tensors(params, Rs, t, P, Vpi) # Generate the frequency comb signal
Pxx, _ = get_psd_ByFFT_jit_tensors(frequency_comb, Fa, NFFT) # Get the power spectrum of the frequency comb signal
log_Pxx = 10*torch.log10(Pxx) # Convert the power spectrum to dB
peaks, _ = get_indx_peaks_jit_tensors(log_Pxx, NFFT/SpS, n_peaks) # Get the indexes of the peaks

print(frequency_comb.shape, Pxx.shape, log_Pxx.shape, peaks.shape)
print(frequency_comb.device, Pxx.device, log_Pxx.device, peaks.device)

%timeit frequencyCombGenerator_PM_MZM_MZM_jit_tensors(params, Rs, t, P, Vpi)
%timeit get_psd_ByFFT_jit_tensors(frequency_comb, Fa, NFFT)
%timeit get_indx_peaks_jit_tensors(log_Pxx, NFFT/SpS, n_peaks)

torch.Size([1, 641]) torch.Size([1, 64]) torch.Size([1, 64]) torch.Size([1, 9])
cpu cpu cpu cpu
396 μs ± 54.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
392 μs ± 69.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
245 μs ± 24.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [24]:
device = "cpu"

params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32, device=device)
Rs = 10e9 # float
Vpi = 2.0 # float
NFFT = 64 # int
Fa = 640e9 #float
SpS = 64 # int
n_peaks = 9 # int


frequency_comb = frequencyCombGenerator_PM_MZM_MZM_jit(params, Rs, t, P, Vpi) # Generate the frequency comb signal
Pxx, _ = get_psd_ByFFT_jit(frequency_comb, Fa, NFFT) # Get the power spectrum of the frequency comb signal
log_Pxx = 10*torch.log10(Pxx) # Convert the power spectrum to dB
peaks, _ = get_indx_peaks_jit(log_Pxx, NFFT/SpS, n_peaks) # Get the indexes of the peaks

print(frequency_comb.shape, Pxx.shape, log_Pxx.shape, peaks.shape)
print(frequency_comb.device, Pxx.device, log_Pxx.device, peaks.device)

%timeit frequencyCombGenerator_PM_MZM_MZM_jit(params, Rs, t, P, Vpi)
%timeit get_psd_ByFFT_jit(frequency_comb, Fa, NFFT)
%timeit get_indx_peaks_jit(log_Pxx, NFFT/SpS, n_peaks)

torch.Size([1, 641]) torch.Size([1, 64]) torch.Size([1, 64]) torch.Size([1, 9])
cpu cpu cpu cpu
379 μs ± 23.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
373 μs ± 85 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
172 μs ± 7.37 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Instead of using all the parameters as tensors, using a few other simple types makes the function faster!!!

## Test unpacking

In [45]:
@torch.jit.script
def pm_jit(Ai: torch.Tensor, u: torch.Tensor, Vpi: float) -> torch.Tensor:

    K = torch.pi/Vpi
    #Ao = Ai * (torch.cos(u*K) + 1j * torch.sin(u*K))
    Ao = Ai * torch.exp(1j * K * u) 
    
    return Ao

@torch.jit.script
def mzm_jit(Ai: torch.Tensor, u: torch.Tensor, Vpi: float, Vb: torch.Tensor) -> torch.Tensor:

    K = torch.pi/Vpi
    Ao = Ai * torch.cos(0.5 * K * (u + Vb))

    return Ao

@torch.jit.script
def frequencyCombGenerator_PM_MZM_MZM_jit2(params: torch.Tensor, Rs:float, t: torch.Tensor, P: torch.Tensor, Vpi: float) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = torch.split(params, 1, dim=-1)
    
    # Extract individual elements from params using indexing
    '''
    V1     = params[:,0].unsqueeze(1)
    V2     = params[:,1].unsqueeze(1)
    V3     = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2    = params[:,6].unsqueeze(1)
    Vb3    = params[:,7].unsqueeze(1)
    #'''

    #print(V1.shape, V2.shape, V3.shape, Phase1.shape, Phase2.shape, Phase3.shape, Vb2.shape, Vb3.shape)

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm_jit(frequency_comb, u1, Vpi)
    frequency_comb = mzm_jit(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm_jit(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb

In [46]:
device = "cpu"

params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32, device=device)
Rs = 10e9 # float
Vpi = 2.0 # float
NFFT = 64 # int
Fa = 640e9 #float
SpS = 64 # int
n_peaks = 9 # int


frequency_comb = frequencyCombGenerator_PM_MZM_MZM_jit2(params, Rs, t, P, Vpi) # Generate the frequency comb signal

In [44]:
params = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.float32).to(device)
params.shape

torch.Size([2, 8])

## Test dataset creation

In [None]:
zero_mean = False

In [None]:
dataset = FrequencyCombDataset(frequencyCombPeaks, n_samples, ofc_args, mod_args.bounds, device = device, zero_mean=zero_mean)

In [None]:
dataset = FrequencyCombDataset(analytical_function_compact_DDMZM, n_samples, ofc_args, mod_args.bounds, device = device, zero_mean=zero_mean)

In [None]:
dataset = FrequencyCombDataset_jit(analytical_function_compact_DDMZM_jit, n_samples, ofc_args, mod_args.bounds, device = device, zero_mean=zero_mean)

## Test use group parameters even with jit

In [9]:
@torch.jit.script
def analytical_function_compact_group_params(params: torch.Tensor, args: parameters) -> torch.Tensor:

    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM using JIT

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    t: torch.Tensor
        Time vector
    Rs: float
        Symbol rate (samples per second)
    Vpi: float
        Half-wave voltage of the MZM (V)
    P: torch.Tensor
        Power of the frequency comb signal (W)
    NFFT: int
        Number of points of the FFT
    Fa: float
        Sampling frequency of the signal (samples per second)
    SpS: int
        Samples by each Rs
    n_peaks: int
        Number of peaks to be found

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal
    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = torch.split(params, 1, dim=-1)

    pi_Rs_t = 2 * torch.pi * args.Rs * args.t
    K = torch.pi / args.Vpi

    #with autocast(enabled=True):
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K = torch.cos(u1 * K)
    sin_u1_K = torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    real_part = args.P * cos_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    imag_part = args.P * sin_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    frequency_comb = torch.complex(real_part, imag_part)

    psd = torch.fft.fft(frequency_comb, n=args.NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (args.NFFT * args.Fa)
    psd[psd == 0] = 1.2e-7 #torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)

    n_peaks_2 = args.n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (args.NFFT // args.SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)

    return peaks

OSError: Can't get source for <class '__main__.parameters'>. TorchScript requires source access in order to carry out compilation, make sure original .py files are available.
'__torch__.parameters' is being compiled since it was called from 'analytical_function_compact_group_params'
  File "C:\Users\ferna\AppData\Local\Temp\ipykernel_1720\4223521918.py", line 2
@torch.jit.script
def analytical_function_compact_group_params(params: torch.Tensor, args: parameters) -> torch.Tensor:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    '''
    ~~~
    Function to generate the frequency comb signal peaks of PM-MZM-MZM using JIT
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    Parameters:
    ~~~~~~~~~~~
    params: torch.Tensor
    ~~~~~~~~~~~~~~~~~~~~
        Parameters to generate the frequency comb signal
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    t: torch.Tensor
    ~~~~~~~~~~~~~~~
        Time vector
        ~~~~~~~~~~~
    Rs: float
    ~~~~~~~~~
        Symbol rate (samples per second)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    Vpi: float
    ~~~~~~~~~~
        Half-wave voltage of the MZM (V)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    P: torch.Tensor
    ~~~~~~~~~~~~~~~
        Power of the frequency comb signal (W)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    NFFT: int
    ~~~~~~~~~
        Number of points of the FFT
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~
    Fa: float
    ~~~~~~~~~
        Sampling frequency of the signal (samples per second)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    SpS: int
    ~~~~~~~~
        Samples by each Rs
        ~~~~~~~~~~~~~~~~~~
    n_peaks: int
    ~~~~~~~~~~~~
        Number of peaks to be found
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~

    Returns:
    ~~~~~~~~
    peaks: torch.Tensor
    ~~~~~~~~~~~~~~~~~~~
        Peaks of the power spectrum of the frequency comb signal
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    '''
    ~~~
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = torch.split(params, 1, dim=-1)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    pi_Rs_t = 2 * torch.pi * args.Rs * args.t
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    K = torch.pi / args.Vpi
    ~~~~~~~~~~~~~~~~~~~~~~~

    #with autocast(enabled=True):
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    cos_u1_K = torch.cos(u1 * K)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    sin_u1_K = torch.sin(u1 * K)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    real_part = args.P * cos_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    imag_part = args.P * sin_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    frequency_comb = torch.complex(real_part, imag_part)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    psd = torch.fft.fft(frequency_comb, n=args.NFFT, dim=-1)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    psd = torch.fft.fftshift(psd, dim=-1)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    psd = torch.abs(psd) ** 2 / (args.NFFT * args.Fa)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    psd[psd == 0] = 1.2e-7 #torch.finfo(psd.dtype).eps
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    log_Pxx = 10 * torch.log10(psd)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    n_peaks_2 = args.n_peaks // 2
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (args.NFFT // args.SpS)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    peaks = torch.index_select(log_Pxx, 1, indx)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    return peaks
    ~~~~~~~~~~~~ <--- HERE


In [6]:
@torch.jit.script
def analytical_function_compact_separated_params(params: torch.Tensor, t: torch.Tensor, Rs: float, Vpi: float, P: torch.Tensor, NFFT: int, Fa: float, SpS: int, n_peaks: int) -> torch.Tensor:

    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM using JIT

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    t: torch.Tensor
        Time vector
    Rs: float
        Symbol rate (samples per second)
    Vpi: float
        Half-wave voltage of the MZM (V)
    P: torch.Tensor
        Power of the frequency comb signal (W)
    NFFT: int
        Number of points of the FFT
    Fa: float
        Sampling frequency of the signal (samples per second)
    SpS: int
        Samples by each Rs
    n_peaks: int
        Number of peaks to be found

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal
    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = torch.split(params, 1, dim=-1)

    pi_Rs_t = 2 * torch.pi * Rs * t
    K = torch.pi / Vpi

    #with autocast(enabled=True):
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K = torch.cos(u1 * K)
    sin_u1_K = torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    real_part = P * cos_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    imag_part = P * sin_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    frequency_comb = torch.complex(real_part, imag_part)

    psd = torch.fft.fft(frequency_comb, n=NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (NFFT * Fa)
    psd[psd == 0] = 1.2e-7 #torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)

    n_peaks_2 = n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (NFFT // SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)

    return peaks

In [8]:
device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = 1.0
Rs = 10e9
Vpi = 2.0
NFFT = 64
Fa = 640e9
SpS = 64
n_peaks = 9

class parameters:
    pass

args = parameters()
args.t = t
args.Rs = Rs
args.Vpi = Vpi
args.P = P
args.NFFT = NFFT
args.Fa = Fa
args.SpS = SpS
args.n_peaks = n_peaks

# Define the parameters class with all attributes initialized in __init__
@torch.jit.script
class parameters:
    def __init__(self, t: torch.Tensor, Rs: float, Vpi: float, P: torch.Tensor, NFFT: int, Fa: float, SpS: int, n_peaks: int):
        self.t = t
        self.Rs = Rs
        self.Vpi = Vpi
        self.P = P
        self.NFFT = NFFT
        self.Fa = Fa
        self.SpS = SpS
        self.n_peaks = n_peaks

args = parameters(t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)

out = analytical_function_compact_group_params(params, args)
print(out.shape, out.device)

%timeit analytical_function_compact_group_params(params, args)

OSError: Can't get source for <class '__main__.parameters'>. TorchScript requires source access in order to carry out compilation, make sure original .py files are available.

In [None]:
device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = 1.0
Rs = 10e9
Vpi = 2.0
NFFT = 64
Fa = 640e9
SpS = 64
n_peaks = 9

out = analytical_function_compact_separated_params(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)
print(out.shape, out.device)

%timeit analytical_function_compact_separated_params(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)