In [3]:
import pandas as pd
import rasterio
import transforms as tf
import numpy as np
import os
from torch.utils.data import Dataset
import torch
from glob import glob
import warnings
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)


import transforms as tf
import preprocessing as pp


class SentinelDataset(Dataset):
    '''Sentinel 1 & 2 dataset.'''

    def __init__(self, tile_file, dir_tiles, dir_target,
                 max_chips=None, transform=None, device='cpu'
                 ):
        '''
        Args:
            tile_file -- path to csv file specifying chipid and month for each tile to be loaded
            dir_tiles -- path to directory containing Sentinel data tiles
            dir_target -- path to directory containing target data (AGWB) tiles
            max_chips -- maximum number of chips to load, used for testing, None --> load all
            transform -- transforms to apply to each sample/batch
            device -- device to load data onto ('cpu', 'mps', 'cuda')
        '''

        if tile_file:
            self.df_tile_list = pd.read_csv(tile_file, index_col=0)
        else:
            self.df_tile_list = self._make_df_tile_list(dir_tiles)
        if max_chips:
            self.df_tile_list = self.df_tile_list[:max_chips]
        self.dir_tiles = dir_tiles
        self.dir_target = dir_target
        self.device = device
        self.transform_s2 = tf.Sentinel2Scale()
        self.transform_s1 = tf.Sentinel1Scale()
        self.transform = transform

    def __len__(self):
        return len(self.df_tile_list)

    def __getitem__(self, idx):
        chipid, month = self.df_tile_list.iloc[idx][['chipid','month']].values
        # Sentinel 1
        try:
            s1_tile = self._load_sentinel_tiles('S1', chipid, month)
            s1_tile_scaled = self.transform_s1(s1_tile)
        except:
            # print(f'Data load failure for S1: {chipid} {month}')
            s1_tile_scaled = torch.full([4, 256, 256], torch.nan, dtype=torch.float32, requires_grad=False, device=self.device)
        # Sentinel 2
        try:
            s2_tile = self._load_sentinel_tiles('S2', chipid, month)
            s2_tile_scaled = self.transform_s2(s2_tile)
        except:
            # print(f'Data load failure for S2: {chipid} {month}')
            s2_tile_scaled = torch.full([11, 256, 256], torch.nan, dtype=torch.float32, requires_grad=False, device=self.device)

        sentinel_tile = torch.cat([s2_tile_scaled, s1_tile_scaled], axis=0)

        if self.dir_target:
            target_tile = self._load_agbm_tile(chipid)
        else:
            target_tile = torch.full([1, 256, 256], torch.nan, dtype=torch.float32, requires_grad=False, device=self.device)

        sample = {'image': sentinel_tile, 'label': target_tile} # 'image' and 'label' are used by torchgeo

        if self.transform:
            sample = self.transform(sample)

        return sample

    def _read_tif_to_tensor(self, tif_path):
        with rasterio.open(tif_path) as src:
            X = torch.tensor(src.read().astype(np.float32),
                             dtype=torch.float32,
                             device=self.device,
                             requires_grad=False,
                             )
        return X

    def _load_sentinel_tiles(self, sentinel_type, chipid, month):
        file_name = f'{chipid}_{sentinel_type}_{str(month).zfill(2)}.tif'
        tile_path = os.path.join(self.dir_tiles, file_name)
        return self._read_tif_to_tensor(tile_path)

    def _load_agbm_tile(self, chipid):
        target_path = os.path.join(self.dir_target,
                                   f'{chipid}_agbm.tif')
        return self._read_tif_to_tensor(target_path)

    def _make_df_tile_list(self, dir_tiles):
        tile_files = [
            os.path.basename(f).split('.')[0] for f in glob(f'{dir_tiles}/*.tif')
        ]
        tile_tuples = []
        for tile_file in tile_files:
            chipid, _, month = tile_file.split('_')
            tile_tuples.append(tuple([chipid, int(month)]))
        tile_tuples = list(set(tile_tuples))
        tile_tuples.sort()
        return pd.DataFrame(tile_tuples, columns=['chipid', 'month'])

In [97]:
# MPS Docs
# https://pytorch.org/docs/master/notes/mps.html

if torch.backends.mps.is_available(): # Mac M1/M2
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

loader_device = device # found that using cpu for data loading was faster than gpu (for my device)
print(f'training device: {device}')
print(f'loader_device: {loader_device}')

training device: cuda
loader_device: cuda


In [4]:
dir_tiles = '/workspace/BioMassters/data/train_features'
dir_target = '/workspace/BioMassters/data/train_agbm'
dir_saved_models = './trained_models'

In [8]:
max_chips = 4000 # number of chips to use from training set, None = Use All  

# A custom dataloader for Sentinel data 
dataset = SentinelDataset(tile_file=None, 
                             dir_tiles=dir_tiles, 
                             dir_target=dir_target,
                             max_chips=max_chips,
                             transform=None,
                             device=loader_device
                            )

In [9]:
dataset.df_tile_list

Unnamed: 0,chipid,month
0,0003d2eb,0
1,0003d2eb,1
2,0003d2eb,2
3,0003d2eb,3
4,0003d2eb,4
...,...,...
3995,0a096068,11
3996,0a105fce,0
3997,0a105fce,1
3998,0a105fce,2


In [98]:


