In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torchvision
import tqdm

from torchvision.transforms import transforms
from torch.utils.data import Dataset
from torchvision.io import read_image
from skimage.color import rgb2lab, lab2rgb
from colorizationTools import EncoderDataset, ColorizationAutoencoder
from torch.utils.data import DataLoader

In [2]:
import os
import numpy as np
import random

In [3]:
base_dir = "/Users/kennethzhang/Desktop/landscape Images/"

ok = read_image(base_dir + "color/" 
                + "0.jpg")

batch_size = 16
epochs = 20
num_workers = 0

total_images = len(os.listdir(base_dir + 'color'))
random_indices = random.sample(list(range(total_images)), total_images)

In [4]:
split_idx = round(total_images * 0.8)
train_idx = random_indices[:split_idx]
test_idx = random_indices[split_idx:]

train_transforms = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
test_transforms = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])

train_dataset = EncoderDataset(train_idx, base_dir)
test_dataset = EncoderDataset(test_idx, base_dir)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size, batch_size)

In [7]:
model = ColorizationAutoencoder()
model

ColorizationAutoencoder(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (pooling2d): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (t_conv1): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  (t_conv2): ConvTranspose2d(256, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  (t_conv3): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  (t_conv4): ConvTranspose2d(192, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (dropout): Dropout(p=0.2, inplace=False)
  (converge): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [12]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

n_epochs = 30
train_losses = []
test_losses = []

for epoch in range(1, n_epochs + 1):

    train_loss = 0.0

    for data in train_loader:
        images, labels = data

        images = images.float()
        labels = labels.float()

        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        train_loss += loss.item() * images.size(0)
    
    print('Epoch: {}\Training Loss: {:.6f}'.format(
        epochs, train_loss
    ))  

    loss = train_loss / len(train_loader)
    train_losses.append(loss)      

    test_loss = 0.0

    with torch.no_grad():
        for images, labels in test_loader:
            output = model(images)
            loss = criterion(output, labels)
            test_loss += loss.item() * images.size(0)

        print("Test Loss: {:.3f}.. ".format(
        test_loss
        ))

    test_loss = test_loss / len(test_loader)
    test_losses.append(test_loss)

Epoch: 20\Training Loss: 37.483808
Test Loss: 4.655.. 
