In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# from torchvision import transforms

import os
import numpy as np
import PIL as pillow

In [6]:
######################################################################
# Setup working directory
######################################################################
%mkdir -p ./content/A2/
%cd ./content/A2

######################################################################
# Helper functions for loading data
######################################################################
# adapted from
# https://github.com/fchollet/keras/blob/master/keras/datasets/cifar10.py

import os
import pickle
import sys
import tarfile

import numpy as np
from PIL import Image
from six.moves.urllib.request import urlretrieve


def get_file(fname, origin, untar=False, extract=False, archive_format="auto", cache_dir="data"):
    datadir = os.path.join(cache_dir)
    if not os.path.exists(datadir):
        os.makedirs(datadir)

    if untar:
        untar_fpath = os.path.join(datadir, fname)
        fpath = untar_fpath + ".tar.gz"
    else:
        fpath = os.path.join(datadir, fname)

    print("File path: %s" % fpath)
    if not os.path.exists(fpath):
        print("Downloading data from", origin)

        error_msg = "URL fetch failure on {}: {} -- {}"
        try:
            try:
                urlretrieve(origin, fpath)
            except URLError as e:
                raise Exception(error_msg.format(origin, e.errno, e.reason))
            except HTTPError as e:
                raise Exception(error_msg.format(origin, e.code, e.msg))
        except (Exception, KeyboardInterrupt) as e:
            if os.path.exists(fpath):
                os.remove(fpath)
            raise

    if untar:
        if not os.path.exists(untar_fpath):
            print("Extracting file.")
            with tarfile.open(fpath) as archive:
                archive.extractall(datadir)
        return untar_fpath

    if extract:
        _extract_archive(fpath, datadir, archive_format)

    return fpath


def load_batch(fpath, label_key="labels"):
    """Internal utility for parsing CIFAR data.
    # Arguments
        fpath: path the file to parse.
        label_key: key for label data in the retrieve
            dictionary.
    # Returns
        A tuple `(data, labels)`.
    """
    f = open(fpath, "rb")
    if sys.version_info < (3,):
        d = pickle.load(f)
    else:
        d = pickle.load(f, encoding="bytes")
        # decode utf8
        d_decoded = {}
        for k, v in d.items():
            d_decoded[k.decode("utf8")] = v
        d = d_decoded
    f.close()
    data = d["data"]
    labels = d[label_key]

    data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels


def load_cifar10(transpose=False):
    """Loads CIFAR10 dataset.
    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    dirname = "cifar-10-batches-py"
    origin = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    path = get_file(dirname, origin=origin, untar=True)

    num_train_samples = 50000

    x_train = np.zeros((num_train_samples, 3, 32, 32), dtype="uint8")
    y_train = np.zeros((num_train_samples,), dtype="uint8")

    for i in range(1, 6):
        fpath = os.path.join(path, "data_batch_" + str(i))
        data, labels = load_batch(fpath)
        x_train[(i - 1) * 10000 : i * 10000, :, :, :] = data
        y_train[(i - 1) * 10000 : i * 10000] = labels

    fpath = os.path.join(path, "test_batch")
    x_test, y_test = load_batch(fpath)

    y_train = np.reshape(y_train, (len(y_train), 1))
    y_test = np.reshape(y_test, (len(y_test), 1))

    if transpose:
        x_train = x_train.transpose(0, 2, 3, 1)
        x_test = x_test.transpose(0, 2, 3, 1)
    return (x_train, y_train), (x_test, y_test)

/Users/nikh/Columbia/vision2_4732/final_project/content/A2


In [7]:
# Download cluster centers for k-means over colours
colours_fpath = get_file(
    fname="colours", origin="http://www.cs.toronto.edu/~jba/kmeans_colour_a2.tar.gz", untar=True
)
# Download CIFAR dataset
m = load_cifar10()

File path: data/colours.tar.gz
Downloading data from http://www.cs.toronto.edu/~jba/kmeans_colour_a2.tar.gz
Extracting file.
File path: data/cifar-10-batches-py.tar.gz
Downloading data from http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


NameError: name 'URLError' is not defined

In [3]:
class ConvBlock(nn.Module):
    """
    A convolutional block that applies two convolutional operations, each followed by
    a ReLU activation and batch normalization.

    Attributes:
    - in_channels (int): Number of channels in the input.
    - out_channels (int): Number of channels produced by the convolution.
    - kernel_size (int, optional): Size of the convolving kernel. Default: 3.
    - padding (int, optional): Zero-padding added to both sides of the input. Default: 1.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding)

    def forward(self, x):
        x = self.relu(self.bn(self.conv1(x)))
        x = self.relu(self.bn(self.conv2(x)))
        return x


class UpConvBlock(nn.Module):
    """
    An upsampling block that first applies a transposed convolution to upsample the input,
    and then performs a convolutional operation to refine the features.

    Attributes:
    - in_channels (int): Number of channels in the input.
    - out_channels (int): Number of channels produced by the convolution.
    - kernel_size (int, optional): Size of the convolving kernel. Default: 3.
    - padding (int, optional): Zero-padding added to both sides of the input. Default: 1.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(UpConvBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_block = ConvBlock(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

    def forward(self, x, skip):
        x = self.upconv(x)
        x = torch.cat([x, skip], dim=1)
        x = self.conv_block(x)
        return x


In [5]:
class ModularUNet(nn.Module):
    """
    A modular U-Net architecture that can be dynamically adjusted in terms of depth.

    Attributes:
    - in_channels (int): Number of channels in the input image.
    - out_channels (int): Number of channels in the output image.
    - num_layers (int): Total number of layers in the encoder (excluding bottleneck).
    """
    def __init__(self, in_channels=3, out_channels=3, num_layers=6):
        super(ModularUNet, self).__init__()
        self.num_layers = num_layers
        filters = [64 * 2 ** i for i in range(num_layers)]  # Define the number of filters in each layer

        # Initial convolution block
        self.initial_conv = ConvBlock(in_channels, filters[0])

        # Encoder - dynamically create down-sampling layers
        self.down_blocks = nn.ModuleList([ConvBlock(filters[i], filters[i+1]) for i in range(num_layers-1)])
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = ConvBlock(filters[-2], filters[-1])

        # Decoder - dynamically create up-sampling layers
        self.up_blocks = nn.ModuleList([UpConvBlock(filters[i], filters[i-1]) for i in reversed(range(1, num_layers))])

        # Final convolution
        self.final_conv = nn.Conv2d(filters[0], out_channels, kernel_size=1)

    def forward(self, x):
        # Step-by-step through encoder
        skip_connections = []
        x = self.initial_conv(x)
        skip_connections.append(x)

        for down in self.down_blocks:
            x = self.pool(x)
            x = down(x)
            skip_connections.append(x)

        x = self.bottleneck(x)

        # Reverse the skip connections list for the decoder
        skip_connections = skip_connections[::-1]

        # Step-by-step through decoder, using skip connections
        for i, up in enumerate(self.up_blocks):
            x = up(x, skip_connections[i])

        # Apply final convolution to get to the target output size
        x = self.final_conv(x)
        return x