# Semi-supervised learning with M2 model

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 10
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
# https://github.com/wohlert/semi-supervised-pytorch/blob/master/examples/notebooks/datautils.py

from functools import reduce
from operator import __or__
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import MNIST
import numpy as np

labels_per_class = 10
n_labels = 10

root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])

mnist_train = MNIST(root=root, train=True, download=True, transform=transform)
mnist_valid = MNIST(root=root, train=False, transform=transform)


def get_sampler(labels, n=None):
    # Only choose digits in n_labels
    (indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(n_labels)]))

    # Ensure uniform distribution of labels
    np.random.shuffle(indices)
    indices = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(n_labels)])

    indices = torch.from_numpy(indices)
    sampler = SubsetRandomSampler(indices)
    return sampler


# Dataloaders for MNIST
kwargs = {'num_workers': 1, 'pin_memory': True}
labelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
                                       sampler=get_sampler(mnist_train.targets.numpy(), labels_per_class),
                                       **kwargs)
unlabelled = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
                                         sampler=get_sampler(mnist_train.targets.numpy()), **kwargs)
validation = torch.utils.data.DataLoader(mnist_valid, batch_size=batch_size,
                                         sampler=get_sampler(mnist_valid.targets.numpy()), **kwargs)


In [3]:
from pixyz.distributions import Normal, Bernoulli, RelaxedCategorical, Categorical
from pixyz.models import Model
from pixyz.losses import ELBO
from pixyz.utils import print_latex

In [4]:
x_dim = 784
y_dim = 10
z_dim = 64


# inference model q(z|x,y)
class Inference(Normal):
    def __init__(self):
        super().__init__(cond_var=["x","y"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim+y_dim, 512)
        self.fc21 = nn.Linear(512, z_dim)
        self.fc22 = nn.Linear(512, z_dim)

    def forward(self, x, y):
        h = F.relu(self.fc1(torch.cat([x, y], 1)))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

    
# generative model p(x|z,y)    
class Generator(Bernoulli):
    def __init__(self):
        super().__init__(cond_var=["z","y"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim+y_dim, 512)
        self.fc2 = nn.Linear(512, x_dim)

    def forward(self, z, y):
        h = F.relu(self.fc1(torch.cat([z, y], 1)))
        return {"probs": torch.sigmoid(self.fc2(h))}


# classifier p(y|x)
class Classifier(RelaxedCategorical):
    def __init__(self):
        super(Classifier, self).__init__(cond_var=["x"], var=["y"], name="p")
        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, y_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.softmax(self.fc2(h), dim=1)
        return {"probs": h}


# prior model p(z)
prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[z_dim], name="p_{prior}").to(device)

# distributions for supervised learning
p = Generator().to(device)
q = Inference().to(device)
f = Classifier().to(device)
p_joint = p * prior

In [5]:
print(p_joint)
print_latex(p_joint)

Distribution:
  p(x,z|y) = p(x|z,y)p_{prior}(z)
Network architecture:
  p_{prior}(z):
  Normal(
    name=p_{prior}, distribution_name=Normal,
    var=['z'], cond_var=[], input_var=[], features_shape=torch.Size([64])
    (loc): torch.Size([1, 64])
    (scale): torch.Size([1, 64])
  )
  p(x|z,y):
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z', 'y'], input_var=['z', 'y'], features_shape=torch.Size([])
    (fc1): Linear(in_features=74, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=784, bias=True)
  )


<IPython.core.display.Math object>

In [6]:
print(q)
print_latex(q)

Distribution:
  q(z|x,y)
Network architecture:
  Inference(
    name=q, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc1): Linear(in_features=794, out_features=512, bias=True)
    (fc21): Linear(in_features=512, out_features=64, bias=True)
    (fc22): Linear(in_features=512, out_features=64, bias=True)
  )


<IPython.core.display.Math object>

In [7]:
print(f)
print_latex(f)

Distribution:
  p(y|x)
Network architecture:
  Classifier(
    name=p, distribution_name=RelaxedCategorical,
    var=['y'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (temperature): torch.Size([1])
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=10, bias=True)
  )


<IPython.core.display.Math object>

In [8]:
# distributions for unsupervised learning
_q_u = q.replace_var(x="x_u", y="y_u")
p_u = p.replace_var(x="x_u", y="y_u")
f_u = f.replace_var(x="x_u", y="y_u")

q_u = _q_u * f_u
p_joint_u = p_u * prior

p_joint_u.to(device)
q_u.to(device)
f_u.to(device)

print(p_joint_u)
print_latex(p_joint_u)

Distribution:
  p(z,x_{u}|y_{u}) = p(x_{u}|z,y_{u})p_{prior}(z)
Network architecture:
  p_{prior}(z):
  Normal(
    name=p_{prior}, distribution_name=Normal,
    var=['z'], cond_var=[], input_var=[], features_shape=torch.Size([64])
    (loc): torch.Size([1, 64])
    (scale): torch.Size([1, 64])
  )
  p(x_{u}|z,y_{u}) -> p(x|z,y):
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z', 'y'], input_var=['z', 'y'], features_shape=torch.Size([])
    (fc1): Linear(in_features=74, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=784, bias=True)
  )


<IPython.core.display.Math object>

In [9]:
print(q_u)
print_latex(q_u)

Distribution:
  q(z,y_{u}|x_{u}) = q(z|x_{u},y_{u})p(y_{u}|x_{u})
