### DO NOT MODIFY THIS FILE

* 01 - Variational Autoencoder

In [1]:
from tensorflow.keras.datasets import cifar100
from tensorflow.keras.layers import Layer
from tensorflow.keras.metrics import binary_accuracy as ba
from tensorflow import shape as tf_shape, round as tf_round
from matplotlib import pyplot as plt
from numpy import float32

In [2]:
def init():
    return {"EPOCH" : 10, "SEED" : 42}

In [3]:
def load_cifar_100_data():
    (X_train_full, y_train_full), (X_test, y_test) = cifar100.load_data()
    X_train_full, X_test = X_train_full.astype(float32) / 255, X_test.astype(float32) / 255
    X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
    y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]
    return (X_train, y_train), (X_valid, y_valid), (X_test, y_test)

In [4]:
def accuracy(y_true, y_pred):
    return ba(tf_round(y_true), tf_round(y_pred))

In [5]:
def plot_image(image):
    plt.imshow(image, cmap="binary")
    plt.axis("off")

In [6]:
def show_reconstructions(model, images, n_images=10):
    reconstructions = model.predict(images[:n_images])
    fig = plt.figure(figsize=(n_images * 1.5, 3))
    for image_index in range(n_images):
        plt.subplot(2, n_images, 1 + image_index)
        plot_image(images[image_index])
        plt.subplot(2, n_images, 1 + n_images + image_index)
        plot_image(reconstructions[image_index])

* 02 - Graph Convolutional Neural Network

In [7]:
from torch_geometric.datasets import QM9
from plotly import graph_objects as go
from periodictable import elements

In [8]:
def load_qm9_data():
    target_dict = {
        0 : "Dipole moment",
        1 : "Isotropic polarizability",
        2 : "Highest occupied molecular orbital energy",
        3 : "Lowest occupied molecular orbital energy",
        4 : "Gap between the highest and lowest occupied molecular orbital energy",
        5 : "Electronic spatial extent",
        6 : "Zero point vibrational energy",
        7 : "Internal energy at 0K",
        8 : "Internal energy at 298.15K",
        9 : "Enthalpy at 298.15K",
        10 : "Free energy at 298.15K",
        11 : "Heat capavity at 298.15K",
        12 : "Atomization energy at 0K",
        13 : "Atomization energy at 298.15K",
        14 : "Atomization enthalpy at 298.15K",
        15 : "Atomization free energy at 298.15K",
        16 : "Rotational constant A",
        17 : "Rotational constant B",
        18 : "Rotational constant C"
    }
    chemical_elements = {el.number : el.name.upper() for el in elements}
    return QM9('.')[:80000], target_dict, chemical_elements

In [9]:
def layout(*args, **kwargs):
    axis = dict(showbackground=False, showticklabels=False, showgrid=False, zeroline=False, title='')
    return dict(
        showlegend=False,
        scene=dict(aspectmode="data", xaxis=dict(**axis), yaxis=dict(**axis), zaxis=dict(**axis)),
        paper_bgcolor="rgba(225,225,225)", # white color
        plot_bgcolor="rgba(0,0,0,0)",
        margin=dict(l=0, r=0, t=0, b=0)
    )

In [10]:
def visualize_molecule(graph, *args, **kwargs):
    pos = graph.pos.clone()
    edge_index = graph.edge_index
    pos = (pos - pos.mean(0)) / pos.std(0)
    x, y, z = pos[:, 0], pos[:, 1], pos[:, 2]
    
    pos = pos[edge_index]
    atom_type = (1 + graph.x[:, :5].argmax(-1))*10
    data = [go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=atom_type, color=atom_type))]
    for i in range(edge_index.size(-1)):
        line_data = pos[:, i, :]
        data.append(go.Scatter3d(x=line_data[:, 0], y=line_data[:, 1], z=line_data[:, 2], 
                mode="lines", line=dict(color='black', width=3)))

    return go.Figure(data=data, layout=layout())

In [11]:
def show_prediction(attr, target, pred, idx=500, *args, **kwargs):
    plt.figure(figsize=(20, 7))
    plt.title(attr)
    plt.plot(target[:idx], label="Target")
    plt.plot(pred[:idx], label="Prediction")
    plt.legend()
    plt.show()