In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sunpy.time import parse_time
from sklearn import preprocessing

from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras.layers import *
import tensorflow_addons as tfa

import os
import shutil
import datetime
import sys
from pathlib import Path
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
from matplotlib import cm
import cv2
import imageio

sys.path.append(os.path.join(Path.cwd(), 'utils'))
sys.path.append(os.path.join(Path.cwd(), 'models'))

from utils.region_detector import *
from utils.im_utils import *
from utils.data_augmentation import *
from models.bidirectional_convlstm_model import *

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [4]:
# returns a formatted file name of the closest AIA data file to the given datetime

def GetClosestDataFileByDate(dt):
    AIA_data_date = f'{dt.year}{dt.month:02d}{dt.day:02d}'
    tmp_dt = dt
    minute = 0
    minute=GetClosestMultiple(tmp_dt.minute, 6)
    AIA_data_time = f'{tmp_dt.hour:02d}{minute:02d}'
    AIA_data_filename = f'AIA{AIA_data_date}_{AIA_data_time}_0094.npz'
    
    return AIA_data_filename

In [5]:
def SaveMaxValsForClass(flare_class):
    class_df = pd.read_csv(f'./event_records/new_events_by_class/{flare_class}.csv')
    max_vals = []
    
    for index, row in class_df.iterrows():
        event_peaktime = parse_time(row['event_peaktime']).datetime
        closest_file = GetClosestDataFileByDate(event_peaktime)
        closest_filepath = f'../data_94/{event_peaktime.year}/{event_peaktime.month:02}/{event_peaktime.day:02}/{closest_file}'
        try:
            closest_image = np.load(closest_filepath)['x']
            max_vals.append(closest_image.max())
        except FileNotFoundError:
            continue
            
    df = pd.DataFrame(max_vals, columns = ['val'])
    df.to_csv(f'./{flare_class}_max_vals.csv', index=False)

In [6]:
DATA_DIR = '../data_94'
VIDEOS_DIR = './videos'
FRAMES_DIR = './frames'
FRAMES_MARKED_DIR = './frames/frames_marked'
CUTOUT_FRAMES_MARKED = './frames/cutout_frames_marked/'
LSTM_CHECKPOINTS_DIR = './checkpoints/lstm_checkpoints'

In [7]:
def delete_files(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))

In [8]:
# returns a formatted file name of the closest AIA data file to the given datetime

def GetClosestDataFileByDate(dt):
    AIA_data_date = f'{dt.year}{dt.month:02d}{dt.day:02d}'
    tmp_dt = dt
    minute = 0
    minute=GetClosestMultiple(tmp_dt.minute, 6)
    AIA_data_time = f'{tmp_dt.hour:02d}{minute:02d}'
    AIA_data_filename = f'AIA{AIA_data_date}_{AIA_data_time}_0094.npz'
    
    return AIA_data_filename

In [9]:
def GetAIAPathAtTime(dt):
    dt_data_dir = os.path.join(DATA_DIR, f'{dt.year}/{dt.month:02d}/{dt.day:02d}')
    closest_data_file = GetClosestDataFileByDate(dt)
    file_path = os.path.join(dt_data_dir, closest_data_file)
    if not os.path.exists(file_path):
        raise FileNotFoundError
    
    return file_path

In [10]:
def GetFilepathsBetweenDates(start_datetime, end_datetime, cadence=6):
    filepaths = []
    loop_date = start_datetime
    
    while(loop_date < end_datetime):
        closest_filepath = None
        try:
            closest_filepath = GetAIAPathAtTime(loop_date)
        except(FileNotFoundError):
            loop_date = loop_date + datetime.timedelta(minutes=cadence)
            continue
        filepaths.append(closest_filepath)
        loop_date = loop_date + datetime.timedelta(minutes=cadence)
    
    return filepaths

In [11]:
def key_func(x):
    return int(x.split('.')[-2].split('/')[-1])

In [12]:
def WriteImages(filepaths):
    frames = [Image.fromarray(np.uint8(np.load(x)['x'])) for x in filepaths]
    for i, frame in enumerate(frames):
        frame.save(f'{FRAMES_DIR}/{i}.png')

In [13]:
def GetCleanAIAFilename(filename):
    return filename.rsplit('/', 1)[-1].rsplit('.', 1)[0]

In [14]:
def PosToClass(pos):
    # depth = len(flare_classes)
    # return tf.one_hot(pos, depth)
    if pos == 0:
        return 'N'
    elif pos == 1:
        return 'C'
    elif pos == 2:
        return 'M'
    elif pos == 3:
        return 'X'
    else:
        return 'N'

In [15]:
def GetImageCutouts(image, cutout_size=64, stride=16):
    width, height = image.shape[0] , image.shape[1]
    h_pass, v_pass = height//cutout_size, width//cutout_size
    cutouts = []
    for i in range(0, height-cutout_size, stride):
        for j in range(0, width-cutout_size, stride):
            cutouts.append(image[i:i+cutout_size, j:j+cutout_size])
    
    return np.array(cutouts)