Network architecture:
  p(y_{u}|x_{u}) -> p(y|x):
  Classifier(
    name=p, distribution_name=RelaxedCategorical,
    var=['y'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (temperature): torch.Size([1])
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=10, bias=True)
  )
  q(z|x_{u},y_{u}) -> q(z|x,y):
  Inference(
    name=q, distribution_name=Normal,
    var=['z'], cond_var=['x', 'y'], input_var=['x', 'y'], features_shape=torch.Size([])
    (fc1): Linear(in_features=794, out_features=512, bias=True)
    (fc21): Linear(in_features=512, out_features=64, bias=True)
    (fc22): Linear(in_features=512, out_features=64, bias=True)
  )


<IPython.core.display.Math object>

In [10]:
print(f_u)
print_latex(f_u)

Distribution:
  p(y_{u}|x_{u})
Network architecture:
  p(y_{u}|x_{u}) -> p(y|x):
  Classifier(
    name=p, distribution_name=RelaxedCategorical,
    var=['y'], cond_var=['x'], input_var=['x'], features_shape=torch.Size([])
    (temperature): torch.Size([1])
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=10, bias=True)
  )


<IPython.core.display.Math object>

In [11]:
elbo_u = ELBO(p_joint_u, q_u)
elbo = ELBO(p_joint, q)
nll = -f.log_prob() # or -LogProb(f)

rate = 1 * (len(unlabelled) + len(labelled)) / len(labelled)

loss_cls = -elbo_u.mean() -elbo.mean() + (rate * nll).mean() 
print(loss_cls)
print_latex(loss_cls)

mean \left(- 470.0 \log p(y|x) \right) - mean \left(\mathbb{E}_{q(z,y_{u}|x_{u})} \left[\log p(z,x_{u}|y_{u}) - \log q(z,y_{u}|x_{u}) \right] \right) - mean \left(\mathbb{E}_{q(z|x,y)} \left[\log p(x,z|y) - \log q(z|x,y) \right] \right)


<IPython.core.display.Math object>

In [12]:
model = Model(loss_cls,test_loss=nll.mean(),
              distributions=[p, q, f], optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)
print_latex(model)

Distributions (for training): 
  p(x|z,y), q(z|x,y), p(y|x) 
Loss function: 
  mean \left(- 470.0 \log p(y|x) \right) - mean \left(\mathbb{E}_{q(z,y_{u}|x_{u})} \left[\log p(z,x_{u}|y_{u}) - \log q(z,y_{u}|x_{u}) \right] \right) - mean \left(\mathbb{E}_{q(z|x,y)} \left[\log p(x,z|y) - \log q(z|x,y) \right] \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.001
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [13]:
def train(epoch):
    train_loss = 0
    for x_u, y_u in tqdm(unlabelled):
        x, y = iter(labelled).next()
        x = x.to(device)
        y = torch.eye(10)[y].to(device)
        x_u = x_u.to(device)
        loss = model.train({"x": x, "y": y, "x_u": x_u})
        train_loss += loss
        
    train_loss = train_loss * unlabelled.batch_size / len(unlabelled.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    
    return train_loss

In [14]:
def test(epoch):
    test_loss = 0
    correct = 0
    total = 0    
    for x, y in validation:
        x = x.to(device)
        y = torch.eye(10)[y].to(device)        
        loss = model.test({"x": x, "y": y})
        test_loss += loss
        
        pred_y = f.sample_mean({"x": x})
        total += y.size(0)
        correct += (pred_y.argmax(dim=1) == y.argmax(dim=1)).sum().item()      

    test_loss = test_loss * validation.batch_size / len(validation.dataset)
    test_accuracy = 100 * correct / total
    print('Test loss: {:.4f}, Test accuracy: {:.4f}'.format(test_loss, test_accuracy))
    return test_loss, test_accuracy

In [15]:
import datetime

dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

In [16]:
import pixyz
v = pixyz.__version__
writer = SummaryWriter("runs/" + v + ".m2"  + exp_time)

import time
start = time.time()
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss, test_accuracy = test(epoch)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)
    writer.add_scalar('test_accuracy', test_accuracy, epoch)    
elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time)
writer.close()

100%|██████████| 469/469 [01:09<00:00,  6.77it/s]

Epoch: 1 Train loss: 341.2683



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 1.8495, Test accuracy: 71.9900


100%|██████████| 469/469 [01:10<00:00,  6.66it/s]

Epoch: 2 Train loss: 215.2050



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.1879, Test accuracy: 70.2200


100%|██████████| 469/469 [01:13<00:00,  6.42it/s]


Epoch: 3 Train loss: 199.4764


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.2798, Test accuracy: 71.2500


100%|██████████| 469/469 [01:14<00:00,  6.27it/s]

Epoch: 4 Train loss: 192.5271



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.7242, Test accuracy: 69.5000


100%|██████████| 469/469 [01:13<00:00,  6.38it/s]

Epoch: 5 Train loss: 188.6577



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.7264, Test accuracy: 71.0500


100%|██████████| 469/469 [01:13<00:00,  6.36it/s]

Epoch: 6 Train loss: 186.1241



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.8207, Test accuracy: 72.5300


100%|██████████| 469/469 [01:13<00:00,  6.36it/s]


Epoch: 7 Train loss: 184.3111


  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.9221, Test accuracy: 72.4700


100%|██████████| 469/469 [01:13<00:00,  6.40it/s]

Epoch: 8 Train loss: 182.8273



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 3.0312, Test accuracy: 74.4200


100%|██████████| 469/469 [01:13<00:00,  6.38it/s]

Epoch: 9 Train loss: 181.6874



  0%|          | 0/469 [00:00<?, ?it/s]

Test loss: 2.5889, Test accuracy: 77.2500


100%|██████████| 469/469 [01:13<00:00,  6.40it/s]


Epoch: 10 Train loss: 180.7162
Test loss: 2.5414, Test accuracy: 77.8800
