<a href="https://colab.research.google.com/github/laure-delisle/cs159-uq/blob/main/BEN_seco.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install rasterio > /dev/null

## Dataset

### Dataset class

In [2]:
'''
Code from: seasonal-contrast (https://github.com/ServiceNow/seasonal-contrast/)
         + [Laure Delisle] modifications to data downloading / extracting
Author: Oscar Mañas
Date: May 2021
'''

import json
from pathlib import Path

import numpy as np
import rasterio
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_and_extract_archive, download_url, extract_archive

ALL_BANDS = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
RGB_BANDS = ['B04', 'B03', 'B02']

BAND_STATS = {
    'mean': {
        'B01': 340.76769064,
        'B02': 429.9430203,
        'B03': 614.21682446,
        'B04': 590.23569706,
        'B05': 950.68368468,
        'B06': 1792.46290469,
        'B07': 2075.46795189,
        'B08': 2218.94553375,
        'B8A': 2266.46036911,
        'B09': 2246.0605464,
        'B11': 1594.42694882,
        'B12': 1009.32729131
    },
    'std': {
        'B01': 554.81258967,
        'B02': 572.41639287,
        'B03': 582.87945694,
        'B04': 675.88746967,
        'B05': 729.89827633,
        'B06': 1096.01480586,
        'B07': 1273.45393088,
        'B08': 1365.45589904,
        'B8A': 1356.13789355,
        'B09': 1302.3292881,
        'B11': 1079.19066363,
        'B12': 818.86747235
    }
}

LABELS = [
    'Agro-forestry areas', 'Airports',
    'Annual crops associated with permanent crops', 'Bare rock',
    'Beaches, dunes, sands', 'Broad-leaved forest', 'Burnt areas',
    'Coastal lagoons', 'Complex cultivation patterns', 'Coniferous forest',
    'Construction sites', 'Continuous urban fabric',
    'Discontinuous urban fabric', 'Dump sites', 'Estuaries',
    'Fruit trees and berry plantations', 'Green urban areas',
    'Industrial or commercial units', 'Inland marshes', 'Intertidal flats',
    'Land principally occupied by agriculture, with significant areas of '
    'natural vegetation', 'Mineral extraction sites', 'Mixed forest',
    'Moors and heathland', 'Natural grassland', 'Non-irrigated arable land',
    'Olive groves', 'Pastures', 'Peatbogs', 'Permanently irrigated land',
    'Port areas', 'Rice fields', 'Road and rail networks and associated land',
    'Salines', 'Salt marshes', 'Sclerophyllous vegetation', 'Sea and ocean',
    'Sparsely vegetated areas', 'Sport and leisure facilities',
    'Transitional woodland/shrub', 'Vineyards', 'Water bodies', 'Water courses'
]

NEW_LABELS = [
    'Urban fabric',
    'Industrial or commercial units',
    'Arable land',
    'Permanent crops',
    'Pastures',
    'Complex cultivation patterns',
    'Land principally occupied by agriculture, with significant areas of natural vegetation',
    'Agro-forestry areas',
    'Broad-leaved forest',
    'Coniferous forest',
    'Mixed forest',
    'Natural grassland and sparsely vegetated areas',
    'Moors, heathland and sclerophyllous vegetation',
    'Transitional woodland/shrub',
    'Beaches, dunes, sands',
    'Inland wetlands',
    'Coastal wetlands',
    'Inland waters',
    'Marine waters'
]

GROUP_LABELS = {
    'Continuous urban fabric': 'Urban fabric',
    'Discontinuous urban fabric': 'Urban fabric',
    'Non-irrigated arable land': 'Arable land',
    'Permanently irrigated land': 'Arable land',
    'Rice fields': 'Arable land',
    'Vineyards': 'Permanent crops',
    'Fruit trees and berry plantations': 'Permanent crops',
    'Olive groves': 'Permanent crops',
    'Annual crops associated with permanent crops': 'Permanent crops',
    'Natural grassland': 'Natural grassland and sparsely vegetated areas',
    'Sparsely vegetated areas': 'Natural grassland and sparsely vegetated areas',
    'Moors and heathland': 'Moors, heathland and sclerophyllous vegetation',
    'Sclerophyllous vegetation': 'Moors, heathland and sclerophyllous vegetation',
    'Inland marshes': 'Inland wetlands',
    'Peatbogs': 'Inland wetlands',
    'Salt marshes': 'Coastal wetlands',
    'Salines': 'Coastal wetlands',
    'Water bodies': 'Inland waters',
    'Water courses': 'Inland waters',
    'Coastal lagoons': 'Marine waters',
    'Estuaries': 'Marine waters',
    'Sea and ocean': 'Marine waters'
}


