In [33]:
import torch
import numpy as np
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import matplotlib.pyplot as plt

In [34]:
Batch_size = 256
Learning_rate = 0.0002
seed = 42
torch.manual_seed(seed)

<torch._C.Generator at 0x7f6ab1b8a9d0>

In [35]:
# Data Preprocessing
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.1307,), std=(0.3081,))])
trainset = torchvision.datasets.MNIST(root='./data',  train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=Batch_size, shuffle=True, num_workers=8, pin_memory=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=Batch_size, shuffle=False, num_workers=8, pin_memory=True)

In [36]:
class Autoencoder(nn.Module):
    def __init__(self, k_size=5):
        super(Autoencoder,self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=k_size),
            nn.ReLU(True),
            nn.Conv2d(6,1,kernel_size=k_size),
            nn.ReLU(True),
            nn.Flatten(),
            nn.Linear(400, 128),
            nn.ReLU(True))
        
        self.mid = nn.Sequential(
            nn.Linear(128, 400),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(   
            nn.ConvTranspose2d(1,6,kernel_size=k_size),
            nn.ReLU(True),
            nn.ConvTranspose2d(6,1,kernel_size=k_size))

    def forward(self,x):
        x = self.encoder(x)
        x = self.mid(x)
        x = x.view(-1, 1, 20, 20)
        x = self.decoder(x)
        return x

In [37]:
def try_gpu(i=0):
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

def save_model(model, path, optimizer=None):
    """
    Saves the state_dict of a torch model and optional optimizer to 'path'
    Returns: None
    """
    state = {"model": model.state_dict()}
    if optimizer is not None:
        state["optimizer"] = optimizer.state_dict()
    torch.save(state, path)


def load_model(model, path, optimizer=None):
    """
    Loads the state_dict of a torch model and optional optimizer from 'path'
    Returns: None
    """
    state = torch.load(path)
    model.load_state_dict(state["model"])
    if optimizer is not None:
        optimizer.load_state_dict(state["optimizer"])


In [38]:
def pretrain(net, train_iter, criterion, optimizer):
    if isinstance(net, nn.Module):
        net.train()
    loss_record = 0
    total =0
    for data in train_iter:
        img, _ = data
        img = Variable(img).cuda()

        output = net(img)
        loss = criterion(output, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_record += loss
        total += len(data)

    return loss_record / total

In [40]:
num_epochs = 15
ae = Autoencoder()
ae.to(device=try_gpu())
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(ae.parameters(), lr=Learning_rate, weight_decay=1e-5)

pre_loss=[]
for epoch in range(num_epochs):
    train_loss = pretrain(ae, train_loader, criterion, optimizer)
    print('epoch [{}/{}], train_loss:{:.4f}'.format(epoch+1, num_epochs, train_loss.item()))
    pre_loss.append(train_loss.item())
    if (epoch+1) % num_epochs == 0:
        save_model(model=ae, path=f"./ae_epoch_{epoch+1}_test.ckpt")

np.save("aetrainloss.npy", pre_loss)


In [41]:
test_examples = None

with torch.no_grad():
    for data in test_loader:
        img, label = data
        img = Variable(img).to(device=try_gpu())
        # output = model(img)
        test_examples = img
        reconstruction = ae(test_examples)
        break
    
with torch.no_grad():
    number = 10
    plt.figure(figsize=(20, 4))
    for index in range(number):
        # display original
        ax = plt.subplot(2, number, index + 1)
        plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, number, index + 1 + number)
        plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

In [43]:
class AE_Classifier(nn.Module):
    def __init__(self, k_size=5):
        super(AE_Classifier,self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=k_size),
            nn.ReLU(True),
            nn.Conv2d(6,1,kernel_size=k_size),
            nn.ReLU(True)
            )

        self.fc = nn.Sequential(
            nn.Linear(400, 128, bias=True),
            nn.ReLU(True),
            nn.Linear(128, 10, bias=True)
        )

    def forward(self,x):
        x = self.encoder(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        return x

In [44]:
classifier = AE_Classifier()
print(classifier)
pre_ae = Autoencoder()
pre_ae.to(device=try_gpu())
classifier.to(device=try_gpu())
load_model(pre_ae, "./ae_epoch_15_test.ckpt")
classifier_dict = classifier.state_dict()
pre = {k: v for k, v in pre_ae.state_dict().items() if k in classifier_dict}
classifier_dict.update(pre)
classifier.load_state_dict(classifier_dict)

AE_Classifier(
  (encoder): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(6, 1, kernel_size=(5, 5), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (fc): Sequential(
    (0): Linear(in_features=400, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
)


<All keys matched successfully>

In [45]:
def train_classifier(net, train_iter, criterion, optimizer):
    if isinstance(net, nn.Module):
        net.train()
    loss_record = 0
    total = 0
    for _, data in enumerate (train_iter, 0):
        # print(len(train_iter))
        img, labels = data
        img = Variable(img)
        label = Variable(labels)
        img = img.to(device=try_gpu())
        label = label.to(device=try_gpu())
        # ===================forward=====================
        output = net(img)
        loss = criterion(output, label)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_record += loss
        total += 1
    # ===================log========================
    return loss_record / total

def eval_accuracy(net, data_iter, criterion):
    if isinstance(net, nn.Module):
        net.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, target in data_iter:
            eval_imgs = Variable(data).to(device=try_gpu())
            # output = model(img)
            target = Variable(target).to(device=try_gpu())
            output = net(eval_imgs)
            # target = target.to(device=try_gpu())
            pred = output.max(dim=1)[1]
            # print(pred)
            # print(target)
            correct += (pred == target).sum().item()

            total += target.size(0)
            # print(total)
        return correct / total

In [46]:
for name, param in classifier.named_parameters():
    if param.requires_grad and "encoder" in name:
        param.requires_grad = True

classifier.to(device=try_gpu())
num_epochs = 10
classifier_criterion = nn.CrossEntropyLoss().to(device=try_gpu())
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-5)

train_loss = []
val_acc =[]
for epoch in range(num_epochs):
    epoch_train_loss = train_classifier(classifier, train_loader, classifier_criterion, classifier_optimizer)
    epoch_val_acc = eval_accuracy(classifier, test_loader, classifier_criterion)
    print('epoch [{}/{}], train_loss:{:.4f}'.format(epoch+1, num_epochs, epoch_train_loss.item()))    
    print('epoch [{}/{}], val_acc:{:.4f}'.format(epoch+1, num_epochs, epoch_val_acc))
    train_loss.append(epoch_train_loss.item())
    val_acc.append(epoch_val_acc)

np.save("aeclassvalacc.npy", val_acc)

epoch [1/10], train_loss:0.3534
epoch [1/10], val_acc:0.9524
epoch [2/10], train_loss:0.1459
epoch [2/10], val_acc:0.9649
epoch [3/10], train_loss:0.1037
epoch [3/10], val_acc:0.9724
epoch [4/10], train_loss:0.0835
epoch [4/10], val_acc:0.9739
epoch [5/10], train_loss:0.0665
epoch [5/10], val_acc:0.9726
epoch [6/10], train_loss:0.0579
epoch [6/10], val_acc:0.9780
epoch [7/10], train_loss:0.0503
epoch [7/10], val_acc:0.9751
epoch [8/10], train_loss:0.0434
epoch [8/10], val_acc:0.9777
epoch [9/10], train_loss:0.0390
epoch [9/10], val_acc:0.9758
epoch [10/10], train_loss:0.0354
epoch [10/10], val_acc:0.9817
