**Open this notebook from google drive**<br>
**Go to "Edit" -> "Notebook settings" and enable GPU.**


In [None]:
# Check if NVIDIA GPU is enabled
!nvidia-smi

**Connect and authorize google drive with google colab:**

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
!ls

**Open our project "Galaxy Classifier" direct0ry in google drive:**

In [None]:
# %cd /content/gdrive/My Drive/
%cd /content/gdrive/My Drive/Colab Notebooks/galaxy_classifier/
!ls

**Data for our "Galaxy Classifier" directory in google drive:**

In [None]:
%cd /content/gdrive/My Drive/data/galaxy_data/
!ls

**Install all required libraries for our project:**

In [None]:
# !pip install -r ./requirements.txt

In [None]:
import os, random, shutil
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline  

import tensorflow as tf
print(tf.__version__)
tf.test.gpu_device_name()

**Test if TensorFlow works with gpu for you, in output should see similar results:**
```
2.2.0
'/device:GPU:0'
```

In [None]:
current_dir = os.getcwd()
print(current_dir)

data_path    = '/content/gdrive/My Drive/data/galaxy_data'
training_dir = os.path.join(data_path, 'training')
valid_dir    = os.path.join(data_path, 'validation')

**Finish training -> plot graphs**

In [None]:
import pickle
with open(data_path + '/train_hist_dict.pkl', 'rb') as f:
    history = pickle.load(f)

In [None]:
acc = history['accuracy']
val_acc = history['val_accuracy']

loss = history['loss']
val_loss = history['val_loss']

epochs_range = range(len(acc))  # range for the number of epochs
print(epochs_range)

plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.savefig(data_path + '/plots.png')
plt.show()