def normalize(img, mean, std):
    min_value = mean - 2 * std
    max_value = mean + 2 * std
    img = (img - min_value) / (max_value - min_value) * 255.0
    img = np.clip(img, 0, 255).astype(np.uint8)
    return img


class Bigearthnet(Dataset):
    url = 'https://bigearth.net/downloads/BigEarthNet-S2-v1.0.tar.gz'
    subdir = 'BigEarthNet-v1.0'
    archive_name = 'BigEarthNet-S2-v1.0.tar.gz'
    list_file = {
        'train': 'https://storage.googleapis.com/remote_sensing_representations/bigearthnet-train.txt',
        'val': 'https://storage.googleapis.com/remote_sensing_representations/bigearthnet-val.txt',
        'test': 'https://storage.googleapis.com/remote_sensing_representations/bigearthnet-test.txt'
    }
    bad_patches = [
        'http://bigearth.net/static/documents/patches_with_seasonal_snow.csv',
        'http://bigearth.net/static/documents/patches_with_cloud_and_shadow.csv'
    ]

    def __init__(self, root, split, archive_root=None, bands=None, transform=None, target_transform=None,
                 download_archive=False, download_split_files=False, extract=False, use_new_labels=True):
        self.root = Path(root)
        self.archive_root = archive_root if archive_root is not None else self.root
        self.split = split
        self.bands = bands if bands is not None else RGB_BANDS
        self.transform = transform
        self.target_transform = target_transform
        self.use_new_labels = use_new_labels

        if download_archive:
            download_and_extract_archive(self.url, self.archive_root)
        
        if download_split_files:
            download_url(self.list_file[self.split], self.root, f'{self.split}.txt')
            for url in self.bad_patches:
                download_url(url, self.root)

        if extract:
            archive = os.path.join(self.archive_root, self.archive_name)
            extract_archive(from_path=archive, to_path=self.root)

        bad_patches = set()
        for url in self.bad_patches:
            filename = Path(url).name
            with open(self.root / filename) as f:
                bad_patches.update(f.read().splitlines())

        self.samples = []
        with open(self.root / f'{self.split}.txt') as f:
            for patch_id in f.read().splitlines():
                if patch_id not in bad_patches:
                    self.samples.append(self.root / self.subdir / patch_id)

    def __getitem__(self, index):
        path = self.samples[index]
        patch_id = path.name

        channels = []
        for b in self.bands:
            ch = rasterio.open(path / f'{patch_id}_{b}.tif').read(1)
            ch = normalize(ch, mean=BAND_STATS['mean'][b], std=BAND_STATS['std'][b])
            channels.append(ch)
        img = np.dstack(channels)
        img = Image.fromarray(img)

        with open(path / f'{patch_id}_labels_metadata.json', 'r') as f:
            labels = json.load(f)['labels']
        if self.use_new_labels:
            target = self.get_multihot_new(labels)
        else:
            target = self.get_multihot_old(labels)

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.samples)

    @staticmethod
    def get_multihot_old(labels):
        target = np.zeros((len(LABELS),), dtype=np.float32)
        for label in labels:
            target[LABELS.index(label)] = 1
        return target

    @staticmethod
    def get_multihot_new(labels):
        target = np.zeros((len(NEW_LABELS),), dtype=np.float32)
        for label in labels:
            if label in GROUP_LABELS:
                target[NEW_LABELS.index(GROUP_LABELS[label])] = 1
            elif label not in set(NEW_LABELS):
                continue
            else:
                target[NEW_LABELS.index(label)] = 1
        return target

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
data_root = '/content/dataset'
archive_root = '/content/drive/MyDrive/datasets/big_earth_net'
split = 'train'
bands = None # get RGB
download = True
use_new_labels = False

