In [31]:
# import packages
import os
import torch 
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from nets import Autoencoder_linear,VAE

In [32]:
NUM_EPOCHS = 15
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
ae='VAE'
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,), std=(0.5,))])

In [33]:
trainset = datasets.FashionMNIST(root='./data',train=True,download=True, transform=transforms.ToTensor())
trainLoader=DataLoader(dataset=trainset,batch_size=BATCH_SIZE,shuffle=True)
##
testset = datasets.FashionMNIST(root='./data',train=False,download=True, transform=transforms.ToTensor())
testLoader=DataLoader(dataset=testset,batch_size=BATCH_SIZE,shuffle=True)

In [34]:
# utility functions
def get_device():
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'
    return device
def make_dir(ae):
    if not os.path.exists(ae):
        os.makedirs(ae)
def save_decoded_image(img, epoch,ae):
    img = img.view(img.size(0), 1, 28, 28)
    save_image(img, './{}/decode_image{}.png'.format(ae,epoch))

device = get_device()

In [42]:
def train(net, trainloader, NUM_EPOCHS):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(),lr=LEARNING_RATE)
    train_loss = []
    for epoch in range(NUM_EPOCHS):
        running_loss = 0.0
        for data in trainloader:
            img, _ = data
            img = img.to(device)
            img = img.view(img.size(0), -1)
            optimizer.zero_grad()
            outputs = net(img)
            loss = criterion(outputs, img)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        loss = running_loss / len(trainloader)
        train_loss.append(loss)
        print('Epoch {} of {}, Train Loss: {:.3f}'.format(
            epoch+1, NUM_EPOCHS, loss))

        if epoch % 5 == 0:
            save_decoded_image(outputs.cpu().data, epoch,ae)

    return train_loss

def vae_train(net, trainloader, NUM_EPOCHS):
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCELoss(reduction='sum')
    #criterion = nn.MSELoss()
    train_loss = []
    for epoch in range(NUM_EPOCHS):  # loop over the dataset multiple times
        running_loss = 0.0
        for data in trainloader:
            inputs, _ = data
            inputs = inputs.to(device)
            inputs = inputs.view(inputs.size(0), -1)
            optimizer.zero_grad()
            reconstruction, mu, logvar = net(inputs)
            bce_loss = criterion(reconstruction, inputs)
            loss = VAE.final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()
            loss.backward()
            optimizer.step()

            # print statistics
        loss = running_loss / len(trainloader)
        train_loss.append(loss)
        print('Epoch {} of {}, Train Loss: {:.3f}'.format(
            epoch+1, NUM_EPOCHS, loss))

        if epoch % 5 == 0:
            save_decoded_image(reconstruction.cpu().data, epoch,ae)
    return train_loss

def test_image_reconstruction(net, testloader):
     for batch in testloader:
        img, _ = batch
        img = img.to(device)
        img = img.view(img.size(0), -1)
        if ae=='VAE':
            outputs,mu,logvar = net(img)
        else:
            outputs = net(img)
        outputs = outputs.view(outputs.size(0), 1, 28, 28).cpu().data
        save_image(outputs, 'images_reconstruction.png')
        break

In [44]:
ae='AE'
make_dir(ae)
if ae=='VAE':
    net = VAE()
    print(net)
    net.to(device)
    train_loss=vae_train(net, trainLoader, NUM_EPOCHS)
else:
    net = Autoencoder_linear()
    print(net)
    net.to(device)
    train_loss = train(net, trainLoader, NUM_EPOCHS)
# train the network
plt.figure()
plt.plot(train_loss)
plt.title('Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.savefig('loss.png')
# test the network
test_image_reconstruction(net, testLoader)

Autoencoder_linear(
  (enc1): Linear(in_features=784, out_features=256, bias=True)
  (enc2): Linear(in_features=256, out_features=128, bias=True)
  (enc3): Linear(in_features=128, out_features=64, bias=True)
  (enc4): Linear(in_features=64, out_features=32, bias=True)
  (enc5): Linear(in_features=32, out_features=16, bias=True)
  (dec1): Linear(in_features=16, out_features=32, bias=True)
  (dec2): Linear(in_features=32, out_features=64, bias=True)
  (dec3): Linear(in_features=64, out_features=128, bias=True)
  (dec4): Linear(in_features=128, out_features=256, bias=True)
  (dec5): Linear(in_features=256, out_features=784, bias=True)
)


AttributeError: 'list' object has no attribute 'to'