In [7]:
import numpy as np
import matplotlib.pyplot as plt 
import scienceplots
from ofc_functions import *
from nn_functions import *
import os, glob

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

In [8]:
import torch.distributions.uniform as urand
from scipy.optimize import minimize
from torch.cuda.amp import autocast

In [9]:
#Trying to use CUDA
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Runnning in {device} mãe")

Runnning in cuda mãe


In [10]:
# Parameters for the OFC signal
ofc_args = parameters()
ofc_args.SpS = 64   # Samples per symbol. Determines the window size
ofc_args.Rs  = 10e9 # Symbol rate (baud rate - symbols per second)
ofc_args.Ts  = 1/ofc_args.Rs # Symbol period (s)
ofc_args.Fa  = 1/(ofc_args.Ts/ofc_args.SpS)  # Sampling frequency of the signal (samples per second)
ofc_args.Ta  = 1/ofc_args.Fa   # Sampling period of the signal (s)
ofc_args.NFFT = ofc_args.SpS   # Number of points of the FFT (multiple of SpS) 
ofc_args.t = ofc_args.Ta * torch.arange(0, ofc_args.SpS*10 + 1, device=device).unsqueeze(0) # Time vector for the signal (s)
ofc_args.Vpi = 2             # Voltage required to achieve a π phase shift (V)
#ofc_args.Vpi_DDMZM = 3.5    # Voltage required to achieve a π phase shift (V)
#ofc_args.Vpi_MZM = 5.5      # Voltage required to achieve a π phase shift (V)
#ofc_args.Vpi_PM = 3.5       # Voltage required to achieve a π phase shift (V)

ofc_args.P = 10**(0/10)/1000 # Power of the signal (W) --> 0 dBm
ofc_args.n_peaks = 9 # Number of peaks to be found in the signal

# Parameters for the modulators cascaded
mod_args = parameters()
mod_args.P_min = 1e-3*10**(-10/10) # Minimum power (W) --> -10 dBm
mod_args.P_max = 1e-3*10**(10/10)  # Maximum power (W) -->  10 dBm
mod_args.V_min = -5    # Minimum amplitude voltage (V)
mod_args.V_max = 5     # Maximum amplitude voltage (V)
mod_args.Phase_min = 0 # Minimum phase (rad)
mod_args.Phase_max = 0 # Maximum phase (rad) --> Phase is periodic with period 2*π
mod_args.Vb_min = -15  # Minimum bias voltage (V)
mod_args.Vb_max = 15   # Maximum bias voltage (V) --> Vb is periodic with period 4*Vπ

mod_args.bounds = [(mod_args.V_min,  mod_args.V_max),  (mod_args.V_min,  mod_args.V_max),  
                   (mod_args.Vb_min, mod_args.Vb_max), (mod_args.Vb_min, mod_args.Vb_max),
                   (mod_args.P_min,  mod_args.P_max)] # Bounds of the parameters


#mod_args.bounds = [(mod_args.V_min,mod_args.V_max),(mod_args.V_min,mod_args.V_max),(mod_args.V_min,mod_args.V_max), 
#        (mod_args.Phase_min, mod_args.Phase_max),(mod_args.Phase_min, mod_args.Phase_max), (mod_args.Phase_min, mod_args.Phase_max),
#        (mod_args.Vb_min,mod_args.Vb_max), (mod_args.Vb_min,mod_args.Vb_max)] # Bounds of the parameters

#mod_args.bounds = [(mod_args.V_min,mod_args.V_max),(mod_args.V_min,mod_args.V_max),(mod_args.V_min,mod_args.V_max), 
#        (mod_args.Phase_min, mod_args.Phase_max),(mod_args.Phase_min, mod_args.Phase_max), (mod_args.Phase_min, mod_args.Phase_max),
#        (mod_args.Vb_min,mod_args.Vb_max)] # Bounds of the parameters

