In [2]:
from __future__ import annotations

from dataclasses import dataclass
from logging import getLogger
from typing import cast, Optional, Union
from numbers import Number
import numpy

from grand import io, ECEF, LTP, CartesianRepresentation, SphericalRepresentation
from grand.simu import TabulatedAntennaModel

__all__ = ["Antenna", "AntennaModel", "ElectricField", "MissingFrameError", "Voltage"]


_logger = getLogger(__name__)


@dataclass
class ElectricField:
    t: Number
    E: CartesianRepresentation  # RK
    r: Union[CartesianRepresentation, None] = None
    frame: Union[ECEF, LTP, GRANDCS, None] = None

    @classmethod
    def load(cls, node: io.DataNode):
        _logger.debug(f"Loading E-field from {node.filename}:{node.path}")

        t = node.read("t", dtype="f8")
        E = node.read("E", dtype="f8")

        try:
            r = node.read("r", dtype="f8")
        except KeyError:
            r = None

        try:
            frame = node.read("frame")
        except KeyError:
            frame = None

        return cls(t, E, r, frame)

    def dump(self, node: io.DataNode):
        _logger.debug(f"Dumping E-field to {node.filename}:{node.path}")

        node.write("t", self.t, dtype="f4")
        node.write("E", self.E, dtype="f4")

        if self.r is not None:
            node.write("r", self.r, dtype="f4")

        if self.frame is not None:
            node.write("frame", self.frame)

    def dump(self, node: io.DataNode):
        _logger.debug(f"Dumping E-field to {node.filename}:{node.path}")

        node.write("t", self.t, unit="ns", dtype="f4")
        node.write("E", self.E, unit="uV/m", dtype="f4")

        if self.r is not None:
            node.write("r", self.r, unit="m", dtype="f4")

        if self.frame is not None:
            node.write("frame", self.frame)


@dataclass
class Voltage:
    t: "s"
    V: "uV"

    @classmethod
    def load(cls, node: io.DataNode):
        _logger.debug(f"Loading voltage from {node.filename}:{node.path}")
        t = node.read("t", dtype="f8")
        V = node.read("V", dtype="f8")
        return cls(t, V)

    def dump(self, node: io.DataNode):
        _logger.debug(f"Dumping E-field to {node.filename}:{node.path}")
        node.write("t", self.t, dtype="f4")
        node.write("V", self.V, dtype="f4")


class AntennaModel:
    def effective_length(
        self, xmax: LTP, Efield: ElectricField, frame: Union[ECEF, LTP, None] = None
    ) -> CartesianRepresentation:
        pass


class MissingFrameError(ValueError):
    pass


