In [382]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sunpy.time import parse_time
from sklearn import preprocessing

from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras.layers import *

import os
import shutil
import datetime
import sys
from pathlib import Path
from PIL import Image, ImageDraw, ImageEnhance
from matplotlib import cm
import cv2
import imageio

sys.path.append(os.path.join(Path.cwd(), 'utils'))
from utils.resnet_model import *
from utils.lstm_model import *
from utils.region_detector import *
from utils.im_utils import *
from utils.data_augmentation import *
from utils.convlstm_model import *

In [383]:
DATA_DIR = '../data'
VIDEOS_DIR = './videos'
FRAMES_DIR = './frames'
FRAMES_MARKED_DIR = './frames_marked'
LSTM_CHECKPOINTS_DIR = './checkpoints/lstm_checkpoints'
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [397]:
def delete_files(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))

In [384]:
# returns a formatted file name of the closest AIA data file to the given datetime

def GetClosestDataFileByDate(dt):
    AIA_data_date = f'{dt.year}{dt.month:02d}{dt.day:02d}'
    tmp_dt = dt
    minute = 0
    minute=GetClosestMultiple(tmp_dt.minute, 6)
    AIA_data_time = f'{tmp_dt.hour:02d}{minute:02d}'
    AIA_data_filename = f'AIA{AIA_data_date}_{AIA_data_time}_0094.npz'
    
    return AIA_data_filename

In [385]:
def GetAIAPathAtTime(dt):
    dt_data_dir = os.path.join(DATA_DIR, f'{dt.year}/{dt.month:02d}/{dt.day:02d}')
    closest_data_file = GetClosestDataFileByDate(dt)
    file_path = os.path.join(dt_data_dir, closest_data_file)
    if not os.path.exists(file_path):
        raise FileNotFoundError
    
    return file_path

In [386]:
def GetFilepathsBetweenDates(start_datetime, end_datetime):
    filepaths = []
    loop_date = start_datetime
    
    while(loop_date < end_datetime):
        closest_filepath = None
        try:
            closest_filepath = GetAIAPathAtTime(loop_date)
        except(FileNotFoundError):
            loop_date = loop_date + datetime.timedelta(minutes=6)
            continue
        filepaths.append(closest_filepath)
        loop_date = loop_date + datetime.timedelta(minutes=6)
    
    return filepaths

In [387]:
def key_func(x):
    return int(x.split('.')[-2].split('/')[-1])

In [388]:
def WriteImages(filepaths):
    frames = [Image.fromarray(np.uint8(np.load(x)['x'])) for x in filepaths]
    for i, frame in enumerate(frames):
        frame.save(f'{FRAMES_DIR}/{i}.png')

