In [257]:
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
from glob import glob
import numpy as np
import xarray as xr
import pandas as pd
import yaml
import imageio.v3 as iio
from PIL import Image

from convml_tt.data.dataset import TRIPLET_TILE_FILENAME_FORMAT, TRIPLET_TILE_IDENTIFIER_FORMAT

# Sample tiles from granules

- specify time span and region
- load scene ids
- functions needed from convml:
    - split_scene_ids - divides scenes into sets for train, study etc. Do I need this? Perhaps just divide by year.
    - tile_scene_splits - works out which scenes to get anchor and distant from 
    - generate_tile_locations - writes a yaml file of triplet locations for each scene and for whole dataset
- new functions needed:
    - tile scene splits as a function of tiles available
    - sample anchor and neighbour from dataframe list
    - 


In [2]:
# Save directory on Sense group workspace
save_dir = "/gws/nopw/j04/sensecdt/users/flojo/data/"

# Working folder
folder = "2020_test"
filepath = Path(save_dir+folder)

In [114]:
month = 7
satellite = 'aqua'

meta_filepath = filepath / str(month) / satellite / "meta"
granule_folder = filepath / str(month) / satellite / "granule_data"

In [4]:
with open(meta_filepath / "scene_ids.yaml", "r") as f:
    scene_ids = yaml.safe_load(f)

threshold = 0.25

# filter on day/night flag and valid pixels
filtered_scene_ids = {key:val for key, val in scene_ids.items() 
                        if val["day_night_flag"] == "DAY"
                        and val['valid_pixel_fraction'] > threshold
                        }
filtered_scene_ids

