Skip to content

Commit

Permalink
Merge pull request #71 from iraikov/fix/resample_duplicates
Browse files Browse the repository at this point in the history
Ensure that solutions returned in the resample_x field are not duplicates
  • Loading branch information
iraikov committed Dec 28, 2023
2 parents 787d557 + cdda685 commit 2727ecd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
14 changes: 8 additions & 6 deletions dmosopt/MOASMO.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def epoch(
if Xinit is None:
Xinit, Yinit, C = yield

x = Xinit.copy().astype(np.float32)
y = Yinit.copy().astype(np.float32)
x_0 = Xinit.copy().astype(np.float32)
y_0 = Yinit.copy().astype(np.float32)

fsbm = None
if C is not None:
Expand All @@ -244,8 +244,8 @@ def epoch(
if feasibility_model:
logger.info(f"Constructing feasibility model...")
fsbm = LogisticFeasibilityModel(x, C)
x = x[feasible, :]
y = y[feasible, :]
x_0 = x_0[feasible, :]
y_0 = y_0[feasible, :]
except:
e = sys.exc_info()[0]
logger.warning(f"Unable to fit feasibility model: {e}")
Expand Down Expand Up @@ -309,7 +309,7 @@ def epoch(
nOutput,
xlb,
xub,
initial=(x, y),
initial=(x_0, y_0),
feasibility_model=fsbm,
logger=logger,
popsize=pop,
Expand Down Expand Up @@ -357,6 +357,9 @@ def epoch(
x_gen = res

if surrogate_method_name is not None:
is_duplicate = MOEA.get_duplicates(best_x, x_0)
best_x = best_x[~is_duplicate]
best_y = best_y[~is_duplicate]
D = MOEA.crowding_distance(best_y)
idxr = D.argsort()[::-1][:N_resample]
x_resample = best_x[idxr, :]
Expand Down Expand Up @@ -418,7 +421,6 @@ def train(
if logger is not None:
logger.info(f"Found {len(feasible)} feasible solutions")


x, y = MOEA.remove_duplicates(x, y)

# resolve shorthands
Expand Down
8 changes: 5 additions & 3 deletions dmosopt/MOEA.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,11 @@ def remove_worst(
return population_parm[0:pop, :], population_obj[0:pop, :], rank[0:pop]


def get_duplicates(X, eps=1e-16):
D = cdist(X, X)
D[np.triu_indices(len(X))] = np.inf
def get_duplicates(X, Y=None, eps=1e-16):
if Y is None:
Y = X
D = cdist(X, Y)
D[np.triu_indices(len(X), m=len(Y))] = np.inf
D[np.isnan(D)] = np.inf

is_duplicate = np.zeros((len(X),), dtype=bool)
Expand Down

0 comments on commit 2727ecd

Please sign in to comment.