class SentinelDatasetLstm(Dataset):
    '''Sentinel 1 & 2 dataset.'''

    def __init__(self, tile_file, dir_tiles, dir_target,
                 max_chips=None, transform=None, device='cpu'
                 ):
        '''
        Args:
            tile_file -- path to csv file specifying chipid and month for each tile to be loaded
            dir_tiles -- path to directory containing Sentinel data tiles
            dir_target -- path to directory containing target data (AGWB) tiles
            max_chips -- maximum number of chips to load, used for testing, None --> load all
            transform -- transforms to apply to each sample/batch
            device -- device to load data onto ('cpu', 'mps', 'cuda')
        '''

        self.df_tile_list = self._make_df_tile_list(dir_tiles)
        self.df_chip_list = self._make_df_chip_list()
        
        if max_chips:
            self.df_tile_list = self.df_tile_list[:max_chips]
        self.dir_tiles = dir_tiles
        self.dir_target = dir_target
        self.device = device
        self.transform_s2 = tf.Sentinel2Scale()
        self.transform_s1 = tf.Sentinel1Scale()
        self.transform = transform

    def __len__(self):
        return len(self.df_tile_list)

    def __getitem__(self, idx):
        chipid = self.df_chip_list.iloc[idx].values[0]
        
        chips = self.df_tile_list.query("chipid == '{}'".format(chipid))
        
        target_tile = None
        images = []
        
        pixels = {(i,j): {"data": [], 'target': None} for i in range(256) for j in range(256)}
        
        for idx in chips.index:
            sample = self.get_month(idx)
            img = sample['image']
            for i in range(256):
                for j in range(256):
                    pixels[(i, j)]['data'].append(img[:,i,j])
            
        for i in range(256):
            for j in range(256):
                label = sample['label']
                pixels[(i, j)]['target'] = sample['label'][:,i,j]
                    
        return pixels
        
    
    def get_month(self, idx):
        chipid, month = self.df_tile_list.iloc[idx][['chipid','month']].values
        # Sentinel 1
        try:
            s1_tile = self._load_sentinel_tiles('S1', chipid, month)
            s1_tile_scaled = self.transform_s1(s1_tile)
        except:
            # print(f'Data load failure for S1: {chipid} {month}')
            s1_tile_scaled = torch.full([4, 256, 256], torch.nan, dtype=torch.float32, requires_grad=False, device=self.device)
        # Sentinel 2
        try:
            s2_tile = self._load_sentinel_tiles('S2', chipid, month)
            s2_tile_scaled = self.transform_s2(s2_tile)
        except:
            # print(f'Data load failure for S2: {chipid} {month}')
            s2_tile_scaled = torch.full([11, 256, 256], torch.nan, dtype=torch.float32, requires_grad=False, device=self.device)

        sentinel_tile = torch.cat([s2_tile_scaled, s1_tile_scaled], axis=0)

        if self.dir_target:
            target_tile = self._load_agbm_tile(chipid)
        else:
            target_tile = torch.full([1, 256, 256], torch.nan, dtype=torch.float32, requires_grad=False, device=self.device)

        sample = {'image': sentinel_tile, 'label': target_tile} # 'image' and 'label' are used by torchgeo

        if self.transform:
            sample = self.transform(sample)

        return sample

    def _read_tif_to_tensor(self, tif_path):
        with rasterio.open(tif_path) as src:
            X = torch.tensor(src.read().astype(np.float32),
                             dtype=torch.float32,
                             device=self.device,
                             requires_grad=False,
                             )
        return X

    def _load_sentinel_tiles(self, sentinel_type, chipid, month):
        file_name = f'{chipid}_{sentinel_type}_{str(month).zfill(2)}.tif'
        tile_path = os.path.join(self.dir_tiles, file_name)
        return self._read_tif_to_tensor(tile_path)

    def _load_agbm_tile(self, chipid):
        target_path = os.path.join(self.dir_target,
                                   f'{chipid}_agbm.tif')
        return self._read_tif_to_tensor(target_path)

    def _make_df_tile_list(self, dir_tiles):
        tile_files = [
            os.path.basename(f).split('.')[0] for f in glob(f'{dir_tiles}/*.tif')
        ]
        tile_tuples = []
        for tile_file in tile_files:
            chipid, _, month = tile_file.split('_')
            tile_tuples.append(tuple([chipid, int(month)]))
        tile_tuples = list(set(tile_tuples))
        tile_tuples.sort()
        return pd.DataFrame(tile_tuples, columns=['chipid', 'month'])
    
    
    def _make_df_chip_list(self):
       
        values = self.df_tile_list['chipid'].unique()
        
        df = df_chip_list = pd.DataFrame()
        df['chipid'] = values
        
        return df

In [99]:
max_chips = 4000 # number of chips to use from training set, None = Use All  

# A custom dataloader for Sentinel data 
dataset = SentinelDatasetLstm(tile_file=None, 
                             dir_tiles=dir_tiles, 
                             dir_target=dir_target,
                             max_chips=max_chips,
                             transform=None,
                             device=loader_device
                            )

In [76]:
dataset.df_chip_list

Unnamed: 0,chipid
0,0003d2eb
1,000aa810
2,000d7e33
3,00184691
4,001b0634
...,...
6189,b6634e26
6190,b674c32b
6191,b67701d2
6192,b678118d


In [100]:
t = dataset.__getitem__(1)

In [102]:
torch.vstack(t[0,0]['data']).view(15,12)[0,:]

tensor([   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
           nan,    nan, 0.5120], device='cuda:0')

In [60]:
len(t[0])

12

In [63]:
tensor = t[0][0]

In [65]:
tensor.shape

torch.Size([15, 256, 256])

In [66]:
tensor[:,0,0]

tensor([0.0240, 0.0892, 0.0868, 0.2828, 0.5828, 0.6388, 0.7395, 0.7517, 0.4695,
        0.2295, 0.0400, 0.5294, 0.4532, 0.5579, 0.4590])

In [73]:
t[1][:,0,0].float()

tensor([6.7300])