In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import numpy as np
import re
from datetime import datetime, timedelta
import random

In [None]:
class DataLoader:
    def __init__(self,
                 datapath,
                 shuffle=True):
        self.shuffle = shuffle
        self.data = {}

        for root, _, files in os.walk(datapath):
            for name in files:
                filename = name.split('.')
                self.data[filename[0] + filename[1]] = np.load(os.path.join(root, name))

    def __call__(self):
        keys = list(self.data.keys())
        if self.shuffle:
            random.shuffle(keys)
        # итерация по ключам в словаре self.data
        for key in keys:
            # если маска целевого значения пустая, то пропускаем пример
            if np.all(self.data[key][:,:,-1,:,:] == 0.):
                continue
            seq = self.__getSequence(key)
            arrays = []
            badCount = 0
            # итерация по историческим данным
            for item in seq:
                # некоторые исторические данные могут отсутствовать
                try:
                    if np.all(self.data[item][:,:,-1,:,:] == 0.):
                        badCount += 1
                    arrays.append(self.data[item])
                except KeyError:
                    # print(f'No key: {item}')
                    badCount += 1
                    arrays.append(np.zeros_like(self.data[key]))
            # если пропусков в данных больше чем 5%, то пропускаем пример
            # print(f'Bad count: {badCount}')
            if badCount / len(arrays) > 0.3:
                continue
            else:
                x = np.concatenate(arrays, axis=1)
                y = self.data[key]
                for beam in range(16):
                    yield x[:,:,:,beam,0], y[:,:,:-1,beam,0]
    
    def __getSequence(self, key):
        keyDT = datetime.strptime(key, '%Y%m%d%H%M')
        # список массивов периодов за месяц
        monthBefore = []
        for i in range(30, 1, -1):
            daysBefore = (keyDT-timedelta(days=i)).strftime('%Y%m%d%H%M')
            monthBefore.append(daysBefore)
        # список массивов периодов за день до целевого массива
        dayBefore = []
        for i in range(24, 0, -2):
            hoursBefore = (keyDT-timedelta(hours=i)).strftime('%Y%m%d%H%M')
            dayBefore.append(hoursBefore)

        return monthBefore + dayBefore

In [None]:
datapath = 'drive/MyDrive/2002-train'

loader = DataLoader(datapath)

In [None]:
for x, y in loader:
    print(x.shape, y.shape)
    break

No key: 200205241400
No key: 200205251400
No key: 200205261400
No key: 200206071400
No key: 200206081400
No key: 200206091400
No key: 200206101400
No key: 200206111400
No key: 200206121400
No key: 200206191400
No key: 200206221600
No key: 200206221800
No key: 200206222000
No key: 200206222200
Bad count: 14
No key: 200206260800
No key: 200207051200
No key: 200207051400
No key: 200207051600
No key: 200207051800
No key: 200207052000
No key: 200207052200
Bad count: 7
(70, 2460, 7) (70, 60, 6)


In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore") 

# генератор
def get_generator():
    inp = layers.Input(shape=(70, 2460, 7))
    conv = layers.Conv2D(128, kernel_size=(1, 2401), activation='tanh')(inp)
    norm = layers.BatchNormalization()(conv)
    out = layers.Dense(6, activation='tanh')(norm)
    model = Model(inp, out)
    return model

# дискриминатор
def get_discriminator():
    hist_inp = layers.Input(shape=(70, 2460, 7))
    hist_conv = layers.Conv2D(filters=6, kernel_size=(1, 1))(hist_inp)
    gen_out = layers.Input(shape=(70, 60, 6))
    joined = layers.Concatenate(axis=2)([hist_conv, gen_out])
    conv = layers.Conv2D(filters=128, kernel_size=(1, 2461), activation='tanh')(joined)
    conv = layers.Conv2D(filters=6, kernel_size=(1, 1), activation='tanh', padding='same')(conv)
    norm = layers.BatchNormalization()(conv)
    flat = layers.Flatten()(norm)
    dense = layers.Dense(4)(flat)
    reshape = layers.Reshape((2, 2, 1))(dense)
    out = layers.Dense(1, activation='sigmoid')(reshape)
    model = Model([hist_inp, gen_out], out)
    return model

