<a href="https://colab.research.google.com/github/cnolascof/SpellCheck/blob/main/JAX_Trials_V0/PA_ZTx_Opt_Trial_NB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PA Impedance Optimization for Maximum Gain at User



In [34]:
# Necessary imports
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import grad
from jax import device_put, device_get
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec
import jax.lax
import autograd
from autograd import grad as agrad

import time
from scipy.io import loadmat
import numpy as np
from google.colab import drive
from dataclasses import dataclass, field
from flax import struct

In [2]:
# Check JAX version
# Note: most recent version is 0.5.0
# If output is older version, run next cell to update JAX and restart runtime
print(f"JAX version: {jax.__version__}")

# Enable 64 precision
from jax import config
config.update("jax_enable_x64", True)

# List available devices
devices = jax.devices()
print("Available devices:", devices)

JAX version: 0.5.0
Available devices: [CpuDevice(id=0)]


In [3]:
# Import latest version of JAX (if necessary) : latest version is 0.5.0
# Note: after updating will need to restart runtime
!pip install --upgrade "jax[cpu]"  # For CPU-only support
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Collecting jax[cpu]
  Downloading jax-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.0,>=0.5.0 (from jax[cpu])
  Downloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (978 bytes)
Downloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl (102.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.0/102.0 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jax-0.5.0-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m55.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.33
    Uninstalling jaxlib-0.4.33:
      Successfully uninstalled jaxlib-0.4.33
  Attempting uninstall: jax
    Found existing installation: jax 0.4.33
    Uninstalling jax-0.4.33:
      Successfully uninstalled jax-0.4.33
Successfully installed jax-0.5.0 jaxlib-0.5.0


Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax-cuda12-plugin<=0.5.0,>=0.5.0 (from jax-cuda12-plugin[with_cuda]<=0.5.0,>=0.5.0; extra == "cuda"->jax[cuda])
  Downloading jax_cuda12_plugin-0.5.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting jax-cuda12-pjrt==0.5.0 (from jax-cuda12-plugin<=0.5.0,>=0.5.0->jax-cuda12-plugin[with_cuda]<=0.5.0,>=0.5.0; extra == "cuda"->jax[cuda])
  Downloading jax_cuda12_pjrt-0.5.0-py3-none-manylinux2014_x86_64.whl.metadata (348 bytes)
Collecting nvidia-cuda-nvcc-cu12>=12.6.85 (from jax-cuda12-plugin[with_cuda]<=0.5.0,>=0.5.0; extra == "cuda"->jax[cuda])
  Downloading nvidia_cuda_nvcc_cu12-12.8.61-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Downloading jax_cuda12_plugin-0.5.0-cp311-cp311-manylinux2014_x86_64.whl (16.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.7/16.7 MB[0m [31m85.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ja

In [35]:
# Import Google Drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Define Architecture parameters

*   RF Frontend: 3 PAs w/ tunable impedances to optimize, z_tx_i
*   Tunning Network: none, short connection between RF and Rad Structure
*   Radiating Structure: 3 dipole antennas at d=lambda/2

In [36]:
# Define System Parameters

@dataclass(frozen=True)
class SystemParams:
    N: int    # Number of PAs
    M: int    # Number of Antennas
    R0: float # Reference impedance
    p_ctr: float  # Power constraint
    f: float  # Frequency (Hz)
    d_ratio: float  # Ratio of lambda for antenna distance
    c: float = 299792458  # Speed of light (m/s), default value

    # Dependent variables are computed
    lambda_: float = field(init=False)  # Wavelength
    k: float = field(init=False)  # Wave number
    d: float = field(init=False)  # Antenna distance

    def __post_init__(self):
        # Compute dependent variables
        object.__setattr__(self, "lambda_", self.c / self.f)  # Wavelength = c / f
        object.__setattr__(self, "k", 2 * jnp.pi / self.lambda_)  # Wave number = 2π / λ
        object.__setattr__(self, "d", self.d_ratio * self.lambda_)  # Wave number = 2π / λ

sys_params = SystemParams(
    N = 3,
    M = 3,
    R0 = 50,
    p_ctr =1,
    f = 12e9,
    d_ratio=0.5,
    c = 299792458,
)

In [37]:
# Define REMS parameters

# Note: Z_Tx defined separately as these values will be dynamically changed
@struct.dataclass
class REMSParams:
    S_T_tt: jnp.ndarray         # Tunning network S parameters
    S_T_tr: jnp.ndarray
    S_T_rt: jnp.ndarray
    S_T_rr: jnp.ndarray
    Omega: jnp.ndarray          # Sampled FF space
    S_R_RR: jnp.ndarray         # Radiating structure S parameters
    s_R_FR_theta: jnp.ndarray
    s_R_FR_phi: jnp.ndarray

# Load the mat file
mat_file_path = "/content/drive/MyDrive/MST_Carolina/dipole_3_d05_s_R_degstep4_inter0.5.mat"
s_R_FR_data = loadmat(mat_file_path)

# Create the REMSParams object
REMS_params = REMSParams(
    S_T_tt=jnp.zeros((sys_params.N, sys_params.N)),
    S_T_tr=jnp.identity(sys_params.N),
    S_T_rt=jnp.identity(sys_params.M),
    S_T_rr=jnp.zeros((sys_params.M, sys_params.M)),
    Omega=s_R_FR_data["Omega"],
    S_R_RR=jnp.array(s_R_FR_data["S_R_RR"]),
    s_R_FR_theta=jnp.array(s_R_FR_data["s_R_FR"][:, :, 0]).T,
    s_R_FR_phi=jnp.array(s_R_FR_data["s_R_FR"][:, :, 1]).T,
)

## Far Field functions

Inputs - Same for all functions

*   Sp: system parameters class (static)
*   Rp: REMS parameters class
*   z_tx: array containing PA impedances
*   v_Tx_exact: exact beamsteering vector
*   idx_Omega: index in sampled far field corresponding to target user location
*   s: symbol to be sent to the user (by default s = 1)

Outputs - Depending on the function

*   far_field_at_user_only: (g_A, g_T, g_R) tuple for the gain at the user location
*   far_field_at_user_only_tograd: g_A at the user location (function modified to be able to apply jit gradient)
*   far_field_at_user_full: (g_A, g_T, g_R) tuple for the gain at all sampled space locations
*   v_Tx_calc: v_Tx, exact beamsteering vector for given user location




In [38]:
# Function to calculate Far Field ONLY at user location
def far_field_at_user_only(Sp, Rp, z_tx, v_Tx_exact, idx_Omega, s):

    # Auxiliary variables for far field results
    In = jnp.identity(Sp.N)
    Im = jnp.identity(Sp.M)
    Z_tx = jnp.diag(z_tx)

    P_A_mat = jnp.asarray(jnp.linalg.solve(jnp.real(Z_tx), jnp.eye(jnp.real(Z_tx).shape[0])))
    S_RF = jnp.asarray(jnp.linalg.solve(Z_tx + Sp.R0 * In, Z_tx - Sp.R0 * In))

    L1 = S_RF @ Rp.S_T_tt
    L2 = Rp.S_T_rr @ Rp.S_R_RR
    aux1 = jnp.linalg.solve(Im - L2, Rp.S_T_rt)
    L3 = S_RF @ Rp.S_T_tr @ Rp.S_R_RR @ aux1
    K_vTx = jnp.linalg.solve(Z_tx + Sp.R0 * In, In) * jnp.sqrt(Sp.R0)

    # G operators
    G_vTx_aT = jnp.linalg.solve(In - L1 - L3, K_vTx)
    G_vTx_bT = (Rp.S_T_tr @ Rp.S_R_RR @ aux1 + Rp.S_T_tt) @ G_vTx_aT

    G_vTx_aR = aux1 @ G_vTx_aT
    G_vTx_bR = Rp.S_R_RR @ G_vTx_aR

    # Get s_R_FR slice we want
    s_R_FR_theta_user = jax.lax.dynamic_slice(Rp.s_R_FR_theta, (idx_Omega, 0), (1, Sp.M))
    s_R_FR_phi_user = jax.lax.dynamic_slice(Rp.s_R_FR_phi, (idx_Omega, 0), (1, Sp.M))

    G_vTx_aF_theta = s_R_FR_theta_user @ G_vTx_aR
    G_vTx_aF_phi = s_R_FR_phi_user @ G_vTx_aR

    # EXACT BEAMSTEERING VECTOR

    # Compute Power Metrics
    P_A_exact_exact = jnp.real((1/4))*((jnp.conjugate(v_Tx_exact).T @ P_A_mat) @ v_Tx_exact)

    aT_exact = G_vTx_aT @ v_Tx_exact
    bT_exact = G_vTx_bT @ v_Tx_exact
    P_T_exact_exact = jnp.linalg.norm(aT_exact)**2 - jnp.linalg.norm(bT_exact)**2

    aR_exact = G_vTx_aR @ v_Tx_exact
    bR_exact = G_vTx_bR @ v_Tx_exact
    P_R_exact_exact = jnp.linalg.norm(aR_exact)**2 - jnp.linalg.norm(bR_exact)**2

    # EXACT MODEL w/ EXACT BEAMSTEERING VECTOR

    # Far field
    aF_theta_exact = G_vTx_aF_theta @ v_Tx_exact
    aF_phi_exact = G_vTx_aF_phi @ v_Tx_exact

    # Normalize
    aF_normsq_exact = jnp.abs(aF_theta_exact)**2 + jnp.abs(aF_phi_exact)**2;

    # Gain metrics
    g_A_exact_vTx_exact = (4 * jnp.pi / P_A_exact_exact) * aF_normsq_exact
    g_T_exact_vTx_exact = (4 * jnp.pi / P_T_exact_exact) * aF_normsq_exact
    g_R_exact_vTx_exact = (4 * jnp.pi / P_R_exact_exact) * aF_normsq_exact

    return (g_A_exact_vTx_exact, g_T_exact_vTx_exact, g_R_exact_vTx_exact)

# Function to calculate Far Field ONLY at user location
# This is the function that will be used for the gradient, meaning
# a) it only computes gA
# b) uses .item() to return a scalar (doesn't work w JIT compilation)
# small note: g_A will be a positive scalar, and grad requires the output to be a scalar,
# which is why the norm is computed at the end
def far_field_at_user_only_tograd(Sp, Rp, z_tx, v_Tx_exact, idx_Omega, s):

    # Auxiliary variables for far field results
    In = jnp.identity(Sp.N)
    Im = jnp.identity(Sp.M)
    Z_tx = jnp.diag(z_tx)

    P_A_mat = jnp.asarray(jnp.linalg.solve(jnp.real(Z_tx), jnp.eye(jnp.real(Z_tx).shape[0])))
    S_RF = jnp.asarray(jnp.linalg.solve(Z_tx + Sp.R0 * In, Z_tx - Sp.R0 * In))

    L1 = S_RF @ Rp.S_T_tt
    L2 = Rp.S_T_rr @ Rp.S_R_RR
    aux1 = jnp.linalg.solve(Im - L2, Rp.S_T_rt)
    L3 = S_RF @ Rp.S_T_tr @ Rp.S_R_RR @ aux1
    K_vTx = jnp.linalg.solve(Z_tx + Sp.R0 * In, In) * jnp.sqrt(Sp.R0)

    # G operators
    G_vTx_aT = jnp.linalg.solve(In - L1 - L3, K_vTx)
    G_vTx_aR = aux1 @ G_vTx_aT

    # Get s_R_FR slice we want
    s_R_FR_theta_user = jax.lax.dynamic_slice(Rp.s_R_FR_theta, (idx_Omega, 0), (1, Sp.M))
    s_R_FR_phi_user = jax.lax.dynamic_slice(Rp.s_R_FR_phi, (idx_Omega, 0), (1, Sp.M))

    G_vTx_aF_theta = s_R_FR_theta_user @ G_vTx_aR
    G_vTx_aF_phi = s_R_FR_phi_user @ G_vTx_aR

    # EXACT BEAMSTEERING VECTOR

    # Compute Power Metrics
    P_A_exact_exact = jnp.real((1/4))*((jnp.conjugate(v_Tx_exact).T @ P_A_mat) @ v_Tx_exact)

    # EXACT MODEL w/ EXACT BEAMSTEERING VECTOR

    # Far field
    aF_theta_exact = G_vTx_aF_theta @ v_Tx_exact
    aF_phi_exact = G_vTx_aF_phi @ v_Tx_exact

    # Normalize
    aF_normsq_exact = jnp.abs(aF_theta_exact)**2 + jnp.abs(aF_phi_exact)**2;

    # Gain metrics
    g_A_exact_vTx_exact = jnp.real((4 * jnp.pi / P_A_exact_exact) * aF_normsq_exact).item()
    g_A_exact_vTx_exact = jnp.real((4 * jnp.pi / P_A_exact_exact) * aF_normsq_exact)

    return jnp.linalg.norm(g_A_exact_vTx_exact)

# Function to calculate full Far Field for user location at specific location
def far_field_at_user_full(Sp, Rp, z_tx, v_Tx_exact, idx_Omega, s):

    # Auxiliary variables for far field results
    In = jnp.identity(Sp.N)
    Im = jnp.identity(Sp.M)
    Z_tx = jnp.diag(z_tx)

    P_A_mat = jnp.asarray(jnp.linalg.solve(jnp.real(Z_tx), jnp.eye(jnp.real(Z_tx).shape[0])))
    S_RF = jnp.asarray(jnp.linalg.solve(Z_tx + Sp.R0 * In, Z_tx - Sp.R0 * In))

    L1 = S_RF @ Rp.S_T_tt
    L2 = Rp.S_T_rr @ Rp.S_R_RR
    aux1 = jnp.linalg.solve(Im - L2, Rp.S_T_rt)
    L3 = S_RF @ Rp.S_T_tr @ Rp.S_R_RR @ aux1
    K_vTx = jnp.linalg.solve(Z_tx + Sp.R0 * In, In) * jnp.sqrt(Sp.R0)

    # G operators
    G_vTx_aT = jnp.linalg.solve(In - L1 - L3, K_vTx)
    G_vTx_bT = (Rp.S_T_tr @ Rp.S_R_RR @ aux1 + Rp.S_T_tt) @ G_vTx_aT

    G_vTx_aR = aux1 @ G_vTx_aT
    G_vTx_bR = Rp.S_R_RR @ G_vTx_aR

    G_vTx_aF_theta = jnp.matmul(Rp.s_R_FR_theta, G_vTx_aR)
    G_vTx_aF_phi = jnp.matmul(Rp.s_R_FR_phi, G_vTx_aR)

    # Compute Power Metrics
    P_A_exact_exact = jnp.real((1/4))*((jnp.conjugate(v_Tx_exact).T @ P_A_mat) @ v_Tx_exact)

    aT_exact = G_vTx_aT @ v_Tx_exact
    bT_exact = G_vTx_bT @ v_Tx_exact
    P_T_exact_exact = jnp.linalg.norm(aT_exact)**2 - jnp.linalg.norm(bT_exact)**2

    aR_exact = G_vTx_aR @ v_Tx_exact
    bR_exact = G_vTx_bR @ v_Tx_exact
    P_R_exact_exact = jnp.linalg.norm(aR_exact)**2 - jnp.linalg.norm(bR_exact)**2

    # EXACT MODEL w/ EXACT BEAMSTEERING VECTOR

    # Far field
    aF_theta_exact = G_vTx_aF_theta @ v_Tx_exact
    aF_phi_exact = G_vTx_aF_phi @ v_Tx_exact

    # Normalize
    aF_normsq_exact = jnp.abs(aF_theta_exact)**2 + jnp.abs(aF_phi_exact)**2;

    # Gain metrics
    g_A_exact_vTx_exact = (4 * jnp.pi / P_A_exact_exact) * aF_normsq_exact
    g_T_exact_vTx_exact = (4 * jnp.pi / P_T_exact_exact) * aF_normsq_exact
    g_R_exact_vTx_exact = (4 * jnp.pi / P_R_exact_exact) * aF_normsq_exact

    return (g_A_exact_vTx_exact, g_T_exact_vTx_exact, g_R_exact_vTx_exact)

# Function to calculate v_Tx (only the first time)
def v_Tx_calc(Sp, Rp, z_tx, idx_Omega, s):

    # Auxiliary variables for far field results
    In = jnp.identity(Sp.N)
    Im = jnp.identity(Sp.M)
    Z_tx = jnp.diag(z_tx)

    P_A_mat = jnp.asarray(jnp.linalg.solve(jnp.real(Z_tx), jnp.eye(jnp.real(Z_tx).shape[0])))
    S_RF = jnp.asarray(jnp.linalg.solve(Z_tx + Sp.R0 * In, Z_tx - Sp.R0 * In))

    L1 = S_RF @ Rp.S_T_tt
    L2 = Rp.S_T_rr @ Rp.S_R_RR
    aux1 = jnp.linalg.solve(Im - L2, Rp.S_T_rt)
    L3 = S_RF @ Rp.S_T_tr @ Rp.S_R_RR @ aux1
    K_vTx = jnp.linalg.solve(Z_tx + Sp.R0 * In, In) * jnp.sqrt(Sp.R0)

    # G operators
    G_vTx_aT = jnp.linalg.solve(In - L1 - L3, K_vTx)
    G_vTx_bT = (Rp.S_T_tr @ Rp.S_R_RR @ aux1 + Rp.S_T_tt) @ G_vTx_aT

    G_vTx_aR = aux1 @ G_vTx_aT
    G_vTx_bR = Rp.S_R_RR @ G_vTx_aR

    # Get s_R_FR slice we want
    s_R_FR_theta_user = jax.lax.dynamic_slice(Rp.s_R_FR_theta, (idx_Omega, 0), (1, Sp.M))
    s_R_FR_phi_user = jax.lax.dynamic_slice(Rp.s_R_FR_phi, (idx_Omega, 0), (1, Sp.M))

    G_vTx_aF_theta = s_R_FR_theta_user @ G_vTx_aR
    G_vTx_aF_phi = s_R_FR_phi_user @ G_vTx_aR

    # EXACT BEAMSTEERING VECTOR

    # Extract the values of the G operator where the user is
    G_vTx_aF_user = jnp.vstack([G_vTx_aF_theta, G_vTx_aF_phi])

    # Create A matrix
    inv_sqrt_P_A_mat = jnp.linalg.solve(jnp.sqrt(P_A_mat), jnp.eye(jnp.sqrt(P_A_mat).shape[0]))
    G_vTx_aF_user_H = jnp.conjugate(G_vTx_aF_user).T

    A = (inv_sqrt_P_A_mat @ G_vTx_aF_user_H) @ (G_vTx_aF_user @ inv_sqrt_P_A_mat)

    # Extract eigenvector corresponding to largest eigenvalue
    A_eigenvals, A_eigenvecs = jnp.linalg.eig(A)

    idx_A = jnp.argmax(jnp.abs(A_eigenvals))  # Index of the largest eigenvalue
    x1 = A_eigenvecs[:, idx_A]

    # Compute v_Tx
    v_Tx_exact = (jnp.sqrt(Sp.p_ctr)*(1/jnp.linalg.norm(x1)))*(inv_sqrt_P_A_mat @ x1)

    return v_Tx_exact

In [88]:
# Define transmission parameters

# Symbol to send
s = 1

# User
theta_user = 90
phi_user = 90
idx = jnp.where((REMS_params.Omega[:, 0] == theta_user) & (REMS_params.Omega[:, 1] == phi_user))[0]
idx_Omega = idx[0].item()

## Optimization loop

while (improvement > min)

- update rule on z_tx: z_tx_t1 = z_tx_t + learning_rate*grad_wrt_ztx(z_tx_t)
- improvement: ff(z_tx_t1) - ff(z_tx_t)
- if loop iteration is multiple of 5, update v_Tx

In [92]:
# Preliminary parameters
z_tx_init = 75                                   # initial PA impedances
z_tx_t = 0                                       # t current parameter
z_tx_t1 = z_tx_init*jnp.ones(sys_params.N)       # t + 1 next parameter
v_Tx = v_Tx_calc(sys_params, REMS_params, z_tx_t1, idx_Omega,s)   # exact beamsteering vector

lr = 0.01                  # learning rate
diff_min = 10**-5         # improvement difference that breaks the loop
diff = 1                  # initiate diff
epsilon = 10**-8          # small constant to avoid dividing by zero

# Preliminary functions
ff_grad_wrt_z_tx = grad(far_field_at_user_only_tograd, 2)                 # gradient wrt z_tx
jit_ff_user_only = jit(far_field_at_user_only, static_argnums=(0,4,5))    # jit compiled ff
jit_vTx_calc = jit(v_Tx_calc, static_argnums=(0,3,4))                     # jit compiled v_Tx

In [93]:
cont = 0

while diff > diff_min:

  # Update parameter
  z_tx_t = z_tx_t1
  grad_t = ff_grad_wrt_z_tx(sys_params, REMS_params, z_tx_t, v_Tx, idx_Omega, s)
  z_tx_t1 = z_tx_t + lr*grad_t/(jnp.linalg.norm(grad_t)+ epsilon)

  # Calculate improvement
  arg1 = jnp.real(jit_ff_user_only(sys_params, REMS_params, z_tx_t1, v_Tx, idx_Omega, s)[0])
  arg2 = jnp.real(jit_ff_user_only(sys_params, REMS_params, z_tx_t, v_Tx, idx_Omega, s)[0])
  diff =  arg1 - arg2

  # Update v_Tx
  if cont % 5 == 0:
    v_Tx = jit_vTx_calc(sys_params, REMS_params, z_tx_t1, idx_Omega, s)

  # Update counting
  cont = cont + 1

  if cont % 100 == 0:
    print(cont)
  """
  if cont == 1000:
    break
  """


print(z_tx_t1)
print(cont)
print(diff)

100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
[58.06242239 64.71658802 58.06233969]
2626
[9.99790366e-06]


In [72]:
# Run one z_tx = 50
# gradient ascent, lr = 0.1, diff_min e-6 (convergence is ff diff)
# ran for about half an hour, stopped runtime myself
# User at theta 90 phi 130
print(z_tx_t1)
print(cont)
print(diff)

[55.1840549  51.75328353 49.4989001 ]
31559
[2.19906498e-06]


In [87]:
# Run two z_tx = 50
# gradient w/ normalization, lr = 0.01, diff_min e-5, epsilon = e-8 (convergence is ff diff)
# ran fully, took 6 minutes
# User at theta 90 phi 130
print(z_tx_t1)
print(cont)
print(diff)
print(jit_ff_user_only(sys_params, REMS_params, z_tx_t1, v_Tx, idx_Omega, s))
print(jit_ff_user_only(sys_params, REMS_params, z_tx_init*jnp.ones(sys_params.N), v_Tx, idx_Omega, s))

[106.77800893  82.88925294  36.50540418]
6750
[9.98742642e-06]
(Array([4.42996662-5.92699032e-17j], dtype=complex128), Array([4.48507259], dtype=float64), Array([4.48507259], dtype=float64))
(Array([3.85916624-5.01942175e-17j], dtype=complex128), Array([4.4105986], dtype=float64), Array([4.4105986], dtype=float64))


In [91]:
# Run three z_tx = 50
# gradient w/ normalization, lr = 0.01, diff_min e-5, epsilon = e-8 (convergence is ff diff)
# ran fully, took 1 minute
# User at theta 90 phi 90
print(z_tx_t1)
print(cont)
print(diff)
print(jit_ff_user_only(sys_params, REMS_params, z_tx_t1, v_Tx, idx_Omega, s))
print(jit_ff_user_only(sys_params, REMS_params, z_tx_init*jnp.ones(sys_params.N), v_Tx, idx_Omega, s))

[54.03621578 62.26139974 54.05124506]
1353
[9.98008414e-06]
(Array([5.91222042+1.00949987e-16j], dtype=complex128), Array([6.47625787], dtype=float64), Array([6.47625787], dtype=float64))
(Array([5.85114147+5.26962363e-17j], dtype=complex128), Array([6.455929], dtype=float64), Array([6.455929], dtype=float64))


In [94]:
# Run four z_tx = 75
# gradient w/ normalization, lr = 0.01, diff_min e-5, epsilon = e-8 (convergence is ff diff)
# ran fully, took 1 minute
# User at theta 90 phi 90
print(z_tx_t1)
print(cont)
print(diff)
print(jit_ff_user_only(sys_params, REMS_params, z_tx_t1, v_Tx, idx_Omega, s))
print(jit_ff_user_only(sys_params, REMS_params, z_tx_init*jnp.ones(sys_params.N), v_Tx, idx_Omega, s))

[58.06242239 64.71658802 58.06233969]
2626
[9.99790366e-06]
(Array([5.91139802+5.3643704e-17j], dtype=complex128), Array([6.47736536], dtype=float64), Array([6.47736536], dtype=float64))
(Array([5.80509127+2.41182871e-18j], dtype=complex128), Array([6.4668544], dtype=float64), Array([6.4668544], dtype=float64))


## Trials - Getting the gradients to work
Checking that new function definition works, can be jit compiled and can be differentiated

*   Note 1: updated function so that it accepts z_tx array as an input, cant calculate the gradient of a class element, need to pass it as a separate input
*   Note 2: also passing vTx as we cannot compute the gradient of an eigenvector (but yes if i where using eigenvalues)

In [None]:
# Input PA impedances
z_tx_init = 25*jnp.ones(sys_params.N)

# v_Tx_init
v_Tx_init = v_Tx_calc(sys_params, REMS_params, z_tx_init, idx_Omega,s)

In [8]:
# Normal execution - WORKS
res_normal = far_field_at_user_only(sys_params, REMS_params,z_tx_init, v_Tx_init, idx_Omega, s)
print(res_normal)

(Array([3.96164179-3.66769966e-17j], dtype=complex128), Array([4.45140864], dtype=float64), Array([4.45140864], dtype=float64))


In [9]:
# Jit compilation - WORKS
jit_ff_user_only = jit(far_field_at_user_only, static_argnums=(0,4,5))
res_jit = jit_ff_user_only(sys_params, REMS_params,z_tx_init, v_Tx_init, idx_Omega, s)
print(res_jit)

(Array([3.96164179-3.66769966e-17j], dtype=complex128), Array([4.45140864], dtype=float64), Array([4.45140864], dtype=float64))


In [12]:
# Gradient - WORKS
z_tx_grad_ff_only = grad(far_field_at_user_only_tograd, 2)
z_tx_v_tx_grad_ff_only = grad(far_field_at_user_only_tograd, (2,3))

# compare results
res_grad = z_tx_v_tx_grad_ff_only(sys_params, REMS_params,z_tx_init, v_Tx_init, idx_Omega, s)
print(f"Grad w.r.t z_tx {res_grad[0]}")
print(f"Grad w.r.t v_tx {res_grad[1]}")
res_grad = z_tx_grad_ff_only(sys_params, REMS_params,z_tx_init, v_Tx_init, idx_Omega, s)
print(f"Grad w.r.t z_tx {res_grad}")

Grad w.r.t z_tx [ 0.01701424  0.00534366 -0.00145055]
Grad w.r.t v_tx [0.00000000e+00-3.57469344e-17j 1.11022302e-16+1.66533454e-16j
 1.11022302e-16-1.66533454e-16j]
Grad w.r.t z_tx [ 0.01701424  0.00534366 -0.00145055]


  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
