In [1]:
from tqdm import trange
from collections import defaultdict
import torch
import numpy as np
import torch.nn.functional as F
from pymoo.factory import get_performance_indicator
from load_data import Dataset
@torch.no_grad()
def evaluate(hypernet, targetnet, loader, rays, device,epoch,name,n_tasks):
    hypernet.eval()
    results = defaultdict(list)
    loss_total = None
    for ray in rays:
        ray = torch.from_numpy(ray.astype(np.float32)).to(device)

        ray /= ray.sum()

        total = 0.0
        full_losses = []
        for batch in loader:
            hypernet.zero_grad()

            batch = (t.to(device) for t in batch)
            xs, ys = batch
            bs = len(ys)

            weights = hypernet(ray)
            pred = targetnet(xs, weights)

            # loss
            curr_losses = get_losses(pred, ys)
            # metrics
            ray = ray.squeeze(0)

            # losses
            full_losses.append(curr_losses.detach().cpu().numpy())
            total += bs
        if loss_total is None:
            loss_total = np.array(np.array(full_losses).mean(0).tolist(),dtype='float32')
        else:
            loss_total += np.array(np.array(full_losses).mean(0).tolist(),dtype='float32')
        results["ray"].append(ray.cpu().numpy().tolist())
        results["loss"].append(np.array(full_losses).mean(0).tolist())
    print("\n")
    print(str(name)+" losses at "+str(epoch)+":",loss_total/len(rays))
    hv = get_performance_indicator(
        "hv",
        ref_point=np.ones(
            n_tasks,
        ),
    )
    hv_result = hv.do(np.array(results["loss"]))
    results["hv"] = hv_result

    return results

def get_losses(pred, label):
    return F.mse_loss(pred, label, reduction="none").mean(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_set, val_set, test_set = Dataset("/home/ubuntu/long.hp/Jura/data/jura.arff").get_data()

bs = 32

train_loader = torch.utils.data.DataLoader(
        dataset=train_set, batch_size=bs, shuffle=True, num_workers=0
    )
val_loader = torch.utils.data.DataLoader(
        dataset=val_set, batch_size=bs, shuffle=True, num_workers=0
    )
test_loader = torch.utils.data.DataLoader(
        dataset=test_set, batch_size=bs, shuffle=False, num_workers=0
    )

In [3]:
from torch import nn
from models import HyperNet, TargetNet
from utils import get_device

In [4]:
n_mo_obj = 4
ref_point = [1]*n_mo_obj
n_tasks = n_mo_obj

device=get_device(gpus='3')


hnet = torch.load("/home/ubuntu/long.hp/Jura/save_models/Jura_MH_Freely_8_0.01_best.pt")
net: nn.Module = TargetNet()
hnet = hnet.to(device)
net = net.to(device)

In [5]:
from pymoo.factory import get_reference_directions
test_rays = get_reference_directions("das-dennis", 4, n_partitions=10).astype(
    np.float32
)

In [6]:
evaluate(hnet, net, test_loader, test_rays, device,0,'abc',n_tasks)





abc losses at 0: [0.05404604 0.08816412 0.07411575 0.08936162]

Compiled modules for significant speedup can not be used!
https://pymoo.org/installation.html#installation

from pymoo.config import Config
Config.show_compile_hint = False



defaultdict(list,
            {'ray': [[0.0, 0.0, 0.0, 1.0],
              [0.0, 0.0, 0.10000000149011612, 0.8999999761581421],
              [0.0, 0.0, 0.20000000298023224, 0.800000011920929],
              [0.0, 0.0, 0.30000001192092896, 0.699999988079071],
              [0.0, 0.0, 0.4000000059604645, 0.6000000238418579],
              [0.0, 0.0, 0.5, 0.5],
              [0.0, 0.0, 0.6000000238418579, 0.4000000059604645],
              [0.0, 0.0, 0.699999988079071, 0.30000001192092896],
              [0.0, 0.0, 0.800000011920929, 0.20000000298023224],
              [0.0, 0.0, 0.8999999761581421, 0.10000000149011612],
              [0.0, 0.0, 1.0, 0.0],
              [0.0, 0.10000000149011612, 0.0, 0.8999999761581421],
              [0.0,
               0.10000000149011612,
               0.10000000149011612,
               0.800000011920929],
              [0.0,
               0.10000000149011612,
               0.20000000298023224,
               0.699999988079071],
              [0