In [108]:
import os 
import glob

import rasterio
import pandas as pd 
import numpy as np
import torch
from tfrecord.torch.dataset import TFRecordDataset

In [121]:
CSV              = os.path.join( "..", "data", "wealth_index.csv" )
RECORDS_DIR      = os.path.join( "..", "data", "landsat_7", "" )
TIF_DIR          = os.path.join( "..", "data", "landsat_tif","" )
BANDS            = ['BLUE','GREEN','RED','NIR','SWIR1','SWIR2','TEMP1']
DESCRIPTOR       = {'filename':"byte", 
               'wealthpooled':"float",
               'system:index':"byte",
               'bounding_box':"byte",
               'BLUE':"float",
               'GREEN':"float",
               'RED':"float",
               'NIR':"float",
               'SWIR1':"float",
               'SWIR2':"float",
               'TEMP1':"float"
              }   

In [3]:
csv = pd.read_csv(CSV)

In [4]:
records = dict()
for year in csv.year.unique():
    records[year]=dict()
    sub_year = csv[ csv.year == year ]
    for country in sub_year.country.unique():
        sub_country = sub_year[ sub_year.country == country ].copy()
        pattern = RECORDS_DIR+"*"+str(country)+"_"+str(year)+"*.tfrecord"
        records[year][country] = glob.glob(pattern)# ?

In [5]:
for year in records:
    for country in records[year]:
        if records[year][country]==[]:
            print(year, country)

2014 kenya
2018 nigeria
2016 ethiopia
2019 ethiopia


In [44]:
# WE WILL COMPLETELY LOOP ON YEAR, COUNTRY PAIRS SOON AFTER
year = 2011
files = records[year]['angola']

In [122]:
# for country in files:
tfrecord_path = files[0]
#     index_path=None
     
dataset = TFRecordDataset(tfrecord_path, index_path, DESCRIPTOR)


In [125]:
def tensor_to_string(data, variable):
    filename = (data[variable].numpy())[0]
    return "".join([chr(item) for item in filename])

def update_csv(csv, idx, bounding_box, filename):
    csv.iloc[int(idx), csv.columns.get_loc('geometry')] = bounding_box
    csv.iloc[int(idx), csv.columns.get_loc('filename')] = filename
    return csv

def tfrecord_to_tif(data, filename):
    arrays = []
    for band in BANDS:
        new_arr = data[band][0].numpy().reshape((255,255))
        arrays.append(new_arr)
    arr = np.swapaxes(np.array(arrays), 0,2 )
    tif_path = TIF_DIR + filename
    transform = rasterio.Affine(1, 0, 0, 0, 1, 0)
    tif = rasterio.open(tif_path, 'w', driver='GTiff',
                            height = arr.shape[0], width = arr.shape[1],
                            count=7, dtype=str(arr.dtype),
                            crs='epsg:3857',
                            transform=transform)
    for i in range(len(BANDS)):
        tif.write(arr[:,:,i],i+1)
    tif.close()

In [126]:
loader = torch.utils.data.DataLoader(dataset, batch_size=1)
tfrecords = iter(loader)
while (data := next(tfrecords,None)) is not None:
    idx = tensor_to_string(data, "system:index")
    filename = tensor_to_string(data, "filename")
    bounding_box = tensor_to_string(data, "bounding_box")
    csv = update_csv(csv, idx, bounding_box, filename)
    tfrecord_to_tif(data, filename)