In [97]:
def WritePredictedCutouts(save_dir, filepaths, full_disk_model, cutout_model, flare_classes, image_size=64, sequence_length=5, stride=32):
    delete_files(save_dir)
    images_raw = [np.load(x)['x'] for x in filepaths]
    images_raw_diff = [abs(images_raw[x]-images_raw[x-1]) for x in range(1, len(images_raw))]
    images = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw]
    images_diff = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw_diff]
    in_flaring_period = False
    # coords for the center of the flare
    flare_coords = None
    flare_class = 0
    base_logfile_name = f'./logs/flare_detections/cutout_model'
    month_logfile = open(f'{base_logfile_name}.txt', 'w')
    all_flare_coords = None
    
    for i in range(sequence_length, len(images_raw_diff)):
        past_images = images_diff[i-sequence_length:i]
        past_images = np.array(past_images)
        past_images = np.expand_dims(past_images, 0)
        past_images = np.expand_dims(past_images, 4)
        # print(past_images.shape)
        # now the shape is 1, 12, 64, 64, 1 of full sun images
        prediction = full_disk_model.predict(past_images, verbose=0)
        pos = prediction[0].argmax()
        past_images_pil = [Image.fromarray(x).convert("RGBA") for x in np.uint8(images_raw[i-sequence_length:i])]
        past_images_raw = images_raw[i-sequence_length:i]
        first_image_pil = past_images_pil[0]
        draw = ImageDraw.Draw(first_image_pil)
            
        if pos > 0:
            if not in_flaring_period:
                in_flaring_period = True
                past_images_raw_diff = images_raw_diff[i-sequence_length:i]
                # past_images_cutouts = np.array([GetImageCutouts(x, 64, stride) for x in past_images_raw_diff])
                # shape = 5, N, 64, 64
                # past_images_cutouts = np.array([past_images_cutouts[:, x] for x in range(past_images_cutouts.shape[1])])
                # print(past_images_cutouts.shape)
                # return past_images_cutouts
                # shape = N, 5, 64, 64
                predictions = []
                # for cutout_sequence in past_images_cutouts:
                #     cutout_sequence_merged = np.concatenate([np.expand_dims(cutout_sequence, 0), past_images[:, :, :, :, 0]], axis=1)
                #     predictions.append(cutout_model.predict(cutout_sequence_merged, verbose=0))
                    
                top_regions_coords = GetImageTopNRegionsCoords(past_images_raw_diff[-1], N=1)
                top_region_cutouts = np.array(
                    [[GetSafeCenteredCutout(im, 64, x) for x in top_regions_coords] for im in past_images_raw]
                )
                top_region_cutouts = np.array([top_region_cutouts[:, x] for x in range(top_region_cutouts.shape[1])])
                
                # return top_region_cutouts
                
                top_regions_coords = [[x[1], x[0]] for x in top_regions_coords]
                
                for region_cutout in top_region_cutouts:
                    cutouts = np.expand_dims(region_cutout, 0)
                    # cutouts = np.expand_dims(cutouts, 4)
                    # shape = 1, 5, 64, 64, 1
                    cutout_sequence_merged = np.concatenate([cutouts, past_images[:, :, :, :, 0]], axis=1)
                    cutout_sequence_merged = np.expand_dims(cutout_sequence_merged, 4)
                    predictions.append(cutout_model.predict(cutout_sequence_merged, verbose=0))
                top_region_prediction_classes = np.array([x.argmax() for x in predictions])
                #now only uses the first flare region aka [2 1 0] -> returns index 0 for only the first flare
                flare_cutout_indeces = np.where(top_region_prediction_classes > 0)[0]
                if len(flare_cutout_indeces) == 0:
                    continue
                flare_cutout_index = flare_cutout_indeces[0]
                flare_coords = top_regions_coords[flare_cutout_index]
                flare_class = top_region_prediction_classes[flare_cutout_index]
                
                month_logfile.write(f'from: {GetCleanAIAFilename(filepaths[i-sequence_length])} to: {GetCleanAIAFilename(filepaths[i+1])}\n')
                month_logfile.write(str(predictions)+'\n')
                month_logfile.write(str(predictions[flare_cutout_index])+'\n')
                month_logfile.write(str(top_regions_coords)+'\n')
                month_logfile.write(str(flare_coords)+'\n')
                month_logfile.write(str(flare_classes[flare_class])+'\n')
                month_logfile.write('--------------------------------------------------------------------------------------\n')
                
                # for idx, c in enumerate(past_images_cutouts):
                    # cutouts = np.expand_dims(c, 0)
                    # cutouts = np.expand_dims(cutouts, 4)
                    # # shape = 1, 5, 64, 64, 1
                    # predictions.append(cutout_model.predict(cutouts, verbose=0))
                # predictions = np.array(predictions)
                # predicted_classes = np.array([x.argmax() for x in predictions])
                # flare_cutout_indeces = np.where(predicted_classes > 0)[0]
                # all_flare_coords = [(x/(512/stride)*stride, x%(512/stride)*stride) for x in flare_cutout_indeces]
                # print(all_flare_coords)
                
                # for coord in flare_coords:
                tl, br = (flare_coords[0]-32, flare_coords[1]-32), (flare_coords[0]+32, flare_coords[1]+32)
                draw = ImageDraw.Draw(first_image_pil)
                draw.rectangle((tl, br), outline="red")