@dataclass
class Antenna:
    model: TabulatedAntennaModel
    frame: Union[ECEF, LTP, None] = None

    def effective_length(
        self, xmax: LTP, Efield: ElectricField, frame: Union[ECEF, LTP, None] = None
    ) -> CartesianRepresentation:
        # frame is shower frame. self.frame is antenna frame.

        if isinstance(xmax, LTP):
            direction = xmax.ltp_to_ltp(self.frame)  # shower frame --> antenna frame
        else:
            raise TypeError("Provide Xmax in LTP frame instead of %s" % type(xmax))

        direction_cart = CartesianRepresentation(direction)
        direction_sphr = SphericalRepresentation(direction_cart)
        theta, phi = direction_sphr.theta, direction_sphr.phi

        # Interpolate using a tri-linear interpolation in (f, phi, theta)
        table = self.model.table

        dtheta = table.theta[1] - table.theta[0]  # deg
        rt1 = (theta - table.theta[0]) / dtheta
        it0 = int(numpy.floor(rt1) % table.theta.size)
        it1 = it0 + 1
        if it1 == table.theta.size:  # Prevent overflow
            it1, rt1 = it0, 0
        else:
            rt1 -= numpy.floor(rt1)
        rt0 = 1 - rt1

        dphi = table.phi[1] - table.phi[0]  # deg
        rp1 = (phi - table.phi[0]) / dphi
        ip0 = int(numpy.floor(rp1) % table.phi.size)
        ip1 = ip0 + 1
        if ip1 == table.phi.size:  # Results are periodic along phi
            ip1 = 0
        rp1 -= numpy.floor(rp1)
        rp0 = 1 - rp1

        def fftfreq(n, t):
            dt = t[1] - t[0]
            return numpy.fft.fftfreq(n, dt)

        def interp(v):
            fp = (
                rp0 * rt0 * v[:, ip0, it0]
                + rp1 * rt0 * v[:, ip1, it0]
                + rp0 * rt1 * v[:, ip0, it1]
                + rp1 * rt1 * v[:, ip1, it1]
            )
            return numpy.interp(x, xp, fp, left=0, right=0)

        E = Efield.E
        Ex = numpy.fft.rfft(E.x)
        x = fftfreq(Ex.size, Efield.t)  # frequency [Hz]
        xp = table.frequency  # frequency [Hz]

        ltr = interp(table.leff_theta)  # LWP. m
        lta = interp(numpy.deg2rad(table.phase_theta))  # LWP. rad
        lpr = interp(table.leff_phi)  # LWP. m
        lpa = interp(numpy.deg2rad(table.phase_phi))  # LWP. rad

        # Pack the result as a Cartesian vector with complex values
        lt = ltr * numpy.exp(1j * lta)
        lp = lpr * numpy.exp(1j * lpa)

        from matplotlib import pyplot as plt

        plt.figure()
        plt.subplot(211)
        labs = print
        plt.plot(
            table.frequency / 1e6,
            table.leff_theta[:, ip1, it0],
            "--",
            label=r"Tabulated at $\theta$=%.1f" % table.theta[it0],
        )  # theta and phi close to interpolated theta and phi.
        plt.plot(
            table.frequency / 1e6,
            table.leff_theta[:, ip0, it1],
            "--",
            label=r"Tabulated at $\theta$=%.1f" % table.theta[it1],
        )  # theta and phi farthest from interpolated theta and phi.
        # plt.plot(table.frequency/1e6, table.leff_theta[:,ip0,it0],'--',label=r'Tabulated at $\theta$=%.2f, $\phi$=%.2f'%(table.theta[it0], table.phi[ip0])) #phi close to real phi
        # plt.plot(table.frequency/1e6, table.leff_theta[:,ip1,it1],'--',label=f'Tabulated at theta={table.theta[it1]}') #phi far from real phi
        plt.plot(
            x[x > 0] / 1e6, numpy.abs(lt[x > 0]), label=r"Interpolated at $\theta$=%.2f" % theta
        )
        plt.xlabel("")
        plt.ylabel("|Leff theta| (m)")
        plt.legend(loc="best")
        plt.grid(ls="--", alpha=0.3)
        plt.subplot(212)
        plt.plot(
            table.frequency / 1e6,
            table.leff_phi[:, ip0, it1],
            "--",
            label=r"Tabulated at $\phi$=%.1f" % table.phi[ip0],
        )
        plt.plot(
            table.frequency / 1e6,
            table.leff_phi[:, ip1, it0],
            "--",
            label=r"Tabulated at $\phi$=%.1f" % table.phi[ip1],
        )
        # plt.plot(table.frequency/1e6, table.leff_phi[:,ip0,it0],'--',label=r'Tabulated at $\phi$=%.1f'%table.phi[ip0])
        # plt.plot(table.frequency/1e6, table.leff_phi[:,ip1,it1],'--',label=r'Tabulated at $\phi$=%.1f'%table.phi[ip1])
        plt.plot(
            x[x > 0] / 1e6, numpy.abs(lp[x > 0]), label=r"Interpolated at $\phi$=%.2f" % (phi % 360)
        )
        plt.xlabel("Frequency (MHz)")
        plt.ylabel("|Leff phi| (m)")
        plt.legend(loc="best")
        plt.grid(ls="--", alpha=0.4)
        # plt.savefig('/Users/rameshkoirala/Documents/GRAND/grandlib/Plots/effective_length_interpolated.png')
        numpy.savetxt("lefft_new.txt", numpy.abs(lt[x > 0]))
        numpy.savetxt("leffp_new.txt", numpy.abs(lp[x > 0]))
        numpy.savetxt("f_new.txt", x[x > 0])

        t, p = numpy.deg2rad(theta), numpy.deg2rad(phi)
        ct, st = numpy.cos(t), numpy.sin(t)
        cp, sp = numpy.cos(p), numpy.sin(p)
        lx = lt * ct * cp - sp * lp
        ly = lt * ct * sp + cp * lp
        lz = -st * lt

        # Treating Leff as a vector (no change in magnitude) and transforming it to the shower frame from antenna frame.
        # antenna frame --> ECEF frame --> shower frame  (ToDo: there might be an easier way to do this.)
        Leff = CartesianRepresentation(x=lx, y=ly, z=lz)
        Leff = numpy.matmul(self.frame.basis.T, Leff)  # vector wrt ECEF frame. Antenna --> ECEF
        Leff = numpy.matmul(frame.basis, Leff)  # vector wrt shower frame. ECEF --> Shower

        plt.figure()
        plt.plot(x[x > 0] / 1e6, numpy.abs(lx)[x > 0], label="Xant=SN")
        plt.plot(x[x > 0] / 1e6, numpy.abs(ly)[x > 0], label="Yant=EW")
        plt.plot(x[x > 0] / 1e6, numpy.abs(lz)[x > 0], label="Zant=Up")
        plt.xlabel("Frequency (MHz)")
        plt.ylabel("|Leff| (m)")
        plt.legend(loc="best")
        plt.grid(ls="--", alpha=0.4)
        # plt.savefig('/Users/rameshkoirala/Documents/GRAND/grandlib/Plots/effective_length.png')

        return CartesianRepresentation(x=Leff.x, y=Leff.y, z=Leff.z)

    def compute_voltage(
        self, xmax: LTP, Efield: ElectricField, frame: Union[ECEF, LTP, None] = None
    ) -> Voltage:

        # Compute the voltage. input Leff and field are in shower frame.
        def rfft(q):
            return numpy.fft.rfft(q)

        def irfft(q):
            return numpy.fft.irfft(q)

        Leff = self.effective_length(xmax, Efield, frame)
        E = Efield.E  # E is CartesianRepresentation
        Ex = rfft(E.x)
        Ey = rfft(E.y)
        Ez = rfft(E.z)

        # Here we have to do an ugly patch for Leff values to be correct
        V = irfft(Ex * (Leff.x - Leff.x[0]) + Ey * (Leff.y - Leff.y[0]) + Ez * (Leff.z - Leff.z[0]))

        t = Efield.t
        t = t[: V.size]

        return Voltage(t=t, V=V)

ImportError: cannot import name 'TabulatedAntennaModel' from 'grand.simu' (/home/projet/grand_wk/dc1/grand/grand/simu/__init__.py)