# Training codes summary for datasets and Classifiers

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append('/content/drive/MyDrive/DL')
# save your nn_model in the certain path

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 256
# epoch_num = 100 
classification_weight = 8
lr = 1e-4

## Training on Fashion-MNIST

In [None]:
from nn_model import VAEClassifier, StAEClassifier
epoch_num = 100
# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(root='./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                   ])), batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
vae_classifier = VAEClassifier()
stae_classifier = StAEClassifier()
vae_classifier = VAEClassifier(num_feature_map=64,encoder_layer=3,decoder_layers=4,input_channels=3).to(device)
stae_classifier = StAEClassifier(num_feature_map=64,encoder_layer=3,decoder_layers=4,input_channels=3).to(device)

CE_Loss = nn.CrossEntropyLoss()
mseloss = torch.nn.MSELoss()
optimizer1 = torch.optim.Adam(vae_classifier.parameters(), lr=lr)
optimizer2 = torch.optim.Adam(stae_classifier.parameters(), lr=lr)

In [None]:
# VAE
for epoch in range(epoch_num):
    vae_classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        x_reconst, z, y, mu, log_var = vae_classifier(data, deterministic=False, classification_only=False)
        recons_loss = torch.sum((x_reconst - data) ** 2)
        kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp()) 
        # jointly training
        loss_val = CE_Loss(y, target) * classification_weight * batch_size + recons_loss + kld_loss 
        optimizer1.zero_grad() 
        loss_val.backward() 
        optimizer1.step() 

vae_classifier.eval()
torch.save(vae_classifier.state_dict(), './fmnist_vae_clf.pth')
# St-AE
for epoch in range(epoch_num):
    stae_classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        x_reconst, z, y = stae_classifier(data, classification_only=False)
        loss_val = CE_Loss(y, target) * classification_weight * batch_size + torch.sum((x_reconst - data) ** 2)
        optimizer2.zero_grad()
        loss_val.backward()
        optimizer2.step()

stae_classifier.eval()
torch.save(stae_classifier.state_dict(), './fmnist_stae_clf.pth')

In [None]:
vae_classifier = vae_classifier.eval()
stae_classifier = stae_classifier.eval()
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(root='./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),])), batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# St-AE
pred_list = []
gt_list = [] # ground truth
for batch_idx, (data, target) in enumerate(test_loader): 
    data, target = data.to(device), target.to(device)
    x_reconst, z, y_test = stae_classifier(data, classification_only=False)
    pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
    gt_list += list(target.detach().cpu().numpy())

acc = np.sum(np.array(gt_list) == np.array(pred_list)) / len(gt_list)
print(acc)
plt.figure(figsize=(20,5))
reconst_sample = np.concatenate(x_reconst[:10,0].detach().cpu().numpy(), axis=1)
input_sample = np.concatenate(data[:10,0].detach().cpu().numpy(), axis=1)
plt.imshow(np.concatenate([input_sample, reconst_sample], axis=0), cmap='gray')
plt.show()
# VAE
pred_list = []
gt_list = []
for batch_idx, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)
    x_reconst, z, y_test, mu, log_var = vae_classifier(data, deterministic=True, classification_only=False)
    pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
    gt_list += list(target.detach().cpu().numpy())

acc = np.sum(np.array(gt_list) == np.array(pred_list)) / len(gt_list)
print(acc)
plt.figure(figsize=(20,5))
reconst_sample = np.concatenate(x_reconst[:10,0].detach().cpu().numpy(), axis=1)
input_sample = np.concatenate(data[:10,0].detach().cpu().numpy(), axis=1)
plt.imshow(np.concatenate([input_sample, reconst_sample], axis=0), cmap='gray')
plt.show()

## Training on MNIST

In [None]:
from nn_model import VAEClassifier, StAEClassifier
epoch_num = 100
# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                   ])), batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
vae_classifier = VAEClassifier()
stae_classifier = StAEClassifier()
vae_classifier = VAEClassifier(num_feature_map=64,encoder_layer=3,decoder_layers=4,input_channels=3).to(device)
stae_classifier = StAEClassifier(num_feature_map=64,encoder_layer=3,decoder_layers=4,input_channels=3).to(device)

CE_Loss = nn.CrossEntropyLoss()
mseloss = torch.nn.MSELoss()
optimizer1 = torch.optim.Adam(vae_classifier.parameters(), lr=lr)
optimizer2 = torch.optim.Adam(stae_classifier.parameters(), lr=lr)

In [None]:
# VAE
for epoch in range(epoch_num):
    vae_classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        x_reconst, z, y, mu, log_var = vae_classifier(data, deterministic=False, classification_only=False)
        recons_loss = torch.sum((x_reconst - data) ** 2)
        kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp()) 
        # jointly training
        loss_val = CE_Loss(y, target) * classification_weight * batch_size + recons_loss + kld_loss 
        optimizer1.zero_grad() 
        loss_val.backward() 
        optimizer1.step() 

