In [None]:
import sys
!{sys.executable} -m pip install -q -U pip
!{sys.executable} -m pip install -q -U -r requirements.txt
!{sys.executable} -m pip install -q -U -r ./EPFD/requirements.txt
!{sys.executable} -m pip install -q -e ./EPFD/PyEnsemble

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import wandb
wandb.login()

run = wandb.init(project="htcv", entity="nicojahn")

# Download the data again and use it as artifact

In [None]:
import time
attempt = 60*5

label_artifact = None
while attempt:
    try:
        time.sleep(1)
        if label_artifact is None:
            label_artifact = run.use_artifact("nicojahn/htcv/validation_labels:latest", type="dataset")
        label_artifact_dir = label_artifact._default_root()
        
        import os
        if not os.path.exists(label_artifact_dir):
            label_artifact_dir = label_artifact.download()
            break
        else:
            continue_while = False
            for file in label_artifact._list():
                if not os.path.exists(label_artifact_dir + "/" + file):
                    continue_while = True
                    break
            if continue_while:
                attempt -= 1
                continue
            else:
                break
    except (ValueError, wandb.CommError) as e:
        attempt -= 1
assert attempt > 0

In [None]:
import time
attempt = 60*5

data_artifact = None
while attempt:
    try:
        time.sleep(1)
        if data_artifact is None:
            data_artifact = run.use_artifact("nicojahn/htcv/validation_predictions:latest", type="dataset")
        data_artifact_dir = data_artifact._default_root()
        
        import os
        if not os.path.exists(data_artifact_dir):
            data_artifact_dir = data_artifact.download()
            break
        else:
            continue_while = False
            for file in data_artifact._list():
                if not os.path.exists(data_artifact_dir + "/" + file):
                    continue_while = True
                    break
            if continue_while:
                attempt -= 1
                continue
            else:
                break
    except (ValueError, wandb.CommError) as e:
        attempt -= 1
assert attempt > 0

# Read the data and optimize on it

In [None]:
from pathlib import Path
import pickle
import numpy as np

validation_labels = [str(p) for p in Path(label_artifact_dir).glob("*")]
assert len(validation_labels) == 1
validation_labels = validation_labels[0]
with open(validation_labels, "rb") as file:
    validation_labels = pickle.load(file)


validation_predictions = [str(p) for p in Path(data_artifact_dir).glob("*")]
assert len(validation_predictions) == 1
validation_predictions = validation_predictions[0]

validation_prediction = []
with open(validation_predictions, "rb") as file:
    data = pickle.load(file)
    for k,v in data.items():
        validation_prediction += [v]

In [None]:
sys.path.insert(0, "EPFD")
from main import *

In [None]:
def accuracy(l,p):
    return np.sum(np.equal(l,p))/l.shape

def evaluate(solution_set, data, labels):
    solutions = []
    for model in solution_set:
        solutions += [data[model]]
    solution = np.median(np.asarray(solutions), axis=0)
    return accuracy(labels, solution)

In [None]:
from argparse import Namespace
config = wandb.config

lam = config["lambda"]
nb_pru = config["nb_pru"]
name_pru = config["name_pru"]
distributed = config["distributed"]
m = config["m"]

setting = {"nb_pru" : nb_pru, "name_pru": name_pru, "distributed": distributed, "lam": lam, "m":m}
args = Namespace(**setting)
indices = main(args, validation_labels, validation_prediction)
solution_set = np.asarray(list(data.keys()))[indices]
print(setting)
print(solution_set)
acc = evaluate(solution_set, data, validation_labels)
print(acc)
wandb.log({"accuracy": acc})

# Collect the images(predictions) of those models and store them in another artifact

In [None]:
for file in solution_set:
    label_artifact = run.use_artifact("nicojahn/htcv/predictions:latest", type="raw_dataset")
    path = label_artifact.get_path(file)
    downloaded_prediction = path.download("/tmp/%s/" % run.name)

In [None]:
final_artifact = wandb.Artifact("final_predictions", type="predictions")
final_artifact.add_dir("/tmp/%s/" % run.name)
run.log_artifact(final_artifact, aliases=run.name)
final_artifact.wait()

In [None]:
run.finish()