In [19]:
import torch
from datasets.mnist import MNIST
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import torchvision
import os
from ipywidgets import interact, FloatSlider
import pandas as pd
import numpy as np


from nn.enforced_ae import EnforcedAE
from nn.ae import AE
from train import train
from schedule import cosine_schedule


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   


In [2]:
train_dataset = MNIST(root='../../datasets', split='train', download=False, device=device, normalize=False, augment=True)
val_dataset = MNIST(root='../../datasets', split='val', download=False, device=device, normalize=False)
test_dataset = MNIST(root='../../datasets', split='test', download=False, device=device, normalize=False)

                                                        

In [90]:
experiment_name = 'base'
trial_name = 'ae_cnn20_schedules'
load=True

model = AE(in_channels=1, z_dim=20, cnn=True).to(device)

if load:
    model.load_state_dict(torch.load(f'out/models/{experiment_name}/{trial_name}/run_0.pt'))
else:

    optimiser = torch.optim.AdamW(model.parameters(), lr=1e-3)
    num_epochs = 200
    batch_size = 256
    compute_dtype = torch.bfloat16

    hyperparams = {
        'lr': cosine_schedule(base=1e-3, end=1e-4, T=num_epochs, warmup=10, flat_end=10),
        'wd': cosine_schedule(base=0.004, end=0.1, T=num_epochs)
    }

    trial_log_dir = f'out/logs/{experiment_name}/{trial_name}'
    run_no = 0
    while os.path.exists(trial_log_dir + f'/run_{run_no}'):
        run_no += 1
    writer = SummaryWriter(trial_log_dir + f'/run_{run_no}')
    save_dir = f'out/models/{experiment_name}/{trial_name}/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_dir=save_dir + f'run_{run_no}.pt'


    train(model, train_dataset, val_dataset, optimiser, num_epochs=num_epochs, batch_size=batch_size, writer=writer, compute_dtype=compute_dtype, save_dir=save_dir, epoch_hyperparams=hyperparams)

  model.load_state_dict(torch.load(f'out/models/{experiment_name}/{trial_name}/run_0.pt'))


In [96]:
experiment_name = 'base'
trial_name = 'base_cnn20_schedules'
load=True


model = EnforcedAE(in_channels=1, z_dim=20, cnn=True).to(device)

if load:
    model.load_state_dict(torch.load(f'out/models/{experiment_name}/{trial_name}/run_0.pt'))
else:

    optimiser = torch.optim.AdamW(model.parameters(), lr=1e-3)
    num_epochs = 200
    batch_size = 256
    compute_dtype = torch.bfloat16

    hyperparams = {
        'lr': cosine_schedule(base=1e-3, end=1e-4, T=num_epochs, warmup=10, flat_end=10),
        'wd': cosine_schedule(base=0.004, end=0.1, T=num_epochs)
    }

    trial_log_dir = f'out/logs/{experiment_name}/{trial_name}'
    run_no = 0
    while os.path.exists(trial_log_dir + f'/run_{run_no}'):
        run_no += 1
    writer = SummaryWriter(trial_log_dir + f'/run_{run_no}')
    save_dir = f'out/models/{experiment_name}/{trial_name}/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_dir=save_dir + f'run_{run_no}.pt'


    train(model, train_dataset, val_dataset, optimiser, num_epochs=num_epochs, batch_size=batch_size, writer=writer, compute_dtype=compute_dtype, save_dir=save_dir, epoch_hyperparams=hyperparams)

  model.load_state_dict(torch.load(f'out/models/{experiment_name}/{trial_name}/run_0.pt'))


In [100]:
# OFFSET LATENTS
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

images, _= next(iter(test_loader))
z = model.infer(images)
def plot_images_with_widgets(model, images):
    def update(z0=0., z1=0., z2=0., z3=0., z4=0.):
        latents = z.clone().detach()
        latents[:, :5] = torch.tensor([z0, z1, z2, z3, z4]).to(device)
        # images_hat = model.decode(latents).detach()
        images_hat = model.transform_images(images, torch.tensor([z0, z1, z2, z3, z4]).to(device))[0]

        import matplotlib.pyplot as plt

        # Plot original images
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        axes[0].imshow(torchvision.utils.make_grid(images.cpu(), nrow=1, padding=2, normalize=True).permute(1, 2, 0))
        axes[0].set_title('Input')
        axes[0].axis('off')

        axes[1].imshow(torchvision.utils.make_grid(images_hat.cpu(), nrow=1, padding=2, normalize=True).permute(1, 2, 0))
        axes[1].set_title('Target')
        axes[1].axis('off')

        plt.show()
    
    interact(update, 
             z0=FloatSlider(min=-3.0, max=3.0, step=0.1, value=z[:,0].item(), description='rotation'),
             z1=FloatSlider(min=-3.0, max=3.0, step=0.1, value=z[:,1].item(), description='x offset'),
             z2=FloatSlider(min=-3.0, max=3.0, step=0.1, value=z[:,2].item(), description='y offset'),
             z3=FloatSlider(min=-3.0, max=3.0, step=0.1, value=z[:,3].item(), description='scale'),
             z4=FloatSlider(min=-3.0, max=3.0, step=0.1, value=z[:,4].item(), description='shear'))

