In [None]:
%load_ext lab_black
%matplotlib inline
%config IPCompleter.greedy=True

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# torch.manual_seed(0)
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torchvision

from torchsummary import summary

from ray import tune
from ray.tune.schedulers import ASHAScheduler

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import time
from pathlib import Path

# face dataset

the following class reads the face dataset and creates a torch dataset object for it. With this, you can easily 
use a dataloader to train your model. 

**1** Make sure that the file "hw2_Q1.npy" is located properly (in this example, it should be in the same folder as this notebook.

**2** Note that the "hw2_Q1.npy" stores images in uint8 format. To use it for our purpose, we convert it to float32. You need to do the same for Q2 and Q3 of the assignment  



In [None]:
class FaceData(Dataset):
    def __init__(self, ray_tune=False):
        # Ray Tune requires an absolute path
        # go back 2 folders since ray goes 2 deeper
        actual_cwd = str(Path.cwd().parents[1])
        if not ray_tune:
            actual_cwd = "."

        self.images = np.load(f"{actual_cwd}/data/hw2_Q1.npy")
        self.images = np.float32(self.images) / 255.0
        self.images = self.images.reshape(-1, 64 * 64)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # single index, shape: (1, 4096)
        # multi-index, shape: (len(idx), 4096)
        sample = self.images[idx, :]
        return sample

Here is an example of how you can create a dataloader for the face data

In [None]:
tmp_dataset = FaceData()
tmp_loader = DataLoader(tmp_dataset, batch_size=8, shuffle=True)

Let's visualize some of the samples

In [None]:
image_batch = next(iter(tmp_loader))
fig, ax_arr = plt.subplots(2, 4)
for i in range(8):
    img = image_batch[i].numpy()
    ax_arr[i // 4, i % 4].imshow(img.reshape([64, 64]), cmap="gray")
    ax_arr[i // 4, i % 4].axis("off")
fig.set_figheight(10)
fig.set_figwidth(20)
plt.show()

# Defining the model and training function

In [None]:
# Defining our neural network
class AE(nn.Module):
    def __init__(self, n):
        """
        Constructor for our SimpleAE,
        where n should be 16 or 64
        """
        super(AE, self).__init__()
        # input -> fc1 -> ReLU -> fc2 -> output
        self.fc1 = nn.Linear(4096, n)
        self.fc2 = nn.Linear(n, 4096)

        self._name = self.__class__.__name__ + "_n{}".format(n)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

## Encapsulating the training function to use with Ray Tune to identify the learning rate

In [None]:
def train_model(config):
    """The original training function has been modified in order to use Ray's Tune"""

    logger = {
        "train": np.zeros(config["num_epochs"]),
        "test": np.zeros(config["num_epochs"]),
    }

    #### LOAD DATA ####
    # no test/validation set
    train_data = FaceData(config["ray_tune_enabled"])
    # b_size = 16
    b_size = config["batch_size"]
    n_workers = 4 * torch.cuda.device_count()
    train_dataloader = DataLoader(
        train_data,
        batch_size=b_size,
        num_workers=n_workers,
        shuffle=True,
        pin_memory=False,
    )

    #### INSTANTIATE MODEL ####
    device = "cuda" if torch.cuda.is_available() else "cpu"
    net = AE(config["n"]).to(device)

    loss_function = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=config["lr"])

    #### BEGIN TRAINING ####
    start_time = time.time()
    for j in range(config["num_epochs"]):
        ## START OF BATCH ##
        train_loss = 0.0
        train_steps = 0
        for batch_id, data in enumerate(train_dataloader):
            data = data.to(device)
            prediction = net(data)

            # there are no labels
            loss = loss_function(prediction, data)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.cpu().detach().numpy()
            train_steps += 1
        ## END OF BATCH ##

        epoch_loss = train_loss / train_steps
        # send current training result back to Tune
        if config["ray_tune_enabled"]:
            tune.report(loss=(epoch_loss))

        logger["train"][j] = epoch_loss

        if config["log_training"] and (j + 1) % config["log_interval"] == 0:
            print(
                f"Epoch:{j+1}/{config['num_epochs']} \
                Train_Loss: {logger['train'][j]:.6f}"
            )

        if config["save_model"] and (j + 1) % config["save_interval"] == 0:
            checkpoint_num = str(j + 1).zfill(len(str(config["num_epochs"])))
            model_path = f"./models/{net._name}_{checkpoint_num}.pt"
            torch.save(net.state_dict(), model_path)

    print(f"{config['num_epochs']} epochs took {time.time() - start_time:.2f}s")

    if config["log_training"]:
        return logger

In [None]:
# create separate folder to store our models
!mkdir models

# Run Tune using ASHA Scheduler to find ideal parameters

## run Tune for n = 16

In [None]:
assert False  # remove to make cell work
search_space = {
    "n": 16,
    "lr": tune.loguniform(1e-6, 1e-1),
    "batch_size": tune.choice([8, 16, 32, 64]),
    "log_training": False,
    "log_interval": 10,
    "save_model": False,
    "save_interval": 10,
    "num_epochs": 200,
    "ray_tune_enabled": True
}
# enable early stopping
asha_scheduler = ASHAScheduler(max_t=200, grace_period=25)
# number of samples to run
n_samples = 20
# run training with Tune
analysis = tune.run(
    train_model,
    num_samples=n_samples,
    config=search_space,
    resources_per_trial={"gpu": 1},
    scheduler=asha_scheduler,
    metric="loss",
    mode="min",
    local_dir="./",
)

## run Tune for n = 64

In [None]:
assert False  # remove to make cell work
search_space = {
    "n": 64,
    "lr": tune.loguniform(1e-6, 1e-1),
    "batch_size": tune.choice([8, 16, 32, 64]),
    "log_training": False,
    "log_interval": 10,
    "save_model": False,
    "save_interval": 10,
    "num_epochs": 200,
    "ray_tune_enabled": True
}
# enable early stopping
asha_scheduler = ASHAScheduler(max_t=200, grace_period=25)
# number of samples to run
n_samples = 20
# run training with Tune
analysis = tune.run(
    train_model,
    num_samples=n_samples,
    config=search_space,
    resources_per_trial={"gpu": 1},
    scheduler=asha_scheduler,
    metric="loss",
    mode="min",
    local_dir="./",
)

# Use ideal parameters from Tune runs

## Train model for n = 16

In [None]:
model_config = {
    "n": 16,
    "lr": 9.33e-5,
    "batch_size": 16,
    "log_training": True,
    "log_interval": 10,
    "save_model": False,
    "save_interval": 10,
    "num_epochs": 200,
    "ray_tune_enabled": False,
}

# run model training
training_log_n16 = train_model(model_config)

## Train model for n = 64

In [None]:
model_config = {
    "n": 64,
    "lr": 1.67e-4,
    "batch_size": 64,
    "log_training": True,
    "log_interval": 10,
    "save_model": False,
    "save_interval": 10,
    "num_epochs": 200,
    "ray_tune_enabled": False,
}

# run model training
training_log_n64 = train_model(model_config)

# Visualizing the training progress

In [None]:
# plot the results
all_logs = [training_log_n16, training_log_n64]

fig, ax = plt.subplots()
epochs = model_config.get("num_epochs")
for i, log in enumerate(all_logs):
    x_axis = np.linspace(1, epochs, epochs)
    if i == 0:
        label_str = r"n = 16, $\eta$ = 9.33e-5, batch_size = 16"
    elif i == 1:
        label_str = r"n = 64, $\eta$ = 1.67e-4, batch_size = 64"
    ax.plot(x_axis, log.get("train"), label=label_str)
ax.set_ylabel("Loss")
ax.set_xlabel("Epochs")
ax.set_title("ASD")
fig.set_figheight(10)
fig.set_figwidth(16)
ax.legend(loc="best", prop={"size": 20})
plt.show()

# Test our model on some outputs

In [None]:
n_faces = 6
sample = next(iter(tmp_loader))
tmp_loader = DataLoader(FaceData(), batch_size=n_faces, shuffle=True)

# instantiate models on CPU
ae_n16 = AE(n=16)
ae_n64 = AE(n=64)
# load saved checkpoints for both models, taking model @ 200 epochs
ae_n16.load_state_dict(torch.load("models/AE_n16_200.pt"))
ae_n64.load_state_dict(torch.load("models/AE_n64_200.pt"))

ae_n16.eval()
ae_n64.eval()
# run inference
out_n16 = ae_n16(sample)
out_n64 = ae_n64(sample)

# visualization
fig, ax = plt.subplots(3, n_faces)
for k in range(n_faces):
    ax[0, k].imshow(sample[k].reshape(64, 64), cmap="gray")
    with torch.no_grad():
        ax[1, k].imshow(out_n16[k].reshape(64, 64), cmap="gray")
        ax[2, k].imshow(out_n64[k].reshape(64, 64), cmap="gray")

for i in range(3):
    for j in range(n_faces):
        # ax[i, j].set_aspect("equal")
        ax[i, j].axis("off")

fig.set_figheight(10)
fig.set_figwidth(20)
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

# Visualizing the fc1 weights
NOTE that this uses the saved weights so set `"save_model": True`

In [None]:
# n = 16, lr = 9.33e-5
# n = 64, lr = 1.67e-4
model_config = {
    "n": 16,
    "lr": 0.9,
    "batch_size": 16,
    "log_training": True,
    "log_interval": 10,
    "save_model": True,
    "save_interval": 10,
    "num_epochs": 200,
    "ray_tune_enabled": False,
}
log = train_model(model_config)

In [None]:
# overwrite previous sample
sample = next(iter(tmp_loader))

# parameter n for different models
N = [16, 64, 256]
save_interval = model_config.get("save_interval")
n_epochs = model_config.get("num_epochs")
n_models = n_epochs // save_interval

for fc1_n in N:
    # instantiate model
    net = AE(fc1_n)
    for k in range(n_models):
        if k != 19:
            continue
        # load saved checkpoints inside loop
        model = str((k + 1) * 10).zfill(len(str(n_epochs)))
        model_path = f"models/AE_n{fc1_n}_{model}.pt"
        net.load_state_dict(torch.load(model_path))

        print(f"Loaded model from {model_path}")

        net.eval()

        with torch.no_grad():
            # extract fc1 weights, shape=(16,4096) or (64,4096)
            fc1_weights = net.fc1.weight + net.fc1.bias.unsqueeze(1)
            # reshape for plotting
            fc1_weights = fc1_weights.reshape(-1, 64, 64)

            # autodetect based on n parameter in the network
            # MUST BE SQUARE!!!!
            L = np.sqrt(fc1_weights.shape[0]).astype(int)
            fig, ax = plt.subplots(L, L)
            for idx in range(fc1_weights.shape[0]):
                # reshape for plotting
                # img = fc1_weights[k].reshape(64, 64).cpu().detach().numpy()
                ax[idx // L, idx % L].imshow(fc1_weights[idx], cmap="gray")
                ax[idx // L, idx % L].axis("off")
            fig.set_figheight(10)
            fig.set_figwidth(20)
            plt.show()

In [None]:
fig, ax = plt.subplots(1, n_faces)
for k in range(n_faces):
    ax[k].imshow(tmp_faces[k].reshape(64, 64), cmap="gray")
    ax[k].axis("off")
fig.set_figheight(10)
fig.set_figwidth(20)
plt.show()

In [None]:

# Extract the weights of fc1 layer for net_n16
fc1_n16 = ae_n16.fc1.weight + ae_n16.fc1.bias.unsqueeze(1)
fc1_n16 = fc1_n16.reshape(-1, 64, 64).cpu().detach().numpy()
L = np.sqrt(fc1_n16.shape[0]).astype(int)
fig, ax = plt.subplots(L, L)
for k in range(fc1_n16.shape[0]):
    ax[k // L, k % L].imshow(fc1_n16[k], cmap="gray")
    ax[k // L, k % L].axis("off")
fig.set_figheight(10)
fig.set_figwidth(10)
plt.show()

In [None]:
# Load models for inference
N = 16
net = SimpleAE(N)

for k in range(n_epochs // checkpoint_interval):
    model = str((k + 1) * 10).zfill(len(str(n_epochs)))
    model_path = f"models/SimpleAE_n{N}_{model}.pt"
    net.load_state_dict(torch.load(model_path))
    print(f"Loaded {model_path}")
    net.eval()

    # run inference
    output = net(tmp_faces).to(device)

    fig, ax = plt.subplots(1, 4)
    for k in range(4):
        img = output[k].reshape(64, 64).cpu().detach().numpy()
        ax[k].imshow(img, cmap="gray")
        ax[k].axis("off")
    fig.set_figheight(10)
    fig.set_figwidth(20)
    plt.show()
    plt.pause(1)

In [None]:
# run inference

output = ae_n64(tmp_faces.to(device))

fig, ax = plt.subplots(1, 4)
for k in range(8):
    img = output[k].reshape(64, 64).cpu().detach().numpy()
    ax[k // 4, k % 4].imshow(img.reshape(64, 64), cmap="gray")
    ax[k // 4, k % 4].axis("off")
fig.set_figheight(10)
fig.set_figwidth(20)
plt.show()

In [None]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in net_n16.state_dict():
    print(param_tensor, "\t", net_n16.state_dict()[param_tensor].size())

In [None]:
print(ae_n16)

In [None]:
summary(ae_n16, (8, 4096))