In [None]:
import itertools

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn

# Question 1
## (a)

In [None]:
import pickle
from torch.utils.data import Dataset, DataLoader


class MnistDataset():
    def __init__(self, X, y, transform_X=lambda x: x):
        assert X.shape[0] == y.shape[0]
        self.X = transform_X(X)
        self.y = y
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    
def load_mnist(path):
    with open(path, 'rb') as f:
        train_data, test_data = pickle.load(f)
    
    X_train = torch.tensor(train_data[0], dtype=torch.float).unsqueeze(1)
    y_train = torch.tensor(train_data[1], dtype=torch.long)
    X_test = torch.tensor(test_data[0], dtype=torch.float).unsqueeze(1)
    y_test = torch.tensor(test_data[1], dtype=torch.long)
    return X_train, y_train, X_test, y_test


In [None]:
# load data and normalize by dividing the maximum value
X_train, y_train, X_test, y_test = ...
X_train = ...
X_test = ...

In [None]:
# further split the train to train/validation with 80/20 
train_index, val_index = ...

# use MnistDataset class to handle the data
train_data = ...
val_data = ...
test_data = ...

## (b)

In [None]:
class VAE(nn.Module):
    def __init__(self, in_channels=1, z_dim=32):
        super().__init__()
        self.encoder = nn.Sequential(
            ..., # conv1, input_channel -> 4
            ..., # relu
            ..., # conv2, channel 4 -> 8
            ..., # relu
            ..., # conv3, channel 8 -> 16
            ..., # relu
            ..., # conv4, channel 16 -> 32
            ..., # relu
            ..., # flatten
        )
        
        # manually calculate the dimension after all convolutions
        dim_after_conv = ...
        hidden_dim = 32 * dim_after_conv * dim_after_conv
        
        # Readout layer is mu
        self.readout_mu = nn.Linear(hidden_dim, z_dim)
        # Readout layer
        self.readout_sigma = nn.Linear(hidden_dim, z_dim)
        
        # You can use nn.ConvTranspose2d to decode
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.Unflatten(1, (32, dim_after_conv, dim_after_conv)),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # transpose-conv, channel 32 -> 16
            ..., # relu
            ..., # transpose-conv, channel 16 -> 8
            ..., # relu
            ..., # transpose-conv, channel 8 -> 4
            ..., # relu
            ..., # transpose-conv, channel 4 -> input_channel, which is 1
            ..., # use a sigmoid activation to squeeze the outputs between 0 and 1
        )
    
    def reparameterize(self, mu, sigma):
        """
        Reparameterize, i.e. generate a z ~ N(\mu, \sigma)
        """
        # generate epsilon ~ N(0, I)
        # hint: use torch.randn or torch.randn_like
        epsilon = ...
        # z = \mu + \sigma * \epsilon
        z = ...
        return z

    def encode(self, x):
        # call the encoder to map input to a hidden state vector
        h = ...
        # use the "readout" layer to get \mu and \sigma
        mu = ...
        sigma = ...
        return mu, sigma

    def decode(self, z):
        # call the decoder to map z back to x
        return ...

    def forward(self, x):
        mu, sigma = self.encode(x)
        z = self.reparameterize(mu, sigma)
        x_recon = self.decode(z)
        return x_recon, mu, sigma

## (c)

*For debugging*: The `test_kld_loss_func` should output 1.3863

In [None]:
def kld_loss_func(mu, sigma):
    """
    KL-Divergence: KLD = 0.5 * sum(\mu^2 + \sigma^2 - ln(\sigma^2) - 1)
    
    Parameters
    ----------
    mu: torch.Tensor
        Mean vector in the VAE bottleneck region
    sigma: torch.Tensor
        Standard Deviation vector in the VAE bottleneck region
    
    Return
    ------
    kld: torch.Tensor
        KL-Divegence loss (a scalar)
    """
    return ...


def vae_loss_func(recon_x, x, mu, sigma):

    bce_loss = nn.BCELoss(reduction='sum')(recon_x, x)
    kld_loss = kld_loss_func(mu, sigma)

    return bce_loss + kld_loss


def test_kld_loss_func():
    mu = torch.tensor([0.5, 0.5, 1.0])
    sigma = torch.tensor([1.0, 0.5, 0.5])
    print(kld_loss_func(mu, sigma))

test_kld_loss_func()

In [None]:
import torch.nn.functional as F

