In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator and then re-execute this cell.')
else:
    print(gpu_info)

In [None]:
from psutil import virtual_memory

ram = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM.'.format(ram))
if ram < 20:
    print('To enable a high-RAM runtime, select the Runtime > "Change runtime type", then select High-RAM in the Runtime shape dropdown '
          'and then re-execute this cell.')
else:
    print('You are using a high-RAM runtime.')

In [None]:
from tensorflow import config

physical_devices = config.list_physical_devices('GPU')
try:
    config.experimental.set_memory_growth(physical_devices[0], True)
except Exception as exception:
    print(exception)

In [None]:
from tensorflow import device
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import models
from tensorflow.keras import optimizers

import csv
import datetime
import gc
import h5py
import numpy as np
import os
import re

In [None]:
berlin_files = '/content/gdrive/My Drive/Licenta/Traffic4Cast/Berlin/files/training'
istanbul_files = '/content/gdrive/My Drive/Licenta/Traffic4Cast/Istanbul/files/training'
moscow_files = '/content/gdrive/My Drive/Licenta/Traffic4Cast/Moscow/files/training'

checkpoints = '/content/gdrive/My Drive/Licenta/Traffic4Cast/All/checkpoints/auto_encoder_12'
logs = '/content/gdrive/My Drive/Licenta/Traffic4Cast/All/logs/auto_encoder_12/training/logs.csv'

In [None]:
def get_file_names(berlin_files, istanbul_files, moscow_files):
    file_names = os.listdir(berlin_files) + os.listdir(istanbul_files) + os.listdir(moscow_files)
    np.random.shuffle(file_names)
    return file_names

In [None]:
def load_data(file_path):
    file = h5py.File(file_path, 'r')
    group_key = list(file.keys())[0]
    data = np.array(file[group_key][:], dtype=np.float32)
    file.close()
    data = np.take(data, np.arange(8), axis=-1)  # keep only the dynamic channels
    data = np.array(np.split(data, 12))  # 12 * (12 + 12) = 288
    np.random.shuffle(data)  # shuffle the 12 batches
    return data / 255.0

In [None]:
with device('gpu:0'):
    model = models.load_model(os.path.join(checkpoints, 'model_1.h5'))
    model.compile(optimizer=optimizers.Adam(learning_rate=0.0001), loss=losses.mean_squared_error)
model.summary()

In [None]:
file_names = get_file_names(berlin_files, istanbul_files, moscow_files)

log_file = open(logs, 'w', newline='')
log_writer = csv.writer(log_file)
log_writer.writerow(['epoch', 'file', 'loss'])
log_file.flush()

In [None]:
for epoch in range(2, 4):
    print('epoch:', epoch)
    for index, file_name in enumerate(file_names):
        print('file:', index)
        if 'berlin' in file_name:
            data = load_data(os.path.join(berlin_files, file_name))
        elif 'istanbul' in file_name:
            data = load_data(os.path.join(istanbul_files, file_name))
        else:
            data = load_data(os.path.join(moscow_files, file_name))
        losses = np.zeros(shape=(4,), dtype=np.float64)
        for batch in range(0, 12, 3):
            inputs = data[batch:batch + 3, :12]
            outputs = data[batch:batch + 3, 12:]
            with device('gpu:0'):
                history = model.fit(inputs, outputs, epochs=1, batch_size=1)
                losses[batch // 3] = history.history['loss'][0]
        log_writer.writerow([epoch, file_name, np.mean(losses, dtype=np.float64)])
        log_file.flush()
        gc.collect()
    model.save(os.path.join(checkpoints, 'model_{}.h5'.format(epoch)))

In [None]:
log_file.close()