# Notebook for texture synthesis for MNIST

In [1]:
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from tqdm import tqdm

# Define the transform to convert the images to tensors and normalize them
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts image to PyTorch tensor
])

# Load the MNIST dataset
mnist_train = MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test = MNIST(root='./data', train=False, download=True, transform=transform)


DATASET_IMAGES = 10000
TESTSET_IMAGES = 5000
# DataLoader
train_loader = DataLoader(mnist_train, batch_size=DATASET_IMAGES, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=TESTSET_IMAGES, shuffle=False)
 
# Move the data to the specified device (GPU 0)
training_images, training_labels = next(iter(train_loader))
test_images, test_labels = next(iter(test_loader))


### Define input args

In [2]:
import datetime
from synthesis_mnist import *
import os

window_size=(28,28)         # Generated image size, (height,width)
kernel_size=5               # history window size
seed_size=7


In [None]:

current_time = datetime.datetime.now()
formatted_time = current_time.strftime("%m-%d_%H:%M:%S")  # Format: month-day hour:minute
out_dir=f'outputs/mnist/SSD+SMT_{formatted_time}'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

ix = torch.randint(0, test_images.shape[0], (1,)).item()
print('randomly selected seed image:', ix)
test_image = test_images[ix].unsqueeze(0)
plt.figure(figsize=(2, 2))
plt.imshow(test_image.squeeze().squeeze().cpu(), vmin=0, vmax=1, cmap='grey')
plt.title('original image')
plt.axis('off')
plt.savefig(f'{out_dir}/original_image.png')




out_path=f'{out_dir}/final.png'

synthesized_texture = synthesize_texture(sample=training_images,
                                         test_sample=test_image,
                                         window_size=window_size, 
                                         kernel_size=kernel_size, 
                                         seed_size=seed_size,
                                         use_SMT_filter=False,
                                         out_dir=out_dir)


synthesized_texture_SMT = synthesize_texture(sample=training_images,
                                         test_sample=test_image,
                                         window_size=window_size, 
                                         kernel_size=kernel_size, 
                                         seed_size=seed_size,
                                         use_SMT_filter=True,
                                         out_dir=out_dir)







fig, [ax1, ax2, ax3] = plt.subplots(1,3)

# Place seed in center of window

ph, pw = (window_size[0] - seed_size + 1) // 2, (window_size[1] - seed_size + 1) // 2
original_seed = synthesized_texture[ph:ph+seed_size, pw:pw+seed_size]

ax1.imshow(original_seed.to('cpu'), vmin=0, vmax=1, cmap='grey')
ax1.set(title=f'seed: {seed_size}x{seed_size}')
ax1.set_xticks([])
ax1.set_yticks([])

ax2.imshow(synthesized_texture.to('cpu'), vmin=0, vmax=1, cmap='grey')
ax2.set(title=f'SSD Only. kernel={kernel_size}')
ax2.set_xticks([])
ax2.set_yticks([])

ax3.imshow(synthesized_texture_SMT.to('cpu'), vmin=0, vmax=1, cmap='grey')
ax3.set(title=f'SSD+SMT.  kernel={kernel_size}')
ax3.set_xticks([])
ax3.set_yticks([])


fig.savefig(out_path)


randomly selected seed image: 2849
Synthesis finished. Time used: 2.6s


  saved_variables = torch.load('SMT_tensors.pt', map_location=torch.device('cuda:0'))


printing SMT nearest neighbours... current window size: 7
printing SMT nearest neighbours... current window size: 9
printing SMT nearest neighbours... current window size: 11
printing SMT nearest neighbours... current window size: 13
printing SMT nearest neighbours... current window size: 15
printing SMT nearest neighbours... current window size: 17


In [4]:
fig.savefig(out_path)