# данные
data_path = 'drive/MyDrive/2002-train'
val_data_path = 'drive/MyDrive/2002-val'

train_generator = DataLoader(data_path)
val_generator = DataLoader(val_data_path)

batch_size = 8
epochs = 20

dataset = tf.data.Dataset.from_generator(train_generator,
                                         output_types=(tf.float64, tf.float64)).batch(batch_size)
val_dataset = tf.data.Dataset.from_generator(val_generator,
                                         output_types=(tf.float64, tf.float64)).batch(batch_size)


# непосредственно GAN
gen = get_generator()
dis = get_discriminator()

gen.compile()
dis.compile()

d_optimizer=tf.keras.optimizers.Adam(lr=0.0001)
g_optimizer=tf.keras.optimizers.Adam(lr=0.0001)

loss_fn = tf.keras.losses.BinaryCrossentropy()
mae = tf.keras.losses.MeanAbsoluteError()

gen_loss_tracker = tf.keras.metrics.Mean(name='generator_loss')
disc_loss_tracker = tf.keras.metrics.Mean(name='discriminator_loss')
gen_mae_tracker = tf.keras.metrics.Mean(name='generator_mae')
val_mae_tracker = tf.keras.metrics.Mean(name='generator_validation_mae')

history = []
old_gloss=1e100
old_dloss=1e100
old_mae=1e100
old_val=1e100
min_val=1e100

block_mask_shape = (2, 2)

for epoch in range(epochs):
    g_mae = 0
    val_mae = 0
    
    print(f'+++++++++++++++++++++++++++++++++++++++ Epoch {epoch} +++++++++++++++++++++++++++++++++++++++')

    # обучение дискриминатора
    for step, (x_batch_train, y_batch_train) in enumerate(dataset):
        print('\r D step:', step, end=' ')

        y_shape = tf.shape(y_batch_train)
        if y_shape[0] != batch_size:
            break
        
        cmp_mask = tf.math.round(tf.random.uniform((y_shape[0], block_mask_shape[0], block_mask_shape[1]), minval=0, maxval=1, dtype=tf.dtypes.float64))
        true_mask = np.zeros(y_shape[:-1])

        height = true_mask[0].shape[0] // block_mask_shape[0]
        width = true_mask[0].shape[1] // block_mask_shape[1]

        for k in range(y_shape[0]):
            for i in range(block_mask_shape[0]):
                for j in range(block_mask_shape[1]):
                    val = cmp_mask[k, i, j]
                    true_mask[k, i*height:(i+1)*height, j*width:(j+1)*width].fill(val)
        
        fake_mask = tf.math.subtract(tf.ones(shape=y_shape[:-1], dtype=tf.dtypes.float64), true_mask)        
        true_mask_6 = tf.repeat(tf.expand_dims(true_mask, axis=-1), 6, axis=-1, name=None)
        fake_mask_6 = tf.repeat(tf.expand_dims(fake_mask, axis=-1), 6, axis=-1, name=None)

        x, y = x_batch_train, y_batch_train

        generated = tf.cast(gen(x), dtype=tf.dtypes.float64)
        # print(f'y shape: {y.shape}, true mask 6 shape: {true_mask_6.sha2 pe}')
        d_input_mixed = tf.math.add(tf.math.multiply(true_mask_6, y), 
                                tf.math.multiply(fake_mask_6, generated))
        
        with tf.GradientTape() as tape:
            prediction_mask = dis([x, d_input_mixed])
            d_loss = loss_fn(cmp_mask, prediction_mask)
            grads = tape.gradient(d_loss, dis.trainable_weights)
            d_optimizer.apply_gradients(zip(grads, dis.trainable_weights))

    print()

    # обучение генератора
    for step, (x_batch_train, y_batch_train) in enumerate(dataset):
        print('\r G step:', step, end=' ')

        y_shape = tf.shape(y_batch_train)
        if y_shape[0] != batch_size:
            break

        x, y = x_batch_train, y_batch_train
        misleading_mask = tf.ones(shape=(y_shape[0], block_mask_shape[0], block_mask_shape[1]), dtype=tf.dtypes.float64)

        # обучение генератора, без обновления весов дискриминатора
        with tf.GradientTape() as tape:
            fake_forcast = gen(x)
            prediction_mask = dis([x, fake_forcast])
            g_loss = loss_fn(misleading_mask, prediction_mask)
            loss_value = g_loss
            # g_mae = mae(y, fake_forcast)
            # loss_value = 0.9*g_loss+0.1*g_mae
            # loss_value += sum(generator.losses)
            grads = tape.gradient(loss_value, gen.trainable_weights)
            g_optimizer.apply_gradients(zip(grads, gen.trainable_weights))
    
    print()

    # MAE
    g_mae = 0
    for step, (x_batch_train, y_batch_train) in enumerate(dataset):
        print('\r MAE step:', step, end=' ')

        y_shape = tf.shape(y_batch_train)
        if y_shape[0] != batch_size:
            break

        x, y = x_batch_train, y_batch_train
        fake_forcast = gen(x)
        g_mae += mae(y, fake_forcast)
    
    print()

    # VAL MAE
    val_mae = 0
    for step, (x_batch_train, y_batch_train) in enumerate(val_dataset):
        print('\r VAL_MAE step:', step, end=' ')

        y_shape = tf.shape(y_batch_train)
        if y_shape[0] != batch_size:
            break
        
        x, y = x_batch_train, y_batch_train
        fake_forcast = gen(x)
        val_mae += mae(y, fake_forcast)
    
    print()

    # monitor loss
    gen_loss_tracker.update_state(g_loss)
    disc_loss_tracker.update_state(d_loss)
    gen_mae_tracker.update_state(g_mae)
    val_mae_tracker.update_state(val_mae)

    print("g_loss:",float(gen_loss_tracker.result()),
          "d_loss:", float(disc_loss_tracker.result()),
          "g_mae:", float(gen_mae_tracker.result()),
          "val_mae:", float(val_mae_tracker.result()))
    
    history.append([float(gen_loss_tracker.result()),
                    float(disc_loss_tracker.result()),
                    float(gen_mae_tracker.result()),
                    float(val_mae_tracker.result())])
    
    if float(val_mae_tracker.result()) < min_val:
        min_val = val_mae_tracker.result()
        gen.save('drive/MyDrive/best_generator.hdf5')
        dis.save('drive/MyDrive/best_discriminator.hdf5')
        best_generator = gen
        best_discriminator = dis

    old_dloss = float(disc_loss_tracker.result())
    old_gloss = float(gen_loss_tracker.result())
    old_mae = float(gen_mae_tracker.result())
    old_val = float(val_mae_tracker.result())