class VAETrainer:
    
    def __init__(self, model, learning_rate, batch_size, epoch, l2):
        self.model = model
        num_params = sum(item.numel() for item in model.parameters())
        print(f"{model.__class__.__name__} - Number of parameters: {num_params}")
        
        self.optimizer = torch.optim.Adam(...)
        
        self.epoch = epoch
        self.batch_size = batch_size
        

    
    def train(self, train_data, val_data, early_stop=True, verbose=True, draw_curve=True):
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        
        train_loss_list = []
        val_loss_list = []
        
        weights = self.model.state_dict()
        lowest_val_loss = np.inf

        for n in tqdm(range(self.epoch), leave=False):
            self.model.train()
            epoch_loss = 0.0
            for X_batch, y_batch in train_loader:
                batch_importance = 1 / len(train_data)
                # call the model
                X_batch_recon, mu, sigma = self.model(...)
                batch_loss = vae_loss_func(...)

                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()
                
                epoch_loss += batch_loss.detach().cpu().item() * batch_importance
            
            train_loss_list.append(epoch_loss)
            
            val_loss = self.evaluate(val_data, print_loss=False)
            val_loss_list.append(val_loss)
            
            if early_stop:
                if val_loss < lowest_val_loss:
                    lowest_val_loss = val_loss
                    weights = self.model.state_dict()
            
        if draw_curve:
            x_axis = np.arange(self.epoch)
            fig, ax = plt.subplots(1, 1, figsize=(5, 4))
            ax.plot(x_axis, train_loss_list, label="Train")
            ax.plot(x_axis, val_loss_list, label="Validation")
            ax.set_title("Total Loss")
            ax.set_xlabel("# Epoch")
        
        if early_stop:
            self.model.load_state_dict(weights)
        
        return {
            "train_loss_list": train_loss_list,
            "val_loss_list": val_loss_list,
        }
    
    def evaluate(self, data, print_loss=True):
        self.model.eval()
        loader = DataLoader(data, batch_size=self.batch_size)
        total_loss = 0.0
        for X_batch, y_batch in loader:
            with torch.no_grad():
                batch_importance = 1 / len(data)
                X_batch_recon, mu, sigma = self.model(...)
                batch_loss = vae_loss_func(...)
                total_loss += batch_loss.detach().cpu().item() * batch_importance
        if print_loss:
            print(f"Total Loss: {total_loss}")
        return total_loss

In [None]:
vae = VAE()
trainer = ...
# train
...

In [None]:
# evaulate the qualitiy of reconstruction
def plot_digits(data, title):
    fig, axes = plt.subplots(5, 5, figsize=(6, 6))
    fig.subplots_adjust(hspace=0.1, wspace=0.1)
    fig.suptitle(title)
    for i, ax in enumerate(axes.flatten()):
        im = ax.imshow(data[i].reshape(32, 32), cmap='gray')
        ax.set_xticks([])
        ax.set_yticks([])

def compare_reconstruct(model, X):
    plot_digits(X, "Original Data")
    with torch.no_grad():
        X_recon, _, _ = model(X)
    plot_digits(X_recon, "Reconstructed Data")


compare_reconstruct(trainer.model, X_test[:100])

# Question 2

In [None]:
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.utils import scatter

## (a)

In [None]:
def load_qm9(path="./QM9"):
    def transform(data):
        edge_index = torch.tensor(
            list(itertools.permutations(range(data.x.shape[0]), 2)), 
            dtype=torch.long
        ).T
        edge_feature = 1 / torch.sqrt(
            torch.sum(
                (data.pos[edge_index[0]] - data.pos[edge_index[1]]) ** 2, 
                axis=1, keepdim=True
            )
        )
        data.edge_index = edge_index
        data.edge_attr = edge_feature
        data.y = data.y[:, [-7]]
        return data
    
    qm9 = QM9(path, transform=transform)
    return qm9

qm9 = load_qm9(...)
train_index, test_index = ... # use train_test_split to do the index
train_data = ...
test_data = ...

In [None]:
# find out the dimension of node input features
node_input_dim = ...
edge_input_dim = 1

## (b)