vae_classifier.eval()
torch.save(vae_classifier.state_dict(), './mnist_vae_clf.pth')
# St-AE
for epoch in range(epoch_num):
    stae_classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        x_reconst, z, y = stae_classifier(data, classification_only=False)
        loss_val = CE_Loss(y, target) * classification_weight * batch_size + torch.sum((x_reconst - data) ** 2)
        optimizer2.zero_grad()
        loss_val.backward()
        optimizer2.step()

stae_classifier.eval()
torch.save(stae_classifier.state_dict(), './mnist_stae_clf.pth')

In [None]:
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='./data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
    ])), batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# St-AE
pred_list = []
gt_list = [] # ground truth
for batch_idx, (data, target) in enumerate(test_loader): 
    data, target = data.to(device), target.to(device)
    x_reconst, z, y_test = stae_classifier(data, classification_only=False)
    pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
    gt_list += list(target.detach().cpu().numpy())

acc = np.sum(np.array(gt_list) == np.array(pred_list)) / len(gt_list)
print(acc)
plt.figure(figsize=(20,5))
reconst_sample = np.concatenate(x_reconst[:10,0].detach().cpu().numpy(), axis=1)
input_sample = np.concatenate(data[:10,0].detach().cpu().numpy(), axis=1)
plt.imshow(np.concatenate([input_sample, reconst_sample], axis=0), cmap='gray')
plt.show()
# VAE
pred_list = []
gt_list = []
for batch_idx, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)
    x_reconst, z, y_test, mu, log_var = vae_classifier(data, deterministic=True, classification_only=False)
    pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
    gt_list += list(target.detach().cpu().numpy())

acc = np.sum(np.array(gt_list) == np.array(pred_list)) / len(gt_list)
print(acc)
plt.figure(figsize=(20,5))
reconst_sample = np.concatenate(x_reconst[:10,0].detach().cpu().numpy(), axis=1)
input_sample = np.concatenate(data[:10,0].detach().cpu().numpy(), axis=1)
plt.imshow(np.concatenate([input_sample, reconst_sample], axis=0), cmap='gray')
plt.show()

## Training on SVHN

In [None]:
from nn_model_2 import VAEClassifier, StAEClassifier
epoch_num = 100
# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.SVHN(root='./data', split='train', download=True,
                   transform=transforms.Compose([
                       transforms.Resize((32, 32)),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])), batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
vae_classifier = VAEClassifier()
stae_classifier = StAEClassifier()
vae_classifier = VAEClassifier(num_feature_map=64,encoder_layer=3,decoder_layers=4,input_channels=3).to(device)
stae_classifier = StAEClassifier(num_feature_map=64,encoder_layer=3,decoder_layers=4,input_channels=3).to(device)

CE_Loss = nn.CrossEntropyLoss()
mseloss = torch.nn.MSELoss()
optimizer1 = torch.optim.Adam(vae_classifier.parameters(), lr=lr)
optimizer2 = torch.optim.Adam(stae_classifier.parameters(), lr=lr)

In [None]:
# VAE
for epoch in range(epoch_num):
    vae_classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        x_reconst, z, y, mu, log_var = vae_classifier(data, deterministic=False, classification_only=False)
        recons_loss = torch.sum((x_reconst - data) ** 2)
        kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp()) 
        # jointly training
        loss_val = CE_Loss(y, target) * classification_weight * batch_size + recons_loss + kld_loss 
        optimizer1.zero_grad() 
        loss_val.backward() 
        optimizer1.step() 

vae_classifier.eval()
torch.save(vae_classifier.state_dict(), './svhn_vae_clf.pth')
# St-AE
for epoch in range(epoch_num):
    stae_classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        x_reconst, z, y = stae_classifier(data, classification_only=False)
        loss_val = CE_Loss(y, target) * classification_weight * batch_size + torch.sum((x_reconst - data) ** 2)
        optimizer2.zero_grad()
        loss_val.backward()
        optimizer2.step()

stae_classifier.eval()
torch.save(stae_classifier.state_dict(), './svhn_stae_clf.pth')


In [None]:
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.SVHN(root='./data', split='test', download=True,transform=transforms.Compose([
        transforms.ToTensor(),])), batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# St-AE
pred_list = []
gt_list = [] # ground truth
for batch_idx, (data, target) in enumerate(test_loader): 
    data, target = data.to(device), target.to(device)
    x_reconst, z, y_test = stae_classifier(data, classification_only=False)
    pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
    gt_list += list(target.detach().cpu().numpy())
acc = np.sum(np.array(gt_list) == np.array(pred_list)) / len(gt_list)
print(acc)

