In [None]:
from gans import Generator, Discriminator
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms.functional import resize
from tqdm import tqdm
from skimage import io, transform
import os
import sys

LR = 1e-3
WEIGHT_DECAY = 1e-5
BATCH_SIZE = 8
EPOCHS = 10
OUTPUT_LOSS_INTERVAL = 1
DEVICE_NAME = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def get_files_in_sub_dirs(path):
    file_names = []
    for root, dirs, files in os.walk(path):
        for file in files:
            file_names.append(os.path.join(root, file))
            
    return file_names

def get_img_paths(path):
    file_names = get_files_in_sub_dirs(path)
    file_names = [file for file in file_names if file.endswith('.jpg')]
    return file_names

In [None]:
img_paths = get_img_paths('../data')

In [None]:
# Code heavily inspired by: "https://pytorch.org/tutorials/beginner/basics/data_tutorial.html"
class SimpsonsDataset(Dataset):
    def __init__(self, img_paths, new_img_size, transform=None) -> None:
        self.image_paths = img_paths
        self.new_img_size = new_img_size
        self.transform = transform
        
    def __len__(self):
        return len(img_paths)
    
    def __getitem__(self, index):
        if index > len(img_paths):
            raise IndexError('Index out of bounds!')
        
        img_path = img_paths[index]
        img = read_image(img_path).float()
        img = resize(img, self.new_img_size, antialias=True)
        
        if self.transform:
            image = self.transform(image)
        
        return img

In [None]:
simp_dataset = SimpsonsDataset(img_paths=img_paths, new_img_size=(64, 64))

In [None]:
train_dataloader = DataLoader(simp_dataset, BATCH_SIZE, True)

In [None]:
# The generated images are 3x64x64
gen_model = Generator(layer_dims=[100, 128, 256, 512, 1024])
disc_model = Discriminator(layer_dims=[1024, 512, 256, 128, 1])

gen_optimizer = torch.optim.AdamW(gen_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
disc_optimizer = torch.optim.AdamW(disc_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

In [None]:
def train(gen_model, disc_model, device, optimizer, train_dataloader, epochs, loss_output_interval):
    model.to(device)
    model.train()
    
    for step in tqdm(range(epochs), desc="Epoch"):
        running_loss = 0
        total_running_loss = 0
        
        for i, X in enumerate(tqdm(train_dataloader, desc="Batch", total=len(train_dataloader))):
            y = torch.flatten(X)
            X = X.to(device)
            y = y.to(device)
            y_hat, loss = model(X, True)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            total_running_loss += loss.item()
        
        if (step+1) % loss_output_interval == 0:
            print(f' Epoch {step+1} Average Batch Loss: {total_running_loss/len(train_dataloader)}')
            
    model.eval()
    

In [None]:
train(gen_model, disc_model, device=torch.device(DEVICE_NAME), optimizer=optimizer, train_dataloader=train_dataloader, epochs=EPOCHS, loss_output_interval=OUTPUT_LOSS_INTERVAL)