In [2]:
import numpy as np
import pandas as pd
from numpy import float32
import warnings
import os
import sys
import glob
import tensorflow as tf
import cv2
from sklearn import utils
from sklearn import preprocessing
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import shutil
from pathlib import Path

from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras.applications import *
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.preprocessing import *
from tensorflow.keras.utils import *
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import regularizers
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow_addons as tfa

sys.path.append(os.path.join(Path.cwd(), 'utils'))
sys.path.append(os.path.join(Path.cwd(), 'data_generators'))
sys.path.append(os.path.join(Path.cwd(), 'models'))

from utils.im_utils import *
from utils.data_augmentation import *
from utils.region_detector import *

from data_generators.FullImageBinaryGen import *
from data_generators.PairAllClassGen import *
from data_generators.FullImageAllClassGen import *
from data_generators.CutoutImageAllClassGen import *
from data_generators.FullImageSingleClassGen import *

from models.bidirectional_convlstm_model import *
from models.pair_convlstm_model import *

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [4]:
FLARE_CLASS = 'ALL'
BEST_TRAINED_MODELS_DIR = './best_trained_models/'

LSTM_CHECKPOINTS_DIR = './checkpoints/lstm_checkpoints'
RESNET_CHECKPOINTS_DIR = './checkpoints/resnet_checkpoints'

In [5]:
def delete_files(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))

In [6]:
def GetDataFolders(train_data_dir, val_data_dir):
    train_folders = []
    for subdir, dirs, files in os.walk(train_data_dir):
        for d in dirs:
            if d != 'positive' and d != 'negative' and d != '.ipynb_checkpoints':
                train_folders.append(os.path.join(subdir, d))
    train_folders = np.array(train_folders)

    val_folders = []
    for subdir, dirs, files in os.walk(val_data_dir):
         for d in dirs:
                if d != 'positive' and d != 'negative' and d != '.ipynb_checkpoints':
                    val_folders.append(os.path.join(subdir, d))
    val_folders = np.array(val_folders)
    
    return train_folders, val_folders

In [7]:
def GetFlaresDataFolders(train_data_dir, val_data_dir, flare_classes):
    train_folders = set()
    for subdir, dirs, files in os.walk(train_data_dir):
        for f in files:
            flare_class = os.path.join(subdir, f).rsplit('/')[-5]
            if flare_class not in flare_classes:
                continue
            file_parent_path = os.path.join(subdir, f).rsplit('/', 2)[0]
            train_folders.add(file_parent_path)

    val_folders = set()
    for subdir, dirs, files in os.walk(val_data_dir):
        for f in files:
            flare_class = os.path.join(subdir, f).rsplit('/')[-5]
            if flare_class not in flare_classes:
                continue
            file_parent_path = os.path.join(subdir, f).rsplit('/', 2)[0]
            val_folders.add(file_parent_path)
    
    return list(train_folders), list(val_folders)

In [8]:
def GetSingleClassDataFolders(train_data_dir, val_data_dir, flare_class):
    train_folders = set()
    for subdir, dirs, files in os.walk(train_data_dir):
        for f in files:
            cur_class = os.path.join(subdir, f).rsplit('/')[-5]
            if cur_class != flare_class and cur_class != 'N':
                continue
            file_parent_path = os.path.join(subdir, f).rsplit('/', 2)[0]
            train_folders.add(file_parent_path)

    val_folders = set()
    for subdir, dirs, files in os.walk(val_data_dir):
        for f in files:
            cur_class = os.path.join(subdir, f).rsplit('/')[-5]
            if cur_class != flare_class and cur_class != 'N':
                continue
            file_parent_path = os.path.join(subdir, f).rsplit('/', 2)[0]
            val_folders.add(file_parent_path)
    
    return list(train_folders), list(val_folders)

In [9]:
def get_labels(generator, feature_extractor):
    labels = []

    for sample in generator:
        new_batch = []
        batch = sample[1]
        labels.append(batch)

    labels = np.array(labels)
    labels = labels.reshape(labels.shape[0]*labels.shape[1], labels.shape[2])
    
    return labels