plt.figure(figsize=(20,5))
reconst_sample_np = np.transpose(x_reconst[:10].detach().cpu().numpy(), (0, 2, 3, 1))
input_sample_np = np.transpose(data[:10].detach().cpu().numpy(), (0, 2, 3, 1))
reconst_sample = np.concatenate(reconst_sample_np, axis=1)
input_sample = np.concatenate(input_sample_np, axis=1)
plt.imshow(np.concatenate([input_sample, reconst_sample], axis=0))
plt.show()
# VAE
pred_list = []
gt_list = []
for batch_idx, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)
    x_reconst, z, y_test, mu, log_var = vae_classifier(data, deterministic=True, classification_only=False)
    pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
    gt_list += list(target.detach().cpu().numpy())
acc = np.sum(np.array(gt_list) == np.array(pred_list)) / len(gt_list)
print(acc)

plt.figure(figsize=(20,5))
reconst_sample_np = np.transpose(x_reconst[:10].detach().cpu().numpy(), (0, 2, 3, 1))
input_sample_np = np.transpose(data[:10].detach().cpu().numpy(), (0, 2, 3, 1))
reconst_sample = np.concatenate(reconst_sample_np, axis=1)
input_sample = np.concatenate(input_sample_np, axis=1)
plt.imshow(np.concatenate([input_sample, reconst_sample], axis=0))
plt.show()

## Training on CIFAR-10

In [None]:
from nn_model_cifar10 import VAEClassifier, StAEClassifier
epoch_num = 200
# Training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                   ])), batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
vae_classifier = VAEClassifier()
stae_classifier = StAEClassifier()
vae_classifier = vae_classifier.to(device) 
stae_classifier = stae_classifier.to(device)
CE_Loss = nn.CrossEntropyLoss()
mseloss = torch.nn.MSELoss()
optimizer1 = torch.optim.Adam(vae_classifier.parameters(), lr=lr)
optimizer2 = torch.optim.Adam(stae_classifier.parameters(), lr=lr)

In [None]:
# VAE
for epoch in range(epoch_num):
    vae_classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        x_reconst, z, y, mu, log_var = vae_classifier(data, deterministic=False, classification_only=False)
        recons_loss = torch.sum((x_reconst - data) ** 2)
        kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp()) 
        # jointly training
        loss_val = CE_Loss(y, target) * classification_weight * batch_size + recons_loss + kld_loss 
        optimizer1.zero_grad() 
        loss_val.backward() 
        optimizer1.step() 

vae_classifier.eval()
torch.save(vae_classifier.state_dict(), './cifar10_vae_clf.pth')
# St-AE
for epoch in range(epoch_num):
    stae_classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        x_reconst, z, y = stae_classifier(data, classification_only=False)
        loss_val = CE_Loss(y, target) * classification_weight * batch_size + torch.sum((x_reconst - data) ** 2)
        optimizer2.zero_grad()
        loss_val.backward()
        optimizer2.step()

stae_classifier.eval()
torch.save(stae_classifier.state_dict(), './cifar10_stae_clf.pth')

In [None]:
# Test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),])), batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# St-AE
pred_list = []
gt_list = [] # ground truth
for batch_idx, (data, target) in enumerate(test_loader): 
    data, target = data.to(device), target.to(device)
    x_reconst, z, y_test = stae_classifier(data, classification_only=False)
    pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
    gt_list += list(target.detach().cpu().numpy())
acc = np.sum(np.array(gt_list) == np.array(pred_list)) / len(gt_list)
print(acc)

plt.figure(figsize=(20,5))
reconst_sample_np = np.transpose(x_reconst[:10].detach().cpu().numpy(), (0, 2, 3, 1))
input_sample_np = np.transpose(data[:10].detach().cpu().numpy(), (0, 2, 3, 1))
reconst_sample = np.concatenate(reconst_sample_np, axis=1)
input_sample = np.concatenate(input_sample_np, axis=1)
plt.imshow(np.concatenate([input_sample, reconst_sample], axis=0))
plt.show()
# VAE
pred_list = []
gt_list = []
for batch_idx, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)
    x_reconst, z, y_test, mu, log_var = vae_classifier(data, deterministic=True, classification_only=False)
    pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
    gt_list += list(target.detach().cpu().numpy())
acc = np.sum(np.array(gt_list) == np.array(pred_list)) / len(gt_list)
print(acc)

plt.figure(figsize=(20,5))
reconst_sample_np = np.transpose(x_reconst[:10].detach().cpu().numpy(), (0, 2, 3, 1))
input_sample_np = np.transpose(data[:10].detach().cpu().numpy(), (0, 2, 3, 1))
reconst_sample = np.concatenate(reconst_sample_np, axis=1)
input_sample = np.concatenate(input_sample_np, axis=1)
plt.imshow(np.concatenate([input_sample, reconst_sample], axis=0))
plt.show()