In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision
%matplotlib inline

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1)  # no padding so we lose 2 pixels
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1)  # no padding so we lose 2 pixels
        self.dropout1 = nn.Dropout2d(p=0.25)
        self.dropout2 = nn.Dropout2d(p=0.5)
        self.fc1 = nn.Linear(in_features=64*12*12, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.conv1(x)  # input becomes 32x26x26
        x = F.relu(x)
        x = self.conv2(x)  # input becomes 64x24x24
        x = F.max_pool2d(x, kernel_size=2)  # input becomes 64x12x12
        x = self.dropout1(x)
        x = torch.flatten(x, 1)  # flatten into a single 9,216 dim vector
        x = self.fc1(x)  # input goes from 9216 to 128
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)  # input goes from 128 to 10
        return x

In [None]:
model = Net()
optimizer = optim.SGD(model.parameters(), lr=1e-3)

In [None]:
batch = 30
nb_digits = 10

# define a transformation pipeline for our raw data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = torchvision.datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True)
dataiter = iter(trainloader)

X_train, y_train = dataiter.next()  # labelled data
imshow(torchvision.utils.make_grid(X_train))

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = torchvision.datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
testloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True)
testdataiter = iter(testloader)

In [None]:
n_epochs = 350
ratio = 5
threshold = .2
accuracy = []

y_onehot = torch.FloatTensor(batch, nb_digits)  # empty buffer

for i in range(n_epochs):
    
    for _ in range(ratio):
        # step 1: train on labeled data, which is never updated
        optimizer.zero_grad()
        output = model(X_train)
        loss = F.cross_entropy(output, y_train)
        loss.backward()
        optimizer.step()
        
    # step 2: view probabilities on unlabeled data
    X_unlabeled, y_unlabeled = dataiter.next()  # representing weakly augmented and unlabeled data
    output = model(X_unlabeled)
    probs = F.softmax(output, dim=1)
    pred = probs.argmax(dim=1, keepdim=True)

    # step 3: train against pseudo labels only if it's high confidence
    for a, b, c in zip(probs, pred, X_unlabeled):
        if (a[b] > threshold):
            
            # step 4: generate strong augmentation data only if we'll use it
            noise = torch.rand((1, 1, 28, 28))
            X_strong = noise + c.unsqueeze(0)

            # step 5: learn against pseudo labels
            optimizer.zero_grad()
            output = model(X_strong)
            loss = F.cross_entropy(output, b)
            loss.backward()
            optimizer.step()
                
    # test
    X_test, y_test = testdataiter.next()
    correct = 0
    output = model(X_test)
    pred = output.argmax(dim=1, keepdim=True)
    correct += pred.eq(y_test.view_as(pred)).sum().item()
    accuracy.append(correct / len(y_test))
    if i % (n_epochs / 50) == 0:
        print("epoch:", i, "accuracy", correct / len(y_test))
    

In [None]:
plt.title("learning for a handful of labeled data")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.plot(accuracy)