In [None]:
print("Running")

In [None]:
import gym
import glob
import torch
import matplotlib.pyplot as plt

In [None]:
deg2rad = lambda x: x * 3.1415926536 / 180.
rad2deg = lambda x: x * 180. / 3.1415926536

In [None]:
env = gym.make('gym_swirl:swirl-v1')
env.seed(42)

In [None]:
def getOR(states, start=1):
    return [state.O_R.mean() for state in states[start:]]

def getDelta(states, start=1):
    return [rad2deg(state.Deltas.mean()) for state in states[start:]]

def getTimes(states, start=1):
    return [state.T for state in states[start:]]

def getEndOR(paths, stride=1):
    finalORs = []
    epochs = []
    for path in paths[::stride]:
        env.load(path)
        ORs = getOR(env.states, start=len(env.states)*3//4)
        meanOR = torch.tensor(ORs).mean().abs()
        finalORs.append(meanOR)
        epochs.append(int(path.split("epoch").pop().split("_2021")[0]))
    
    return epochs, finalORs

def getEndDeltas(paths, stride=1):
    finalDeltas = []
    epochs = []
    for path in paths[::stride]:
        env.load(path)
        Deltas = getDelta(env.states, start=len(env.states)*3//4)
        meanDeltas = torch.tensor(Deltas).mean()
        finalDeltas.append(meanDeltas)
        epochs.append(int(path.split("epoch").pop().split("_2021")[0]))
    
    return epochs, finalDeltas

In [None]:
attempt = 12#
type_rl = "acdd"#"mfmaac"

In [None]:
paths = sorted(glob.glob(f"runs/{type_rl}{attempt:03d}_*"), key=lambda x: int(x.split("epoch").pop().split("_2021")[0]))[-50:]
len(paths)

In [None]:
end_deltas = getEndDeltas(paths, stride=1)

In [None]:
plt.rcParams.update({'font.size': 22})
fig = plt.figure(figsize=(20,10))
optimum_line, = plt.plot([-5, end_deltas[0][-1]+5], [76]*2, "--", c="red", label='$\Delta_{opt}$ = 76°')
learned_line, = plt.plot(*end_deltas, c="black", label='Avg. final $\Delta$')
plt.xlabel("Epoch #")
plt.ylabel("Avg. final $\Delta$")
plt.xlim([-5, end_deltas[0][-1]+5])
plt.legend(handles=[optimum_line, learned_line])
#fig.savefig('avg_last_or_ac.pdf', bbox_inches='tight', pad_inches=0)

In [None]:
path = paths[-1]

env.load(path)
times = getTimes(env.states)
plt.plot(times, getDelta(env.states))
plt.plot(times, getOR(env.states))

In [None]:
plt.plot(*getEndOR(paths, stride=5))