# Prepare train and test sets

In [None]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import pickle
from sklearn.model_selection import train_test_split

## Data choice

In [None]:
fruit = 'orange'

In [None]:
cut_type = 'cut'

In [None]:
train_percentage = 0.7
max_size_dataset = 1200

## Segmented data reading

In [None]:
pickle_in = open('segmented_' + cut_type + 's_' + fruit + '.pickle', 'rb')
segmented_runs = pickle.load(pickle_in)

In [None]:
all_data_df = pd.DataFrame()
desired_headers = ['displacement', 'ee_force_x']

for file, data in segmented_runs.items():
    if file.split('_')[5] != '0.005000':
        continue
    data = data[data['phase'] == 1]
    data = data[desired_headers]
    all_data_df = pd.concat([all_data_df, data])

#print(all_data_df)
del file, data
total_samples = len(all_data_df.index)
print(total_samples)
plt.plot(all_data_df['displacement'], all_data_df['ee_force_x'], '.')
plt.show()

## Downsample and split dataset


In [None]:
X_train, X_test, y_train, y_test = train_test_split(all_data_df['displacement'].values, 
                                                    all_data_df['ee_force_x'].values,
                                                    test_size=int(max_size_dataset * (1 - train_percentage)),
                                                    train_size=int(max_size_dataset * train_percentage),
                                                    random_state=42)

plt.plot(X_train, y_train, '.')
plt.plot(X_test, y_test, '.')
plt.show()

## Export

In [None]:
if not os.path.isdir("splits"):
    os.mkdir("splits")

splits_dict = {'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test}
pickle_out = open('splits/splits_' + fruit + '_' + cut_type + 's.pickle', 'wb')
pickle.dump(splits_dict, pickle_out)
pickle_out.close()

splits_df = pd.DataFrame()
splits_df['X_train'] = X_train
splits_df['y_train'] = y_train
splits_df.to_csv("splits/train_splits_" + fruit + '_' + cut_type + "s.csv", index=False)

splits_df = pd.DataFrame()
splits_df['X_test'] = X_test
splits_df['y_test'] = y_test
splits_df.to_csv("splits/test_splits_" + fruit + '_' + cut_type + "s.csv", index=False)