In [None]:
from image_classification_simulation.models.autoencoder_baseline import ConvAutoEncoder
from image_classification_simulation.data.office31_loader import Office31Loader
from image_classification_simulation.data.mnist_loader import MNISTLoader
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Loading Office31 data

hparams = {
    "loss": "MSELoss",
    "optimizer": "adam",
    "num_channels": 3,
    "num_filters": 16,
    "batch_size": 16,
}

test_loader = Office31Loader("../examples/data/domain_adaptation_images/amazon/images", hyper_params=hparams)
test_loader.setup(stage="fit")

In [None]:
#Training the AE from the main function
! main --data ../examples/data/domain_adaptation_images/amazon/images --output ./output --config ../examples/conv_ae/config.yaml --start-from-scratch

In [None]:
# To visualize input image vs reconstructed output from the AE

i = iter(test_loader.test_set.dataset)
img, label = next(i)

model = ConvAutoEncoder(hparams)
checkpoint = torch.load("./output/last_model/model.ckpt")
model.load_state_dict(checkpoint['state_dict'])
model.eval()

output = model(img)
print(output.shape)
output = output.detach().numpy()

fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(np.transpose(img, (1, 2, 0)))
ax1.set_title("Original input image")
ax2.imshow(np.transpose(output, (1, 2, 0)))
ax2.set_title("Reconstructed image")
