In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 10
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import Normal, Bernoulli, RelaxedCategorical, Categorical
from Tars.models import CustomLossModel
from Tars.losses import ELBO, NLL

In [3]:
# https://github.com/wohlert/semi-supervised-pytorch/blob/master/examples/notebooks/datautils.py

from functools import reduce
from operator import __or__
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import MNIST
import numpy as np
from itertools import cycle

labels_per_class = 10
n_labels = 10

mnist_train = MNIST('../data', train=True, download=True,
                    transform=transforms.ToTensor())
mnist_valid = MNIST('../data', train=False, transform=transforms.ToTensor())

def get_sampler(labels, n=None):
    # Only choose digits in n_labels
    (indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(n_labels)]))

    # Ensure uniform distribution of labels
    np.random.shuffle(indices)
    indices = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(n_labels)])

    indices = torch.from_numpy(indices)
    sampler = SubsetRandomSampler(indices)
    return sampler

# Dataloaders for MNIST
kwargs = {'num_workers': 1, 'pin_memory': True}
labelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
                                       sampler=get_sampler(mnist_train.train_labels.numpy(), labels_per_class),
                                       **kwargs)
unlabelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
                                         sampler=get_sampler(mnist_train.train_labels.numpy()), **kwargs)
validation = torch.utils.data.DataLoader(mnist_valid, batch_size=batch_size,
                                         sampler=get_sampler(mnist_valid.test_labels.numpy()), **kwargs)


In [4]:
x_dim = 784
y_dim = 10
z_dim = 64


# inference model q(z|x,y)
class Inference(Normal):
    def __init__(self):
        super().__init__(cond_var=["x","y"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim+y_dim, 512)
        self.fc21 = nn.Linear(512, z_dim)
        self.fc22 = nn.Linear(512, z_dim)

    def forward(self, x, y):
        h = F.relu(self.fc1(torch.cat([x, y], 1)))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

    
# generative model p(x|z,y)    
class Generator(Bernoulli):
    def __init__(self):
        super().__init__(cond_var=["z","y"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim+y_dim, 512)
        self.fc2 = nn.Linear(512, x_dim)

    def forward(self, z, y):
        h = F.relu(self.fc1(torch.cat([z, y], 1)))
        return {"probs": F.sigmoid(self.fc2(h))}

# discriminative model p(y|x)
class Discriminative(RelaxedCategorical):
    def __init__(self):
        super(Discriminative, self).__init__(cond_var=["x"], var=["y"], temperature=0.1, name="p")
        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, y_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.softmax(self.fc2(h), dim=1)
        return {"probs": h}
    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="prior")    

In [5]:
# distributions for supervised learning
p = Generator()
q = Inference()
f = Discriminative()
p_joint = p * prior

p_joint.to(device)
q.to(device)
p_joint.to(device)

print(p_joint)
print(q)
print(f)

Distribution:
  p(x,z|y) = p(x|z,y)prior(z)
Network architecture:
  prior(z) (Normal): Normal()
  p(x|z,y) (Bernoulli): Generator(
    (fc1): Linear(in_features=74, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=784, bias=True)
  )
Distribution:
  q(z|x,y) (Normal)
Network architecture:
  Inference(
    (fc1): Linear(in_features=794, out_features=512, bias=True)
    (fc21): Linear(in_features=512, out_features=64, bias=True)
    (fc22): Linear(in_features=512, out_features=64, bias=True)
  )
Distribution:
  p(y|x) (RelaxedCategorical)
Network architecture:
  Discriminative(
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=10, bias=True)
  )


In [6]:
# distributions for unsupervised learning
q_u = q.replace_var(x="x_u", y="y_u")
p_u = p.replace_var(x="x_u", y="y_u")
f_u = f.replace_var(x="x_u", y="y_u")

q_u = q_u * f_u
p_joint_u = p_u * prior

p_joint_u.to(device)
q_u.to(device)
f_u.to(device)

print(p_joint_u)
print(q_u)
print(f_u)

Distribution:
  p(x_u,z|y_u) = p(x_u|z,y_u)prior(z)
Network architecture:
  prior(z) (Normal): Normal()
  p(x_u|z,y_u) (Bernoulli): Generator(
    (fc1): Linear(in_features=74, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=784, bias=True)
  )
Distribution:
  p(z,y_u|x_u) = q(z|x_u,y_u)p(y_u|x_u)
Network architecture:
  p(y_u|x_u) (RelaxedCategorical): Discriminative(
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=10, bias=True)
  )
  q(z|x_u,y_u) (Normal): Inference(
    (fc1): Linear(in_features=794, out_features=512, bias=True)
    (fc21): Linear(in_features=512, out_features=64, bias=True)
    (fc22): Linear(in_features=512, out_features=64, bias=True)
  )
