In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from tqdm import tqdm

learning_rate = 1e-3
learning_rate = 0.002
mnist_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
mnist_data = list(mnist_data)[:4096]
device = torch.device("mps")


other_data = datasets.Flowers102('data2', split="train", download=True,
                               transform=transforms.Compose([
                               transforms.Resize(128),
                               transforms.CenterCrop(128),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               transforms.Grayscale(num_output_channels=1),
                           ]))
other_data = list(other_data)[:4096]

In [33]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential( # like the Composition layer you built
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 7)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [34]:
def train(model, num_epochs=5, batch_size=64, learning_rate=learning_rate):
    torch.manual_seed(42)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    train_loader = torch.utils.data.DataLoader(mnist_data, batch_size=batch_size, shuffle=True, pin_memory=True)
    outputs = []
    for epoch in range(num_epochs):
        for data in tqdm(train_loader):
            img, label = data
            img = img.to(device)

            recon = model(img)
            loss = criterion(recon, img)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss)))
        outputs.append((epoch, img, recon),)
    return outputs


model = Autoencoder().to(device)
max_epochs = 50
outputs = train(model, num_epochs=max_epochs)

  0%|          | 0/64 [00:00<?, ?it/s]

100%|██████████| 64/64 [00:05<00:00, 12.78it/s]


Epoch:1, Loss:0.0696


100%|██████████| 64/64 [00:04<00:00, 13.91it/s]


Epoch:2, Loss:0.0647


100%|██████████| 64/64 [00:04<00:00, 13.75it/s]


Epoch:3, Loss:0.0518


100%|██████████| 64/64 [00:04<00:00, 13.82it/s]


Epoch:4, Loss:0.0358


100%|██████████| 64/64 [00:04<00:00, 13.60it/s]


Epoch:5, Loss:0.0290


100%|██████████| 64/64 [00:04<00:00, 13.63it/s]


Epoch:6, Loss:0.0202


100%|██████████| 64/64 [00:04<00:00, 13.99it/s]


Epoch:7, Loss:0.0164


100%|██████████| 64/64 [00:04<00:00, 14.12it/s]


Epoch:8, Loss:0.0149


100%|██████████| 64/64 [00:04<00:00, 13.76it/s]


Epoch:9, Loss:0.0139


100%|██████████| 64/64 [00:04<00:00, 13.94it/s]


Epoch:10, Loss:0.0103


100%|██████████| 64/64 [00:04<00:00, 14.58it/s]


Epoch:11, Loss:0.0083


100%|██████████| 64/64 [00:04<00:00, 13.79it/s]


Epoch:12, Loss:0.0088


100%|██████████| 64/64 [00:04<00:00, 14.09it/s]


Epoch:13, Loss:0.0097


100%|██████████| 64/64 [00:04<00:00, 14.50it/s]


Epoch:14, Loss:0.0080


100%|██████████| 64/64 [00:04<00:00, 14.38it/s]


Epoch:15, Loss:0.0068


  0%|          | 0/64 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
for k in range(0, max_epochs, 5):
    plt.figure(figsize=(9, 2))
    imgs = outputs[k][1].detach().numpy()
    recon = outputs[k][2].detach().numpy()
    for i, item in enumerate(imgs):
        if i >= 9: break
        plt.subplot(2, 9, i+1)
        plt.imshow(item[0])
        
    for i, item in enumerate(recon):
        if i >= 9: break
        plt.subplot(2, 9, 9+i+1)
        plt.imshow(item[0])

TypeError: can't convert mps:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

<Figure size 900x200 with 0 Axes>