In [None]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm

batch_size = 32
epochs = 100
seed = 1

context_num = 10 # length of contexts

torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [None]:
from Tars.distributions import Normal, Bernoulli
from Tars.distributions.divergences import KullbackLeibler
from Tars.models import VAE

In [None]:
kwargs = {'num_workers': 1, 'pin_memory': True}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

In [None]:
import numpy as np

def get_coordinated_mnist(x):
    # input (torch.Tensor) : x(batch_size, 1, 28, 28)
    # output（numpy）: x_coordinate(seq_len=28*28, batch_size, 2), y_coordinate(seq_len=28*28, batch_size, 1)
    
    batch_size = x.shape[0]
    
    X = np.tile(np.arange(28), (batch_size, 1))
    Y = np.tile(np.arange(28), (1, 1))
    xx, yy = np.asarray(np.meshgrid(X,Y))
    xx = xx.reshape(1,batch_size,28*28).T
    yy = yy.reshape(28*28,batch_size, 1)
    
    x_coordinate = np.concatenate([xx, yy],axis=2)
    y_coordinate = x.view(-1, 784).t().data.numpy()[:,:,np.newaxis]

    return x_coordinate, y_coordinate

def shuffle_dim(x, y, random_seed=1234, shuffle_each_example=True):

    if shuffle_each_example:
        x = x.swapaxes(0, 1)
        y = y.swapaxes(0, 1)

        dim0, dim1, _ = x.shape
        indices = np.indices((dim0, dim1))
        rows = indices[0]
    
        cols = [np.random.permutation(dim1) for _ in range(dim0)]
        x = x[rows, cols]
        y = y[rows, cols]
    
        x = x.swapaxes(0, 1)
        y = y.swapaxes(0, 1)
        
    else:
        np.random.seed(1234)
        dim0, _, _ = x.shape
        rows = np.random.permutation(dim0)
        x = x[rows]
        y = y[rows]
    
    return x, y

def split_context_target(x, y, context_num=100, device="cpu", shuffle_each_example=True):
    x, y = shuffle_dim(x, y, shuffle_each_example=shuffle_each_example)
    x_context = torch.Tensor(x[:context_num]).to(device)
    y_context = torch.Tensor(y[:context_num]).to(device)
    
    x_target = torch.Tensor(x[context_num:]).to(device)
    y_target = torch.Tensor(y[context_num:]).to(device)
#    x_target = torch.Tensor(x).to(device)
#    y_target = torch.Tensor(y).to(device)
    
    return x_context, y_context, x_target, y_target

def convert_plot_img(x, y):
    seq_len = x.shape[0]
    x = x.astype("int32")
    
    dummy_data = np.ones((28, 28),dtype="float32")*-1.    
    for i in range(seq_len):
        dummy_data[tuple(x[i])] = y[i]
    return dummy_data

In [None]:
import io
import PIL.Image
from torchvision.transforms import ToTensor

def plot_image(x_c, y_c, x_t, y_t, image_id=0):

    plt.subplot(1,5,1)
    img = convert_plot_img(x_t[:,image_id].cpu().data.numpy(), y_t[:,image_id].cpu().data.numpy())
    plt.imshow(img.T, cmap="gray")

    plt.subplot(1,5,2)
    z = q.sample({"x_c": x_c, "y_c": y_c, "x_t": x_t, "y_t": y_t})["z"]
    pred_y_t = _p.sample_mean({"x_t": x_t, "z": z})
    img = convert_plot_img(x_t[:,image_id].cpu().data.numpy(), pred_y_t[:,image_id].cpu().data.numpy())
    plt.imshow(img.T, cmap="gray")
    
    plt.subplot(1,5,3)
    z = q.sample({"x_c": x_c, "y_c": y_c, "x_t": x_t, "y_t": y_t})["z"]
    pred_y_t = _p.sample_mean({"x_t": x_t, "z": z})
    img = convert_plot_img(x_t[:,image_id].cpu().data.numpy(), pred_y_t[:,image_id].cpu().data.numpy())
    plt.imshow(img.T, cmap="gray")
    
    plt.subplot(1,5,4)
    z = q.sample({"x_c": x_c, "y_c": y_c, "x_t": x_t, "y_t": y_t})["z"]
    pred_y_t = _p.sample_mean({"x_t": x_t, "z": z})
    img = convert_plot_img(x_t[:,image_id].cpu().data.numpy(), pred_y_t[:,image_id].cpu().data.numpy())
    plt.imshow(img.T, cmap="gray")    


    plt.subplot(1,5,5)
    img = convert_plot_img(x_c[:,image_id].cpu().data.numpy(), y_c[:,image_id].cpu().data.numpy())
    plt.imshow(img.T, cmap="gray")  

    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg')
    buf.seek(0)    
    return buf

In [None]:
x_dim = 2 #x_cordinate_dim
y_dim = 1
z_dim = 64
h_dim = 512

