In [8]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.cuda.amp import autocast, GradScaler
torch.manual_seed(42)

<torch._C.Generator at 0x7beec544a2b0>

In [15]:
torch.__version__

'2.0.0'

In [9]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        # Define the architecture of the encoder
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding='same')
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding='same')
        self.maxpool1 = nn.MaxPool2d(kernel_size=7, stride=7, padding='same', return_indices=True)
        self.flatten = nn.Flatten()
        self.softplus = nn.Softplus()
        
    def sample(self, alpha_hat):
        u = torch.rand(size=alpha_hat.size(), requires_grad=True).to(device)
        v = torch.pow(u * alpha_hat * torch.exp(torch.lgamma(alpha_hat)),1.0/alpha_hat)
        z = v / torch.sum(v)
        return z

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x, mask_1 = self.maxpool1(x)
        x = l2m_pool_1
        x_size = x.size()
        alpha_hat = self.flatten(x)
        alpha_hat_size = alpha_hat.size()
        alpha_hat = self.softplus(nn.Linear(alpha_hat.size()[1],latent_dim)(alpha_hat))
        z = self.sample(alpha_hat)
        return z, mask_1, l2m_pool_1, x_size, alpha_hat, alpha_hat_size


In [10]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        # Define the architecture of the decoder
        self.maxunpool = nn.MaxUnpool2D(kernel_size=7,padding='same')
        self.conv1 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding='same')
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2D(in_channels=16, out_channels=1, kernel_size=5, stride=1, padding='same')
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        z, mask_1, x_size, alpha_hat_size = inputs
        x_hat = nn.Linear(latent_dim, alpha_hat_size[1])(z)
        x_hat = torch.reshape(x_hat, x_size)
        x_hat = l2m_unpool_1
        x_hat = self.maxunpool(x, mask_1)
        x_hat = self.relu(self.conv1(x_hat))
        x_hat = self.sigmoid(self.conv2(x_hat))
        return x_hat, l2m_unpool_1

In [11]:
class SWWAE(nn.Module):
    def __init__(self, latent_dim):
        super(SWWAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def forward(self, x):
        z, mask_1, l2m_pool_1, x_size, alpha_hat, alpha_hat_size = self.encoder(x)
        x_hat, l2m_unpool_1 = self.decoder([z, mask_1, x_size, alpha_hat_size])
        return x_hat, l2m_pool_1, l2m_unpool_1, alpha_hat

In [12]:
def loss_fn(x, x_hat, l2m_pool_1, l2m_unpool_1, alpha_hat, alpha):
    ll_loss = F.binary_cross_entropy_with_logits(x_hat.view(-1, 28*28), x.view(-1, 28*28), reduction='sum')
    
    l2m_loss = F.mse_loss(l2m_pool_1, l2m_unpool_1)
    
    lgamma_alpha = torch.lgamma(alpha).to(device)
    lgamma_alpha_hat = torch.lgamma(alpha_hat).to(device)
    digamma_alpha_hat = torch.digamma(alpha_hat).to(device)
    
    kld = torch.sum(lgamma_alpha - lgamma_alpha_hat + (alpha_hat - alpha) * digamma_alpha_hat)
    
    total_loss = ll_loss + l2m_loss + kld

In [13]:
def update_alpha_mme(z):
    dirichlet = torch.distributions.Dirichlet(z)
    p_set = dirichlet.sample()
    N, K = p_set.size()

    mu1_tilde = torch.mean(p_set, axis=0)
    mu2_tilde = torch.mean(torch.pow(p_set,2), axis=0)

    S = 1/K * torch.sum((mu1_tilde-mu2_tilde) / (mu2_tilde-torch.pow(mu1_tilde,2)))

    alpha = S/N * torch.sum(p_set, axis=0)
    
    return alpha

In [14]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()),
                                           batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
                                          batch_size=100, shuffle=True)
cuda = torch.cuda.is_available()

device = torch.device("cuda" if cuda else "cpu")

latent_dim = 50

model = SWWAE(latent_dim).to(device)

params = model.parameters()
optimizer = optim.Adam(params, lr=5e-4)

alpha =  ((1 - 1/latent_dim) * torch.ones(size=(latent_dim,))).to(device)

epochs = 300

scaler = GradScaler()

for epoch in range(epochs):
    model.train()
    for batch_idx, (x, _) in enumerate(train_loader): 
        x = x.to(device)
        optimizer.zero_grad()
        with autocast():
            x_hat, l2m_pool_1, l2m_unpool_1, alpha_hat = model(x)
            loss = loss_fn(x, x_hat, l2m_pool_1, l2m_unpool_1, alpha_hat, alpha)
        scaler.scale(loss).backward()
        #torch.nn.utils.clip_grad_norm_(params, 1.0)
        scaler.step(optimizer)
        scaler.update()
    print(f'loss at end of epoch {epoch}: {loss.item()}')
    
    model.eval()
    with torch.no_grad():
        for i, (val_x, _) in enumerate(test_loader):
            val_x = val_x.to(device)
            val_x_hat, val_l2m_pool_1, val_l2m_unpool_1, val_alpha_hat = model(val_x)
            test_loss = loss_fn(val_x, val_x_hat, val_l2m_pool_1, val_l2m_unpool_1, val_alpha_hat, alpha)
    print(f'test loss at end of epoch {epoch}: {test_loss.item()}')
    
    if epoch == 0:
        print('ORIGINAL')
        plt.imshow(test_loader.dataset[0][0].numpy().reshape(28,28))
        plt.show()
    with torch.no_grad():
        sample = test_loader.dataset[0][0].to(device)
        img, img_l2m_pool_1, img_l2m_unpool_1, img_alpha_hat = model(sample)
    img = torch.sigmoid(img)
    img = img.to('cpu').numpy().reshape(28,28)
    print('RECONSTRUCTED')
    plt.imshow(img)
    plt.show()
    
    if epoch % 50 == 0 and epoch >= 200 and epoch < 299:
        alpha = update_alpha_mme(z)
        print('alpha:', alpha)

AttributeError: module 'torch.nn' has no attribute 'MaxUnpool2D'