## Define model architecture


## Load model

In [None]:
ckpt_path = get_best_ckpt(model,verbose=True)
# load_model(model, ckpt_path)
# load_best_model(model)

### Recons of inputs from training data

In [None]:
train_mean, train_std = dm.train_mean, dm.train_std
model.eval()
with torch.no_grad():
    for mode in ['train', 'val']:
        dl = getattr(model, f"{mode}_dataloader")()
        x,y = next(iter(dl))
        x = x.to(model.device)
        x_recon = model.generate(x)
        
        # unnormalize for visualization
        x = x.cpu()
        x_recon = x_recon.cpu()
        x_unnormed = unnormalize(x, train_mean, train_std)
        x_recon_unnormed = unnormalize(x_recon, train_mean, train_std)
        show_timgs(x_unnormed, title=f"{mode} dataset", cmap='gray')
        show_timgs(x_recon_unnormed, title=f"{mode}: recon", cmap='gray')
        show_timgs(LinearRescaler()(x_recon_unnormed), title=f"{mode}: recon", cmap='gray')
        
        # Print out
        info(x, f"{mode}_x")
        info(x_recon, f"{mode}_x_recon")
        print("===")
        info(x_unnormed, f"{mode}_x_unnormed")
        info(x_recon_unnormed, f"{mode}_x_recon_unnormed")
        
        # Log input-recon grid to TB
        input_grid = torchvision.utils.make_grid(x_unnormed) # (C, gridh, gridw)
        recon_grid = torchvision.utils.make_grid(x_recon_unnormed) # (C, gridh, gridw)
#         normed_recon_grid = torchvision.utils.make_grid(LinearRescaler()(x_recon_unnormed))
        grid = torch.cat([input_grid, recon_grid], dim=-1) #inputs | recons
        tb_logger.experiment.add_image(f"{mode}/recons", grid, global_step=0)
                            


## Visualize embeddings
- collect a batch of inputs -> encoder -> [mu, log_var] -> sample -> a batch of z's (embeddings)
- use tb logger


In [None]:
model.eval()
with torch.no_grad():
    x, y = next(iter(trainer.train_dataloader))
    mu, log_var = model.encode(x)
    z = model.reparameterize(mu, log_var)
#     out = model.get_embeddings(x) # dict of mu, log_var, z
#     z = out['z']
    
    # log embedding to tensorboard 
    writer = model.logger.experiment
    writer.add_embedding(z,
                         label_img=LinearRescaler()(x), 
                         metadata=y.tolist(),
                         global_step=trainer.global_step, #todo
                        )
    
    



## Visualize original images of the close neighbors in the latent space
- Compute pairwise distance using cosine similarity
- For each row (ie. a latent code), get the index of the smallest values. 
- Select the images in the batch x and visualize (can do this all in show_timgs)



In [None]:
from sklearn.metrics import pairwise_distances

In [None]:
model.eval()
with torch.no_grad():
    x, y = next(iter(trainer.train_dataloader))
    mu, log_var = model.encode(x)
    z = model.reparameterize(mu, log_var)
    #     out = model.get_embeddings(x) # dict of mu, log_var, z
    #     z = out['z']metric = 'cosine'
    pdists = pairwise_distances(z.numpy(), metric=metric)
    plt.imshow(pdists, cmap='gray')
    plt.title("Pairwise dists of z's")
    plt.axis('off')
    plt.show()
    
    # smaller values means closer in distance
    n_ngbrs = 5
    n_rows = 100
    
    selected_rows = np.random.choice(len(x), size=n_rows)
    for idx in selected_rows:
        args = np.argsort(pdists[idx])[:n_ngbrs]
#         print(args)
        show_timgs(LinearRescaler()(x[args]), cmap='gray', factor=2, 
                   nrows=1, title=f'Nearest of img {idx}')

In [None]:
# smaller values means closer in distance
n_ngbrs = 5
n_rows = 10
selected_rows = np.random.choice(len(x), size=n_rows)
for idx in selected_rows:
    args = np.argsort(pdists[idx])[:n_ngbrs]
    print(args)
    show_timgs(LinearRescaler()(x[args]), cmap='gray', factor=2, 
               nrows=1, title=f'Nearest of img {idx}')

In [None]:
np.argsort(pdists[1])[:n_ngbrs]

In [None]:
n_ngbrs = 5
args = np.argsort(pdists, axis=1)[:n_ngbrs]
print(args.shape)
# show_timgs(LinearRescaler()(x[args]), cmap='gray', factor=2, nrows=1)

## Latent Space Traversal
1. Linear traversal in a single dimension

In [None]:
chosen_dim = 0 # must be in range(latent_dim)
fixed_vec = torch.randn((1, model.latent_dim-1))
fixed_values = fixed_vec.repeat((n_samples,1))
n_samples = 16
zi_min, zi_max = -2,2
varying = torch.linspace(zi_min, zi_max, n_samples).view((-1,1))

varying.shape,fixed_values.shape





In [None]:
def construct_from(a_col:torch.Tensor, other_cols:torch.Tensor, ind):
    """
    Make a tensor from a column vector and a matrx containing all the other columns
    by inserting the `onc_column` at the final matrix's `ind`th column.
    """
    assert a_
    n_cols = 1 + 
    out = a_col.new_zeros((