In [None]:
class Layer(nn.Module):
    """
    Basic layer, a linear layer with a ReLU activation 
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            ..., # linear layer
            ... # relu
        )
    
    def forward(self, x):
        return self.layers(x)
    
    
class MessagePassingLayer(nn.Module):
    """
    A message passing layer that updates nodes/edge features
    """
    def __init__(self, node_hidden_dim, edge_hidden_dim):
        super().__init__()
        # figure out the input/output dimension
        self.edge_net = Layer(...)
        # figure out the input/output dimension
        self.node_net = Layer(...)
    
    def forward(self, node_features, edge_features, edge_index):
        """
        Update node and edge features
        
        Parameters
        ----------
        node_features: torch.Tensor
            Node features from the previous layer
        edge_features: torch.Tensor
            Edge features from the previous layer
        edge_index: torch.Tensor
            A sparse matrix (n_edge, 2) in which each column denotes node indices forming an edge
        """
        # concatnate previous edge features with node features forming the edge
        # hint: use node_features[edge_index[0(or 1)]] to get node features forming the edge
        concate_edge_features = torch.cat([
            ..., # features of one node
            ..., # features of the other node
            ... # previous edge features
        ], dim=1)
        
        # pass through the "edge_net" to map it back to the original dimension
        updated_edge_features = self.edge_net(...)
        
        
        # use scatter to aggrate the edge features to nodes
        aggr_edge_features = scatter(...)
        # concatenate it with previous node features
        concate_node_features = torch.cat([..., ...], dim=1)
        # pass through the "node_net" to map it back to the original dimension
        updated_node_features = self.node_net(...)
        
        return updated_node_features, updated_edge_features

        
class GraphNet(nn.Module):
    def __init__(self, node_input_dim, edge_input_dim, node_hidden_dim, edge_hidden_dim):
        super().__init__()
        # embed the input node features
        self.node_embed = Layer(...)
        # embed the input edge features
        self.edge_embed = Layer(...)
        # use a linear layer as readout to get the "atomic" energy contribution
        self.readout = ...
        # message passing layer
        self.message_passing = MessagePassingLayer(..., ...)
    
    def forward(self, node_features, edge_features, edge_index, batch):
        """
        Update node and edge features
        
        Parameters
        ----------
        node_features: torch.Tensor
            Node features from the previous layer
        edge_features: torch.Tensor
            Edge features from the previous layer
        edge_index: torch.Tensor
            A sparse matrix (n_edges, 2) in which each column denotes node indices forming an edge
        batch: torch.Tensor
            A 1-D tensor (n_nodes,) that tells you each node belongs to which graph
        """
        node_hidden = ... # call the node embedding
        edge_hidden = ... # call the edge embedding
        updated_node_hidden, updated_edge_hidden = ... # call the message passing layer
        readout = ... # use the readout layer to output "atomic" contributions
        out = ... # use the scatter function to aggregate atomic readouts
        return out
        

In [None]:
class GNNTrainer:
    def __init__(self, model, batch_size, learning_rate, epoch, l2):
        self.model = model
        
        num_params = sum(item.numel() for item in model.parameters())
        print(f"{model.__class__.__name__} - Number of parameters: {num_params}")
        
        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(...)
        self.epoch = epoch
    
    def train(self, dataset, draw_curve=True):
        self.model.train()
        loader = GraphDataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        
        loss_func = nn.MSELoss()
        batch_loss_list = []
        for i in range(self.epoch):
            print(f"Epoch: {i}")
            for batch_data in tqdm(loader, leave=False):
                batch_size = batch_data.y.shape[0]
                batch_pred = self.model(...)
                batch_loss = loss_func(...)

                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()

                batch_loss_list.append(batch_loss.detach().numpy())
        
        if draw_curve:
            fig, ax = plt.subplots(1, 1, figsize=(5, 4), constrained_layout=True)
            ax.set_yscale("log")
            ax.plot(np.arange(len(batch_loss_list)), batch_loss_list)
            ax.set_xlabel("# Batch")
            ax.set_ylabel("Loss")
        
        return batch_loss_list
    
    def evaluate(self, dataset, draw_curve=True):
        self.model.eval()
        loader = GraphDataLoader(dataset, batch_size=self.batch_size)
        y_true, y_pred = [], []
        with torch.no_grad():
            for batch_data in tqdm(loader, leave=False):
                batch_pred = self.model(...)
                y_pred.append(batch_pred.detach().numpy().flatten())
                y_true.append(batch_data.y.detach().numpy().flatten())
        
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        mse = np.mean((y_true - y_pred) ** 2)
        
        if draw_curve:
            fig, ax = plt.subplots(1, 1, figsize=(5, 4), constrained_layout=True)
            ax.scatter(y_true, y_pred, label=f"MSE: {mse:.2f}", s=2)
            ax.set_xlabel("Ground Truth")
            ax.set_ylabel("Predicted")
            xmin, xmax = ax.get_xlim()
            ymin, ymax = ax.get_ylim()
            vmin, vmax = min(xmin, ymin), max(xmax, ymax)
            ax.set_xlim(vmin, vmax)
            ax.set_ylim(vmin, vmax)
            ax.plot([vmin, vmax], [vmin, vmax], color='red')
            ax.legend()
            
        return mse
        

## (c)

In [None]:
node_hidden_dim = 64
edge_hidden_dim = 64

net = GraphNet(...)

In [None]:
# train
learning_rate = ...
n_epoch = ...
batch_size = ...

trainer = GNNTrainer(...)
trainer.train(...)

In [None]:
# evaulate
trainer.evaluate(test_data)