In [10]:
from sunpy.visualization.colormaps import cm
from torchvision.utils import make_grid
from pathlib import Path
import numpy as np
import os 
import sunpy
from sunpy.visualization.colormaps import cm
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import torchvision
import torch
import pandas as pd
import torchvision.transforms.functional as F
from sdo.sood.data.sdo_ml_v2_dataset import get_default_transforms, SDOMLv2NumpyDataset
from torch.utils.data import DataLoader

#inspect an image

#Channels that correspond to HMI Magnetograms 
HMI_WL = ['Bx','By','Bz']
#A colormap for visualizing HMI
HMI_CM = LinearSegmentedColormap.from_list("bwrblack", ["#0000ff","#000000","#ff0000"])

def channel_to_map(name):
    """Given channel name, return colormap"""
    return HMI_CM if name in HMI_WL else cm.cmlist.get('sdoaia%d' % int(name))

def vis(X, cm):
    """Given image, colormap, and visualize results"""
    Xcv = cm(X)
    return (Xcv[:,:,:3]*255).astype(np.uint8)

def show_grid(imgs, ordered_dates, df, ncols=4, channel="171"):
    nrows=int(len(imgs)/ncols)
    if nrows <= 0:
        nrows = 1
        ncols = len(imgs)
    fix, axs = plt.subplots(figsize=(20,9), ncols=ncols, nrows=nrows, squeeze=True)
    row_index = 0
    i = 0
    for t_obs in ordered_dates[:10]:
        t_obs = t_obs.isoformat(timespec='milliseconds').replace("0+00:00", "Z")
        img = imgs[t_obs]
        row = df.loc[t_obs]
        col = i % ncols
        if i != 0 and i % ncols == 0:
            row_index = row_index + 1
        axs[row_index, col].imshow(img)
        img_name = f"{t_obs} {channel}A"
        score = row["score_norm"]
        axs[row_index, col].set_title(f"{img_name}\n with score " + "%.5f" % score)
        axs[row_index, col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        i += 1
    plt.show()

def visualize_batch(loader, ordered_dates, img_df):
    for batch_idx, samples in enumerate(loader):
        X, y = samples
        V = {}
        for x, t_obs in zip(X, y["T_OBS"]):
            x = x.permute(1,2,0) # torch to pillow
            x = np.squeeze(x.numpy())
            v = vis(x, channel_to_map(171))
            V[t_obs] = Image.fromarray(v)
        
        show_grid(V, ordered_dates, df, ncols=5)
        break   
        
        
def visualize_batch_norm(loader, ordered_times, df):
    for batch_idx, samples in enumerate(loader):
        X, y = samples
        V = {}
        for x, t_obs in zip(X, y["T_OBS"]):
            grid = make_grid(x, normalize=True, value_range=(-1.0, 1.0))
            ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(
                    1, 2, 0).to("cpu", torch.uint8).numpy()
            m = cm.cmlist.get('sdoaia%d' % int(171))
            v = np.squeeze(ndarr[:, :, 0])
            v = m(v)
            v = (v[:, :, :3]*255).astype(np.uint8)
            V[t_obs] = Image.fromarray(v)
        show_grid(V, ordered_times, df, ncols=5)
        break  
        
def anomaly_threshold(loader, ordered_times, df):
    for batch_idx, samples in enumerate(loader):
        X, y = samples
        V = {}
        for x, t_obs in zip(X, y["T_OBS"]):
            grid = make_grid(x, normalize=True)
            ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(
                        1, 2, 0).to("cpu", torch.uint8).numpy()
            lower = ndarr.mean() - 2 * ndarr.std()
            upper = ndarr.mean() + 2 * ndarr.std()
            print(lower, upper)
            
            ndarr[ndarr < upper] = 0
            #ndarr[ndarr >= upper] = 255
            
            ndarr = np.invert(ndarr)
            m = cm.cmlist.get('sdoaia%d' % int(171))
            v = np.squeeze(ndarr[:, :, 0])
            v = m(v)
            v = (v[:, :, :3]*255).astype(np.uint8)
            V[t_obs] = Image.fromarray(ndarr)
        show_grid(V, ordered_times, df, ncols=5)
        break
        
def spaced_obs_times(df, min_diff_seconds=24*60*60, min_size = 100):
    obs_times = []

    for index, row in df.iterrows():
        t_obs = row["t_obs"]
        has_close_neighbour = False
        for obs_time in obs_times:
            diff = abs((t_obs - obs_time).total_seconds())
            if diff < min_diff_seconds:
                has_close_neighbour = True
                #print(f"ignoring {t_obs} for diff {diff}")
                break
                
            
        if not has_close_neighbour:
            score = row["score_norm"]
            #print(f"found obs time {t_obs} with score {score}")
            obs_times.append(t_obs)
            
        if len(obs_times) >= min_size:
            break

    return obs_times

def get_data_loader(obs_times):
    storage_root = "/home/marius/data/sdomlv2_full/sdomlv2.zarr"
    storage_driver = "fs"
    cache_max_size = 1*1024*1024*2014
    test_year = ["2011", "2014", "2016", "2019"]
    channel= "171A"
    target_size = 256
    mask_limb = False
    mask_limb_radius_scale_factor = 1.0
    transforms = get_default_transforms(
                target_size=target_size, channel=channel, mask_limb=mask_limb, radius_scale_factor=mask_limb_radius_scale_factor)
    dataset = SDOMLv2NumpyDataset(
                storage_root=storage_root,
                storage_driver=storage_driver,
                cache_max_size=cache_max_size,
                year=test_year,
                start=None,
                end=None,
                freq=None,
                obs_times=obs_times[:10],
                irradiance=None,
                irradiance_channel=None,
                goes_cache_dir=None,
                channel=channel,
                transforms=transforms,
                reduce_memory=True
            )
    
    print(f"found dataset with size {len(dataset)}")
    loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0, pin_memory=False,
                          drop_last=False,
                          prefetch_factor=2)
    return loader