#                 max_predictions = np.array([predictions[x].max() for x in flare_cutout_indeces])
#                 max_predictions_index = max_predictions.argmax()
#                 max_pred_index = flare_cutout_indeces[max_predictions_index]
                
#                 flare_center_coords = (max_pred_index/(512/stride)*stride, max_pred_index%(512/stride)*stride)
#                 flare_coords = flare_center_coords
            # for coord in flare_coords:
            tl, br = (flare_coords[0]-32, flare_coords[1]-32), (flare_coords[0]+32, flare_coords[1]+32)
            draw = ImageDraw.Draw(first_image_pil)
            draw.rectangle((tl, br), outline="red")
            # tl, br = (flare_coords[0]-32, flare_coords[1]-32), (flare_coords[0]+32, flare_coords[1]+32)
            # draw = ImageDraw.Draw(first_image_pil)
            # draw.rectangle((tl, br), outline="red")
            
        elif pos == 0:
            in_flaring_period = False
            flare_coords = None
            flare_class = 0
            all_flare_coords = None
            
        filename = filepaths[i-sequence_length].rsplit('/')[-1].rsplit('.', 1)[0]
        draw.text((12, 490),filename,(255,255,255))
        draw.text((400, 490), flare_classes[flare_class], (255,255,255))
        first_image_pil.save(f'{save_dir}/{i-sequence_length}.png', "PNG")
        
    month_logfile.close()

In [79]:
def TestWritePredictedCutouts(save_dir, filepaths, full_disk_model, cutout_model, flare_classes, image_size=64, sequence_length=5, stride=32):
    delete_files(save_dir)
    images_raw = [np.load(x)['x'] for x in filepaths]
    images_raw_diff = [abs(images_raw[x]-images_raw[x-1]) for x in range(1, len(images_raw))]
    images = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw]
    images_diff = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw_diff]
    in_flaring_period = False
    # coords for the center of the flare
    flare_coords = None
    flare_class = 0
    base_logfile_name = f'./logs/flare_detections/cutout_model'
    month_logfile = open(f'{base_logfile_name}.txt', 'w')
    all_flare_coords = None
    
    for i in range(sequence_length, len(images_raw_diff)):
        past_images = images_diff[i-sequence_length:i]
        past_images = np.array(past_images)
        past_images = np.expand_dims(past_images, 0)
        past_images = np.expand_dims(past_images, 4)
        # print(past_images.shape)
        # now the shape is 1, 12, 64, 64, 1 of full sun images
        prediction = full_disk_model.predict(past_images, verbose=0)
        pos = prediction[0].argmax()
        past_images_pil = [Image.fromarray(x).convert("RGBA") for x in np.uint8(images_raw[i-sequence_length:i])]
        past_images_raw = images_raw[i-sequence_length:i]
        first_image_pil = past_images_pil[0]
        draw = ImageDraw.Draw(first_image_pil)
            
        if pos > 0:
            if not in_flaring_period:
                in_flaring_period = True
                past_images_raw_diff = images_raw_diff[i-sequence_length:i]
                past_images_cutouts = np.array([GetImageCutouts(x, 64, stride) for x in past_images_raw_diff])
                # shape = 5, N, 64, 64
                past_images_cutouts = np.array([past_images_cutouts[:, x] for x in range(past_images_cutouts.shape[1])])
                return past_images, past_images_cutouts
                # print(past_images_cutouts.shape)
                # return past_images_cutouts
                # shape = N, 5, 64, 64
                predictions = []
                for cutout_sequence in past_images_cutouts:
                    cutout_sequence_merged = np.concatenate([past_images[:, :, :, :, 0], np.expand_dims(cutout_sequence, 0)], axis=1)
                    predictions.append(cutout_model.predict(np.expand_dims(cutout_sequence_merged, 4), verbose=0))
                    
                # month_logfile.write(f'from: {GetCleanAIAFilename(filepaths[i-sequence_length])} to: {GetCleanAIAFilename(filepaths[i+1])}\n')
                # month_logfile.write(str(predictions)+'\n')
                # month_logfile.write(str(predictions[flare_cutout_index])+'\n')
                # month_logfile.write(str(top_regions_coords)+'\n')
                # month_logfile.write(str(flare_coords)+'\n')
                # month_logfile.write(str(flare_classes[flare_class])+'\n')
                # month_logfile.write('--------------------------------------------------------------------------------------\n')
                
                # for idx, c in enumerate(past_images_cutouts):
                #     cutouts = np.expand_dims(c, 0)
                #     cutouts = np.expand_dims(cutouts, 4)
                #     # shape = 1, 5, 64, 64, 1
                #     predictions.append(cutout_model.predict(cutouts, verbose=0))
                predictions = np.array(predictions)
                predicted_classes = np.array([x.argmax() for x in predictions])
                flare_cutout_indeces = np.where(predicted_classes > 0)[0]
                flare_coords = [(x/(512/stride)*stride, x%(512/stride)*stride) for x in flare_cutout_indeces]
                
            for coord in flare_coords:
                tl, br = (coord[0]-32, coord[1]-32), (coord[0]+32, coord[1]+32)
                draw = ImageDraw.Draw(first_image_pil)
                draw.rectangle((tl, br), outline="red")
            
        elif pos == 0:
            in_flaring_period = False
            flare_coords = None
            flare_class = 0
            all_flare_coords = None
            
        filename = filepaths[i-sequence_length].rsplit('/')[-1].rsplit('.', 1)[0]
        draw.text((12, 490),filename,(255,255,255))
        draw.text((400, 490), flare_classes[flare_class], (255,255,255))
        first_image_pil.save(f'{save_dir}/{i-sequence_length}.png', "PNG")
        
    month_logfile.close()

