## Imports

In [16]:
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

from fastprogress.fastprogress import master_bar, progress_bar

from jupyterthemes import jtplot

from torch.utils.data import DataLoader, Subset, Dataset
from torchsummary import summary
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid
from torchvision import transforms

from PIL import Image

import os

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

jtplot.style(context="talk")

In [17]:
# #makes github ignore all the data

# with open(".gitignore", "w") as f:
#     f.write("""
# # Ignore image data folders
# tiny-imagenet-200/
# tiny-imagenet-200-grayscale/

# # Ignore any .DS_Store or similar files
# .DS_Store

# # Ignore Python cache files
# __pycache__/
# *.pyc
# """)


## Dataset Download

## Dataset Utility

In [18]:
def get_data_loaders(path, batch_size, valid_batch_size=0):

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

    tbs = len(train_dataset) if batch_size == 0 else batch_size
    train_loader = DataLoader(train_dataset, batch_size=tbs, shuffle=True)

    valid_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    vbs = len(valid_dataset) if valid_batch_size == 0 else valid_batch_size
    valid_loader = DataLoader(valid_dataset, batch_size=vbs, shuffle=True)

    return train_loader, valid_loader

In [19]:
def rgb_to_gray(img):
    return img.mean(dim=1, keepdim=True)

## Training Utility

In [20]:
def train_one_epoch(mb, loader, device, model, criterion, optimizer):
    
    model.train()

    losses = []

    num_batches = len(loader)
    dataiterator = iter(loader)

    for batch in progress_bar(range(num_batches), parent=mb):

        mb.child.comment = "Training"

        # Grab the batch of data and send it to the correct device
        X, _ = next(dataiterator)
        Y = X
        Y = rgb_to_gray(Y).to(device)
        X = X.to(device)
        print("X:", X)  # Should be [batch_size, 3, 64, 64]
        print("Y:", Y)


        # Compute the output
        output = model(Y)

        # Compute loss
        
        loss = criterion(output, X)
        losses.append(loss.item())

        # Update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return losses

## Validation Utility

In [21]:
def validate(mb, loader, device, model, criterion):

    model.eval()

    losses = []
    num_correct = 0

    num_classes = len(loader.dataset.classes)
    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    N = len(loader.dataset)
    num_batches = len(loader)
    dataiterator = iter(loader)

    with torch.no_grad():

        batches = range(num_batches)
        batches = progress_bar(batches, parent=mb) if mb else batches
        for batch in batches:

            if mb:
                mb.child.comment = f"Validation"

            # Grab the batch of data and send it to the correct device
            X, _ = next(dataiterator)
            Y = X
            Y = rgb_to_gray(Y).to(device)
            X = X.to(device)

            output = model(Y)

            print(output.shape)
            print(Y.shape)
            loss = criterion(output, X)
            losses.append(loss.item())

            # Convert network output into predictions (one-hot -> number)
            predictions = output.argmax(dim=1)

            # Sum up total number that were correct
            comparisons = predictions == Y
            num_correct += comparisons.type(torch.float).sum().item()

            # Sum up number of correct per class
            for result, clss in zip(comparisons, Y):
                class_correct[clss] += result.item()
                class_total[clss] += 1

    accuracy = 100 * (num_correct / N)
    accuracies = {
        clss: 100 * class_correct[clss] / class_total[clss]
        for clss in range(num_classes)
    }

    return losses, accuracy, accuracies

## Loss Plotting Utility

In [22]:
def update_plots(mb, train_losses, valid_losses, epoch, num_epochs):

    # Update plot data
    max_loss = max(max(train_losses), max(valid_losses))
    min_loss = min(min(train_losses), min(valid_losses))

    x_margin = 0.2
    x_bounds = [0 - x_margin, num_epochs + x_margin]

    y_margin = 0.1
    y_bounds = [min_loss - y_margin, max_loss + y_margin]

    train_xaxis = torch.linspace(0, epoch + 1, len(train_losses))
    valid_xaxis = torch.linspace(0, epoch + 1, len(valid_losses))
    graph_data = [[train_xaxis, train_losses], [valid_xaxis, valid_losses]]

    mb.update_graph(graph_data, x_bounds, y_bounds)

## Data Loading

In [23]:
# TODO: tune the training batch size
train_batch_size = 128

# Let's use some shared space for the data (so that we don't have copies
# sitting around everywhere)
data_path = "~/data"

# Use the GPUs if they are available
# TODO: if you run into GPU memory errors you should set device to "cpu" and restart the notebook
#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda"
print(f"Using '{device}' device.")

valid_batch_size = 32
train_loader, valid_loader = get_data_loaders(data_path, train_batch_size, valid_batch_size)


# Input and output sizes depend on data

# class_names = sorted(os.listdir(grayscale_path))
# num_classes = len(class_names)


# print(class_names)

Using 'cuda' device.
Files already downloaded and verified
Files already downloaded and verified


In [24]:
# # Grab a bunch of images and change the range to [0, 1]
# nprint = 64
# images = torch.tensor(train_loader.dataset.data[:nprint] / 255)
# targets = train_loader.dataset.targets[:nprint]
# labels = [f"{class_names[target]:>10}" for target in targets]

# # Create a grid of the images (make_grid expects (BxCxHxW))
# image_grid = make_grid(images.permute(0, 3, 1, 2))

# _, ax = plt.subplots(figsize=(16, 16))
# ax.imshow(image_grid.permute(1, 2, 0))
# ax.grid(None)

# images_per_row = int(nprint ** 0.5)
# for row in range(images_per_row):
#     start_index = row * images_per_row
#     print(" ".join(labels[start_index : start_index + images_per_row]))

