In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
try:
    import matplotlib.pyplot as plt
except:
    !pip3 install matplotlib
import numpy as np
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data.dataloader import default_collate
from typing import Union, Iterable, List, Dict, Tuple, Optional, cast
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
torch.manual_seed(42)

<torch._C.Generator at 0x7efd4c5fae90>

In [2]:
# z = torch.rand(size=(100,50))

# dirichlet = torch.distributions.Dirichlet(z)
# p_set = dirichlet.sample()
# N, K = p_set.size()

# mu1_tilde = torch.mean(p_set, axis=0)
# mu2_tilde = torch.mean(torch.pow(p_set,2), axis=0)

# S = 1/K * torch.sum((mu1_tilde-mu2_tilde) / (mu2_tilde-torch.pow(mu1_tilde,2)))

# alpha = S/N * torch.sum(p_set, axis=0)

In [3]:
# alpha.size()

In [4]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()

        self.dense1 = nn.Linear(784, 500)
        self.dense2 = nn.Linear(500, 500)
        self.dense3 = nn.Linear(500, latent_dim)

    def sample(self, alpha_hat):
        u = torch.rand(size=alpha_hat.size(), requires_grad=False).to(device)
        v = torch.pow(u * alpha_hat * torch.exp(torch.lgamma(alpha_hat)),1.0/alpha_hat)
        z = v / torch.sum(v)
        return z

    def forward(self, x):
        alpha_hat = x.view(-1, 28*28)
        alpha_hat = F.relu(self.dense1(alpha_hat))
        alpha_hat = F.relu(self.dense2(alpha_hat))
        alpha_hat = F.softplus(self.dense3(alpha_hat))
        z = self.sample(alpha_hat)
        return z, alpha_hat

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.dense1 = nn.Linear(latent_dim, 500)
        self.dense2 = nn.Linear(500, 28*28)

    def forward(self, x):
        x_hat = F.relu(self.dense1(x))
        return nn.Sigmoid()(self.dense2(x_hat))

In [5]:
class DirVAE(nn.Module):
    def __init__(self, latent_dim):
        super(DirVAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
    
    def forward(self, x):
        z, alpha_hat = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat, alpha_hat, z

In [6]:
def ELBO(x_hat, x, alpha_hat, alpha, epsilon=5e-16):
    
    #likelihood = F.binary_cross_entropy_with_logits(x_hat, x.view(-1, 28*28), reduction='sum') #remove output func
    #likelihood = torch.sum(torch.pow(x_hat - x.view(-1, 28*28), 2))
    likelihood = torch.abs(torch.sum(x.view(-1, 28*28) * torch.log(x_hat + epsilon) + (1.0 - x.view(-1, 28*28)) * torch.log(1.0-x_hat + epsilon)))
    
    lgamma_alpha = torch.lgamma(alpha).to(device)
    lgamma_alpha_hat = torch.lgamma(alpha_hat).to(device)
    digamma_alpha_hat = torch.digamma(alpha_hat).to(device)
    
    kld = torch.sum(lgamma_alpha - lgamma_alpha_hat + (alpha_hat - alpha) * digamma_alpha_hat)
    
#     if torch.isnan(likelihood):
#         print('LIKELIHOOD IS NAN')
        
#     if torch.isnan(kld):
#         print('KLD IS NAN') 

    return likelihood + kld

In [7]:
def update_alpha_mme(z):
    dirichlet = torch.distributions.Dirichlet(z)
    p_set = dirichlet.sample()
    N, K = p_set.size()

    mu1_tilde = torch.mean(p_set, axis=0)
    mu2_tilde = torch.mean(torch.pow(p_set,2), axis=0)

    S = 1/K * torch.sum((mu1_tilde-mu2_tilde) / (mu2_tilde-torch.pow(mu1_tilde,2)))

    alpha = S/N * torch.sum(p_set, axis=0)
    
    return alpha
    

In [8]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()),
                                           batch_size=100, shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
                                          batch_size=100, shuffle=True, collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x)))
cuda = torch.cuda.is_available()

device = torch.device("cuda" if cuda else "cpu")

# nll_loss = nn.NLLLoss()

latent_dim = 50

model = DirVAE(latent_dim).to(device)

params = model.parameters()
optimizer = optim.Adam(params, lr=5e-4)

alpha =  ((1 - 1/latent_dim) * torch.ones(size=(latent_dim,))).to(device)

epochs = 1000

scaler = GradScaler()

for epoch in range(epochs):
    model.train()
    for batch_idx, (x, _) in enumerate(train_loader): 
        #x = x.to(device)
        optimizer.zero_grad()
        with autocast():
            x_hat, alpha_hat, z = model(x)
            loss = ELBO(x_hat, x, alpha_hat, alpha)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(params, 750.0)
        scaler.step(optimizer)
        scaler.update()

     # = [p.grad for p in if p.grad is not None]
    temp1 = model.encoder.dense1.weight.grad
    temp2 = model.encoder.dense2.weight.grad
    temp3 = model.encoder.dense3.weight.grad
    print(dir(params))
    print(params.next())
    
    print([torch.linalg.vector_norm(temp1).cpu().numpy().tolist(),
           torch.linalg.vector_norm(temp2).cpu().numpy().tolist(),
           torch.linalg.vector_norm(temp3).cpu().numpy().tolist()])
    print(f'loss at end of epoch {epoch}: {loss.item()}')
    
    model.eval()
    with torch.no_grad():
        for i, (val_x, _) in enumerate(test_loader):
            val_x = val_x.to(device)
            val_x_hat, val_alpha_hat, val_z = model(val_x)
            test_loss = ELBO(val_x_hat, val_x, val_alpha_hat, alpha)
    print(f'test loss at end of epoch {epoch}: {test_loss.item()}')
    
    if epoch == 0:
        print('ORIGINAL')
        plt.imshow(test_loader.dataset[0][0].numpy().reshape(28,28))
        plt.show()
    with torch.no_grad():
        sample = test_loader.dataset[0][0].to(device)
        img, img_alpha_hat, img_z = model(sample)
    img = torch.sigmoid(img)
    img = img.to('cpu').numpy().reshape(28,28)
    print('RECONSTRUCTED')
    plt.imshow(img)
    plt.show()
#     if epoch % 50 == 0 and epoch >= 200 and epoch < 299:
#         alpha = update_alpha_mme(z)
#         print('alpha:', alpha)

['__class__', '__del__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__lt__', '__name__', '__ne__', '__new__', '__next__', '__qualname__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'close', 'gi_code', 'gi_frame', 'gi_running', 'gi_yieldfrom', 'send', 'throw']


AttributeError: 'generator' object has no attribute 'next'