In [34]:
def WritePredictedImages(save_dir, title, filepaths, model, flare_classes, image_size=64, sequence_length=5, rect_size=64):
    delete_files(save_dir)
    images_raw = [np.load(x)['x'] for x in filepaths]
    images_raw_diff = [abs(images_raw[x]-images_raw[x-1]) for x in range(1, len(images_raw))]
    images = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw]
    images_diff = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw_diff]
    logfile = open(f'./videos/logs/{title}.txt', 'w')
    
    for i in range(sequence_length, len(images_diff)):
        past_images = images_diff[i-sequence_length:i]
        past_images = np.array(past_images)
        past_images = np.expand_dims(past_images, 0)
        past_images = np.expand_dims(past_images, 4)
        # now the shape is 1, 12, 64, 64, 1 of full sun images
        prediction = model.predict(past_images, verbose=0)
        pos = prediction[0].argmax()
        # print(f'{prediction} CLASS: {PosToClass(pos)}')
        # print('--------------------------------------------------------------------------------------')
        past_images_pil = [Image.fromarray(x).convert("RGBA") for x in np.uint8(images_raw[i-sequence_length:i])]
        past_images_raw = images_raw[i-sequence_length:i]
        first_image_pil = past_images_pil[0]
        logfile.write(f'from: {GetCleanAIAFilename(filepaths[i-sequence_length])} to: {GetCleanAIAFilename(filepaths[i+1])}\n')
        logfile.write(str(prediction[0])+'\n')
        logfile.write(str(pos)+'\n')
        logfile.write('--------------------------------------------------------------------------------------\n')

        draw = ImageDraw.Draw(first_image_pil)
        if pos > 1:
            first_image_raw = past_images_raw[0]
            center_coord = GetImageTopNRegionsCoords(first_image_raw, 1)[0]
            tl, br = (center_coord[1]-rect_size//2, center_coord[0]-rect_size//2), (center_coord[1]+rect_size//2, center_coord[0]+rect_size//2)
            draw.rectangle((tl, br), outline="red")
        filename = filepaths[i-sequence_length].rsplit('/')[-1].rsplit('.', 1)[0]
        draw.text((12, 490),filename,(255,255,255))
        draw.text((400, 490),PosToClass(pos),(255,255,255))
        first_image_pil.save(f'{save_dir}/{i-sequence_length}.png', "PNG")
    logfile.close()

In [19]:
def WriteStartEndImages(save_dir, title, filepaths, start_model, end_model, image_size=64, sequence_length=5, rect_size=64):
    delete_files(save_dir)
    
    model = start_model
    cur_flare_location = None
    cur_flare_class = 'N'
    
    
    images_raw = [np.load(x)['x'] for x in filepaths]
    images_raw_diff = [abs(images_raw[x]-images_raw[x-1]) for x in range(1, len(images_raw))]
    images = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw]
    images_diff = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw_diff]
    logfile = open(f'./videos/logs/{title}.txt', 'w')
    for i in range(sequence_length, len(images_diff)):
        past_images = images_diff[i-sequence_length:i]
        past_images = np.array(past_images)
        past_images = np.expand_dims(past_images, 0)
        past_images = np.expand_dims(past_images, 4)
        # now the shape is 1, 12, 64, 64, 1 of full sun images
        prediction = model.predict(past_images, verbose=0)
        pos = prediction[0].argmax()
        cur_flare_class = PosToClasslass(pos)
        # print(f'{prediction} CLASS: {PosToClass(pos)}')
        # print('--------------------------------------------------------------------------------------')
        past_images_pil = [Image.fromarray(x).convert("RGBA") for x in np.uint8(images_raw[i-sequence_length:i])]#[Image.fromarray(np.uint8(images_raw[i-x])).convert("RGBA") for x in range(sequence_length, 0, -1)]
        past_images_raw = images_raw[i-sequence_length:i]#[images_raw[i-x] for x in range(sequence_length, 0, -1)]
        first_image_pil = past_images_pil[0]
        logfile.write(f'from: {GetCleanAIAFilename(filepaths[i-sequence_length])} to: {GetCleanAIAFilename(filepaths[i+1])}\n')
        logfile.write(str(prediction[0])+'\n')
        logfile.write(str(pos)+'\n')
        logfile.write('--------------------------------------------------------------------------------------\n')

        draw = ImageDraw.Draw(first_image_pil)
        if model == start_model:
            if pos > 0:
                # Flare detected
                model = end_model
                first_image_raw = past_images_raw[0]
                cur_flare_location = GetImageTopNRegionsCoords(first_image_raw, 1)[0]
                tl, br = (cur_flare_location[1]-rect_size//2, cur_flare_location[0]-rect_size//2), (cur_flare_location[1]+rect_size//2, cur_flare_location[0]+rect_size//2)
                draw.rectangle((tl, br), outline="red")
        elif model == end_model:
            if pos > 0:
                cur_flare_location = None
                cur_flare_class = 'N'
                model = start_model
            else:
                tl, br = (cur_flare_location[1]-rect_size//2, cur_flare_location[0]-rect_size//2), (cur_flare_location[1]+rect_size//2, cur_flare_location[0]+rect_size//2)
                draw.rectangle((tl, br), outline="red")
        filename = filepaths[i-sequence_length].rsplit('/')[-1].rsplit('.', 1)[0]
        draw.text((12, 490),filename,(255,255,255))
        draw.text((400, 490), cur_flare_class, (255,255,255))
        first_image_pil.save(f'{save_dir}/{i-sequence_length}.png', "PNG")
    logfile.close()

In [20]:
def GetClosestFlareToDatetime(dt):
    year_month_flare_df = pd.read_csv(f'./event_records/new_events_by_date/{dt.year}/{dt.month}.csv')
    closest_flare = year_month_flare_df.iloc[0]
    for index, row in year_month_flare_df.iterrows():
        cur_closest_flare_time = parse_time(closest_flare['event_starttime'], precision=0).datetime
        row_start_time = parse_time(row['event_starttime'], precision=0).datetime
        if abs(dt-row_start_time) < abs(dt-cur_closest_flare_time):
            closest_flare = row
    
    return closest_flare

In [21]:
def CalculateDetectionError(filepaths, start_model, end_model, image_size=64, sequence_length=5):
    model = start_model
    cur_flare_location = None
    closest_flare = None
    cur_flare_class = 'N'
    start_total_error_time = datetime.timedelta(seconds=0)
    end_total_error_time = datetime.timedelta(seconds=0)
    predicted_flares = 0
    images_raw = [np.load(x)['x'] for x in filepaths]
    images_raw_diff = [abs(images_raw[x]-images_raw[x-1]) for x in range(1, len(images_raw))]
    images_diff = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw_diff]
    predicted_start_time, predicted_end_time = None, None
    logfile = open(f'./logs/flare_detections/flare_detections.txt', 'w')
    pred_logfile = open(f'./logs/flare_detections/flare_detections_full_log.txt', 'w')
    for idx, i in enumerate(range(sequence_length, len(images_diff))):
        past_images = images_diff[i-sequence_length:i]
        past_images = np.array(past_images)
        past_images = np.expand_dims(past_images, 0)
        past_images = np.expand_dims(past_images, 4)
        # now the shape is 1, 12, 64, 64, 1 of full sun images
        prediction = model.predict(past_images, verbose=0)
        pos = prediction[0].argmax()
        cur_flare_class = PosToClass(pos)
        
        pred_logfile.write(f'from: {GetCleanAIAFilename(filepaths[i-sequence_length])} to: {GetCleanAIAFilename(filepaths[i+1])}\n')
        pred_logfile.write(str(prediction[0])+'\n')
        pred_logfile.write(str(pos)+'\n')
        pred_logfile.write('--------------------------------------------------------------------------------------\n')
        
        if pos > 0:
            if model == start_model:
                # Flare detected
                predicted_start_time = parse_time(filepaths[i-sequence_length].rsplit('.', 1)[0].rsplit('/', 1)[-1].rsplit('_', 1)[0][3:]).datetime
                closest_flare = GetClosestFlareToDatetime(predicted_start_time)
                model = end_model
                start_predict_error = abs(predicted_start_time-parse_time(closest_flare['event_starttime'], precision=0).datetime)
                start_total_error_time += start_predict_error
                predicted_flares += 1
                
                logfile.write(f'predicted flare START: {predicted_start_time} class {cur_flare_class}\n')
                logfile.write(f"closest true flare START: {closest_flare['event_starttime']} class {closest_flare['fl_goescls']}\n")
                logfile.write(f"error: {start_predict_error}\n")
                logfile.write('--------------------------------------------------------------\n')
            elif model == end_model:
                # Flare end detected
                predicted_end_time = parse_time(filepaths[i-sequence_length].rsplit('.', 1)[0].rsplit('/', 1)[-1].rsplit('_', 1)[0][3:]).datetime
                model = start_model
                end_predict_error = abs(predicted_start_time-parse_time(closest_flare['event_endtime'], precision=0).datetime)
                end_total_error_time += end_predict_error
                
                logfile.write(f'predicted flare END: {predicted_start_time} class {cur_flare_class}\n')
                logfile.write(f"closest true flare END: {closest_flare['event_endtime']} class {closest_flare['fl_goescls']}\n")
                logfile.write(f"error: {end_predict_error}\n")
                logfile.write('--------------------------------------------------------------\n')
    logfile.write(f'average start error time: {start_total_error_time/predicted_flares}\n')
    logfile.write(f'average end error time: {end_total_error_time/predicted_flares}\n')

In [22]:
def CalculateDetectionErrorSingleModel(months_filepaths, start_model, flare_classes, image_size=64, sequence_length=5):
    # os.remove(f'./logs/flare_detections/flare_detections_single_model.txt')
    # os.remove(f'./logs/flare_detections/flare_detections_single_model_full_log.txt')
    global_logfile = open(f'./logs/flare_detections/flare_detections_averages.txt', 'w')
    # logfile = open(f'./logs/flare_detections/flare_detections_single_model.txt', 'a')
    base_logfile_name = f'./logs/flare_detections/flare_detections_single_model'
    base_pred_logfile_name = f'./logs/flare_detections/flare_detections_single_model_full_log'
    # pred_logfile = open(f'./logs/flare_detections/flare_detections_single_model_full_log.txt', 'a')
    start_total_error_time = datetime.timedelta(seconds=0)
    end_total_error_time = datetime.timedelta(seconds=0)
    total_predicted_flares = 0
    
    for mi, filepaths in enumerate(months_filepaths):
        month_logfile = open(f'{base_logfile_name}_{mi+1}.txt', 'w')
        month_pred_logfile = open(f'{base_pred_logfile_name}_{mi+1}.txt', 'w')
        print(f'processing month {mi+1}')
        model = start_model
        cur_flare_location = None
        closest_flare = None
        cur_flare_class = 'N'
        predicted_flares = 0
        detected_mode = False
        consecutive_preds = 0
        images_raw = [np.load(x)['x'] for x in filepaths]
        images_raw_diff = [abs(images_raw[x]-images_raw[x-1]) for x in range(1, len(images_raw))]
        images_diff = [cv2.resize(x, (image_size, image_size), interpolation = cv2.INTER_AREA) for x in images_raw_diff]
        predicted_start_time, predicted_end_time = None, None
        month_average_start_error, month_average_end_error = datetime.timedelta(seconds=0), datetime.timedelta(seconds=0)
        

        for idx, i in enumerate(range(sequence_length, len(images_diff))):
            past_images = images_diff[i-sequence_length:i]
            past_images = np.array(past_images)
            past_images = np.expand_dims(past_images, 0)
            past_images = np.expand_dims(past_images, 4)
            # now the shape is 1, 12, 64, 64, 1 of full sun images
            prediction = model.predict(past_images, verbose=0)
            pos = prediction[0].argmax()
            cur_flare_class = flare_classes[pos]#PosToClass(pos, flare_classes)

            month_pred_logfile.write(f'from: {GetCleanAIAFilename(filepaths[i-sequence_length])} to: {GetCleanAIAFilename(filepaths[i+1])}\n')
            month_pred_logfile.write(str(prediction[0])+'\n')
            month_pred_logfile.write(str(pos)+'\n')
            month_pred_logfile.write('--------------------------------------------------------------------------------------\n')

            time_string = filepaths[i-sequence_length].rsplit('.', 1)[0].rsplit('/', 1)[-1].rsplit('_', 1)[0][3:]
            year, month, day, hour, minute = int(time_string[:4]), int(time_string[4:6]), int(time_string[6:8]), int(time_string[9:11]), int(time_string[11:13])

            if pos > 0:
                if detected_mode:
                    consecutive_preds += 1
                    continue
                # flare start detected
                detected_mode = True
                # Flare detected

                time_string = filepaths[i-sequence_length].rsplit('.', 1)[0].rsplit('/', 1)[-1].rsplit('_', 1)[0][3:]
                year, month, day, hour, minute = int(time_string[:4]), int(time_string[4:6]), int(time_string[6:8]), int(time_string[9:11]), int(time_string[11:13])
                predicted_start_time = datetime.datetime(year, month, day, hour, minute)
                
                try:
                    closest_flare = GetClosestFlareToDatetime(predicted_start_time)
                except:
                    break
                # model = end_model
                start_predict_error = abs(predicted_start_time-parse_time(closest_flare['event_starttime'], precision=0).datetime)
                start_total_error_time += start_predict_error
                month_average_start_error += start_predict_error
                predicted_flares += 1
                total_predicted_flares += 1
                consecutive_preds += 1
                month_logfile.write(f'predicted flare START: {predicted_start_time} class {cur_flare_class}\n')
                month_logfile.write(f"closest true flare START: {closest_flare['event_starttime']} class {closest_flare['fl_goescls']}\n")
                month_logfile.write(f"error: {start_predict_error}\n")
                month_logfile.write('--------------------------------------------------------------\n')

            elif pos == 0:
                if detected_mode:
                    if consecutive_preds < 2:
                        consecutive_preds = 0
                        detected_mode = False
                        continue
                    # flare end detected
                    time_string = filepaths[i-sequence_length].rsplit('.', 1)[0].rsplit('/', 1)[-1].rsplit('_', 1)[0][3:]
                    year, month, day, hour, minute = int(time_string[:4]), int(time_string[4:6]), int(time_string[6:8]), int(time_string[9:11]), int(time_string[11:13])
                    predicted_end_time = datetime.datetime(year, month, day, hour, minute) + datetime.timedelta(minutes=-6)


                    # model = start_model
                    end_predict_error = abs(predicted_end_time-parse_time(closest_flare['event_endtime'], precision=0).datetime)
                    end_total_error_time += end_predict_error
                    month_average_end_error += end_predict_error

                    month_logfile.write(f'predicted flare END: {predicted_end_time} class {cur_flare_class}\n')
                    month_logfile.write(f"closest true flare END: {closest_flare['event_endtime']} class {closest_flare['fl_goescls']}\n")
                    month_logfile.write(f"error: {end_predict_error}\n")
                    month_logfile.write('--------------------------------------------------------------\n')
                consecutive_preds = 0
                detected_mode = False
        month_logfile.write(f'\naverage month start error time: {month_average_start_error/predicted_flares}\n')
        month_logfile.write(f'average month end error time: {month_average_end_error/predicted_flares}\n')
        month_logfile.close()
        month_pred_logfile.close()
    
    global_logfile.write(f'average start error time: {start_total_error_time/total_predicted_flares}\n')
    global_logfile.write(f'average end error time: {end_total_error_time/total_predicted_flares}\n')
    
    global_logfile.close()

In [23]:
def WriteVideo(frames_dir, video_name):
    paths = []
    for subdir, dirs, files in os.walk(frames_dir):
        for f in files:
            if f.rsplit('.', 1)[-1] == 'png':
                paths.append(os.path.join(subdir, f))
        break
    paths = sorted(paths, key=key_func)
    writer = imageio.get_writer(f'{VIDEOS_DIR}/{video_name}.mp4', fps=8)
    for file in paths:
        im = imageio.imread(file)
        if im.shape != (512, 512, 4):
            print(im.shape)
            continue
        writer.append_data(im)
    writer.close()

In [24]:
def WritePredictedImages(save_dir, title, filepaths, flare_classes, rect_size=64):
    delete_files(save_dir)
    images_raw = [np.load(x)['x'] for x in filepaths]
    
    for im in images_raw:
        irradiance = im.max()
        

        draw = ImageDraw.Draw(first_image_pil)
        if pos > 1:
            first_image_raw = past_images_raw[0]
            center_coord = GetImageTopNRegionsCoords(first_image_raw, 1)[0]
            tl, br = (center_coord[1]-rect_size//2, center_coord[0]-rect_size//2), (center_coord[1]+rect_size//2, center_coord[0]+rect_size//2)
            draw.rectangle((tl, br), outline="red")
        filename = filepaths[i-sequence_length].rsplit('/')[-1].rsplit('.', 1)[0]
        draw.text((12, 490),filename,(255,255,255))
        draw.text((400, 490),PosToClass(pos),(255,255,255))
        first_image_pil.save(f'{save_dir}/{i-sequence_length}.png', "PNG")
    logfile.close()

In [101]:
flare_classes=['H', 'M', 'X']
start_datetime, end_datetime = datetime.datetime(2013, 5, 1), datetime.datetime(2013, 5, 31)
filepaths = GetFilepathsBetweenDates(start_datetime, end_datetime, 6)

In [102]:
full_disk_model = tf.keras.models.load_model(f'./best_trained_models/ALL_lstm_data_nmx_during_leftout2013_cadence6.h5')
cutout_model = tf.keras.models.load_model(f'./best_trained_models/HMX_cadence6_frame6_binary.h5')

In [103]:
# WritePredictedImages(FRAMES_MARKED_DIR, 'new_NMX_test', filepaths, full_disk_model, flare_classes, image_size=64, sequence_length=5, rect_size=64)



In [104]:
# data_point = 'AIA20130109_1434_0094'
# full = []
# for subdir, dirs, files in os.walk(f'./new_data/cadence6_frame6/val/H/{data_point}/0/full'):
#     for f in files:
#         full.append(os.path.join(subdir, f))
# cutouts = []
# for subdir, dirs, files in os.walk(f'./new_data/cadence6_frame6/val/H/{data_point}/0/sequence'):
#     for f in files:
#         cutouts.append(os.path.join(subdir, f))
        
# full = sorted(full)
# cutouts = sorted(cutouts)
# full = [np.load(x) for x in full]
# cutouts = [np.load(x) for x in cutouts]
# full = [abs(abs(full[x]) - abs(full[x - 1])) for x in range(1, 6)]
# cutouts = [abs(abs(cutouts[x]) - abs(cutouts[x - 1])) for x in range(1, 6)]
# full = [cv2.resize(x, (64, 64), interpolation = cv2.INTER_AREA) for x in full]
# merged = np.concatenate([full, cutouts])
# merged = np.expand_dims(merged, 0)
# merged = np.expand_dims(merged, 4)
# cutout_model.predict(merged)

In [105]:
WritePredictedCutouts(CUTOUT_FRAMES_MARKED, filepaths, full_disk_model, cutout_model, flare_classes, image_size=64, sequence_length=5, stride=16)

In [106]:
WriteVideo(CUTOUT_FRAMES_MARKED, 'final_demo')

  im = imageio.imread(file)


In [36]:
# fig, axes = plt.subplots(2, 3, figsize=(10, 8))

# for idx, ax in enumerate(axes.flat):
#     ax.imshow(trc[2][idx], cmap='jet')
#     ax.set_title(f"Frame {idx + 1}")
#     ax.axis("off")

# plt.show()

In [None]:
# CalculateDetectionError(filepaths, start_model, end_model)

In [None]:
# CalculateDetectionErrorSingleModel(filepaths, start_model, flare_classes)

In [105]:
# data_title = 'ALL_lstm_data_nmx_during_leftout2013_cadence6'
# # start_model = tf.keras.models.load_model(f'./best_trained_models/ALL_lstm_data_hmx_new_during_leftout2013_cadence6.h5')
# # end_model = tf.keras.models.load_model(f'./best_trained_models/ALL_lstm_data_nmx_end_leftout2013_cadence6.h5')
# batch_size, sequence_length, image_size, num_classes = 128, 6, 64, 3
# # start_datetime, end_datetime = datetime.datetime(2016, 1, 1), datetime.datetime(2016, 1, 3)
# start_datetime, end_datetime = datetime.datetime(2013, 5, 2), datetime.datetime(2013, 5, 31)
# filepaths = GetFilepathsBetweenDates(start_datetime, end_datetime, 6)
# model = tf.keras.models.load_model(f'./best_trained_models/{data_title}.h5')
# # model = ConvLSTMModelAllClass(batch_size, image_size, sequence_length-1, num_classes)
# # model.load_weights(f'{LSTM_CHECKPOINTS_DIR}/full_image_duringflare_nmx_leftout2013_bidirectional_convlstm')
# # WriteStartEndImages(FRAMES_MARKED_DIR, data_title, filepaths, start_model, end_model, image_size=64, sequence_length=sequence_length-1, rect_size=64)

# WritePredictedImages(FRAMES_MARKED_DIR, data_title, filepaths, model, flare_classes, image_size=64, sequence_length=sequence_length-1, rect_size=64)
# # WritePredictedCutouts(CUTOUT_FRAMES_MARKED, filepaths, model, image_size=64, sequence_length=sequence_length-1)
# # video_title = f'full_image_{parse_time(start_datetime, precision=0).fits}~{parse_time(end_datetime, precision=0).fits}'
# WriteVideo(FRAMES_MARKED_DIR, data_title)

In [106]:
# model = ConvLSTMModelAllClass(batch_size, image_size, sequence_length)
# model.load_weights(f'{LSTM_CHECKPOINTS_DIR}/cutout_image_duringflare_bidirectional_convlstm')
# # c = WritePredictedCutouts(CUTOUT_FRAMES_MARKED, filepaths, model, image_size=64, sequence_length=5)
# video_title = f'full_image_{parse_time(start_datetime, precision=0).fits}~{parse_time(end_datetime, precision=0).fits}'
# WriteVideo(CUTOUT_FRAMES_MARKED, video_title)sequence_length

In [22]:
def CreateAIAVideo(start_datetime, end_datetime):
    # filepaths = GetFilepathsBetweenDates(start_datetime, end_datetime)
    # WriteImages(filepaths)
    WriteVideo(FRAMES_DIR)

In [None]:
def DrawRectangle(img, top_left, bottom_right, save_dir):
    result = img.copy()
    result = cv2.rectangle(result, coord1, coord2, color=(0, 0, 255), thickness=3)
    cv2.imwrite(save_dir, result)

In [None]:
def CreateAIAVideoPrediction(start_datetime, end_datetime, model):
    filepaths = GetFilepathsBetweenDates(start_datetime, end_datetime)
    return WritePredictedImages(filepaths, model)

In [107]:
# start_datetime, end_datetime = datetime.datetime(2016, 1, 1), datetime.datetime(2016, 1, 3)
# model = ConvLSTMModel(64)
# model.load_weights(f'{LSTM_CHECKPOINTS_DIR}/conv_lstm_trial_5')
# # end_model = ConvLSTMModel(64)
# # end_model.load_weights(f'{LSTM_CHECKPOINTS_DIR}/conv_lstm_trial_end_full')
# delete_files(FRAMES_MARKED_DIR)
# CreateAIAVideoPrediction(start_datetime, end_datetime, model)
# WriteVideo(FRAMES_MARKED_DIR)