{'MYD021KM.A2020001.1350': {'day_night_flag': 'DAY',
  'filepaths': {'data': '/neodc/modis/data/MYD021KM/collection61/2020/01/01/MYD021KM.A2020001.1350.061.2020002190357.hdf',
   'geolocation': '/neodc/modis/data/MYD03/collection61/2020/01/01/MYD03.A2020001.1350.061.2020002170245.hdf',
   'seaice': '/badc/ecmwf-era5/data/oper/an_sfc/2020/01/01/ecmwf-era5_oper_an_sfc_202001011300.ci.nc'},
  'regions': {'barents_fraction': 0.0,
   'gin_fraction': 0.187,
   'labrador_fraction': 0.747},
  'valid_pixel_fraction': 0.934,
  'valid_tiles_64px': {'barents_sea': 0,
   'gin_seas': 170,
   'labrador_sea': 681,
   'total': 851}},
 'MYD021KM.A2020001.1525': {'day_night_flag': 'DAY',
  'filepaths': {'data': '/neodc/modis/data/MYD021KM/collection61/2020/01/01/MYD021KM.A2020001.1525.061.2020002190551.hdf',
   'geolocation': '/neodc/modis/data/MYD03/collection61/2020/01/01/MYD03.A2020001.1525.061.2020002170331.hdf',
   'seaice': '/badc/ecmwf-era5/data/oper/an_sfc/2020/01/01/ecmwf-era5_oper_an_sfc_202001

In [20]:
def filter_scene_ids(scene_ids, threshold=0.25, day_night_flag="DAY"):
    """ Filter list of granules from scene_ids.yaml 
    based on valid_pixel_fraction and day_night_flag"""
    filtered_scene_ids = {key:val for key, val in scene_ids.items() 
                        if val["day_night_flag"] == day_night_flag
                        and val['valid_pixel_fraction'] > threshold
                        }
    return filtered_scene_ids

### Check how much usable data there is for a given year


In [18]:
def check_valid_granules(filepath, tile_size=64, region='total', threshold=0.25):
    ''' Check how many granules contain valid tiles 
    for a given region, tile size and threshold '''
    granule_totals = {}
    for month in range(1,13):
        granule_totals[month] = {}
        for satellite in ['aqua', 'terra']:
            meta_filepath = filepath / str(month) / satellite / "meta"
            with open(meta_filepath / "scene_ids.yaml", "r") as f:
                scene_ids = yaml.safe_load(f)
            filtered_scene_ids = {key:val for key, val in scene_ids.items() 
                        if val["day_night_flag"] == "DAY"
                        and val['valid_pixel_fraction'] > threshold
                        and val[f'valid_tiles_{tile_size}px'][region] > 0
                        }
            granule_totals[month][satellite] = len(filtered_scene_ids)
    return granule_totals

In [15]:
def check_available_tiles(filepath, tile_size=64, region='total', threshold=0.25):
    ''' Given a particular tile_size, check number of available tiles
    each month from each satellite
    region = 'labrador_sea', 'gin_seas', 'barents_sea', 'total'
    '''
    available_tiles = {}
    for month in range(1,13):
        available_tiles[month] = {}
        for satellite in ['aqua', 'terra']:
            meta_filepath = filepath / str(month) / satellite / "meta"
            with open(meta_filepath / "scene_ids.yaml", "r") as f:
                scene_ids = yaml.safe_load(f)
            filtered_scene_ids = {key:val for key, val in scene_ids.items() 
                        if val["day_night_flag"] == "DAY"
                        and val['valid_pixel_fraction'] > threshold
                        }
            available_tiles[month][satellite] = sum([val[f"valid_tiles_{tile_size}px"][region] for val in filtered_scene_ids.values()])
    return available_tiles
    


In [19]:
check_valid_granules(filepath, region='labrador_sea')

{1: {'aqua': 67, 'terra': 72},
 2: {'aqua': 85, 'terra': 88},
 3: {'aqua': 97, 'terra': 103},
 4: {'aqua': 92, 'terra': 102},
 5: {'aqua': 102, 'terra': 108},
 6: {'aqua': 104, 'terra': 108},
 7: {'aqua': 111, 'terra': 112},
 8: {'aqua': 69, 'terra': 117},
 9: {'aqua': 100, 'terra': 112},
 10: {'aqua': 108, 'terra': 82},
 11: {'aqua': 77, 'terra': 40},
 12: {'aqua': 59, 'terra': 28}}

In [63]:
check_available_tiles(filepath, region='labrador_sea')

{1: {'aqua': 36671, 'terra': 38580},
 2: {'aqua': 42525, 'terra': 41610},
 3: {'aqua': 46234, 'terra': 45271},
 4: {'aqua': 45253, 'terra': 45736},
 5: {'aqua': 48410, 'terra': 49402},
 6: {'aqua': 47854, 'terra': 48667},
 7: {'aqua': 50668, 'terra': 50282},
 8: {'aqua': 32189, 'terra': 51642},
 9: {'aqua': 47962, 'terra': 49659},
 10: {'aqua': 51130, 'terra': 41229},
 11: {'aqua': 41233, 'terra': 20594},
 12: {'aqua': 32927, 'terra': 13037}}

### Generate random tiles

In [43]:
rng = np.random.default_rng(seed=42)

In [None]:
tile_size = 64

In [65]:
def open_tile_list(filepath, month, satellite, tile_size=64):
    ''' Open list of tiles for a given month and satellite
     as a pandas dataframe'''
    meta_filepath = filepath / str(month) / satellite / "meta"
    all_tiles_df = pd.read_pickle(meta_filepath / f"all_tiles_{tile_size}px.pkl")
    return all_tiles_df

In [115]:
all_tiles_df = open_tile_list(filepath, month, satellite, tile_size)
all_tiles_df

Unnamed: 0,scene_id,x_c,y_c,valid_tile,study_region,latitude,longitude
0,MYD021KM.A2020183.0015,31,31,False,0,82.645821,75.190819
1,MYD021KM.A2020183.0015,31,63,False,0,82.480171,73.370934
2,MYD021KM.A2020183.0015,31,95,False,0,82.305954,71.624756
3,MYD021KM.A2020183.0015,31,127,False,0,82.126732,69.960426
4,MYD021KM.A2020183.0015,31,159,False,0,81.939819,68.366127
...,...,...,...,...,...,...,...
277135,MYD021KM.A2020213.1645,479,1855,False,0,67.397057,-74.676147
277136,MYD021KM.A2020213.1645,479,1887,False,0,67.652031,-75.043175
277137,MYD021KM.A2020213.1645,479,1919,False,0,67.905800,-75.413780
277138,MYD021KM.A2020213.1645,479,1951,False,0,68.152893,-75.791489


In [185]:
def select_random_tile(all_tiles_df, region=None, rng=None):
    ''' Select a random valid tile from a dataframe
    Select only from a given region if specified
    1 = Labrador Sea
    2 = Greenland-Iceland-Norwegian Seas
    3 = Barents Sea
    otherwise select from any region. 
    '''
    if region is None:
        random_tile = all_tiles_df[all_tiles_df['valid_tile']].sample(1, random_state=rng)
    else:
        region_df = all_tiles_df[(all_tiles_df['valid_tile']) &
                                (all_tiles_df['study_region'] == region)]
        random_tile = region_df.sample(1, random_state=rng)
    return random_tile

### Make singlets

In [None]:
N_tiles = 24000


In [None]:
# output = dict of 
# - tile_id
# - scene_id
# - x_c and y_c

# need to merge together all tile lists

### Make triplets

In [204]:
# select anchor tile from Labrador Sea
anchor = select_random_tile(all_tiles_df, region=1, rng=rng)
anchor_x, anchor_y = anchor['x_c'].item(), anchor['y_c'].item()
anchor_x, anchor_y

(415, 895)

In [214]:
anchor

Unnamed: 0,scene_id,x_c,y_c,valid_tile,study_region,latitude,longitude
37041,MYD021KM.A2020187.1425,415,895,True,1,45.723568,-23.50029


In [221]:
def find_possible_neighbours(anchor, all_tiles_df, tile_size=64):
    ''' Find possible valid neighbour tiles for a given anchor tile'''
    anchor_scene_id = anchor['scene_id'].item()
    anchor_x, anchor_y = anchor['x_c'].item(), anchor['y_c'].item()
    offset = tile_size//2
    shifts = [(0,1),(0,-1),(1,0),(-1,0)]
    neighbour_coords = [(anchor_x + shift[0]*offset, 
                         anchor_y + shift[1]*offset) 
                         for shift in shifts]
    # check neighbour validity
    for neighbour in neighbour_coords:
        if all_tiles_df.query(f"scene_id == @anchor_scene_id and \
                              x_c == {neighbour[0]} and \
                              y_c == {neighbour[1]}")['valid_tile'].empty:
            neighbour_coords.remove(neighbour) # remove tiles which don't exist
        elif all_tiles_df.query(f"scene_id == @anchor_scene_id and \
                                x_c == {neighbour[0]} and \
                                y_c == {neighbour[1]}")['valid_tile'].item() == False:
            neighbour_coords.remove(neighbour) # remove invalid tiles
        else:
            continue
   
    return neighbour_coords

In [None]:
def generate_triplet_locations(month, satellite, N_triplets, region=None, rng=None):
    """ generate triplet location """
    all_tiles_df = open_tile_list(filepath, month, satellite, tile_size)
    all_tiles_df

In [222]:
find_possible_neighbours(anchor, all_tiles_df)

[(415, 927), (415, 863), (447, 895), (383, 895)]

In [46]:
# possible neighbours:
offset = tile_size//2
shifts = [(0,1),(0,-1),(1,0),(-1,0)]
neighbours = [(anchor_x + shift[0]*offset, anchor_y + shift[1]*offset) for shift in shifts]

neighbours

[(255, 735), (255, 671), (287, 703), (223, 703)]

In [None]:
# check neighbour validity
for neighbour in neighbours:
    if all_tiles_df.query(f"x_c == {neighbour[0]} and y_c == {neighbour[1]}")['valid_tile'].empty:
        neighbours.remove(neighbour)
    elif all_tiles_df.query(f"x_c == {neighbour[0]} and y_c == {neighbour[1]}")['valid_tile'].item() == False:
        neighbours.remove(neighbour)
    else:
        continue

# if len(neighbours) == 0: choose new anchor
# else choose one of the neighbours

In [None]:
neighbour = rng.choice(neighbours)
all_tiles_df.query('x_c == @neighbour[0] and y_c == @neighbour[1]')

In [56]:
# select distant tile from another month
remaining_months = [m for m in range(1,13) if m != month]


In [57]:
remaining_months

[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

In [59]:

distant_month = rng.choice(remaining_months)
distant_satellite = rng.choice(['aqua', 'terra'])
distant_month, distant_satellite

(np.int64(6), np.str_('aqua'))

In [223]:
N_triplets = 24000 # 2000 per month, 1000 per satellite 
tile_size = 64

In [224]:
all_scene_ids = {}
# how to pick even number of tiles across seasons, 
# but also pick distant tile from another month?

### Make tile images

In [225]:
anchor

Unnamed: 0,scene_id,x_c,y_c,valid_tile,study_region,latitude,longitude
37041,MYD021KM.A2020187.1425,415,895,True,1,45.723568,-23.50029


In [227]:
anchor_scene_id = anchor['scene_id'].item()
anchor_x, anchor_y = anchor['x_c'].item(), anchor['y_c'].item()

In [228]:
granule_data = xr.open_dataset(granule_folder / f"{anchor_scene_id}.nc")
granule_data

In [243]:
tile_rgb = granule_data["true_color"].isel(x=slice(anchor_x-tile_size//2, anchor_x+tile_size//2), 
                                y=slice(anchor_y-tile_size//2, anchor_y+tile_size//2))
tile_rgb

In [244]:
img = Image.fromarray(tile_rgb.values)

In [258]:
new_img = img.resize((256,256), Image.Resampling.NEAREST)

In [253]:
new_img.save("test2.png")

In [259]:
new_img.size

(256, 256)

In [None]:
# tile meta: what needs to be included?
tile_meta = dict(
                   loc=dict(x_c=anchor_x, 
                        y_c=anchor_y,
                        central_latitude = float(scene_ds.latitude.sel(x=int(x_c), y=int(y_c)).values),    
                        central_longitude = float(scene_ds.longitude.sel(x=int(x_c), y=int(y_c)).values)), 
                        scene_id=scene_id,
    )