In [None]:
import sys
sys.path.append('../')

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from functools import partial

import matplotlib.pyplot as plt
import numpy as np

from src.models.net import *
from src.si import *
from src.util import util as ut

# the original interpolant was trained with mnist at time zero and svhn at time one
nsamples = 10
mnist_testloader = ut.get_mnist_test()
mnist_1000_testsamples = torch.stack([mnist_testloader.dataset[i][0] for i in range(nsamples)])
svhn_testloader = ut.get_svhn_test()
svhn_1000_testsamples = torch.stack([svhn_testloader.dataset[i][0] for i in range(nsamples)])

# save real samples
svhn_1000_testsamples_np = svhn_1000_testsamples.squeeze().numpy()
mnist_1000_testsamples_np = mnist_1000_testsamples.squeeze().numpy()

for i in range(nsamples):
    plt.imsave('svhn_{}.png'.format(i), svhn_1000_testsamples_np[i], cmap='gray')
    plt.imsave('mnist_{}.png'.format(i), mnist_1000_testsamples_np[i], cmap='gray')


model_names = ['encdec', 'lin', 'linsb', 'sqrt', 'squared']
SIs = [EncoderDecoderInterpolant, LinearInterpolant, LinearInterpolant, 
       partial(PolynomialInterpolant, p=0.5), partial(PolynomialInterpolant, p=2)]


for m in range(5):
    model = UNet(
        dim = 28,
        dim_mults = (1,2,),
        flash_attn = True,
        channels=1,
        resnet_block_groups=4,
        attn_dim_head=32,
        attn_heads=4,
    )

    model.load_state_dict(torch.load(f'../models/mnist_model_{model_names[m]}/epoch_500.pt'))
    
    model = ut.model_counter(model)
    model.eval()


    si = SIs[m](model)
    si.loss_type = 'velocity'
    if model_names[m] == 'linsb':
        si = make_noisy(si, 0.5)
    else:
        si = make_sinsq_noisy(si)
    
    x_initial_mnist = mnist_1000_testsamples
    gen_svhn_samples = si.sample(x_initial=x_initial_mnist, direction='f')
    print('svhn samples generated')
    print(f'NFE for svhn for model {model_names[m]}:', model.counter)
    # convert tensors to images
    gen_svhn_samples = gen_svhn_samples.squeeze().detach().numpy()
    # save images
    for i in range(nsamples):
        plt.imsave(f'../gen_samples/{model_names[m]}/svhn/{i}.png', gen_svhn_samples[i], cmap='gray')

    model.counter = 0
    x_initial_svhn = svhn_1000_testsamples
    gen_mnist_samples = si.sample(x_initial=x_initial_svhn, direction='r')
    print('mnist samples generated')
    print(f'NFE for mnist for model {model_names[m]}:', model.counter)
    # convert tensors to images
    gen_mnist_samples = gen_mnist_samples.squeeze().detach().numpy()
    # save images
    for i in range(nsamples):
        plt.imsave(f'../gen_samples/{model_names[m]}/mnist/{i}.png', gen_mnist_samples[i], cmap='gray')


svhn samples generated
NFE for svhn for model encdec: 200
mnist samples generated
NFE for mnist for model encdec: 218
