In [None]:
using PyPlot
using Colors
using LaTeXStrings
using JLD
import Dates

using Random
rng = MersenneTwister(1234)
import LinearAlgebra as LA

In [None]:
function altGDA_trapezoidal(P, Q, R, eta, T; xinit=nothing, yinit=nothing, rng=MersenneTwister(1234))
    if isnothing(xinit)
        xinit = randn(rng, dimx)
    end
    if isnothing(yinit)
        yinit = randn(rng, dimy)
    end
    x = xinit
    y = yinit
    copies_x = Array{Float64}(undef, dimx, T+1)
    copies_y = Array{Float64}(undef, dimy, T+1)
    for t=1:T
        copies_x[:,t] = x
        copies_y[:,t] = y

        x = x .- eta .* Q * x .- eta .* P * y
        y = y .- eta .* R * y .+ eta .* P' * x
    end
    copies_x[:,T+1] = x
    copies_y[:,T+1] = y
    for t=1:T
        copies_y[:,t] = (copies_y[:,t] + copies_y[:,t+1])/2
    end
    return copies_x, copies_y
end


### Illustrative setting in dimx=dimy=2

In [None]:
Random.seed!(rng, 1243) # tweak until I find a good illustrative one...

dimx, dimy = 2, 2
rankx, ranky = 1, 1
alph = 0.4
# alph = 0
Q = [1 0
     0 0]
R = [0 0
     0 0]
P = randn(rng, dimx, dimy)
P = round.(P,digits=2)

In [None]:
T0 = 10000       # must tweak this!
eta = 1e-2       # and also this!
T = Int(T0/eta)

T_min = Int(T/2)
# T_min = 1
T_max = T
# plottrajevery = max(1, Int(floor((T_max-T_min)/5000)))
plottrajevery=1

for (i, alph) in enumerate([0.01, 0.04, 0.1, 0.4])
    copies_x, copies_y = altGDA_trapezoidal(P, alph*Q, R, eta, T; xinit=[1, 1], yinit=[1, 1])

    copies_z = vcat(copies_x, copies_y)
    copies_x1 = copies_x[1,:]
    copies_x2 = copies_x[2,:]
    copies_y1 = copies_y[1,:]
    copies_y2 = copies_y[2,:]

    cm = get_cmap(:tab20)
    colorrange = (0:19) ./ 20
    figure(figsize=[3,3])
    plot(copies_x1[T_min:plottrajevery:T_max], copies_y1[T_min:plottrajevery:T_max], lw=1, 
        color=cm(colorrange[1]),
        label=(i==1 ? L"$[x_1^k, y_1^k]$" : ""), zorder=10)
    plot(copies_x2[T_min:plottrajevery:T_max], copies_y2[T_min:plottrajevery:T_max], lw=1, 
        color=cm(colorrange[3]),
        label=(i==1 ? L"$[x_2^k, y_2^k]$" : ""))
    plot([copies_x1[T_min]], [copies_y1[T_min]], markersize=5,
        color=cm(colorrange[1]),
        "o")
    plot([copies_x2[T_min]], [copies_y2[T_min]], markersize=5,
        color=cm(colorrange[3]),
        "o")

    # plot([copies_x2[T_max]], [copies_y2[T_max]], markersize=8,
    #     color=cm(colorrange[3]),
    #     "s")
    # plot([copies_x1[T_max]], [copies_y1[T_max]], markersize=8,
    #     color=cm(colorrange[1]),
    #     "s")

    props = Dict("alpha"=>0.5, "facecolor"=>"white")
    text(0.025, 0.095, L"\alpha=%$(alph)", transform=gca().transAxes, 
        fontsize=12,
        verticalalignment="top", bbox=props)
    
    if i==1 
        legend(fontsize=12)
    end
    grid("on")
    xticks([0],[])
    yticks([0],[])
    # savefig("trajectory__alpha$(alph)T$(T0)eta$(eta)__Tmin$(T_min)max$(T_max)every$(plottrajevery).png", bbox_inches="tight", dpi=200)
    savefig("trajectory__alpha$(alph)T$(T0)eta$(eta).png", bbox_inches="tight", dpi=200)
end