In [1]:
from autoencoder import Autoencoder
import torch.nn.functional as F
import torch
import os
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms.functional import resize
from tqdm import tqdm

LR = 1e-3
WEIGHT_DECAY = 1e-5
BATCH_SIZE = 8
EPOCHS = 10
OUTPUT_LOSS_INTERVAL = 1

In [2]:
def get_files_in_sub_dirs(path):
    file_names = []
    for root, dirs, files in os.walk(path):
        for file in files:
            file_names.append(os.path.join(root, file))
            
    return file_names

def get_img_paths(path):
    file_names = get_files_in_sub_dirs(path)
    file_names = [file for file in file_names if file.endswith('.jpg')]
    return file_names

In [3]:
img_paths = get_img_paths('../data')

In [4]:
# Code heavily inspired by: "https://pytorch.org/tutorials/beginner/basics/data_tutorial.html"
class SimpsonsDataset(Dataset):
    def __init__(self, img_paths, new_img_size, transform=None) -> None:
        self.image_paths = img_paths
        self.new_img_size = new_img_size
        self.transform = transform
        
    def __len__(self):
        return len(img_paths)
    
    def __getitem__(self, index):
        if index > len(img_paths):
            raise IndexError('Index out of bounds!')
        
        img_path = img_paths[index]
        img = read_image(img_path).float()
        img = resize(img, self.new_img_size)
        
        if self.transform:
            image = self.transform(image)
        
        return img

In [None]:
simp_dataset = SimpsonsDataset(img_paths=img_paths, new_img_size=(64, 64))

In [6]:
train_dataloader = DataLoader(simp_dataset, BATCH_SIZE, True)

In [7]:
model = Autoencoder((3, 64, 64), (8192, 4096, 1024, 256))
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

In [None]:
print("Hyperparameters:")
print("Batch size:", BATCH_SIZE)
print("Learning rate:", LR)
print("Weight decay:", WEIGHT_DECAY)
print("Epochs:", EPOCHS)
print("Output loss interval:", OUTPUT_LOSS_INTERVAL)
print()

In [8]:
def train(model, device, optimizer, train_dataloader, epochs, loss_output_interval):
    model.to(device)
    model.train()
    
    for step in tqdm(range(epochs), desc="Epoch"):
        running_loss = 0
        total_running_loss = 0
        
        for i, X in enumerate(tqdm(train_dataloader, desc="Batch", total=len(train_dataloader))):
            y = torch.flatten(X)
            X = X.to(device)
            y = y.to(device)
            y_hat, loss = model(X, X)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            total_running_loss += loss.item()
        
        if (step+1) % loss_output_interval == 0:
            print(f' Epoch {step+1} Average Batch Loss: {total_running_loss/len(train_dataloader)}')
            
    model.eval()
    

In [None]:
train(model, device=torch.device('cuda'), optimizer=optimizer, train_dataloader=train_dataloader, epochs=EPOCHS, loss_output_interval=OUTPUT_LOSS_INTERVAL)