with open('drive/MyDrive/log.txt', 'w') as fp:
    for record in history:
        fp.write(f'{record[0]} {record[1]} {record[2]} {record[3]}')

+++++++++++++++++++++++++++++++++++++++ Epoch 0 +++++++++++++++++++++++++++++++++++++++
 D step: 4181 
 G step: 4181 
 MAE step: 4181 
 VAL_MAE step: 1461 
g_loss: 2.403637409210205 d_loss: 0.04296540468931198 g_mae: 42970.83984375 val_mae: 15377.775390625
+++++++++++++++++++++++++++++++++++++++ Epoch 1 +++++++++++++++++++++++++++++++++++++++
 D step: 4181 
 G step: 4181 
 MAE step: 4181 
 VAL_MAE step: 1461 
g_loss: 1.539368987083435 d_loss: 0.022736212238669395 g_mae: 43092.8359375 val_mae: 15393.7548828125
+++++++++++++++++++++++++++++++++++++++ Epoch 2 +++++++++++++++++++++++++++++++++++++++
 D step: 4181 
 G step: 4181 
 MAE step: 4181 
 VAL_MAE step: 1461 
g_loss: 1.2651041746139526 d_loss: 0.02319999225437641 g_mae: 43002.11328125 val_mae: 15363.9052734375
+++++++++++++++++++++++++++++++++++++++ Epoch 3 +++++++++++++++++++++++++++++++++++++++
 D step: 4181 
 G step: 4181 
 MAE step: 4181 
 VAL_MAE step: 1461 
g_loss: 0.9658921957015991 d_loss: 0.017713144421577454 g_mae: 43107.9