In [1]:
import os 
import glob
import gzip
import shutil
import torch
import rasterio
import pandas as pd 
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CSV              = os.path.join( "..", "data", "dataset_viirs_only.csv" )
RECORDS_DIR      = os.path.join( "..", "data", "landsat_7_less","" )
TIF_DIR          = os.path.join( "..", "data", "landsat_7_less","" )
BANDS            = ['BLUE','GREEN','RED','NIR','SWIR1','SWIR2','TEMP1','NIGHTLIGHTS']
DESCRIPTOR       = {
                'cluster':"float",
                'lat':"float", 
                "lon":"float",
                'wealthpooled':"float",
                'BLUE':"float",
                'GREEN':"float",
                'RED':"float",
                'NIR':"float",
                'SWIR1':"float",
                'SWIR2':"float",
                'TEMP1':"float",
                'NIGHTLIGHTS':"float"
              }   

In [3]:
csv=pd.read_csv(CSV)

In [4]:
records = dict()
for year in csv.year.unique():
    records[year]=dict()
    sub_year = csv[ csv.year == year ]
    for country in sub_year.country.unique():
        sub_country = sub_year[ sub_year.country == country ].copy()
        pattern = RECORDS_DIR+"*"+str(country)+"_"+str(year)+"/*.tfrecord*"
        records[year][country] = glob.glob(pattern)
records[2015]['angola'][:5]

['../data/landsat_7_less/angola_2015/604.tfrecord.gz',
 '../data/landsat_7_less/angola_2015/605.tfrecord.gz',
 '../data/landsat_7_less/angola_2015/606.tfrecord.gz',
 '../data/landsat_7_less/angola_2015/607.tfrecord.gz',
 '../data/landsat_7_less/angola_2015/608.tfrecord.gz']

In [5]:
def decompress_tfrecord(tfrecord_archive):
    with gzip.open(tfrecord_archive, 'rb') as f_in:
        # WITHOUT .GZ
        with open(tfrecord_archive[:-3], 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    return tfrecord_archive[:-3]

def tensor_to_string(data, variable):
    filename = (data[variable].numpy())[0][0]
    return str(filename).replace(".","")

def tfrecord_to_tif(data, filename, mins, maxs, minnl, maxnl):
    arrays = []
    for i in range(len(BANDS)):
        new_arr = data[BANDS[i]][0].numpy().reshape((255,255))
        arrays.append(new_arr)
        if i == len(BANDS)-1:
            minnl = min(minnl, new_arr.min())
            maxnl = max(maxnl, new_arr.max())
        else:
            mins[i] = min(mins[i], new_arr.min())
            maxs[i] = max(maxs[i], new_arr.max())

    arr = np.swapaxes(np.array(arrays), 0, 2 )
    tif_path = TIF_DIR + filename
    tif = rasterio.open(tif_path, 'w', driver='GTiff',
                            height = arr.shape[0], width = arr.shape[1],
                            count=8, dtype=str(arr.dtype),
                            crs='epsg:3857',
                            transform=None)
    for i in range(len(BANDS)):
        tif.write(arr[:,:,i],i+1)
    tif.close()

    return mins, maxs, minnl, maxnl

In [8]:
mins=[1e3, 1e3, 1e3, 1e3, 1e3, 1e3, 1e3] 
minviirs=1e3
maxs=[-1e3, -1e3, -1e3, -1e3, -1e3, -1e3, -1e3]
maxviirs=1e3

for year in records:
    print(year)
    for country in records[year]:
        if records[year][country]==[]:
            continue
        for tfrecord_archive in records[year][country]:
            if tfrecord_archive[-3:] == '.gz':
                tfrecord = decompress_tfrecord(tfrecord_archive=tfrecord_archive)
                tfrecord = tfrecord_archive[:-3]
            else:
                tfrecord = tfrecord_archive
            dataset = TFRecordDataset(tfrecord, index_path=None, description=DESCRIPTOR)
            loader = torch.utils.data.DataLoader(dataset, batch_size=1)
            iterator = iter(loader)
            while (data := next(iterator, None)) is not None:
                filename = str(country)+"_"+str(year)+"/"+tensor_to_string(data, "cluster")[:-1]+".tif"
                mins, maxs, minviirs, maxviirs = tfrecord_to_tif(data, filename, mins, maxs,minviirs, maxviirs)

2015
2013
2017
2014
2018
2016
2019


In [9]:
print(mins, maxs, minviirs, maxviirs)

[-0.2, -0.0641, -0.0866, -0.0308, -0.00245, 0.0, 0.0] [0.9576, 0.9212, 0.97355, 1.2277, 1.48375, 1.57635, 316.9] -0.07087274 3104.1401


In [None]:
# CHECK INTEGRITY
records = dict()
for year in csv.year.unique():
    records[year]=dict()
    sub_year = csv[ csv.year == year ]
    for country in sub_year.country.unique():
        sub_country = sub_year[ sub_year.country == country ].copy()
        pattern = RECORDS_DIR+"*"+str(country)+"_"+str(year)+"/*.tif"
        records[year][country] = glob.glob(pattern)
for year in records:
    print(year)
    for country in records[year]:
        if records[year][country]==[]:
            continue
        for tif in records[year][country]:
            tile = rasterio.open(tif) 
            tile_= tile.read()