In [None]:
import sys
!{sys.executable} -m pip install -q -U pip
!{sys.executable} -m pip install -q -U -r requirements.txt
!{sys.executable} -m pip install -q -U -r ./EPFD/requirements.txt
!{sys.executable} -m pip install -q -e ./EPFD/PyEnsemble

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import wandb
wandb.login()

project = "htcv"
run = wandb.init(project=project, entity="nicojahn")

# Upload the data to W&B as artifact

In [None]:
label_artifact = wandb.Artifact("labels", type="raw_dataset")
label_artifact.add_file("data/2018_IEEE_GRSS_DFC_GT_TR.tif")
run.log_artifact(label_artifact)
label_artifact.wait()

data_artifact = wandb.Artifact("predictions", type="raw_dataset")
data_artifact.add_dir("data/ensemble/")
run.log_artifact(data_artifact)
data_artifact.wait()

# Download the data again and use it as artifact

In [None]:
import time
attempt = 60*5

label_artifact = None
while attempt:
    try:
        time.sleep(1)
        if label_artifact is None:
            label_artifact = run.use_artifact("nicojahn/" + project + "/labels:latest", type="raw_dataset")
        label_artifact_dir = label_artifact._default_root()
        
        import os
        if not os.path.exists(label_artifact_dir):
            label_artifact_dir = label_artifact.download()
            break
        else:
            continue_while = False
            for file in label_artifact._list():
                if not os.path.exists(label_artifact_dir + "/" + file):
                    continue_while = True
                    break
            if continue_while:
                attempt -= 1
                continue
            else:
                break
    except (ValueError, wandb.CommError) as e:
        attempt -= 1
assert attempt > 0

In [None]:
import time
attempt = 60*15

data_artifact = None
while attempt:
    try:
        time.sleep(1)
        if data_artifact is None:
            data_artifact = run.use_artifact("nicojahn/" + project + "/predictions:latest", type="raw_dataset")
        data_artifact_dir = data_artifact._default_root()
        
        import os
        if not os.path.exists(data_artifact_dir):
            data_artifact_dir = data_artifact.download()
            break
        else:
            continue_while = False
            for file in data_artifact._list():
                if not os.path.exists(data_artifact_dir + "/" + file):
                    continue_while = True
                    break
            if continue_while:
                attempt -= 1
                continue
            else:
                break
    except (ValueError, wandb.CommError) as e:
        attempt -= 1
assert attempt > 0

# Read the data and transform it into label and prediction maps

In [None]:
scale_up = False

In [None]:
from data_utils import *
from utils import *
import pickle

In [None]:
from pathlib import Path

label_file = [str(p) for p in Path(label_artifact_dir).glob("*")]
assert len(label_file) == 1
label_file = label_file[0]

train_samples = [str(p) for p in Path(data_artifact_dir).glob("*png")]
assert len(train_samples) == 1
train_samples = train_samples[0]

predictions = [str(p) for p in Path(data_artifact_dir).glob("*") if p.is_dir()]

In [None]:
labels = read_image(label_file, scale=False)
if not scale_up:
    labels = scale_image_down(labels)
print(labels.shape)
plt.imshow(labels)

In [None]:
color_codes = get_color_codes()

In [None]:
training_samples = read_image(train_samples)
print(training_samples.shape)
plt.imshow(training_samples)
plt.show()

if scale_up:
    training_samples = scale_image_up(training_samples)
    print(training_samples.shape)
    plt.imshow(training_samples)
    plt.show()

In [None]:
from matplotlib.colors import ListedColormap
color_list = ["#%02x%02x%02x"%tuple(v["rgb"]) for k,v in color_codes.items()]
custom_cmap = ListedColormap(color_list)
plt.imshow(labels, vmin=1, vmax=len(custom_cmap.colors), cmap=custom_cmap, interpolation='none')

In [None]:
which_training_samples = extract_validation_patch(training_samples)
which_training_samples.shape

In [None]:
red = np.asarray(np.where(which_training_samples[...,0] == 255))
blue = np.asarray(np.where(which_training_samples[...,2] == 100))

In [None]:
used = np.zeros(which_training_samples.shape[:2])
used[red[0],red[1]] = 1
plt.imshow(used, vmin=0, vmax=1, cmap=ListedColormap(["white", "red"]))

In [None]:
free = np.zeros(which_training_samples.shape[:2])
free[blue[0],blue[1]] = 1
plt.imshow(free, vmin=0, vmax=1, cmap=ListedColormap(["white", "blue"]))

In [None]:
assert which_training_samples.reshape(-1,3).shape[0] == labels[red[0], red[1]].shape[0] + labels[blue[0], blue[1]].shape[0]

In [None]:
# pixels available for validation
# filter for zeros
validation_labels = labels[blue[0], blue[1]]
pixels_with_classes = validation_labels!=0
validation_labels = validation_labels[pixels_with_classes]

assert 0 == np.sum(validation_labels==0)
print(validation_labels)
print(validation_labels.shape)

with open("validation_labels",'wb') as file:
    pickle.dump(validation_labels, file)

In [None]:
# pixels used for training
red = red[:, labels[red[0], red[1]]!=0]
labels[red[0], red[1]].shape

In [None]:
class_list = {repr(list(v["rgb"])):k for k,v in color_codes.items()}
class_list

In [None]:
validation_predictions = {}
for setting in predictions:
    for tree in sorted([str(p) for p in Path(setting).glob("*png")]):
        key = "/".join(tree.split("/")[-2:])
        
        # map color patches to class patches
        color_patch = extract_validation_patch(read_image(tree))[blue[0], blue[1]][pixels_with_classes]
        class_patch = np.zeros(color_patch.shape[0], dtype=np.uint8)
        for k,v in class_list.items():
            class_patch[np.all(color_patch==string_to_numpy(k), axis=1)] = v
        assert not np.isin(-1, class_patch)
        validation_predictions[key] = class_patch

with open("validation_predictions",'wb') as file:
    pickle.dump(validation_predictions, file)

# Upload the produced label and prediction maps again

In [None]:
label_artifact = wandb.Artifact("validation_labels", type="dataset")
label_artifact.add_file("validation_labels")
run.log_artifact(label_artifact)
label_artifact.wait()

data_artifact = wandb.Artifact("validation_predictions", type="dataset")
data_artifact.add_file("validation_predictions")
run.log_artifact(data_artifact)
data_artifact.wait()

In [None]:
run.finish()