In [None]:
using Comrade
using Pyehtim

In [None]:
using Random
rng = Random.default_rng()

In [None]:
obs = Pyehtim.load_uvfits_and_array(joinpath(dirname(pathof(Comrade)), "..", "examples", "PolarizedExamples/polarized_gaussian_nogains_withdterms_withfr.uvfits"),
                        joinpath(dirname(pathof(Comrade)), "..", "examples", "PolarizedExamples/array.txt"))

In [None]:
obs = scan_average(obs)

In [None]:
dvis = extract_table(obs, Coherencies())

In [None]:
# Construct the sky model
function sky(θ, metadata)
    (;c, f, p, angparams) = θ
    (;K, grid, cache) = metadata
    # Construct the image model
    # produce Stokes images from parameters
    imgI = K(f*c)
    # Converts from poincare sphere parameterization of polzarization to Stokes Parameters
    pimg = PoincareSphere2Map(imgI, p, angparams, grid)
    m = ContinuousImage(pimg, cache)
    return m
end

In [None]:
function instrument(θ, metadata)
    (;dRx, dRy, dLx, dLy, lgp, gpp, lgr, gpr) = θ 
    (;tcache, scancache, phasecache, trackcache, trackcache_ratio) = metadata
    
    # Now construct the basis transformation cache
    jT = jonesT(tcache)

    # Gain product parameters
    gPa = exp.(lgp/2 .+ 0im)
    gPp = exp.(1im.*gpp/2)
    Gpa = jonesG(gPa, gPa, scancache)
    Gpp = jonesG(gPp, gPp, phasecache)
    # Gain ratio
    gRa = exp.(lgr/2)
    gRp = exp.(1im.*gpr/2)
    Gra = jonesG(gRa, inv.(gRa), trackcache)
    Grp = jonesG(gRp, conj.(gRp), trackcache_ratio)
    ##D-terms
    D = jonesD(complex.(dRx, dRy), complex.(dLx, dLy), trackcache)
    # sandwich all the jones matrices together
    J = Gpa*Gpp*Gra*Grp*D*jT
    # form the complete Jones or RIME model. We use tcache here
    # to set the reference basis of the model.
    return CorruptionModel(J, tcache)

end

In [None]:
fovx = μas2rad(50.0)
fovy = μas2rad(50.0)
nx = 5
ny = floor(Int, fovy/fovx*nx)
grid = imagepixels(fovx, fovy, nx, ny) # image grid
buffer = IntensityMap(zeros(nx, ny), grid) # buffer to store temporary image
pulse = BSplinePulse{3}() # pulse we will be using
cache = create_cache(NFFTAlg(dvis), buffer, pulse); # cache to define the NFFT transform

In [None]:
using VLBIImagePriors # Load some special VLBI priors

In [None]:
K = K = CenterImage(grid)
skymeta = (;K, cache, grid);

In [None]:
tcache = TransformCache(dvis; add_fr=true, ehtim_fr_convention=false);

In [None]:
scancache = jonescache(dvis, ScanSeg());

In [None]:
phase_segs = station_tuple(dvis, ScanSeg(); AA=FixedSeg(1.0 + 0.0im))
phasecache = jonescache(dvis, phase_segs);

In [None]:
trackcache = jonescache(dvis, TrackSeg());
trackcache_ratio = jonescache(dvis, station_tuple(dvis, TrackSeg(); AA=FixedSeg(1 + 0.0im)));

instrumentmeta = (;tcache, scancache, phasecache, trackcache, trackcache_ratio);

In [None]:
using Distributions
using DistributionsAD
st = stations(dvis)
distamp = station_tuple(st, Normal(0.0, 0.1))

In [None]:
using VLBIImagePriors
distphase = station_tuple(st, DiagonalVonMises(0.0,inv(π^2)); reference=:AA)

In [None]:
distphase_ratio = station_tuple(st, DiagonalVonMises(0, inv(0.1^1)); reference=:AA)

In [None]:
distD = station_tuple(dvis, Normal(0.0, 0.1))

In [None]:
prior = (
          c = ImageDirichlet(1.0, nx, ny),
          f = Uniform(0.7, 1.2),
          p = ImageUniform(nx, ny),
          angparams = ImageSphericalUniform(nx, ny),
          dRx = CalPrior(distD, trackcache),
          dRy = CalPrior(distD, trackcache),
          dLx = CalPrior(distD, trackcache),
          dLy = CalPrior(distD, trackcache),
          lgp = CalPrior(distamp, scancache),
          gpp = CalPrior(distphase, phasecache),
          lgr = CalPrior(distamp, trackcache),
          gpr = CalPrior(distphase_ratio, trackcache_ratio),
          )

In [None]:
lklhd = RadioLikelihood(sky, instrument, dvis; skymeta, instrumentmeta)
post = Posterior(lklhd, prior)

In [None]:
tpost = asflat(post)

In [None]:
ndim = dimension(tpost)

In [None]:
logdensityof(tpost, randn(ndim))

In [None]:
using ComradeOptimization
using OptimizationOptimJL
using Zygote
f = OptimizationFunction(tpost, Optimization.AutoZygote())
ℓ = logdensityof(tpost)
prob = Optimization.OptimizationProblem(f, prior_sample(rng, tpost), nothing)
sol = solve(prob, LBFGS(), maxiters=15_000, callback=((x,p)->(@info ℓ(x);false)), g_tol=1e-1)

In [None]:
xopt = transform(tpost, sol)

In [None]:
using Plots
residual(vlbimodel(post, xopt), dvis)

In [None]:
using AxisKeys
imgtrue = Comrade.load(joinpath(dirname(pathof(Comrade)), "..", "examples", "PolarizedExamples/polarized_gaussian.fits"), StokesIntensityMap)
imgtruesub = imgtrue(Interval(-fovx/2, fovx/2), Interval(-fovy/2, fovy/2))
plot(imgtruesub, title="True Image", xlims=(-25.0,25.0), ylims=(-25.0,25.0))

In [None]:
img = intensitymap!(copy(imgtruesub), vlbimodel(post, xopt))
plot(img, title="Reconstructed Image", xlims=(-25.0,25.0), ylims=(-25.0,25.0))

In [None]:
using Comrade.ComradeBase: linearpol
ftrue = flux(imgtruesub);
@info "Linear polarization true image: $(abs(linearpol(ftrue))/ftrue.I)"
frecon = flux(img);
@info "Linear polarization recon image: $(abs(linearpol(frecon))/frecon.I)"


In [None]:
@info "Circular polarization true image: $(ftrue.V/ftrue.I)"
@info "Circular polarization recon image: $(frecon.V/frecon.I)"


In [None]:
dR = caltable(trackcache, complex.(xopt.dRx, xopt.dRy))


In [None]:
dL = caltable(trackcache, complex.(xopt.dLx, xopt.dLy))