In [27]:
flare_classes=['N', 'M', 'X']
batch_size=64
num_classes=len(flare_classes)
sequence_length=6
data_dir = 'cadence6_frame6'
output_name = f"{''.join(flare_classes)}_{data_dir}"
train_dir = os.path.join(f'./new_data/{data_dir}/', 'train')
val_dir = os.path.join(f'./new_data/{data_dir}/', 'val')
train_folders, val_folders = GetFlaresDataFolders(train_dir, val_dir, flare_classes)

In [28]:
traingen = FullImageAllClassGen(
    train_folders, 
    batch_size, 
    flare_classes,
    image_size=64, 
    sequence_length=sequence_length
)

valgen = FullImageAllClassGen(
    val_folders, 
    batch_size, 
    flare_classes,
    image_size=64, 
    sequence_length=sequence_length
)

In [29]:
# model = tf.keras.models.load_model(f'./best_trained_models/NCMX_cadence6_frame6.h5')

# true_vals = [x[1] for x in valgen]
# true_vals = np.concatenate(true_vals)
# true_labels = [x.argmax() for x in true_vals]

# preds = model.predict(valgen)
# pred_labels = [x.argmax() for x in preds]

# c = 0
# for i, l in enumerate(pred_labels):
#      if l == true_labels[i]:
#             c+=1
# print(c/len(pred_labels))

# tf.math.confusion_matrix(
#     true_labels,
#     pred_labels,
#     num_classes=num_classes,
# )

0.7736443883984867


<tf.Tensor: shape=(4, 4), dtype=int32, numpy=
array([[491, 279,   3,   2],
       [ 50, 716,   1,   0],
       [  1,  19,  17,   3],
       [  0,   1,   0,   3]], dtype=int32)>

In [67]:
model = PairConvLSTMModel(batch_size, 64, sequence_length-1, num_classes)

(None, 10, 64, 64, 1)
(None, 5, 64, 64, 1)
(None, 5, 64, 64, 1)
(None, 5, 64, 64, 1)
(None, 5, 64, 64, 1)


In [69]:
mc = ModelCheckpoint(f'{BEST_TRAINED_MODELS_DIR}/{output_name}_binary.h5', monitor='val_loss', save_best_only=True)

In [70]:
callbacks_list = [mc]
metrics = [
    tf.keras.metrics.CategoricalAccuracy(),
    tf.keras.metrics.Precision(),
    tf.keras.metrics.Recall(),
    tfa.metrics.F1Score(num_classes=num_classes)
]

In [71]:
adam_fine = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, decay=0.0002, amsgrad=False)
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=1e-3,
decay_steps=10000,
decay_rate=0.9)
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
model.compile(
    loss="categorical_crossentropy", optimizer=adam_fine, metrics=metrics
)

In [1]:
epochs=10
history = model.fit(traingen, validation_data=valgen, epochs=epochs, callbacks=callbacks_list)

In [None]:
# plt.plot(history.history['accuracy'])
# plt.plot(history.history['val_accuracy'])
# plt.title('model accuracy')
# plt.ylabel('accuracy')
# plt.xlabel('epoch')
# plt.legend(['train', 'val'], loc='upper left')
# plt.show()

In [None]:
# plt.plot(history.history['loss'])
# plt.plot(history.history['val_loss'])
# plt.title('model loss')
# plt.ylabel('loss')
# plt.xlabel('epoch')
# plt.legend(['train', 'val'], loc='upper left')
# plt.show()

In [None]:
# model.save_weights(f'{LSTM_CHECKPOINTS_DIR}/{output_name}')

In [None]:
# data_folder = './new_data/ALL_lstm_data_during_leftout2013/train/M/AIA20100807_1748_0094/0/full'
# paths = []
# for subdir, dirs, files in os.walk(data_folder):
#     for f in files:
#         paths.append(os.path.join(subdir, f))
# paths = sorted(paths)
# for p in paths:
#     print(p)

# fig, axes = plt.subplots(2, 3, figsize=(10, 8))

# for idx, ax in enumerate(axes.flat):
#     ax.imshow(np.squeeze(preprocessing.normalize(np.load(paths[idx]))), cmap='jet')
#     ax.set_title(f"Frame {idx + 1}")
#     ax.axis("off")

# plt.show()