In [None]:
class TikhPGSolver:
    def __init__(self, zexp_re, zexp_im, omg, lamT0, lampg0, fname, lb, ub, mode='real'):
        self.rpol = zexp_re[-1, :] - zexp_re[0, :]  # now rpol is an array with one value per spectrum
        self.zexp_re = zexp_re
        self.zexp_im = zexp_im
        self.zexp_re_norm = zexp_re / self.rpol
        self.zexp_im_norm = zexp_im / self.rpol
        self.omg = omg
        self.mode = mode
        self.lamT0 = lamT0
        self.lampg0 = lampg0
        self.niter = 80
        self.flagiter = 0
        self.fname = fname
        self.lb = lb
        self.ub = ub

        self.tau = 1 / self.omg[:, None]  # each row is the same for all spectra
        self.lntau = jnp.log(self.tau)
        self.dlntau = self.create_dmesh(self.lntau)
        self.dtau = self.create_dmesh(self.tau)
        self.Idm = jnp.identity(self.omg.size, dtype=jnp.integer)
        self.am = jnp.zeros((self.omg.size, self.omg.size), dtype=jnp.float64)
        self.CreateTikhMatrix()
        self.jacobian = jax.jacobian(self.Tikh_residual)
        self.gfun_init = self.Tikh_solver(self.lamT0, self.amTam, self.brs, self.Idm)
        if mode == 'real':
            self.fsuffix = 'zre'
        else:
            self.fsuffix = 'zim'

    def _am_row(self, omg):
        prod = omg * self.tau
        if self.mode == 'real':
            return self.dlntau / (1 + prod**2)
        else:
            return prod * self.dlntau / (1 + prod**2)
        
    def CreateTikhMatrix(self):
        # _am_row is already vectorized over the omg values
        self.am = jax.vmap(self._am_row)(self.omg)
        self.amT = self.am.transpose()
        self.amTam = jnp.matmul(self.amT, self.am)
        self.amTikh = self.amTam + self.lamT0 * self.Idm  # broadcasting should handle the lamT0 scalar

        if self.mode == 'real':
            self.brs = jnp.matmul(self.amT, self.zexp_re_norm.T)  # .T to ensure correct broadcasting
        else:
            self.brs = jnp.matmul(self.amT, self.zexp_im_norm.T)  # .T to ensure correct broadcasting

    def Tikh_solver(self, lamt, amTam, brs, Idm):
        # handle multiple spectra by solving for each column in brs
        amTikh = amTam + lamt * Idm
        sol, residuals, rank, sv = jax.vmap(jnp.linalg.lstsq, in_axes=(None, None, 1, None))(amTikh, brs, None)
        return sol

    def objective_fun(self, gtau, amTikh, brs):
        residuals = jnp.matmul(amTikh, gtau[:, None]) - brs  # expand_dims on gtau to handle multiple spectra
        return jnp.sum(residuals ** 2, axis=0)  # sum across frequencies, not spectra

    def pg_solver(self, lamvec, amTikh, amTam, brs, dlntau, Idm):
        lamT, lampg = lamvec
        gtau = self.Tikh_solver(lamT, amTam, brs, Idm)  # initial Gfun from Tikhonov solver
        amTikh = amTam + lamT * Idm  # new Tikhonov matrix

        pg = jaxopt.ProjectedGradient(fun=jax.jit(self.objective_fun),
                                      projection=jaxopt.projection.projection_non_negative, tol=1e-8,
                                      maxiter=self.niter * 1000, implicit_diff=True, jit=True)
        solution = pg.run(init_params=gtau, amTikh=amTikh, brs=brs)

        rpoly = jnp.sum(solution.params * dlntau, axis=0)  # sum across frequencies, not spectra

        return solution.params, rpoly, solution.state.iter_num
    # ... Rest of the class
