## Imports

In [1]:
import numpy as np
import tensorflow as tf
import pandas as pd
from tensorflow.python.keras.callbacks import EarlyStopping

# training.py

Contains code from relevant file

In [2]:
# Define the neural network architecture
model = tf.keras.Sequential([
    tf.keras.layers.Reshape((8, 8, 13), input_shape=(832,)),
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, activation='linear')
])

# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)  # Try different learning rates
model.compile(optimizer=optimizer, loss='MeanSquaredError', metrics=['MeanAbsoluteError'])

# Data Loader

So we can use big dataset

In [2]:
class DataSequence(tf.keras.utils.Sequence):
    
    def __init__(self, files, batch_size):
        self.files = files
        self.batch_size = batch_size
        self.file_lengths = [len(np.load(f)) for f in files]
        self.total_length = sum(self.file_lengths)
        self.cumulative_lengths = np.cumsum([0] + self.file_lengths)
        self.file_cache = {}

    def __len__(self):
        return int(np.ceil(self.total_length) / self.batch_size)
    
    def __getitem__(self, idx):
        global_start = idx * self.batch_size
        global_end = global_start + self.batch_size

        for i, (start, end) in enumerate(zip(self.cumulative_lengths[:-1], self.cumulative_lengths[1:])):
            if global_start >= start and global_end <= end:
                local_start = global_start - start
                local_end = local_start + self.batch_size

                # Load the data from the cache if available, otherwise load from disk
                if i not in self.file_cache:
                    self.file_cache[i] = np.load(self.files[i])

                data = self.file_cache[i][local_start:local_end]
                break

        df = pd.DataFrame(data).sample(frac=1)
        x = df.iloc[:, :-1].to_numpy().astype(np.int8)
        y = df.iloc[:, -1].to_numpy()

        return (x, y)

### Fitting Model

Need to rerun the processedDataset, such that the evaluations are correct.

In [3]:
files = ['../data/createdData/npyFiles/preprocessedChunk' + str(num) +'.npy' for num in np.arange(12)]

# Make the DataSequence object to pass data to model
data = DataSequence(files, batch_size = 64)

# Get validation data
validation = pd.DataFrame(np.load("../data/createdData/npyFiles/preprocessedChunk12.npy")).sample(frac=0.1)

X_validate = validation.iloc[:, :-1].to_numpy().astype(np.int8)
y_validate = validation.iloc[:, -1].to_numpy()

del(validation)

In [5]:
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

history = model.fit(data, epochs=3, validation_data=(X_validate, y_validate), callbacks=[early_stopping])

Epoch 1/3
Epoch 2/3
Epoch 3/3


In [7]:
model.save("../saved_models/corrected_12mil_3epoch_64batch_0.0001learnRate")



INFO:tensorflow:Assets written to: ../saved_models/corrected_12mil_3epoch_64batch_0.0001learnRate\assets


INFO:tensorflow:Assets written to: ../saved_models/corrected_12mil_3epoch_64batch_0.0001learnRate\assets


# Continue Training

For overnight training 5/3/23

In [4]:
model = tf.keras.models.load_model("../saved_models/corrected_12mil_3epoch_64batch_0.0001learnRate")

early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

history = model.fit(data, epochs=22, validation_data=(X_validate, y_validate), callbacks=[early_stopping])

model.save("../saved_models/corrected_12mil_25epoch_64batch_0.0001learnRate")

Epoch 1/22
    13/187500 [..............................] - ETA: 53:45:46 - loss: 159230.3594 - mean_absolute_error: 254.9128