In [400]:
def WritePredictedImages(filepaths, model, threshold=0.9, rect_size=64):
    images_raw = [np.load(x)['x'] for x in filepaths]
    images = [cv2.resize(x, (64, 64), interpolation = cv2.INTER_AREA) for x in images_raw]
    for i in range(6, len(images)):
        past_images = images[i-6:i]+images[i-6:i]
        past_images = np.array(past_images)
        past_images = np.expand_dims(past_images, 0)
        past_images = np.expand_dims(past_images, 4)
        # now the shape is 1, 12, 64, 64, 1 of full sun images
        prediction = model.predict(past_images)[0][0]
        print(prediction)
        past_images_pil = [Image.fromarray(np.uint8(images_raw[i-x])).convert("RGBA") for x in range(6)]
        past_images_raw = [images_raw[i-x] for x in range(6)]
        if prediction > threshold:
            print(i)
            last_image_raw = past_images_raw[-1]
            last_image_pil = past_images_pil[-1]
            previous_images_pil = past_images_pil[:-1]
            center_coord = GetImageTopNRegionsCoords(last_image_raw, 1)[0]
            tl, br = (center_coord[1]-rect_size//2, center_coord[0]-rect_size//2), (center_coord[1]+rect_size//2, center_coord[0]+rect_size//2)
            draw = ImageDraw.Draw(last_image_pil)
            draw.rectangle((tl, br), outline="red")
            past_images_marked = previous_images_pil
            past_images_marked.append(last_image_pil)
            last_image_pil.save(f'{FRAMES_MARKED_DIR}/{i}.png', 'PNG')
            # for idx, img in enumerate(past_images_marked):
            #     img.save(f'{FRAMES_MARKED_DIR}/{i+idx}.png', "PNG")
        else:
            last_image_pil = past_images_pil[-1]
            last_image_pil.save(f'{FRAMES_MARKED_DIR}/{i}.png', "PNG")

In [401]:
def WriteVideo(frames_dir):
    paths = []
    for subdir, dirs, files in os.walk(frames_dir):
        for f in files:
            if f.rsplit('.', 1)[-1] == 'png':
                paths.append(os.path.join(subdir, f))
    paths = sorted(paths, key=key_func)
    writer = imageio.get_writer('trial.mp4', fps=20)
    for file in paths:
        im = imageio.imread(file)
        if im.shape != (512, 512, 4):
            print(im.shape)
            continue
        writer.append_data(im)
    writer.close()

In [402]:
def CreateAIAVideo(start_datetime, end_datetime):
    # filepaths = GetFilepathsBetweenDates(start_datetime, end_datetime)
    # WriteImages(filepaths)
    WriteVideo(FRAMES_DIR)

In [403]:
def DrawRectangle(img, top_left, bottom_right, save_dir):
    result = img.copy()
    result = cv2.rectangle(result, coord1, coord2, color=(0, 0, 255), thickness=3)
    cv2.imwrite(save_dir, result)

In [404]:
def CreateAIAVideoPrediction(start_datetime, end_datetime, model):
    filepaths = GetFilepathsBetweenDates(start_datetime, end_datetime)
    return WritePredictedImages(filepaths, model)

In [407]:
start_datetime, end_datetime = datetime.datetime(2016, 7, 10), datetime.datetime(2016, 7, 15)
model = ConvLSTMModel(64)
model.load_weights(f'{LSTM_CHECKPOINTS_DIR}/conv_lstm_trial_2')
end_model = ConvLSTMModel(64)
end_model.load_weights(f'{LSTM_CHECKPOINTS_DIR}/conv_lstm_trial_end_full')
# delete_files(FRAMES_MARKED_DIR)
CreateAIAVideoPrediction(start_datetime, end_datetime, end_model)
# WriteVideo(FRAMES_MARKED_DIR)

(None, 12, 64, 64, 1)
8.50987e-06
1.8577066e-05
3.2336673e-05
5.3288586e-05
0.00014753822
0.87758714
0.99499655
12
0.9998965
13
0.9999058
14
0.98686105
15
0.20175973
3.2021177e-05
7.415123e-05
5.771739e-05
4.8283814e-06
3.513906e-06
6.0418274e-06
7.1213954e-06
9.776375e-06
2.161619e-05
3.4450542e-05
9.772142e-06
1.2681395e-05
7.564563e-06
9.497829e-06
2.2620305e-05
9.348612e-05
0.003175824
0.002041967
0.006338264
0.004507764
0.0008284978
0.0003758087
7.116699e-05
0.00030851742
0.0005211669
0.0006006884
0.0008632874
0.00068029197
0.0005634989
0.00018583304
0.00012607097
0.000117069394
1.2842003e-05
1.3340848e-05
3.528565e-05
3.6677718e-06
2.0101925e-06
2.9665289e-06
7.05233e-06
6.957086e-06
1.36547915e-05
2.4298455e-05
1.03298635e-05
8.057801e-06
1.1733071e-05
1.9148993e-05
2.54623e-05
0.00013904642
0.0021843382
0.0006085763
9.281203e-05
3.7512287e-05
0.00016128384
6.263782e-05
3.9799303e-05
4.983123e-06
3.074619e-06
1.1938993e-06
4.4017872e-07
2.1627476e-07
4.3426041e-07
8.5477683e-07


  im = imageio.imread(file)
