Skip to content

Commit

Permalink
plot reconstructions
Browse files Browse the repository at this point in the history
  • Loading branch information
dribnet committed Oct 1, 2023
1 parent f1f3f3d commit b1b4bdd
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion examples/autoencoder_fsq.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def iterate_dataset(data_loader):
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = DataLoader(
datasets.FashionMNIST(
datasets.MNIST(
root="~/data/fashion_mnist", train=True, download=True, transform=transform
),
batch_size=256,
Expand All @@ -92,3 +92,47 @@ def iterate_dataset(data_loader):
model = SimpleFSQAutoEncoder(levels).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
train(model, train_dataset, train_iterations=train_iter)

# ---- 8< -----

batch = next(iter(train_dataset))
img, _ = batch
img = img.to(device)
rec_x2 = model(img)

# Extracting recorded information
temp = rec_x2[0].cpu().detach().numpy()

import matplotlib.pyplot as plt

# Initializing subplot counter
counter = 1

# Plotting first five images of the last batch
for idx in range(5):
plt.subplot(2, 5, counter)
plt.title(f"index {idx}")
plt.imshow(temp[idx].reshape(28,28), cmap= 'gray')
plt.axis('off')

# Incrementing the subplot counter
counter+=1

# Iterating over first five
# images of the last batch

# Obtaining image from the dictionary
val = img.cpu()

for idx in range(5):
# Plotting image
plt.subplot(2,5,counter)
plt.imshow(val[idx].reshape(28, 28), cmap = 'gray')
plt.title("Original Image")
plt.axis('off')

# Incrementing subplot counter
counter+=1

plt.tight_layout()
plt.savefig('figgy2.png')

0 comments on commit b1b4bdd

Please sign in to comment.