# Data exploration

In this notebook we conduct an exploration of the dataset created in the module: `./comm_agents/data/data_generator.py`.

### Import stuff

In [None]:
import os
os.chdir('..')
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from comm_agents.data.reference_experiments import RefExperimentMass, RefExperimentCharge
from comm_agents.data.data_handler import RefExpDataset
from comm_agents.models.model_single_enc import SingleEncModel
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from ipywidgets import interact

### Read data

In [None]:
DATA_PATH = './data/training/large_chunk_1.csv'
df = pd.read_csv(DATA_PATH)
# we device the angles py pi, so that we can read them
df.loc[:, ['alpha_star0', 'alpha_star1', 'phi_star0', 'phi_star1']] = \
    df[['alpha_star0', 'alpha_star1', 'phi_star0', 'phi_star1']] / np.pi
df['q0_t_q1'] = df.q0 * df.q1
df.head(9)

In [None]:
df.info()

## Mass experiment - observations

In [None]:
t = np.linspace(0, 10, 10)
def get_ref_ex_a_obs(particle):
    p = particle
    t = np.linspace(0, 10, 10)
    fig = go.Figure()
    num_samples = 100
    for i in np.random.choice(range(len(df)), num_samples):
        m = df[f'm{p}'][i:i+1].values[0]
        v_ref = df['v_ref_a'][i:i+1].values[0]
        y = df[[c for c in df.columns if f'o_a_{p}' in c]][i:i+1].values[0]
        trace = go.Scatter(x=t, y=y, mode='lines+markers', opacity=.1,
                           hovertemplate = f'm0={m*1e20:.2f} e-20<extra></extra>',
                           showlegend = False)
        fig.add_trace(trace)
        fig.update_layout(title=f'{num_samples} randomly selected observations of experiment A particle {p}',
                             xaxis_title='time',
                             yaxis_title=f'Particle {p} position in x direction',)
    fig.show()
interact(get_ref_ex_a_obs, particle=[0, 1])

## Charge experiment - observations

In [None]:
def get_ref_ex_b_obs(particle):
    p = particle
    t = np.linspace(0, 10, 10)
    fig = go.Figure()
    num_samples = 100
    for i in np.random.choice(range(len(df)), num_samples):
        m = df[f'm{p}'][i:i+1].values[0]
        q0_t_q1 = df[f'q{p}'][i:i+1].values[0] * -1e-17
        y = df[[c for c in df.columns if f'o_b_{p}' in c]][i:i+1].values[0]
        trace = go.Scatter(x=t, y=y, mode='lines+markers', opacity=.1,
                           hovertemplate = f'm0={m*10**20:.2f} e-20;,'
                            f' q0_t_q1={q0_t_q1*1e32:.2f} e-32;'
                            f' m / q0_t_q1={m/q0_t_q1/1e12:.2f} 1e-12<extra></extra>',
                           showlegend = False)
        fig.add_trace(trace)
        fig.update_layout(title=f'{num_samples} randomly selected observations of experiment B particle {p}',
                             xaxis_title='time',
                             yaxis_title=f'Particle {p} position in x direction',)
    fig.show()
interact(get_ref_ex_b_obs, particle=[0, 1])

## Mass experiments - optimal answers
### Scatter plots for optimal answers and influence factors

In [None]:
# take subsample for plotting
df = df.sample(frac=.1).reset_index()

In [None]:
fig = px.scatter_matrix(df[['m0', 'v_ref_a']], color=df.alpha_star0, opacity=1,
                       title='Pairsplot for reference experiment A particle 0, color: alpha_star0')
fig.show()

In [None]:
fig = px.scatter_3d(df, x=df.m0, y=df.v_ref_a, z=df.alpha_star0, color=df.alpha_star0,
                    title='Alpha_star over v_ref_a and m0')
fig.show()

## Charge experiments - optimal answers
### Scatter plots for optimal answers and influence factors

In [None]:
fig = px.scatter_matrix(df[['m0', 'v_ref_b', 'q0_t_q1']],
                        color=df.phi_star1, opacity=.5,
                        title='Pairsplot for reference experiment B particle 0, color: phi_star0')

fig.show()

In [None]:
fig = px.scatter_3d(df, x=df.v_ref_b, y=df.q0_t_q1, z=df.phi_star0, color=df.phi_star0,
                   title='Phi_star over v_ref_b and q0 * q1')
fig.show()

### Feasibility check for optimal answers

In [None]:
DATA_PATH = './data/training/large_chunk_1.csv'
df_check = pd.read_csv(DATA_PATH)
# df_check.loc[:, ['alpha_star0', 'alpha_star1', 'phi_star0', 'phi_star1']] = \
#     df_check[['alpha_star0', 'alpha_star1', 'phi_star0', 'phi_star1']] / np.pi
df_check.head()

In [None]:
i = 0
def get_expample_exp(next):
    global i
    i+=1
    m0 = df_check.m0[i]
    m1 = df_check.m1[i]
    q0 = df_check.q0[i]
    q1 = df_check.q1[i]
    v_ref_c = df_check.v_ref_b[i]
    phi0 = df_check.phi_star0[i]
    phi1 = df_check.phi_star1[i]
    req = RefExperimentCharge( m=[m0, m1], q=[q0, q1], m_ref_c=2e-20,
                     v_ref_c=v_ref_c, q_ref=[None, None], d=.1, N=1000,
                              phi=[phi0, phi1],
                     dt=.001, is_golf_game=True, y_cap=True)
    req.run()
    req.visualize(golf_hole_loc=0.1, tolerance=.01)


    v_ref_m = df_check.v_ref_a[i]
    alpha0 = df_check.alpha_star0[i]
    alpha1 = df_check.alpha_star1[i]
    rem = RefExperimentMass(m=[m0, m1], m_ref_m=2e-20, v_ref_m=v_ref_m, N=1000,
                     alpha=[alpha0, alpha1], dt=.001, gravity=True)
    rem.angle = np.array([alpha0, alpha1])
    rem.run()
    rem.visualize(golf_hole_loc=0.1, tolerance=.01)
interact(get_expample_exp, next=False)

### Oversampling

In [None]:
# import pytorch data set and create instance
from comm_agents.data.data_handler import RefExpDataset
ds = RefExpDataset(oversample=False)

In [None]:
df_oversample = ds.oversample(df, ['alpha_star0', 'phi_star0'],
                                 [(0, .25), (.5, .75)], 10, frac=.5)
df_oversample = ds.oversample(df_oversample, ['alpha_star1', 'phi_star1'],
                                 [(0, .25), (.5, .75)], 10, frac=.5)

In [None]:
l_drop = ['q0', 'q1', 'm0', 'm1', 'v_ref_a', 'v_ref_b', 'q0_t_q1','alpha_star0', 'alpha_star1',
          'phi_star0', 'phi_star1']
def get_2d_hist(x, y, oversammple):
    df_plt = df_oversample if oversammple else df
    fig = px.density_heatmap(df_plt, x=x, y=y,
                             marginal_x="histogram",
                             marginal_y="histogram")
    fig.show()
interact(get_2d_hist, x=l_drop, y=l_drop, oversammple=False)