plot_images_with_widgets(model, images)

interactive(children=(FloatSlider(value=-7.502201333409175e-05, description='rotation', max=3.0, min=-3.0), Fl…

In [None]:

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False)

images, _= next(iter(test_loader))
def plot_images_with_widgets(model, images):
    def update(rotation=0., translate_x=0., translate_y=0., scale=1., shear=0.):
        action = torch.tensor([rotation, translate_x, translate_y, scale, shear]).to(device)
        images_aug, actions = model.transform_images(images, action)
        images_hat = model(images, action)[0]

        import matplotlib.pyplot as plt

        # Plot original images
        fig, axes = plt.subplots(1, 3, figsize=(12, 6))
        axes[0].imshow(torchvision.utils.make_grid(images.cpu(), nrow=4, padding=2, normalize=True).permute(1, 2, 0))
        axes[0].set_title('Original Inputs')
        axes[0].axis('off')

        axes[1].imshow(torchvision.utils.make_grid(images_aug.cpu(), nrow=4, padding=2, normalize=True).permute(1, 2, 0))
        axes[1].set_title('Transformed Targets')
        axes[1].axis('off')

        axes[2].imshow(torchvision.utils.make_grid(images_hat.cpu(), nrow=4, padding=2, normalize=True).permute(1, 2, 0))
        axes[2].set_title('Output Predictions')
        axes[2].axis('off')

        plt.show()
    
    interact(update, 
             rotation=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.),
             translate_x=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.),
             translate_y=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.),
             scale=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.),
             shear=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.))

plot_images_with_widgets(model, images)

In [11]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False)

images, _= next(iter(test_loader))
def plot_images_with_widgets(model, images):
    latents = model.infer(images)
    def update(rotation=0., translate_x=0., translate_y=0., scale=1., shear=0.):
        action = torch.tensor([rotation, translate_x, translate_y, scale, shear]).to(device)
        images_aug, actions = model.transform_images(images, action)
        images_hat = model(images, action)[0]

        import matplotlib.pyplot as plt

        # Plot original images
        fig, axes = plt.subplots(1, 3, figsize=(12, 6))
        axes[0].imshow(torchvision.utils.make_grid(images.cpu(), nrow=4, padding=2, normalize=True).permute(1, 2, 0))
        axes[0].set_title('Original Inputs')
        axes[0].axis('off')

        axes[1].imshow(torchvision.utils.make_grid(images_aug.cpu(), nrow=4, padding=2, normalize=True).permute(1, 2, 0))
        axes[1].set_title('Transformed Targets')
        axes[1].axis('off')

        axes[2].imshow(torchvision.utils.make_grid(images_hat.cpu(), nrow=4, padding=2, normalize=True).permute(1, 2, 0))
        axes[2].set_title('Output Predictions')
        axes[2].axis('off')

        plt.show()
    
    interact(update, 
             rotation=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.),
             translate_x=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.),
             translate_y=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.),
             scale=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.),
             shear=FloatSlider(min=-1.0, max=1.0, step=0.1, value=0.))

plot_images_with_widgets(model, images)

interactive(children=(FloatSlider(value=0.0, description='rotation', max=1.0, min=-1.0), FloatSlider(value=0.0…

In [92]:

def plot_images_with_widgets(model):
    idx = torch.randint(0, len(test_dataset), (1,))
    image = test_dataset[idx][0]
    z = model.infer(image)
    def update(z0=0., z1=0., z2=0., z3=0., z4=0., z5=0., z6=0., z7=0., z8=0., z9=0.):
        z[:, :10] = torch.tensor([z0, z1, z2, z3, z4, z5, z6, z7, z8, z9])
        x_hat = model.decode(z).detach()
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        axes[0].imshow(image.cpu().squeeze(0).permute(1, 2, 0), cmap='gray')
        axes[0].set_title('Original Image')
        axes[0].axis('off')

        axes[1].imshow(x_hat.cpu().squeeze(0).permute(1, 2, 0), cmap='gray')
        axes[1].set_title('Reconstructed Image')
        axes[1].axis('off')

        plt.show()

    sliders = {
        
        'z0': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,0].item()),
        'z1': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,1].item()),
        'z2': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,2].item()),
        'z3': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,3].item()),
        'z4': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,4].item()),
        'z5': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,5].item()),
        'z6': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,6].item()),
        'z7': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,7].item()),
        'z8': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,8].item()),
        'z9': FloatSlider(min=-3.0, max=3.0, step=0.01, value=z[:,9].item()),
    }

    interact(update, **sliders)

plot_images_with_widgets(model)

interactive(children=(FloatSlider(value=-1.7956098318099976, description='z0', max=3.0, min=-3.0, step=0.01), …