In [1]:
import os
os.chdir('..')
%load_ext autoreload
%autoreload 2

In [2]:
from comm_agents.data.reference_experiments import RefExperimentMass, RefExperimentCharge
from comm_agents.data.optimal_answers import get_alpha_star, get_phi_star
import numpy as np

In [3]:
# initial experiment parameters
GOLF_HOLE_LOC_M = .1
GOLF_HOLE_LOC_C = .1
TOLERANCE = .1
PARAM_DICT = dict(
    m=[1e-20, 1e-20],
    q=[1e-16, -1e-15],
    m_ref_m=2e-20,
    v_ref_m=2,
    m_ref_c=2e-20,
    v_ref_c=1,
    N=10,
    alpha=[0, 0],
    phi=[0, 0],
    dt=.01,
    d=.1,
    gravity=True)
M_RANGES = [1e-20, 5e-20]
Q0_RANGE = [1e-16, 2e-16]
Q1_RANGE = [-1e-15, -2e-15]
SAMPLE_SIZE = 10

In [4]:
# ranges for experiments
np.random.seed(124)
m0_ls = np.random.uniform(*M_RANGES, SAMPLE_SIZE)
m1_ls = np.random.uniform(*M_RANGES, SAMPLE_SIZE)
q0_ls = np.random.uniform(*Q0_RANGE, SAMPLE_SIZE)
q1_ls = np.random.uniform(*Q1_RANGE, SAMPLE_SIZE)
m_ls = list(zip(m0_ls, m1_ls))
q_ls = list(zip(q0_ls, q1_ls))
param_dict = PARAM_DICT

In [28]:
param_dict.update(m=m_ls[2], q=q_ls[2])
def get_observations():
    rem_obs = RefExperimentMass(**param_dict)
    rem_obs.run()
    return rem_obs.x_series[:, 0], rem_obs.x_series[:, 1]

x0, x1 = get_observations()
x0

array([0.        , 0.01512501, 0.03025001, 0.04537502, 0.06050002,
       0.07562503, 0.09075004, 0.10587504, 0.12100005, 0.13612505])

In [29]:
def get_optimal_alpha():
    rem_opt = RefExperimentMass(**param_dict)
    rem_opt.N = 100
    rem_opt.dt = .001
    alpha_star, loss = get_alpha_star(rem_opt)
    rem_opt.set_initial_state()
    rem_opt.angle = alpha_star
    rem_opt.run()
    return alpha_star, loss, rem_opt.check_for_hole_in_one(), rem_opt

In [30]:
alpha_star, loss, success, rem_opt = get_optimal_alpha()

In [27]:
rem_opt.visualize(GOLF_HOLE_LOC_M)

In [31]:
rem_opt.visualize(GOLF_HOLE_LOC_M)

In [46]:
PARAM_DICT

{'m': (1.4242596238189308e-20, 3.396376655239924e-20),
 'q': (1.6226334141694085e-16, -1.2309561203855734e-15),
 'm_ref_m': 2e-20,
 'v_ref_m': 2,
 'm_ref_c': 2e-20,
 'v_ref_c': 1,
 'N': 100,
 'alpha': [0, 0],
 'phi': [0, 0],
 'dt': 0.01,
 'd': 0.1,
 'gravity': True}