In [47]:
#################### Import Modules ####################
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import FashionMNIST, MNIST
from torchvision.transforms import ToTensor
from sklearn.svm import SVC
from sklearn.utils import check_random_state
import pickle
from prettytable import PrettyTable


# Set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: ", device)
# Set the device globally
torch.set_default_device(device)

Device:  cpu


In [48]:
##################### Random Seed ######################
torch.manual_seed(42)
random_state = check_random_state(42)

In [49]:
###################### Load Data #######################
def load_data(labeled_size=100, batch_size=64, dataset="FashionMNIST"):
    if dataset == "FashionMNIST":
        train_dataset = FashionMNIST(root='./data', train=True, transform=ToTensor(), download=True)
        test_dataset = FashionMNIST(root='./data', train=False, transform=ToTensor())
    elif dataset == "MNIST":
        train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True)
        test_dataset = MNIST(root='./data', train=False, transform=ToTensor())
    unlabeled_size = len(train_dataset) - labeled_size
    ul_train_dataset, l_train_dataset = random_split(train_dataset, [unlabeled_size, labeled_size])

    ul_train_loader = DataLoader(ul_train_dataset, batch_size=batch_size, shuffle=True)
    l_train_loader = DataLoader(l_train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return ul_train_loader, l_train_loader, test_loader

3. Semi-Supervised Variational Autoencoder
http://arxiv.org/abs/1406.5298

In [50]:
##################### Build Model ######################
class M1(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(M1, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Softplus(),
            nn.Linear(hidden_dim, 2 * latent_dim),  # Two outputs for mean and log variance
            nn.Softplus()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.Softplus(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Assuming input values are normalized between 0 and 1
        )

    def forward(self, x):
        # Encode
        h = self.encoder(x)
        mu, log_var = torch.chunk(h, 2, dim=1)
        z = self.reparameterize(mu, log_var)

        # Decode
        x_recon = self.decoder(z)

        return x_recon, mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

class SVM:
    def __init__(self, kernel):
        self.model = None
        self.kernel = kernel

    def train(self, X, y):
        self.model = SVC(kernel=self.kernel)
        self.model.fit(X, y)

    def predict(self, X):
        return self.model.predict(X)

In [51]:
####################### Helpers ########################
def compute_loss(recon_x, x, mu, logvar):
    # Reconstruction loss
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # KL divergence loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

def extract_m1_features(model, train_loader, device):
    X_train_features = []
    y_train_labels = []

    with torch.no_grad():
        for data, labels in train_loader:
            data = data.view(-1, 784).to(device)
            recon_x, mu, logvar = model(data)
            X_train_features.append(mu)
            y_train_labels.append(labels)

    X_train_features = torch.cat(X_train_features)
    y_train_labels = torch.cat(y_train_labels)

    return X_train_features, y_train_labels

def train_svm(svm, X_train_features, y_train_labels):
    svm.train(X_train_features, y_train_labels)

def evaluate_svm(model, svm, test_loader, device):
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.view(-1, 784).to(device)
            recon_x, mu, logvar = model(data)
            features = mu.cpu().numpy()
            predicted_labels = svm.predict(features)
            total += labels.size(0)
            correct += (predicted_labels == labels.numpy()).sum().item()

    return correct, total

In [52]:
##################### Train Model ######################
def train_m1_model(model, optimizer, train_loader, device, num_epochs, log_interval):
    model.train()
    for epoch in range(num_epochs):
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.view(-1, 784).to(device)

            recon_x, mu, logvar = model(data)

            # Compute loss and optimize
            loss = compute_loss(recon_x, data, mu, logvar)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.item()))

In [53]:
################### Hyperparameters ####################
input_dim = 784
hidden_dim = 600
latent_dim = 50
batch_size = 64
learning_rate = 3e-4
momentum = 0.1
num_epochs = 100
log_interval = 100
N_labeled = (100, 600, 1000, 3000)
kernel = 'rbf'

In [54]:
######################## Main ##########################
results = []
for dataset in ["FashionMNIST", "MNIST"]:
    for n in N_labeled:
        # Initialize M1 model, Transductive SVM, optimizer, and data loaders
        m1_model = M1(input_dim, hidden_dim, latent_dim)
        svm = SVM(kernel)
        optimizer = torch.optim.RMSprop(m1_model.parameters(), lr=learning_rate, momentum=momentum)

        ul_train_loader, l_train_loader, test_loader = load_data(labeled_size=n, batch_size=batch_size, dataset=dataset)
        # Train M1 model
        train_m1_model(m1_model, optimizer, ul_train_loader, device, num_epochs, log_interval)
        torch.save(m1_model.state_dict(), f'./checkpoints/M1_+_SVM/{dataset}/M1_{n}labeled.pth')

        # Extract features from M1 model for SVM
        X_train_features, y_train_labels = extract_m1_features(m1_model, l_train_loader, device)

        # Train SVM
        train_svm(svm, X_train_features, y_train_labels)
        with open(f'./checkpoints/M1_+_SVM/{dataset}/svm_{n}labeled.pkl','wb') as f:
            pickle.dump(svm,f)

        # Evaluate SVM
        correct, total = evaluate_svm(m1_model, svm, test_loader, device)
        accuracy = 100 * correct / total
        print(f'Accuracy of the SVM on the test images with {n} labeled images: {accuracy}%')

        results.append([dataset, n, accuracy, round(100 - accuracy,2)])
        

Epoch [1/100], Step [1/936], Loss: 38252.8516
Epoch [1/100], Step [101/936], Loss: 21494.0391
Epoch [1/100], Step [201/936], Loss: 19981.9492
Epoch [1/100], Step [301/936], Loss: 21149.5254
Epoch [1/100], Step [401/936], Loss: 20780.9219
Epoch [1/100], Step [501/936], Loss: 19600.6875
Epoch [1/100], Step [601/936], Loss: 19764.6328
Epoch [1/100], Step [701/936], Loss: 18280.0723
Epoch [1/100], Step [801/936], Loss: 18461.3223
Epoch [1/100], Step [901/936], Loss: 18068.8477
Epoch [2/100], Step [1/936], Loss: 18606.6367
Epoch [2/100], Step [101/936], Loss: 17816.9629
Epoch [2/100], Step [201/936], Loss: 19391.3887
Epoch [2/100], Step [301/936], Loss: 18458.3906
Epoch [2/100], Step [401/936], Loss: 18550.2168
Epoch [2/100], Step [501/936], Loss: 17516.8926
Epoch [2/100], Step [601/936], Loss: 18852.3008
Epoch [2/100], Step [701/936], Loss: 17672.6562
Epoch [2/100], Step [801/936], Loss: 17923.3203
Epoch [2/100], Step [901/936], Loss: 18566.1211
Epoch [3/100], Step [1/936], Loss: 18058.544

In [None]:
#################### Print Results #####################
def print_results(results):
    table = PrettyTable()
    table.field_names = ["Dataset", "N", "Accuracy", "Percentage Error"]
    for r in results:
        table.add_row(r)

    print("Results table for M1 model + SVM")
    print(table)

print_results(results)

In [56]:
##################### Test Model #######################
def test():
    results = []
    for dataset in ["FashionMNIST", "MNIST"]:
        for n in N_labeled:
            ul_train_loader, l_train_loader, test_loader = load_data(labeled_size=n, batch_size=batch_size, dataset=dataset)

            m1_model = M1(input_dim, hidden_dim, latent_dim)
            svm = SVM(kernel)
            optimizer = torch.optim.RMSprop(m1_model.parameters(), lr=learning_rate, momentum=momentum)

            m1_model.load_state_dict(torch.load(f'./checkpoints/M1_+_SVM/{dataset}/M1_{n}labeled.pth'))
            m1_model.eval()
            with open(f'./checkpoints/M1_+_SVM/{dataset}/svm_{n}labeled.pkl', 'rb') as f:
                svm = pickle.load(f)

            correct, total = evaluate_svm(m1_model, svm, test_loader, device)
            accuracy = 100 * correct / total
            print(f'Accuracy of the SVM on the test images with {n} labeled images: {accuracy}%')

            results.append([dataset, n, accuracy, round(100 - accuracy,2)])
    
    print_results(results)
    

test()

Accuracy of the SVM on the test images with 100 labeled images: 65.09%
Accuracy of the SVM on the test images with 600 labeled images: 77.79%
Accuracy of the SVM on the test images with 1000 labeled images: 78.93%
Accuracy of the SVM on the test images with 3000 labeled images: 81.66%
Accuracy of the SVM on the test images with 100 labeled images: 76.66%
Accuracy of the SVM on the test images with 600 labeled images: 92.69%
Accuracy of the SVM on the test images with 1000 labeled images: 93.59%
Accuracy of the SVM on the test images with 3000 labeled images: 95.59%
Results table for M1 model + SVM
+--------------+------+----------+------------------+
|   Dataset    |  N   | Accuracy | Percentage Error |
+--------------+------+----------+------------------+
| FashionMNIST | 100  |  65.09   |      34.91       |
| FashionMNIST | 600  |  77.79   |      22.21       |
| FashionMNIST | 1000 |  78.93   |      21.07       |
| FashionMNIST | 3000 |  81.66   |      18.34       |
|    MNIST     | 