Skip to content

Commit

Permalink
fixed example
Browse files Browse the repository at this point in the history
  • Loading branch information
ldeecke committed Jul 30, 2018
1 parent 351874d commit 4893f3d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
Binary file modified example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 25 additions & 12 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@


def main():
n, d = 400, 2
n, d = 300, 2

# generate some data points ..
data = torch.Tensor(n, d).normal_()
# .. as well as a random partition that ..
ids = np.random.choice(n, n//2, replace=False)
# .. is permuted to come from a non-standard Gaussian N(7, 16)
# .. and shift them around to non-standard Gaussians
data[:n//2] -= 1
data[:n//2] *= sqrt(3)
data[n//2:] += 1
Expand All @@ -29,25 +27,40 @@ def main():
n_components = 2
model = GaussianMixture(n_components, d)
model.fit(data)
# .. used to predict the data points that where shifted
# .. used to predict the data points as they where shifted
y = model.predict(data)
c = np.isin(np.where(y > 0), ids)

plot(data, y)


def plot(data, y):
n = y.shape[0]

fig, ax = plt.subplots(1, 1, figsize=(1.61803398875*4, 4))
ax.set_facecolor('#bbbbbb')
ax.set_xlabel("Dimension 1")
ax.set_ylabel("Dimension 2")

# plot the locations of all data points ..
ax.scatter(*data[:n//2].data.t(), color="#000000", s=3, alpha=.75, label="Ground-truth 1")
ax.scatter(*data[n//2:].data.t(), color="#ffffff", s=3, alpha=.75, label="Ground-truth 2")
for i, point in enumerate(data.data):
if i <= n//2:
# .. separating by ground truth ..
ax.scatter(*point, color="#000000", s=3, alpha=.75, zorder=n+i)
else:
ax.scatter(*point, color="#ffffff", s=3, alpha=.75, zorder=n+i)

if y[i] == 0:
# .. as well as predicted classes
ax.scatter(*point, zorder=i, color="#dbe9ff", alpha=.6, edgecolors=colors[1])
else:
ax.scatter(*point, zorder=i, color="#ffdbdb", alpha=.6, edgecolors=colors[5])

# .. and circle them according to their classification
ax.scatter(*data[np.where(y == 0)].data.t(), zorder=0, color="#dbe9ff", alpha=.6, edgecolors=colors[1], label="Predicted 1")
ax.scatter(*data[np.where(y == 1)].data.t(), zorder=0, color="#ffdbdb", alpha=.6, edgecolors=colors[5], label="Predicted 2")
handels = [plt.Line2D([0], [0], color='w', lw=4, label='Ground Truth 1'), plt.Line2D([0], [0], color='black', lw=4, label='Ground Truth 2'), plt.Line2D([0], [0], color=colors[1], lw=4, label='Predicted 1'), plt.Line2D([0], [0], color=colors[5], lw=4, label='Predicted 2')]
legend = ax.legend(loc="best", handles=handels)

ax.legend(loc="best")
plt.tight_layout()
plt.savefig("example.pdf")


if __name__ == "__main__":
main()

0 comments on commit 4893f3d

Please sign in to comment.