val_set = Bigearthnet(root=data_root, split=split, archive_root=archive_root, bands=bands, transform=None, target_transform=None,
                      download_archive=False, download_split_files=True, extract=True, use_new_labels=use_new_labels)

Downloading https://storage.googleapis.com/remote_sensing_representations/bigearthnet-train.txt to /content/dataset/train.txt


100%|██████████| 11565074/11565074 [00:00<00:00, 23268707.16it/s]


Downloading https://bigearth.net/static/documents/patches_with_seasonal_snow.csv to /content/dataset/patches_with_seasonal_snow.csv


100%|██████████| 2065439/2065439 [00:00<00:00, 19913201.62it/s]


Downloading https://bigearth.net/static/documents/patches_with_cloud_and_shadow.csv to /content/dataset/patches_with_cloud_and_shadow.csv


100%|██████████| 313331/313331 [00:00<00:00, 5771247.81it/s]


In [6]:
val_set[0]

(<PIL.Image.Image image mode=RGB size=120x120 at 0x7EFD84E5A290>,
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))

In [7]:
len(val_set)

311667

### Dataset subsets

In [8]:
labels = [
    'Agro-forestry areas', 'Airports',
    'Annual crops associated with permanent crops', 'Bare rock',
    'Beaches, dunes, sands', 'Broad-leaved forest', 'Burnt areas',
    'Coastal lagoons', 'Complex cultivation patterns', 'Coniferous forest',
    'Construction sites', 'Continuous urban fabric',
    'Discontinuous urban fabric', 'Dump sites', 'Estuaries',
    'Fruit trees and berry plantations', 'Green urban areas',
    'Industrial or commercial units', 'Inland marshes', 'Intertidal flats',
    'Land principally occupied by agriculture, with significant areas of '
    'natural vegetation', 'Mineral extraction sites', 'Mixed forest',
    'Moors and heathland', 'Natural grassland', 'Non-irrigated arable land',
    'Olive groves', 'Pastures', 'Peatbogs', 'Permanently irrigated land',
    'Port areas', 'Rice fields', 'Road and rail networks and associated land',
    'Salines', 'Salt marshes', 'Sclerophyllous vegetation', 'Sea and ocean',
    'Sparsely vegetated areas', 'Sport and leisure facilities',
    'Transitional woodland/shrub', 'Vineyards', 'Water bodies', 'Water courses'
]

In [9]:
label_to_id = {label:i for (label, i) in zip(labels, range(len(labels)))}
id_to_label = {i:label for (label, i) in zip(labels, range(len(labels)))}

In [10]:
forest_labels = ['Agro-forestry areas',
                 'Broad-leaved forest',
                 'Coniferous forest',
                 'Mixed forest']
water_labels = ['Water bodies',
                'Water courses']
bare_labels = ['Bare rock']
mine_labels = ['Mineral extraction sites']

forest_ids = [label_to_id[label] for label in forest_labels]
water_ids = [label_to_id[label] for label in water_labels]
bare_ids = [label_to_id[label] for label in bare_labels]
mine_ids = [label_to_id[label] for label in mine_labels]

In [11]:
def intersection(cat_labels, img_labels):
    return list(set(cat_labels) & set(img_labels))

def get_labels(img):
    return [id_to_label[id] for (id, i) in enumerate(img[1]) if i == 1]

In [12]:
# core
def has_X(img, labels):
    img_labels = get_labels(img)
    if intersection(labels, img_labels):
        return True
    else: return False

# facades
def has_forest(img):
    return has_X(img, forest_labels)

def has_water(img):
    return has_X(img, water_labels)

def has_bare(img):
    return has_X(img, bare_labels)
    
def has_mine(img):
    return has_X(img, mine_labels)

In [13]:
get_labels(val_set[0])