Distribution:
  p(y_u|x_u) (RelaxedCategorical)
Network architecture:
  Discriminative(
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=10, bias=True)
  )


In [7]:
elbo_u = ELBO(p_joint_u, q_u)
elbo = ELBO(p_joint, q)
nll = NLL(f)

rate = 1 * (len(unlabelled) + len(labelled)) / len(labelled)

loss_cls = -elbo_u.mean() -elbo.mean() + (rate * nll).mean() 
print(loss_cls)

-(mean(E_p(z,y_u|x_u)[log p(x_u,z|y_u)/p(z,y_u|x_u)])) - mean(E_q(z|x,y)[log p(x,z|y)/q(z|x,y)]) + mean(log p(y|x) * 470.0)


In [8]:
model = CustomLossModel(loss_cls,test_loss=nll.mean(),                            
                        distributions=[p, q, f], optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

Distributions (for training): 
  p(x|z,y), q(z|x,y), p(y|x) 
Loss function: 
  -(mean(E_p(z,y_u|x_u)[log p(x_u,z|y_u)/p(z,y_u|x_u)])) - mean(E_q(z|x,y)[log p(x,z|y)/q(z|x,y)]) + mean(log p(y|x) * 470.0)


In [9]:
def train(epoch):
    train_loss = 0
    for (x, y), (x_u, y_u) in tqdm(zip(cycle(labelled), unlabelled), total=len(unlabelled)):
        x = x.view(-1, 784).to(device)
        y = torch.eye(10)[y].to(device)
        x_u = x_u.view(-1, 784).to(device)        
        loss = model.train({"x": x, "y": y, "x_u": x_u})
        train_loss += loss
        
    train_loss = train_loss * unlabelled.batch_size / len(unlabelled.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    
    return train_loss

In [10]:
def test(epoch):
    test_loss = 0
    correct = 0
    total = 0    
    for i, (x, y) in enumerate(validation):
        x = x.view(-1, 784).to(device)
        y = torch.eye(10)[y].to(device)        
        loss = model.test({"x": x, "y": y})
        test_loss += loss
        
        pred_y = f.sample_mean({"x": x})
        total += y.size(0)
        correct += (pred_y.argmax(dim=1) == y.argmax(dim=1)).sum().item()      

    test_loss = test_loss * validation.batch_size / len(validation.dataset)
    test_accuracy = 100 * correct / total
    print('Test loss: {:.4f}, Test accuracy: {:.4f}'.format(test_loss, test_accuracy))
    return test_loss, test_accuracy

In [11]:
writer = SummaryWriter()

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss, test_accuracy = test(epoch)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)
    writer.add_scalar('test_accuracy', test_accuracy, epoch)    
    
writer.close()

100%|██████████| 469/469 [00:09<00:00, 48.53it/s]

Epoch: 1 Train loss: 329.3315





Test loss: 0.7689, Test accuracy: 81.4200


100%|██████████| 469/469 [00:09<00:00, 49.66it/s]

Epoch: 2 Train loss: 213.9175





Test loss: 0.8811, Test accuracy: 81.7200


100%|██████████| 469/469 [00:09<00:00, 49.57it/s]


Epoch: 3 Train loss: 197.6596
Test loss: 1.0261, Test accuracy: 81.5900


100%|██████████| 469/469 [00:09<00:00, 47.44it/s]


Epoch: 4 Train loss: 190.5572
Test loss: 1.1654, Test accuracy: 80.8900


100%|██████████| 469/469 [00:09<00:00, 49.62it/s]


Epoch: 5 Train loss: 186.5252
Test loss: 1.1687, Test accuracy: 81.7600


100%|██████████| 469/469 [00:09<00:00, 48.31it/s]


Epoch: 6 Train loss: 183.8721
Test loss: 1.3471, Test accuracy: 81.7100


100%|██████████| 469/469 [00:09<00:00, 51.87it/s]


Epoch: 7 Train loss: 181.9751
Test loss: 1.3701, Test accuracy: 82.4600


100%|██████████| 469/469 [00:08<00:00, 53.95it/s]


Epoch: 8 Train loss: 180.5036
Test loss: 1.6142, Test accuracy: 80.9400


100%|██████████| 469/469 [00:09<00:00, 52.00it/s]

Epoch: 9 Train loss: 179.2395





Test loss: 1.4760, Test accuracy: 83.5400


100%|██████████| 469/469 [00:08<00:00, 53.90it/s]

Epoch: 10 Train loss: 178.2590





Test loss: 1.7052, Test accuracy: 82.4500
