In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import numpy as np
from datetime import datetime, timedelta
import random

In [None]:
class DataLoader:
    def __init__(self,
                 datapath,
                 shuffle=True):
        self.shuffle = shuffle
        self.data = {}

        for root, _, files in os.walk(datapath):
            for name in files:
                filename = name.split('.')
                self.data[filename[0] + filename[1]] = np.load(os.path.join(root, name))

    def __call__(self):
        keys = list(self.data.keys())
        if self.shuffle:
            random.shuffle(keys)
        # итерация по ключам в словаре self.data
        for key in keys:
            # если маска целевого значения пустая, то пропускаем пример
            if np.all(self.data[key][:,:,-1,:,:] == 0.):
                continue
            seq = self.__getSequence(key)
            arrays = []
            badCount = 0
            # итерация по историческим данным
            for item in seq:
                # некоторые исторические данные могут отсутствовать
                try:
                    if np.all(self.data[item][:,:,-1,:,:] == 0.):
                        badCount += 1
                    arrays.append(self.data[item])
                except KeyError:
                    # print(f'No key: {item}')
                    badCount += 1
                    arrays.append(np.zeros_like(self.data[key]))
            # если пропусков в данных больше чем 5%, то пропускаем пример
            # print(f'Bad count: {badCount}')
            if badCount / len(arrays) > 0.3:
                continue
            else:
                x = np.concatenate(arrays, axis=1)
                y = self.data[key]
                for beam in range(16):
                    yield x[:,:,:,beam,0], y[:,:,:-1,beam,0]
    
    def __getSequence(self, key):
        keyDT = datetime.strptime(key, '%Y%m%d%H%M')
        # список массивов периодов за месяц
        monthBefore = []
        for i in range(30, 1, -1):
            daysBefore = (keyDT-timedelta(days=i)).strftime('%Y%m%d%H%M')
            monthBefore.append(daysBefore)
        # список массивов периодов за день до целевого массива
        dayBefore = []
        for i in range(24, 0, -2):
            hoursBefore = (keyDT-timedelta(hours=i)).strftime('%Y%m%d%H%M')
            dayBefore.append(hoursBefore)

        return monthBefore + dayBefore

In [None]:
from tensorflow import keras
import tensorflow as tf

In [None]:
# генератор
def get_generator():
    inp = keras.layers.Input(shape=(70, 2460, 7))
    conv = keras.layers.Conv2D(128, kernel_size=(1, 2401), activation='tanh')(inp)
    norm = keras.layers.BatchNormalization()(conv)
    out = keras.layers.Dense(6, activation='tanh')(norm)
    model = keras.models.Model(inp, out, name='generator')
    return model

# дискриминатор
def get_discriminator():
    hist_inp = keras.layers.Input(shape=(70, 2460, 7))
    hist_conv = keras.layers.Conv2D(filters=6, kernel_size=(1, 1))(hist_inp)
    gen_out = keras.layers.Input(shape=(70, 60, 6))
    joined = keras.layers.Concatenate(axis=2)([hist_conv, gen_out])
    conv = keras.layers.Conv2D(filters=128, kernel_size=(1, 2461), activation='tanh')(joined)
    norm = keras.layers.BatchNormalization()(conv)
    conv = keras.layers.Conv2D(filters=6, kernel_size=(70, 1), activation='tanh')(norm)
    dense = keras.layers.Dense(1, activation='sigmoid')(conv)
    out = keras.layers.Reshape((60,))(dense)
    model = keras.models.Model([hist_inp, gen_out], out, name='discriminator')
    return model

In [None]:
class RadarGAN(keras.Model):
    def __init__(self, discriminator, generator):
        super(RadarGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
        self.gen_mae_tracker = keras.metrics.Mean(name="generator_mae")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker, self.gen_mae_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(RadarGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.mae = keras.losses.MeanAbsoluteError()

    def train_step(self, data):

        # класс 1 - настоящие данные, класс 0 - дискриминированные данные
        x, y = data

        y_shape = tf.shape(y)

        # real - таргет дискриминатора
        real = tf.math.round(tf.random.uniform(shape=(y_shape[0], y_shape[2]), minval=0, maxval=1, dtype=tf.dtypes.float64))

        # маски для перемешивания данных на входе дискриминатора
        real_mask = tf.reshape(real, shape=(y_shape[0], 1, y_shape[2], 1))
        real_mask = tf.repeat(real_mask, repeats=y_shape[1], axis=1)
        real_mask = tf.repeat(real_mask, repeats=y_shape[3], axis=3)
        fake_mask = tf.math.subtract(tf.ones(shape=tf.shape(real_mask), dtype=tf.dtypes.float64), real_mask)

        # выход генератора
        generated = tf.cast(self.generator(x), dtype=tf.dtypes.float64)

        # перемешивание данных
        mixed = tf.math.add(tf.math.multiply(real_mask, y), tf.math.multiply(fake_mask, generated))

        # обучение дискриминатора
        with tf.GradientTape() as tape:
            predictions = self.discriminator([x, mixed])
            d_loss = self.loss_fn(real, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )
        
        # таргет генератора как выход дискриминатора
        misleading_labels = tf.ones(shape=tf.shape(real), dtype=tf.dtypes.float64)

        # тренировка генератора
        with tf.GradientTape() as tape:
            fake_forecast = self.generator(x)
            predictions = self.discriminator([x, fake_forecast])
            g_loss = self.loss_fn(misleading_labels, predictions)
            g_mae = self.mae(y, fake_forecast)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        self.gen_mae_tracker.update_state(g_mae)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
            "g_mae": self.gen_mae_tracker.result()
            }

In [None]:
radar_gan = RadarGAN(discriminator=get_discriminator(), generator=get_generator())
radar_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.001),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss_fn=keras.losses.BinaryCrossentropy()
)

In [None]:
batch_size = 512

train_loader = DataLoader('drive/MyDrive/2002-train')
val_loader = DataLoader('drive/MyDrive/2002-val')

train_dataset = tf.data.Dataset.from_generator(train_loader,
                                         output_types=(tf.float64, tf.float64)).batch(batch_size)
val_dataset = tf.data.Dataset.from_generator(val_loader,
                                         output_types=(tf.float64, tf.float64)).batch(batch_size)

In [None]:
hist = radar_gan.fit(train_dataset, epochs=20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [None]:
radar_gan.generator.save('drive/MyDrive/radargan_generator.hdf5')
radar_gan.discriminator.save('drive/MyDrive/radargan_discriminator.hdf5')