## Model Creation

In [25]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda")


# Simple CNN
class ColorizationCNN(nn.Module):
    def __init__(self):
        super(ColorizationCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=4, dilation=2)
        self.relu1 = nn.ReLU(True)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=4, dilation=2)
        self.relu2 = nn.ReLU(True)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=4, dilation=2)
        self.relu3 = nn.ReLU(True)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=5, stride=1, padding=4, dilation=2)
        self.relu4 = nn.ReLU(True)
        self.conv5 = nn.Conv2d(256, 128, kernel_size=5, stride=1, padding=4, dilation=2)
        self.relu5 = nn.ReLU(True)
        self.conv6 = nn.Conv2d(128, 3, kernel_size=5, stride=1, padding=4, dilation=2)
        self.relu6 = nn.ReLU(True)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.relu(self.conv3(x))
        x = nn.functional.relu(self.conv4(x))
        x = nn.functional.relu(self.conv5(x))
        x = torch.sigmoid(self.conv6(x))
        return x



    
    # def __init__(self):
    #     super().__init__()
    #     self.encoder = nn.Sequential(
    #         nn.Conv2d(1, 64, 3, padding=1),  # grayscale input
    #         nn.ReLU(),
    #         nn.MaxPool2d(2, 2),
    #         nn.Conv2d(64, 128, 3, padding=1),
    #         nn.ReLU(),
    #         nn.MaxPool2d(2, 2)
    #     )
    #     self.decoder = nn.Sequential(
    #         nn.ConvTranspose2d(128, 64, 2, stride=2),
    #         nn.ReLU(),
    #         nn.ConvTranspose2d(64, 3, 2, stride=2),  # 3-channel output
    #         nn.Sigmoid()  # values in [0, 1]
    #     )

    # def forward(self, x):
    #     x = self.encoder(x)
    #     x = self.decoder(x)
    #     return x


In [26]:
# TODO: try out different network widths and depths
# neurons_per_hidden_layer = [1024, 512, 256]
# layer_sizes = [num_features, *neurons_per_hidden_layer, num_classes]
# model = NeuralNetwork(layer_sizes).to(device)

# TODO: complete the CNN class in the cell above this one and then uncomment this line
# model = CNN().to(device)

# TODO: use an off-the-shell model from PyTorch
# from torchvision.models import ...
# model = ...

# TINT TODO: make the output have 3 nodes.
# from torchvision.models import resnet18
# model = resnet18(num_classes=num_classes).to(device)

# summary(model)

#TINT
# Instantiate the model
model = ColorizationCNN().to(device)
summary(model)
# pixel-wise loss
criterion = nn.MSELoss()

# TODO: try out different Adam hyperparameters
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()

Layer (type:depth-idx)                   Param #
├─Conv2d: 1-1                            1,664
├─ReLU: 1-2                              --
├─Conv2d: 1-3                            102,464
├─ReLU: 1-4                              --
├─Conv2d: 1-5                            204,928
├─ReLU: 1-6                              --
├─Conv2d: 1-7                            819,456
├─ReLU: 1-8                              --
├─Conv2d: 1-9                            819,328
├─ReLU: 1-10                             --
├─Conv2d: 1-11                           9,603
├─ReLU: 1-12                             --
Total params: 1,957,443
Trainable params: 1,957,443
Non-trainable params: 0


In [27]:
model = ColorizationCNN().to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [28]:
# Test single batch

# gray_batch, color_batch = next(iter(train_loader))
# gray_batch, color_batch = gray_batch.to(device), color_batch.to(device)
# output = model(gray_batch)
# print("Output shape:", output.shape)  # Should be [batch_size, 3, 64, 64]

        


## Training and Analysis

In [29]:
# TODO: tune the number of epochs
num_epochs = 3

train_losses = []
valid_losses = []
accuracies = []

# A master bar for fancy output progress
mb = master_bar(range(num_epochs))
mb.names = ["Train Loss", "Valid Loss"]
mb.main_bar.comment = f"Epochs"

# Loss and accuracy prior to training
vl, accuracy, _ = validate(None, valid_loader, device, model, criterion)
valid_losses.extend(vl)
accuracies.append(accuracy)

for epoch in mb:

    tl = train_one_epoch(mb, train_loader, device, model, criterion, optimizer)
    train_losses.extend(tl)

    vl, accuracy, acc_by_class = validate(mb, valid_loader, device, model, criterion)
    valid_losses.extend(vl)
    accuracies.append(accuracy)

    update_plots(mb, train_losses, valid_losses, epoch, num_epochs)

torch.Size([32, 3, 32, 32])
torch.Size([32, 1, 32, 32])


TypeError: only integer tensors of a single element can be converted to an index

In [None]:
plt.plot(accuracies, '--o')
plt.title("Accuracy")
plt.xlabel("Epoch")
plt.xticks(range(num_epochs+1))
plt.ylim([0, 100])

max_name_len = max(len(name) for name in class_names)

print("Accuracy per class")
for clss in acc_by_class:
    class_name = class_names[clss]
    class_accuracy = acc_by_class[clss]
    print(f"  {class_name:>{max_name_len+2}}: {class_accuracy:.1f}%")

In [None]:
y_trues = []
y_preds = []
model.to(device)
for x, y in valid_loader:
    y_trues.append(y.cpu())
    y_preds.append(model(x.to(device)).argmax(dim=1).cpu())

y_true = torch.hstack(y_trues)
y_pred = torch.hstack(y_preds)

In [31]:
cm = confusion_matrix(y_true, y_pred)
ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names).plot();
plt.grid(False)

NameError: name 'y_true' is not defined

In [None]:
# TODO: Take the three outputs and reconstruct an image