In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator and then re-execute this cell.')
else:
    print(gpu_info)

In [None]:
from psutil import virtual_memory

ram = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM.'.format(ram))
if ram < 20:
    print('To enable a high-RAM runtime, select the Runtime > "Change runtime type", then select High-RAM in the Runtime shape dropdown '
          'and then re-execute this cell.')
else:
    print('You are using a high-RAM runtime.')

In [None]:
from tensorflow import config

physical_devices = config.list_physical_devices('GPU')
try:
    config.experimental.set_memory_growth(physical_devices[0], True)
except Exception as exception:
    print(exception)

In [None]:
!pip install tensorflow-addons

In [None]:
from tensorflow import device
from tensorflow_addons import layers as new_layers
from tensorflow.keras import losses
from tensorflow.keras import models
from tensorflow.keras import optimizers

import csv
import gc
import h5py
import numpy as np
import os

In [None]:
town = 'Moscow'  #@param ['Berlin', 'Istanbul', 'Moscow']

files = '/content/gdrive/My Drive/Licenta/Traffic4Cast/{}/files/training'.format(town)

checkpoints = '/content/gdrive/My Drive/Licenta/Traffic4Cast/{}/checkpoints/UNet3_2'.format(town)
logs = '/content/gdrive/My Drive/Licenta/Traffic4Cast/{}/logs/UNet3_2/training/logs.csv'.format(town)

In [None]:
def get_file_names(files):
    file_names = os.listdir(files)
    np.random.shuffle(file_names)
    return file_names

In [None]:
def get_data(file_path, index):
    file = h5py.File(file_path, 'r')
    group_key = list(file.keys())[0]
    data = np.array(file[group_key][index:index + 72], dtype=np.float32)  # (72, 495, 436, 9)
    file.close()
    return data

In [None]:
def get_training_data(file_path, index):
    data = get_data(file_path, index)
    data = np.take(data, np.arange(8), axis=-1)  # keep only the dynamic channels
    data = np.array(np.split(data, 12))  # split in 12 batches of 3 + 3 timestamps
    data = np.moveaxis(data, 1, -1).reshape((12, 495, 436, -1))  # combine the timestamps with the channels
    np.random.shuffle(data)  # shuffle the batches
    data /= 255.0
    inputs = data[:, :, :, :24]
    outputs = data[:, :, :, 24:]
    return inputs, outputs

In [None]:
with device('gpu:0'):
    model = models.load_model(os.path.join(checkpoints, 'model_4.h5'))
    model.compile(optimizer=optimizers.Adam(learning_rate=0.0001), loss=losses.mean_squared_error)
model.summary()

In [None]:
log_file = open(logs, 'a', newline='')
log_writer = csv.writer(log_file)

In [None]:
for epoch in range(5, 10):
    print('epoch:', epoch)
    file_names = get_file_names(files)
    for index, file_name in enumerate(file_names):
        print('file:', index)
        losses = np.zeros(shape=(4,), dtype=np.float64)
        for index in range(0, 288, 72):
            inputs, outputs = get_training_data(os.path.join(files, file_name), index)
            with device('gpu:0'):
                history = model.fit(inputs, outputs, epochs=1, batch_size=3)
                losses[index // 72] = history.history['loss'][0]
        log_writer.writerow([epoch, file_name, np.mean(losses, dtype=np.float64)])
        log_file.flush()
        gc.collect()
    model.save(os.path.join(checkpoints, 'model_{}.h5'.format(epoch)))

In [None]:
log_file.close()