['Discontinuous urban fabric',
 'Industrial or commercial units',
 'Mixed forest',
 'Moors and heathland',
 'Pastures',
 'Road and rail networks and associated land']

In [14]:
get_labels(val_set[1])

['Non-irrigated arable land', 'Olive groves', 'Permanently irrigated land']

In [15]:
print("forest", [has_forest(val_set[i]) for i in range(2)])
print("water", [has_water(val_set[i]) for i in range(2)])
print("bare", [has_bare(val_set[i]) for i in range(2)])
print("mine", [has_mine(val_set[i]) for i in range(2)])

forest [True, False]
water [False, False]
bare [False, False]
mine [False, False]


In [None]:
forest_img_ids = [i for (i, img) in enumerate(val_set) if has_forest(img)]

In [None]:
import pickle
from tqdm import tqdm

# with open(os.path.join(archive_root,'forest_img_ids.pkl'), 'wb') as f:
#     pickle.dump(forest_img_ids, f)

# mine_img_ids = [i for i in tqdm(forest_img_ids) if has_mine(val_set[i])]
# with open(os.path.join(archive_root,'mine_img_ids.pkl'), 'wb') as f:
#     pickle.dump(mine_img_ids, f)

# bare_water_img_ids = [i for i in tqdm(forest_img_ids) if (has_water(val_set[i]) and has_bare(val_set[i]))]
# with open(os.path.join(archive_root,'bare_water_img_ids.pkl'), 'wb') as f:
#     pickle.dump(bare_water_img_ids, f)

100%|██████████| 202219/202219 [1:01:54<00:00, 54.44it/s]
100%|██████████| 202219/202219 [1:08:54<00:00, 48.91it/s]


In [16]:
import pickle
from tqdm import tqdm

## Load susbet id lists from pickle if existing, otherwise create
def load_from_pickle(list_name):
    file_path = os.path.join(archive_root, '{}.pkl'.format(list_name))
    if os.path.exists(file_path):
        with open(file_path, 'rb') as f:
            id_list = pickle.load(f)
        return id_list
    else:
        return None

# Forest
forest_img_ids = load_from_pickle('forest_img_ids')
if not forest_img_ids:
    forest_img_ids = [i for (i, img) in enumerate(val_set) if has_forest(img)]

# Mine + forest
mine_img_ids = load_from_pickle('mine_img_ids')
if not mine_img_ids:
    mine_img_ids = [i for i in tqdm(forest_img_ids) if has_mine(val_set[i])]

# Bare + water + forest
bare_water_img_ids = load_from_pickle('bare_water_img_ids')
if not bare_water_img_ids:
    bare_water_img_ids = [i for i in tqdm(forest_img_ids) if (has_water(val_set[i]) and has_bare(val_set[i]))]

In [22]:
print("Forest tiles:", len(forest_img_ids))
print("..with mines:", len(mine_img_ids))
print("..bare+water:", len(bare_water_img_ids))

Forest tiles: 202219
..with mines: 1649
..bare+water: 38


In [24]:
val_set[bare_water_img_ids[0]]

(<PIL.Image.Image image mode=RGB size=120x120 at 0x7EFD84E5B1C0>,
 array([0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 1., 0.], dtype=float32))

In [None]:
forest_dataset = val_se

## TODO:
- [x] create set with forest (one of)
- [x] create subset of forest with mines
- [x] create subset of forest with bare rock + water bodies (one of)
- [ ] plotting method
- [ ] evaluate overlap if subsets 1 and 2
- [ ] create prompts for mine
- [ ] CLIP embedding machinery
- [ ] cosine sim + radius + extraction of confidence score
- [ ] find a way to map cosine sim to [0,1] interval
- [ ] evaluate prompts vs mine --> is it enough for zero-shot
- [ ] evaluate prompts vs bare rock + water bodies --> does it pick up on these features? (check assumption that it has good intermediate representations)

use pseudo-labels + conf score to train 
- [ ] task 0: predict [1,0] (baseline, no uncertainty used)
- [ ] task 1: predict [uncertainty, 1-uncertainty]
- [ ] task 2: predict [1,0] + weight loss with uncertainty