In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator and then re-execute this cell.')
else:
    print(gpu_info)

In [None]:
from psutil import virtual_memory

ram = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM.'.format(ram))
if ram < 20:
    print('To enable a high-RAM runtime, select the Runtime > "Change runtime type", then select High-RAM in the Runtime shape dropdown '
          'and then re-execute this cell.')
else:
    print('You are using a high-RAM runtime.')

In [None]:
from tensorflow import config

physical_devices = config.list_physical_devices('GPU')
try:
    config.experimental.set_memory_growth(physical_devices[0], True)
except Exception as exception:
    print(exception)

In [None]:
from tensorflow import device
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import models
from tensorflow.keras import optimizers

import csv
import datetime
import gc
import h5py
import numpy as np
import os
import re

In [None]:
town = 'Berlin'  #@param ['Berlin', 'Istanbul', 'Moscow']

files = '/content/gdrive/My Drive/Licenta/Traffic4Cast/{}/files/training'.format(town)
checkpoints = '/content/gdrive/My Drive/Licenta/Traffic4Cast/{}/checkpoints/seq2seq3'.format(town)
logs = '/content/gdrive/My Drive/Licenta/Traffic4Cast/{}/logs/seq2seq3/training/logs.csv'.format(town)

In [None]:
def get_date(file_name):
    match = re.search(r'\d{4}-\d{2}-\d{2}', file_name)
    return datetime.datetime.strptime(match.group(), '%Y-%m-%d').date()

In [None]:
def get_file_names(files, excluded_dates=[]):
    file_names = os.listdir(files)
    np.random.shuffle(file_names)
    excluded_dates = [datetime.datetime.strptime(excluded_date, '%Y-%m-%d').date() for excluded_date in excluded_dates]
    file_names = [file_name for file_name in file_names if get_date(file_name) not in excluded_dates]
    return file_names[:45]

In [None]:
def load_data(file_path):
    file = h5py.File(file_path, 'r')
    group_key = list(file.keys())[0]
    data = np.array(file[group_key][:], dtype=np.float32)
    file.close()
    data = np.take(data, np.arange(8), axis=-1)  # keep only the dynamic channels
    data = np.array(np.split(data, 48))  # 48 * (3 + 3) = 288
    data = np.moveaxis(data, -1, 2)  # transpose to (batches, timestamps, channels, rows, columns)
    np.random.shuffle(data)  # shuffle the 48 batches
    return data / 255.0

In [None]:
with device('gpu:0'):
    model = models.load_model(os.path.join(checkpoints, 'model_3.h5'))
    model.compile(optimizer=optimizers.Adam(learning_rate=0.0001), loss=losses.mean_squared_error)
model.summary()

In [None]:
log_file = open(logs, 'a', newline='')
log_writer = csv.writer(log_file)

In [None]:
for epoch in range(4, 8):
    print('epoch:', epoch)
    file_names = get_file_names(files)
    for index, file_name in enumerate(file_names):
        print('file:', index)
        data = load_data(os.path.join(files, file_name))
        inputs = data[:, :3]
        outputs = data[:, 3:]
        with device('gpu:0'):
            history = model.fit(inputs, outputs, epochs=1, batch_size=3)
        log_writer.writerow([epoch, file_name, history.history['loss'][0]])
        log_file.flush()
        gc.collect()
    model.save(os.path.join(checkpoints, 'model_{}.h5'.format(epoch)))

In [None]:
log_file.close()