In [None]:
import wandb
wandb.login()

project = "htcv"
run = wandb.init(project=project, entity="nicojahn")
config = wandb.config

# Download the data again and use it as artifact

In [None]:
import time
attempt = 60*5

label_artifact = None
while attempt:
    try:
        time.sleep(1)
        if label_artifact is None:
            label_artifact = run.use_artifact("nicojahn/" + project + "/validation_labels:latest", type="dataset")
        label_artifact_dir = label_artifact._default_root()
        
        import os
        if not os.path.exists(label_artifact_dir):
            label_artifact_dir = label_artifact.download()
            break
        else:
            continue_while = False
            for file in label_artifact._list():
                if not os.path.exists(label_artifact_dir + "/" + file):
                    continue_while = True
                    break
            if continue_while:
                attempt -= 1
                continue
            else:
                break
    except (ValueError, wandb.CommError) as e:
        attempt -= 1
assert attempt > 0

In [None]:
import time
attempt = 60*5

data_artifact = None
while attempt:
    try:
        time.sleep(1)
        if data_artifact is None:
            data_artifact = run.use_artifact("nicojahn/" + project + "/validation_predictions:latest", type="dataset")
        data_artifact_dir = data_artifact._default_root()
        
        import os
        if not os.path.exists(data_artifact_dir):
            data_artifact_dir = data_artifact.download()
            break
        else:
            continue_while = False
            for file in data_artifact._list():
                if not os.path.exists(data_artifact_dir + "/" + file):
                    continue_while = True
                    break
            if continue_while:
                attempt -= 1
                continue
            else:
                break
    except (ValueError, wandb.CommError) as e:
        attempt -= 1
assert attempt > 0

# Read the data and optimize on it

In [None]:
from pathlib import Path
import pickle
import numpy as np

validation_labels = [str(p) for p in Path(label_artifact_dir).glob("*")]
assert len(validation_labels) == 1
validation_labels = validation_labels[0]
with open(validation_labels, "rb") as file:
    validation_labels = pickle.load(file)

# select only subset of the data
data_set_size = config["data_set_size"]
assert validation_labels.shape[0] >= data_set_size
indices = np.random.choice(validation_labels.shape[0], data_set_size, replace=False)
validation_labels = validation_labels[indices]

validation_predictions = [str(p) for p in Path(data_artifact_dir).glob("*")]
assert len(validation_predictions) == 1
validation_predictions = validation_predictions[0]

import re

models_regex = list(config["models"])
with open(validation_predictions, "rb") as file:
    data_replacement = dict()
    data = pickle.load(file)
    for k,v in data.items():
        if len(models_regex) > 0:
            for model in models_regex:
                matching = re.findall(model, k)
                if len(matching) > 0:
                    data_replacement[k] = v[indices]
        else:
            data_replacement[k] = v[indices]

    data = data_replacement
    validation_prediction = np.asarray(list(data.values()))

print(f"Pruning from {validation_prediction.shape[0]} models")

In [None]:
import sys
sys.path.insert(0, "EPFD")
from main import *

In [None]:
from scipy.stats import mode

def accuracy(l,p):
    return np.sum(np.equal(l,p))/l.shape

def get_prediction(v):
    """
    x, y = np.unique(v, return_counts=True)
    classes = x[np.argmax(y)]
    return classes
    """
    return mode(v, axis=1)[0].reshape(-1)

def evaluate(solution_set, data, labels):
    solutions = []
    for model in solution_set:
        solutions += [data[model]]
    prediction = get_prediction(np.asarray(solutions).T)
    return accuracy(labels, prediction), prediction

In [None]:
from argparse import Namespace
import time as t

lam = config["lambda"]
nb_pru = config["nb_pru"]
name_pru = config["name_pru"]
distributed = config["distributed"]
m = config["m"]
timeout = config["timeout"]

# You can't prune an ensemble to n models from less than n models  
assert validation_prediction.shape[0] > nb_pru

setting = {"nb_pru" : nb_pru, "name_pru": name_pru, "distributed": distributed, "lam": lam, "m":m}
args = Namespace(**setting)

# Process timing/stopping copied from: https://stackoverflow.com/a/14924210
import multiprocessing
import queue

Q = multiprocessing.Queue()
def main_wrapper(*args, **kwargs):
    Q.put(main(*args, **kwargs))
    
p = multiprocessing.Process(target=main_wrapper, args=(args, validation_labels, validation_prediction))
start = t.time()
p.start()

solution_set = []
try:
    # Wait for n seconds or until process finishes
    indices = Q.get(timeout=timeout)

    print(setting)
    p.join()
    solution_set = np.asarray(list(data.keys()))[indices]
    print(solution_set)

except queue.Empty:
    still_alive = p.is_alive()
    if still_alive:
        print("Process is still running. Killing it.")

        # Terminate - may not work if process is stuck for good
        #p.terminate()
        # OR Kill - will work for sure, no chance for process to finish nicely however
        p.kill()

wandb.log({"time": t.time()-start})

# Collect the images(predictions) of those models and store them in another artifact

In [None]:
table = wandb.Table(columns=["model_name"])
for file in solution_set:
    label_artifact = run.use_artifact("nicojahn/" + project + "/predictions:latest", type="raw_dataset")
    path = label_artifact.get_path(file)
    downloaded_prediction = path.download("/tmp/%s/" % run.name)
    
    table.add_data([file])

In [None]:
from data_utils import *
from utils import *
color_codes = get_color_codes()
class_list = {repr(list(v["rgb"])):k for k,v in color_codes.items()}

images = [str(p) for p in Path("/tmp/%s/" % run.name).rglob("*png")]

if len(images):
    # calculate the accuracy on the pruning data
    acc, prediction = evaluate(solution_set, data, validation_labels)
    print(acc)
    
    wandb.log({"accuracy": acc, "model_names": table})
    
    predictions = []
    for member in images:
        color_patch = read_image(member).reshape(-1, 3)
        class_patch = np.zeros(color_patch.shape[0], dtype=np.uint8)
        for k,v in class_list.items():
            class_patch[np.all(color_patch==string_to_numpy(k), axis=1)] = v
        predictions += [class_patch]
        
    predictions = np.asarray(predictions)
    predictions = get_prediction(predictions.T)
    predictions = predictions.reshape(1202, 4172)
    predictions = scale_image_up(predictions)
    
    from PIL import Image
    predictions = Image.fromarray(predictions)
    predictions.save("/tmp/%s/final_prediction.tif" % run.name)

    final_artifact = wandb.Artifact("final_prediction", type="predictions")
    final_artifact.add_file("/tmp/%s/final_prediction.tif" % run.name)
    run.log_artifact(final_artifact, aliases=run.name)
    final_artifact.wait()

In [None]:
run.finish()