In [1]:
from fastai import *
from fastai.vision import *

In [2]:
URLs.MNIST

In [3]:
path = untar_data(URLs.MNIST)
data = ImageDataBunch.from_folder(path, train='training', valid='testing')

In [32]:
!ls {path}

In [33]:
!ls {path/'training'}

In [34]:
data.show_batch()

In [35]:
isinstance(data, torch.utils.data.DataLoader)

In [36]:
class InvertedMNISTDataset(Dataset):
  def __init__(self, base_ds, exclude=(1,2,3,4,5), exclude_in_w=True):
    self.base_ds = base_ds
    self.exclude = exclude
    self.exclude_in_w = exclude_in_w
  def __len__(self):
    return len(self.base_ds)
  def __getitem__(self, index):
    img, cond = self.base_ds[index]
    cond = cond.data
    domain = 0
    if index % 2 and int(cond) not in self.exclude:
      img = 1 - img.data[0] # invert the image
      domain = 1
    else:
      img = img.data[0]
    cond_oh = torch.zeros(10, requires_grad=False)
    cond_oh[cond] = 1
    domain = torch.torch.tensor([domain], dtype=torch.float, requires_grad=False)
    return (img, cond_oh, domain)

In [37]:
inv_ds = InvertedMNISTDataset(data.train_ds)


In [38]:
from torchsupport.training.trvae import TransformingVAETraining
from torchsupport.modules.basic import MLP
from torch import nn

In [39]:
inv_ds[0][2]

In [40]:
class Encoder(nn.Module):
  def __init__(self, z=16, condition_size=10):
    super(Encoder, self).__init__()
    self.cond_size = condition_size
    self.z = z
    self.encoder = nn.Sequential(
      nn.Linear(28 * 28, 256),
      nn.LeakyReLU(),
      nn.Linear(256, 128),
      nn.LeakyReLU(),
    )
    self.conditioning = nn.Sequential(
        nn.Linear(128 + self.cond_size + 1, 64),
        nn.LeakyReLU(),
        nn.Linear(64, z),
        nn.LeakyReLU()
    )
    self.mean = nn.Linear(z, z)
    self.logvar = nn.Linear(z, z)

  def forward(self, inputs, condition, domain):
    inputs = inputs.view(inputs.size(0), -1)
    features = self.encoder(inputs)
    features = self.conditioning(torch.cat((features,condition, domain),dim=-1))
    mean = self.mean(features)
    logvar = self.logvar(features)
    return features, mean, logvar

In [41]:
class Decoder(nn.Module):
  def __init__(self, z=16, condition_size=10):
    super(Decoder, self).__init__()
    self.cond_size = condition_size
    self.z = z
    self.decoder = nn.Sequential(
      nn.Linear(64, 128),
      nn.LeakyReLU(),
      nn.Linear(128, 256),
      nn.LeakyReLU(),
      nn.Linear(256, 28 * 28)
    )
    self.conditioning = nn.Sequential(
        nn.Linear(self.z + self.cond_size + 1, 64),
        nn.LeakyReLU(),
        nn.Linear(64, 64),
        nn.LeakyReLU()
    )
  def forward(self, sample, condition, domain):
    cond_latent = self.conditioning(
        torch.cat((sample, condition, domain), dim=-1))
    data = self.decoder(cond_latent).view(-1, 28, 28)
    return data, cond_latent

In [42]:
def normalize(image):
  return (image - image.min()) / (image.max() - image.min())

In [43]:
class ImageTRVAETraining(TransformingVAETraining):
  def run_networks(self, data, *args):
    mean, logvar, reconstruction, data, domain = super().run_networks(data, *args)

    self.writer.add_image("target", normalize(data[0][None]), self.step_id)
    self.writer.add_image("reconstruction", normalize(reconstruction[0][0][None].sigmoid()), self.step_id)
    return mean, logvar, reconstruction, data, domain

In [44]:
!rm -r mnist_trvae/

In [45]:
training = ImageTRVAETraining(
    Encoder(), Decoder(), inv_ds, 
    network_name='mnist_trvae',
    device='cpu',
    batch_size=256,
    max_epochs=1,
    verbose=True,
    domain_scale=100
)

In [46]:
training.train()


In [0]:
!ls


In [0]:
def run_domain_change(trvae, data, target_domain):
    target_domain = torch.tensor(target_domain, dtype=torch.float)
    dat, cond, orig_domain = data
    if target_domain.shape != orig_domain.shape:
        target_domain = target_domain.expand(orig_domain.shape)
    _, mu, logvar = trvae.encoder(dat, cond, orig_domain)
    z = trvae.sample(mu, logvar)
    return trvae.decoder(z, cond, target_domain)[0]
    

In [0]:
inv_dl = DataLoader(inv_ds, batch_size=64)

In [0]:
test_batch = next(iter(inv_dl))

In [0]:
zeroed = run_domain_change(training, test_batch, 0)
oned = run_domain_change(training, test_batch, 1)

In [0]:
for i in zeroed[:8]:
    show_image(i[None])

In [0]:
for i in oned[:8]:
    show_image(i[None])