class Representation(nn.Module):
    def __init__(self, x_dim, y_dim):
        super(Representation, self).__init__()
        
        self.fc1 = nn.Linear(x_dim + y_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, h_dim)
        
        self.bn1 = nn.BatchNorm1d(h_dim)        
        self.bn2 = nn.BatchNorm1d(h_dim)        
        
    def forward(self, x, y):
        # x : (batch_size, x_dim)
        # y : (batch_size, y_dim)
        
        h = F.relu(self.bn1(self.fc1(torch.cat([x,y], 1))))
        h = F.relu(self.bn2(self.fc2(h)))
        
        # h : (batch_size, h_dim)
        return h
    
rep = Representation(x_dim, y_dim)

# inference model q(z|x_c, y_c, x_t, y_t)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x_c", "y_c", "x_t", "y_t"], var=["z"])        
        
        self.h = rep
        self.fc31 = nn.Linear(h_dim, z_dim)
        self.fc32 = nn.Linear(h_dim, z_dim)

    def forward(self, x_c, y_c, x_t, y_t):
        # x_c : (context_num, batch_size, x_dim)
        # y_c : (context_num, batch_size, y_dim)
        # x_t : (28*28 - context_num, batch_size, x_dim)   
        # x_t : (28*28 - context_num, batch_size, y_dim)
        
        _batch_size = x_c.shape[1]
                        
        x_c = x_c.view(-1, x_dim)
        y_c = y_c.view(-1, y_dim)
        x_t = x_t.view(-1, x_dim)
        y_t = y_t.view(-1, y_dim)       
        
        r_c = self.h(x_c, y_c).view(-1, _batch_size, h_dim)
        r_t = self.h(x_t, y_t).view(-1, _batch_size, h_dim)
        
        r = torch.cat([r_c, r_t], dim=0)
        r = torch.mean(r, dim=0)

        # r : (batch_size, z_dim)
        return {"loc": self.fc31(r), "scale": F.softplus(self.fc32(r))}    
    
# generative model p(y_t|x_t, z)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["x_t", "z"], var=["y_t"], sequential=True)

        self.fc1 = nn.Linear(z_dim+x_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, h_dim)
        self.fc3 = nn.Linear(h_dim, y_dim)
        self.bn1 = nn.BatchNorm1d(h_dim)        
        self.bn2 = nn.BatchNorm1d(h_dim)        

    def g(self, x_t, z):
        h = F.relu(self.bn1(self.fc1(torch.cat([z, x_t], 1))))
        h = F.relu(self.bn2(self.fc2(h)))
        return F.sigmoid(self.fc3(h))
        
    def forward(self, x_t, z):
        # x_t : (28*28 - context_num, batch_size, x_dim)
        # z : (batch_size, z_dim)        
        
        y_t = [self.g(_x_t, z)[np.newaxis,:,:] for _x_t in x_t]
        y_t = torch.cat(y_t)

        # y_t : (28*28 - context_num, batch_size, y_dim)
        return {"probs": y_t} 
    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim)

In [None]:
_p = Generator()
q = Inference()

p = _p * prior

print(p.prob_factorized_text, q.prob_factorized_text)

p.to(device)
q.to(device)

In [None]:
kl = KullbackLeibler(q, prior)
model = VAE(q, _p, regularizer=[kl], optimizer=optim.Adam, optimizer_params={"lr":1e-3})

In [None]:
def train(epoch):
    train_loss = 0    
    t = tqdm(train_loader)
    for batch_idx, (data, _) in enumerate(t):
        t.set_description('Epoch: {}'.format(epoch))
        
        x, y = get_coordinated_mnist(data)
        x_c, y_c, x_t, y_t = split_context_target(x, y, context_num=context_num, device=device)
        
        lower_bound, loss = model.train({"x_c": x_c, "y_c": y_c, "x_t": x_t, "y_t": y_t})
        train_loss += loss
        
        t.set_postfix(loss=loss.item())
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [None]:
def test(epoch):
    test_loss = 0
    for i, (data, _) in enumerate(test_loader):
        x, y = get_coordinated_mnist(data)
        x_c, y_c, x_t, y_t = split_context_target(x, y, context_num=context_num, device=device)

        lower_bound, loss = model.test({"x_c": x_c, "y_c": y_c, "x_t": x_t, "y_t": y_t})        
        test_loss += loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}'.format(test_loss))
    return test_loss

In [None]:
writer = SummaryWriter()

x_original = test_loader.dataset.test_data[:2]
x, y = get_coordinated_mnist(x_original)
x_c, y_c, x_t, y_t = split_context_target(x, y, context_num=context_num, device=device)

for epoch in range(1, epochs + 1):
    plot_buf = plot_image(x_c, y_c, x_t, y_t)

    image = PIL.Image.open(plot_buf)
    image = ToTensor()(image).unsqueeze(0)
    
    writer.add_image('Image', image, epoch)
    
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)    
    
writer.close()