def jacoby(self, pvec, amTikh, amTam, brs, dlntau, Idm):
        return (jax.jacobian(self.Tikh_residual, argnums=0)(jnp.array(pvec), amTikh, amTam, brs, dlntau, Idm))

    def Tikh_residual(self, lamvec, amTikh, amTam, brs, dlntau, Idm):
        gfvec, rp, kk = self.pg_solver(lamvec, amTikh, amTam, brs, dlntau, Idm)
        resid = jnp.matmul(self.amTikh, gfvec[:, None]) - self.brs  # expand_dims on gfvec to handle multiple spectra
        return resid

    def Tikh_residual_norm(self, gtau, lamT, amTam, brs, Idm):
        amTikh = amTam + lamT * Idm  
        work = jnp.matmul(amTikh, gtau[:, None])  # expand_dims on gtau to handle multiple spectra
        sumres = jnp.sqrt(jnp.sum((work - brs)**2, axis=0))  # sum over frequencies, not spectra
        sumlhs = jnp.sqrt(jnp.sum(work**2, axis=0))  # sum over frequencies, not spectra
        return sumres, sumlhs

    def residual_norm(self, gtau):
        work = jnp.matmul(self.amTam, gtau[:, None])  # expand_dims on gtau to handle multiple spectra
        normres = jnp.sqrt(jnp.sum((work - self.brs)**2, axis=0))  # sum over frequencies, not spectra
        return normres

    def Zmodel_imre(self, gtau):
        gtau_expanded = jnp.expand_dims(gtau, axis=0)  # to handle multiple spectra
        prod = self.omg[:, None] * self.tau  # [:, None] for broadcasting with 2D gtau
        if self.mode == 'real':
            integrand = gtau_expanded / (1 + prod ** 2)
        else:
            integrand = prod * gtau_expanded / (1 + prod ** 2)
        zmod = jnp.sum(self.dlntau * integrand, axis=1)  # sum over tau values
        return jnp.flip(self.rpol * zmod, axis=0)  # flip along the frequency axis

    def rpol_peaks(self, gtau):
        peaks, dummy = jax.vmap(find_peaks, in_axes=(0, None))(gtau, 0.01)  # find peaks for each spectrum
        width = jax.vmap(peak_widths, in_axes=(0, 0))(gtau, peaks, rel_height=1)

        integr = jnp.zeros((peaks.shape[1], gtau.shape[0]), dtype=jnp.float64)  # rows are peaks, columns are spectra
        for n in range(peaks.shape[1]):
            lb, ub = int(width[2][n]), int(width[3][n])
            integr = integr.at[n, :].set(jnp.sum(gtau[:, lb:ub] * self.dlntau[lb:ub], axis=1))

        pparms = jnp.zeros((2, peaks.shape[1], gtau.shape[0]), dtype=jnp.float64)
        pparms = pparms.at[0, :, :].set(jnp.flip(1 / (2 * jnp.pi * self.tau[peaks]), axis=0))   # peak frequencies
        pparms = pparms.at[1, :, :].set(jnp.flip(integr, axis=0))                   # peak polarization fractions
        return pparms

    def find_lambda(self):
        kmax, lam1 = 25, 1e-25
        solnorm = jnp.zeros((kmax, self.brs.shape[1]), dtype=jnp.float64)  # second dimension for multiple spectra
        resid = jnp.zeros((kmax, self.brs.shape[1]), dtype=jnp.float64)  # second dimension for multiple spectra
        lamT = jnp.zeros(kmax, dtype=jnp.float64)
        lampg = jnp.zeros(kmax, dtype=jnp.float64)
        for k in range(kmax):
            lam1 = lam1 * 10
            lamT = lamT.at[k].set(lam1)
            gfun = self.Tikh_solver(lam1, self.amTam, self.brs, self.Idm)
            resid = resid.at[k, :].set(self.residual_norm(gfun))
            solnorm = solnorm.at[k, :].set(jnp.sqrt(jnp.sum(gfun**2, axis=0)))  # sum over tau values
            lampg = lampg.at[k].set(1 / jnp.linalg.norm(self.amTikh, axis=0))  # norm over frequencies and tau values
        return resid, solnorm, lamT, lampg

    def driver(self, lsq):   # --- omega must be in descending order!
        myplots = Plotter(self.zexp_re, self.zexp_im, self.omg, self.mode)

        resid, solnorm, arrlamT, arrlampg = self.find_lambda()
        for i in range(self.brs.shape[1]):  # loop over spectra
            myplots.plotLambda(resid[:, i], solnorm[:, i], arrlamT,
                            self.fname + f'_lambda_T_spectrum_{i}','$\lambda_T^0$', 0)
            myplots.plotLambda(resid[:, i], solnorm[:, i], arrlampg,
                            self.fname + f'_lambda_PG_spectrum_{i}', '$\lambda_{pg}^0$', 0)

        for i in range(self.brs.shape[1]):  # loop over spectra
            myplots.plotNyq(self.zexp_re[:, i] - 1j * self.zexp_im[:, i], f'Initial spectrum_{i}')

        for i in range(self.gfun_init.shape[1]):  # loop over spectra
            myplots.plotgamma(self.gfun_init[:, i], self.fname + f'_init_spectrum_{i}', 0, 'Tikhonov gamma')
            zmod = self.Zmodel_imre(self.gfun_init[:, i])
            myplots.plotshow_Z(zmod, self.fname + f'_init_spectrum_{i}' + self.fsuffix, 0, 'Tikhonov solution')

        start = time.time()
        lamvecinit = jnp.array([self.lamT0, self.lampg0], dtype=jnp.float64)
        low, high = lamvecinit / 10, lamvecinit * 10

        if lsq == 'trf':
            resparm = least_squares(jax.jit(self.Tikh_residual), lamvecinit,
                                    method='lm', jac=self.jacoby, args=(self.amTikh, self.amTam, self.brs, self.dlntau, self.Idm))

        else:
            resparm = least_squares(jax.jit(self.Tikh_residual), lamvecinit, bounds=(low, high),
                                    method='trf', jac=self.jacoby, args=(self.amTikh, self.amTam, self.brs, self.dlntau, self.Idm))

        res = resparm.x

        print(f"resparm.x = {res}")

        gfun, rpoly, nit = self.pg_solver(res, self.amTikh, self.amTam, self.brs, self.dlntau, self.Idm)
        # gamres = 2 * np.pi * self.fHz * gfun
        end = time.time()
        print('Projected gradient iterations =', nit, ', rpol =', rpoly,
            ', Rpol = ', self.rpol)
        print('lamTfit, lampgfit =', res, ', elapsed: ', (end - start), ' sec')

        resinit, lhsinit = self.Tikh_residual_norm(self.gfun_init, self.lamT0, self.amTam, self.brs, self.Idm)
        resfin , lhsfin  = self.Tikh_residual_norm(gfun, res[0], self.amTam, self.brs, self.Idm)
        print('Tikhonov residual: initial, final = ', resinit, resfin)
        print('Tikhonov lhs norm: initial, final =', lhsinit, lhsfin)
        if resparm.status > 0 :
            print('Number of Jacobian evaluations =', resparm.njev, ', status = OK')
        if self.flagiter == 1:
            print('Warning, limiting number of iterations is achieved')

        # myplots.plotgamma(gamres, self.fname + 'gamma', 1, '')
        for i in range(gfun.shape[1]):  # loop over spectra
            myplots.plotGfun(gfun[:, i], self.fname + f'Gfun_spectrum_{i}', 1, 'Final G-function')
            zmod = self.Zmodel_imre(gfun[:, i])
            myplots.plotshow_Z(zmod, self.fname + f'spectrum_{i}' + self.fsuffix, 1, '')

        peakparms = self.rpol_peaks(gfun)

        for i in range(peakparms.shape[1]):  # loop over spectra
            print(f'For spectrum_{i}:')
            print('Peak frequencies (beta):   ', ''.join(['{:.5f}  '.format(item) for item in peakparms[0, i, :]]))
            print('Peak polarizations (beta): ', ''.join(['{:.5f}  '.format(item) for item in peakparms[1, i, :]]))


