# Training Basic (Non-Graph) Autoencoder

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torchvision
import torch
from torchvision.transforms import Compose, ToTensor, Resize, Scale
from torch.utils.data.dataloader import DataLoader
from torch.nn.functional import one_hot
from basic_autoencoder import BasicAE
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import save_image
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = 'cuda:3'

In [4]:
transforms = Compose([
    ToTensor(),
    Resize((320//2, 480//2)),
    ])
dataset = torchvision.datasets.ImageFolder('data/CLEVR_v1.0/images/train', transform=transforms)


In [5]:
def to_np(tnsr):
    return tnsr.detach().cpu().numpy().transpose((1,2,0))

In [6]:
ae = BasicAE(n_channels=3, w=320//2, h=480//2, device=device).to(device)
optim = torch.optim.Adam(params=ae.parameters())

dataloader = DataLoader(dataset=dataset, batch_size=1)

optim.zero_grad()
batch_size = 100
n_epochs = 5
i=0
batch_loss = 0
batch_variance = 0
batch_overlap = 0
batch_l1 = 0
batch_total = 0
checkpoint = 1000
niter = 100000

tmstp = datetime.strftime(datetime.now(), '%Y%m%d-%H%M')

image,_ = next(iter(dataloader))
# print(image.shape)

for epoch in range(n_epochs):
    i=0
    for image,_ in dataloader:
    # for j in range(niter):
        image = image.squeeze(0).to(device)
        recon = ae(image)
        loss = torch.mean((recon - image)**2) 
        loss.backward()
        batch_loss += float(loss)/batch_size
        
        i+=1
        if i%batch_size==0:
            optim.step()
            optim.zero_grad()
            print(f"epoch={epoch:4d} n={i:8d} loss={batch_loss:8.4f} ", flush=True)
            batch_loss = 0
            # scheduler.step()
        if i%checkpoint==0:
            torch.save(ae.state_dict(), f'models/lgvae_{tmstp}.torch')
            save_image(recon, f'outputs/basic/basic_ae_{epoch}-{i}-_{tmstp}.png')


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch=   0 n=     100 loss=  0.2137 
epoch=   0 n=     200 loss=  0.2110 
epoch=   0 n=     300 loss=  0.1929 
epoch=   0 n=     400 loss=  0.1134 
epoch=   0 n=     500 loss=  0.1567 
epoch=   0 n=     600 loss=  0.1741 
epoch=   0 n=     700 loss=  0.1565 
epoch=   0 n=     800 loss=  0.0674 
epoch=   0 n=     900 loss=  0.5439 
epoch=   0 n=    1000 loss=  0.0568 
epoch=   0 n=    1100 loss=  0.1654 
epoch=   0 n=    1200 loss=  0.1920 
epoch=   0 n=    1300 loss=  0.2022 
epoch=   0 n=    1400 loss=  0.2051 
epoch=   0 n=    1500 loss=  0.2104 
epoch=   0 n=    1600 loss=  0.2084 
epoch=   0 n=    1700 loss=  0.2088 
epoch=   0 n=    1800 loss=  0.2041 
epoch=   0 n=    1900 loss=  0.2039 
epoch=   0 n=    2000 loss=  0.2025 
epoch=   0 n=    2100 loss=  0.1970 
epoch=   0 n=    2200 loss=  0.1854 
epoch=   0 n=    2300 loss=  0.1618 
epoch=   0 n=    2400 loss=  0.1064 
epoch=   0 n=    2500 loss=  0.0862 
epoch=   0 n=    2600 loss=  0.0353 
epoch=   0 n=    2700 loss=  0.0228 
e

KeyboardInterrupt: 