freqs_peaks_GHz = torch.linspace(-(ofc_args.n_peaks//2), ofc_args.n_peaks//2, ofc_args.n_peaks) # Frequency range in GHz for n_peaks


### Testing compact function with and without torch.jit.script decorator

In [61]:
def analytical_function_compact2(params: torch.Tensor, args) -> torch.Tensor:
    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    args: parameters
        Parameters object with the arguments to generate the frequency comb signal
        Should contain the following attributes: t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks
        args.t: torch.Tensor
            Time vector
        args.Rs: float
            Symbol rate (samples per second)
        args.Vpi: float
            Half-wave voltage of the MZM (V)
        args.P: float
            Power of the frequency comb signal (W)
        args.NFFT: int
            Number of points of the FFT (multiple of SpS)
        args.Fa: float
            Sampling frequency of the signal (samples per second)
        args.SpS: int
            Samples by each Rs
        args.n_peaks: int
            Number of peaks to be found (odd number)

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal

    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    t = args.t#.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * args.Rs * t
    K = torch.pi / args.Vpi
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K, sin_u1_K = torch.cos(u1 * K), torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    frequency_comb = args.P * (cos_u1_K + 1j * sin_u1_K) * cos_u2_Vb2_K * cos_u3_Vb3_K
    
    psd = torch.fft.fft(frequency_comb, n=args.NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (args.NFFT * args.Fa)
    psd[psd == 0] = torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)
    
    n_peaks_2 = args.n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (args.NFFT // args.SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)
    return peaks

In [62]:
def analytical_function_compact3(params: torch.Tensor, t: torch.Tensor, Rs: torch.Tensor, Vpi: torch.Tensor, P: torch.Tensor, NFFT: torch.Tensor, Fa: torch.Tensor, SpS: torch.Tensor, n_peaks: torch.Tensor) -> torch.Tensor:
    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    args: parameters
        Parameters object with the arguments to generate the frequency comb signal
        Should contain the following attributes: t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks
        args.t: torch.Tensor
            Time vector
        args.Rs: float
            Symbol rate (samples per second)
        args.Vpi: float
            Half-wave voltage of the MZM (V)
        args.P: float
            Power of the frequency comb signal (W)
        args.NFFT: int
            Number of points of the FFT (multiple of SpS)
        args.Fa: float
            Sampling frequency of the signal (samples per second)
        args.SpS: int
            Samples by each Rs
        args.n_peaks: int
            Number of peaks to be found (odd number)

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal

    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    #t = t.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * Rs * t
    K = torch.pi / Vpi
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K, sin_u1_K = torch.cos(u1 * K), torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    frequency_comb = P * (cos_u1_K + 1j * sin_u1_K) * cos_u2_Vb2_K * cos_u3_Vb3_K
    
    psd = torch.fft.fft(frequency_comb, n=NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (NFFT * Fa)
    psd[psd == 0] = torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)
    
    n_peaks_2 = n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (NFFT // SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)
    return peaks.to('cuda')

In [63]:
@torch.jit.script
def analytical_function_jit2(params: torch.Tensor, t: torch.Tensor, Rs: torch.Tensor, Vpi: torch.Tensor, P: torch.Tensor, NFFT: torch.Tensor, Fa: torch.Tensor, SpS: torch.Tensor, n_peaks: torch.Tensor) -> torch.Tensor:

    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM using JIT

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    t: torch.Tensor
        Time vector
    Rs: torch.Tensor
        Symbol rate (samples per second)
    Vpi: torch.Tensor
        Half-wave voltage of the MZM (V)
    P: torch.Tensor
        Power of the frequency comb signal (W)
    NFFT: torch.Tensor
        Number of points of the FFT
    Fa: torch.Tensor
        Sampling frequency of the signal (samples per second)
    SpS: torch.Tensor
        Samples by each Rs
    n_peaks: torch.Tensor
        Number of peaks to be found

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal
    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    #t = t.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * Rs * t
    K = torch.pi / Vpi

    #with autocast(enabled=True):
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K = torch.cos(u1 * K)
    sin_u1_K = torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    real_part = P * cos_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    imag_part = P * sin_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    frequency_comb = torch.complex(real_part, imag_part)

    psd = torch.fft.fft(frequency_comb, n=NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (NFFT * Fa)
    psd[psd == 0] = 1e-8 # torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)

    n_peaks_2 = n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (NFFT // SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)

    return peaks

In [68]:
@torch.jit.script
def analytical_function_jit3(params: torch.Tensor, t: torch.Tensor, Rs: torch.Tensor, Vpi: torch.Tensor, P: torch.Tensor, NFFT: torch.Tensor, Fa: torch.Tensor, SpS: torch.Tensor, n_peaks: torch.Tensor) -> torch.Tensor:

    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM using JIT

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    t: torch.Tensor
        Time vector
    Rs: torch.Tensor
        Symbol rate (samples per second)
    Vpi: torch.Tensor
        Half-wave voltage of the MZM (V)
    P: torch.Tensor
        Power of the frequency comb signal (W)
    NFFT: torch.Tensor
        Number of points of the FFT
    Fa: torch.Tensor
        Sampling frequency of the signal (samples per second)
    SpS: torch.Tensor
        Samples by each Rs
    n_peaks: torch.Tensor
        Number of peaks to be found

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal
    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    #t = t.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * Rs * t
    K = torch.pi / Vpi

    #with autocast(enabled=True):
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K = torch.cos(u1 * K)
    sin_u1_K = torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    real_part = P * cos_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    imag_part = P * sin_u1_K * cos_u2_Vb2_K * cos_u3_Vb3_K
    frequency_comb = torch.complex(real_part, imag_part)

    psd = torch.fft.fft(frequency_comb, n=NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (NFFT * Fa)
    psd[psd == 0] = 1e-8 # torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)

    n_peaks_2 = n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (NFFT // SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)

    return peaks.to('cuda')

In [69]:
def analytical_function_compact4(params: torch.Tensor, t: torch.Tensor, Rs: float, Vpi: float, P: float, NFFT: float, Fa: float, SpS: float, n_peaks: float) -> torch.Tensor:
    '''
    Function to generate the frequency comb signal peaks of PM-MZM-MZM

    Parameters:
    params: torch.Tensor
        Parameters to generate the frequency comb signal
    args: parameters
        Parameters object with the arguments to generate the frequency comb signal
        Should contain the following attributes: t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks
        args.t: torch.Tensor
            Time vector
        args.Rs: float
            Symbol rate (samples per second)
        args.Vpi: float
            Half-wave voltage of the MZM (V)
        args.P: float
            Power of the frequency comb signal (W)
        args.NFFT: int
            Number of points of the FFT (multiple of SpS)
        args.Fa: float
            Sampling frequency of the signal (samples per second)
        args.SpS: int
            Samples by each Rs
        args.n_peaks: int
            Number of peaks to be found (odd number)

    Returns:
    peaks: torch.Tensor
        Peaks of the power spectrum of the frequency comb signal

    '''
    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = [p.contiguous() for p in params.T.unsqueeze(2)]

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    #t = t.unsqueeze(0)
    pi_Rs_t = 2 * torch.pi * Rs * t
    K = torch.pi / Vpi
    u1 = V1 * torch.cos(pi_Rs_t + Phase1)
    u2 = V2 * torch.cos(pi_Rs_t + Phase2)
    u3 = V3 * torch.cos(pi_Rs_t + Phase3)
    cos_u1_K, sin_u1_K = torch.cos(u1 * K), torch.sin(u1 * K)
    cos_u2_Vb2_K = torch.cos((0.5 * (u2 + Vb2)) * K)
    cos_u3_Vb3_K = torch.cos((0.5 * (u3 + Vb3)) * K)
    frequency_comb = P * (cos_u1_K + 1j * sin_u1_K) * cos_u2_Vb2_K * cos_u3_Vb3_K
    
    psd = torch.fft.fft(frequency_comb, n=NFFT, dim=-1)
    psd = torch.fft.fftshift(psd, dim=-1)
    psd = torch.abs(psd) ** 2 / (NFFT * Fa)
    psd[psd == 0] = torch.finfo(psd.dtype).eps
    log_Pxx = 10 * torch.log10(psd)
    
    n_peaks_2 = n_peaks // 2
    indx = (log_Pxx.size(-1) // 2) + torch.arange(-n_peaks_2, n_peaks_2 + 1, device=params.device).to(torch.int) * (NFFT // SpS)
    peaks = torch.index_select(log_Pxx, 1, indx)
    return peaks.to('cuda')

In [75]:
device = "cuda"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)

out = analytical_function_compact2(params, ofc_args)
print(out.shape, out.device)

%timeit analytical_function_compact2(params, ofc_args)


torch.Size([1, 9]) cuda:0
335 μs ± 9.12 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [76]:

device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32, device=device)
Rs = torch.tensor(10e9, dtype=torch.float32, device=device)
Vpi = torch.tensor(2.0, dtype=torch.float32, device=device)
NFFT = torch.tensor(64, dtype=torch.int32, device=device)
Fa = torch.tensor(640e9, dtype=torch.float32, device=device)
SpS = torch.tensor(64, dtype=torch.int32, device=device)
n_peaks = torch.tensor(9, dtype=torch.int32, device=device)

out = analytical_function_compact3(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)
print(out.shape, out.device)

%timeit analytical_function_compact3(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)

torch.Size([1, 9]) cuda:0
196 μs ± 12.7 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [77]:

device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = 1.0
Rs = 10e9
Vpi = 2.0
NFFT = 64
Fa = 640e9
SpS = 64
n_peaks = 9

out = analytical_function_compact4(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)
print(out.shape, out.device)

%timeit analytical_function_compact4(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)

torch.Size([1, 9]) cuda:0
206 μs ± 7.63 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [78]:

device = "cpu"

params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32, device=device)
Rs = torch.tensor(10e9, dtype=torch.float32, device=device)
Vpi = torch.tensor(2.0, dtype=torch.float32, device=device)
NFFT = torch.tensor(64, dtype=torch.int32, device=device)
Fa = torch.tensor(640e9, dtype=torch.float32, device=device)
SpS = torch.tensor(64, dtype=torch.int32, device=device)
n_peaks = torch.tensor(9, dtype=torch.int32, device=device)

out = analytical_function_jit3(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)
print(out.shape, out.device)

%timeit analytical_function_jit3(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)

torch.Size([1, 9]) cuda:0
162 μs ± 6.94 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [79]:
device = "cuda"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32, device=device).unsqueeze(0) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32, device=device)
Rs = torch.tensor(10e9, dtype=torch.float32, device=device)
Vpi = torch.tensor(2.0, dtype=torch.float32, device=device)
NFFT = torch.tensor(64, dtype=torch.int32, device=device)
Fa = torch.tensor(640e9, dtype=torch.float32, device=device)
SpS = torch.tensor(64, dtype=torch.int32, device=device)
n_peaks = torch.tensor(9, dtype=torch.int32, device=device)

out = analytical_function_jit2(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)
print(out.shape, out.device)

%timeit analytical_function_jit2(params, t, Rs, Vpi, P, NFFT, Fa, SpS, n_peaks)

torch.Size([1, 9]) cuda:0
506 μs ± 14.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Using the decorator torch.jit.script with 'cpu' is faster!!!

### Testing some basic functions with the torch.jit.script decorator

In [84]:

@torch.jit.script
def pm_jit(Ai: torch.Tensor, u: torch.Tensor, Vpi: torch.Tensor) -> torch.Tensor:
    K = torch.pi / Vpi
    #Ao = Ai * torch.exp(1j * K * u)
    Ao = Ai * (torch.cos(K * u) + 1j * torch.sin(K * u))
    return Ao

@torch.jit.script
def mzm_jit(Ai: torch.Tensor, u: torch.Tensor, Vpi: torch.Tensor, Vb: torch.Tensor) -> torch.Tensor:

    K = torch.pi/Vpi
    Ao = Ai * torch.cos(0.5 * K * (u + Vb))

    return Ao

@torch.jit.script
def frequencyCombGenerator_MZM_MZM_PM_jit(params: torch.Tensor, Rs: torch.Tensor, t: torch.Tensor, P:torch.Tensor, Vpi:torch.Tensor) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''

    #V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting
    
    # Extract individual elements from params using indexing

    V1 = params[:,0].unsqueeze(1)
    V2 = params[:,1].unsqueeze(1)
    V3 = params[:,2].unsqueeze(1)
    Phase1 = params[:,3].unsqueeze(1)
    Phase2 = params[:,4].unsqueeze(1)
    Phase3 = params[:,5].unsqueeze(1)
    Vb2 = params[:,6].unsqueeze(1)
    Vb3 = params[:,7].unsqueeze(1)

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm_jit(frequency_comb, u1, Vpi)
    frequency_comb = mzm_jit(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm_jit(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb.to("cuda")


device = "cpu"
params = torch.tensor([1,  1,  1,  1, 1,1,1,1]).unsqueeze(0).to(device)
Rs = torch.tensor(10e9).to(device)
t = torch.arange(0, 641).unsqueeze(0).to(device)*1.5625e-12
P = torch.tensor([1], device=device)
Vpi = torch.tensor([2], device=device)

oioi1 = frequencyCombGenerator_MZM_MZM_PM_jit(params, Rs, t, P, Vpi);
oioi2 = frequencyCombGenerator_MZM_MZM_PM_jit(params, Rs, t, P, Vpi);
oioi3 = frequencyCombGenerator_MZM_MZM_PM_jit(params, Rs, t, P, Vpi);



In [85]:

def frequencyCombGenerator_MZM_MZM_PM2(params: torch.Tensor, Rs: torch.Tensor, t: torch.Tensor, P: torch.Tensor, Vpi:torch.Tensor) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm(frequency_comb, u1, Vpi)
    frequency_comb = mzm(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb


def frequencyCombGenerator_MZM_MZM_PM22(params: torch.Tensor, Rs: torch.Tensor, t: torch.Tensor, P: torch.Tensor, Vpi:torch.Tensor) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm(frequency_comb, u1, Vpi)
    frequency_comb = mzm(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb.to("cuda")


In [86]:
def frequencyCombGenerator_MZM_MZM_PM3(params: torch.Tensor, Rs: float, t: torch.Tensor, P: torch.Tensor, Vpi:float) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm(frequency_comb, u1, Vpi)
    frequency_comb = mzm(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb

def frequencyCombGenerator_MZM_MZM_PM33(params: torch.Tensor, Rs: float, t: torch.Tensor, P: torch.Tensor, Vpi:float) -> torch.Tensor:

    '''
    This function generates a frequency comb signal using a PM and two MZMs.
    
    Parameters:

    params: torch.Tensor
        Parameters to generate the frequency comb signal
        params.V1, params.V2, params.V3: float
            Amplitude of the signals (V)
        params.Phase1, params.Phase2, params.Phase3: float
            Phase of the signals (rad)
        params.Vb2, params.Vb3: float
            Bias voltage of the MZMs (V)
    Rs: float
        Symbol rate (samples per second)
    t: torch.Tensor
        Time vector
    P: float
        Power of the frequency comb signal (W)
    Vpi: float
        Half-wave voltage of the MZM (V)

    Returns:
    frequency_comb: array
        Frequency comb signal
    '''
    V1, V2, V3, Phase1, Phase2, Phase3, Vb2, Vb3 = params.T.unsqueeze(2) # Reshape each one to [N, 1] for broadcasting

    omega_m = 2 * torch.pi * Rs
    pi_Rs_t = omega_m * t
    u1, u2, u3 = [V * torch.cos(pi_Rs_t + Phase) for V, Phase in zip([V1, V2, V3], [Phase1, Phase2, Phase3])]

    frequency_comb = P
    frequency_comb =  pm(frequency_comb, u1, Vpi)
    frequency_comb = mzm(frequency_comb, u2, Vpi, Vb2)
    frequency_comb = mzm(frequency_comb, u3, Vpi, Vb3)

    return frequency_comb.to("cuda")

In [88]:

device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = torch.tensor(10e9, dtype=torch.float32).to(device)
Vpi = torch.tensor(2.0, dtype=torch.float32).to(device)

print("Using CPU and JIT")
out = frequencyCombGenerator_MZM_MZM_PM_jit(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_MZM_MZM_PM_jit(params, Rs, t, P, Vpi)


device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = torch.tensor(10e9, dtype=torch.float32).to(device)
Vpi = torch.tensor(2.0, dtype=torch.float32).to(device)

print("\nUsing CPU, and all tensors")
out = frequencyCombGenerator_MZM_MZM_PM22(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_MZM_MZM_PM22(params, Rs, t, P, Vpi)

device = "cpu"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = 10e9
Vpi = 2.0

print("\nUsing CPU, and somes floats")
out = frequencyCombGenerator_MZM_MZM_PM33(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_MZM_MZM_PM33(params, Rs, t, P, Vpi)

device = "cuda"
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = torch.tensor(10e9, dtype=torch.float32).to(device)
Vpi = torch.tensor(2.0, dtype=torch.float32).to(device)

print("\nUsing CUDA, and all tensors")
out = frequencyCombGenerator_MZM_MZM_PM2(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_MZM_MZM_PM2(params, Rs, t, P, Vpi)

print("\nUsing CUDA, and some floats")
params = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).unsqueeze(0).to(device)
t = torch.arange(0, 641, dtype=torch.float32).unsqueeze(0).to(device) * 1.5625e-12
P = torch.tensor(1.0, dtype=torch.float32).to(device)
Rs = 10e9
Vpi = 2.0

out = frequencyCombGenerator_MZM_MZM_PM3(params, Rs, t, P, Vpi)
print(out.shape, out.device)

%timeit frequencyCombGenerator_MZM_MZM_PM3(params, Rs, t, P, Vpi)


Using CPU and JIT
torch.Size([1, 641]) cuda:0
86 μs ± 13.5 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Using CPU, and all tensors
torch.Size([1, 641]) cuda:0
120 μs ± 11.9 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Using CPU, and somes floats
torch.Size([1, 641]) cuda:0
92.7 μs ± 13 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Using CUDA, and all tensors
torch.Size([1, 641]) cuda:0
223 μs ± 2.43 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Using CUDA, and some floats
torch.Size([1, 641]) cuda:0
144 μs ± 1.3 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Using device = "cpu" with @torch.jit.script is faster!!!