In [5]:
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

#default-256 outputs
sample_pred_path = "/home/marius/sdo-cli/output/predictions/20220803-111810_cevae/predictions.txt"

In [6]:
df = pd.read_csv(sample_pred_path)

df = df.sort_values(by=['score'], ascending=False)
df["score_norm"] = (df["score"]-df["score"].min())/(df["score"].max()-df["score"].min())
folder_time_format = "%Y%m%d-%H%M%S"
df["t_obs"] = pd.to_datetime(df["t_obs"])
#somehow some images in 2019 are duplicates?
df = df.drop_duplicates(subset=['t_obs'], keep='first')
df = df.set_index('t_obs', drop=False, verify_integrity=True)

top_obs_times = []

for index, row in df[:1000].iterrows():
    t_obs = row["t_obs"] # .isoformat(timespec='milliseconds').replace("+00:00", "Z") #.replace(microsecond=0)
    top_obs_times.append(t_obs)

spaced_top_obs_times_7d = spaced_obs_times(df, min_diff_seconds=24*60*60*7)

In [7]:
spaced_top_obs_times_7d[:10]

[Timestamp('2014-08-01 14:48:13.730000+0000', tz='UTC'),
 Timestamp('2014-01-07 10:12:13.730000+0000', tz='UTC'),
 Timestamp('2014-11-25 12:48:12.340000+0000', tz='UTC'),
 Timestamp('2014-10-30 04:42:13.470000+0000', tz='UTC'),
 Timestamp('2014-10-03 03:54:12.340000+0000', tz='UTC'),
 Timestamp('2014-12-21 13:00:13.740000+0000', tz='UTC'),
 Timestamp('2014-08-10 01:12:12.340000+0000', tz='UTC'),
 Timestamp('2014-07-11 03:00:13.360000+0000', tz='UTC'),
 Timestamp('2014-02-25 00:54:12.990000+0000', tz='UTC'),
 Timestamp('2014-12-08 19:12:12.340000+0000', tz='UTC')]

In [9]:
spaced_7d_top_loader = get_data_loader(spaced_top_obs_times_7d)


KeyboardInterrupt



In [None]:
visualize_batch_norm(spaced_7d_top_loader, spaced_top_obs_times_7d, df)

In [None]:
for batch_idx, samples in enumerate(spaced_7d_top_loader):
    X, y = samples
    V = {}
    for x, t_obs in zip(X, y["T_OBS"]):
        grid = make_grid(x, normalize=True, value_range=(-1.0, 1.0))
        ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(
                    1, 2, 0).to("cpu", torch.uint8).numpy()
        m = cm.cmlist.get('sdoaia%d' % int(171))
        v = np.squeeze(ndarr[:, :, 0])
        v = m(v)
        v = (v[:, :, :3]*255).astype(np.uint8)
        V[t_obs] = Image.fromarray(v)
    break

In [None]:
src_img_path = 
aia_wave = "171A"

src_images = list(Path(src_img_path).rglob(f'*__{aia_wave}_src.png'))
for src_img in src_images:
    header = get_meta_info(
        src_img.name, src_img_path / Path("meta.csv"))
    loader = HEKEventManager(db_connection_string)
    timestamp_str = src_img.name.split("__")[0]
    timestamp = dt.datetime.strptime(timestamp_str, date_format)
    events_df = loader.find_events_at(
        timestamp, observatory="SDO", instrument="AIA", event_types=hek_event_types)
    if len(events_df) < 1:
        logger.warn(
            f"no events found")
        continue
    # filter events that were observed in the respective wavelength, possibly also filter by feature extraction method
    events_df = events_df[events_df['obs_channelid'].str.contains(
        "171")]
    logger.info(
        f"after filter {len(events_df)} events")
    hek_bboxes, hek_polygons = convert_events_to_pixelunits(
        events_df, header)
    map_path = sood_map_path / Path(src_img.name)
    anomaly_boxes = extract_bounding_boxes_from_anomaly_map(
        map_path, mode="otsu", scale=8, gaussian_filter=False)
    save_fig_with_hek_bounding_boxes_and_anomalies(src_img, hek_bboxes, hek_polygons, anomaly_boxes, out_dir)