In [1]:
import dgl
from data.data import LoadData
import numpy as np
from scipy import sparse as sp
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

Using backend: pytorch


In [2]:
num_atom_type = 28 # KNOWN META-INFO ABOUT ZINC DATASET
pos_emb_dim = 8 # the minimal number of nodes in the dataset is 9
batch_size = 256
nhead = 4

pad_val = num_atom_type
cls_val = num_atom_type+1

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
reduced_dataset = True

verbose = False
def vprint(txt):
    if verbose:
        print(txt)

## dataset

In [3]:
dataset = LoadData('ZINC')
dataset_train = dataset.train
dataset_val = dataset.val

[I] Loading dataset ZINC...
train, test, val sizes : 10000 1000 1000
[I] Finished loading.
[I] Data load time: 10.5564s


In [4]:
adj = dataset_train[0][0].adjacency_matrix()
tmp = adj.to_dense() != 0
tmp

tensor([[False,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [ True, False,  True, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [False,  True, False,  True, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False,  True],
        [False, False,  True, False,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False],
        [False, False, False,  True, False,  True, False, False, False, Fals

## dataloader

In [5]:
def laplacian_positional_encoding(g, pos_enc_dim):
    """
        Graph positional encoding v/ Laplacian eigenvectors
    """

    # Laplacian
    A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float)
    N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float)
    L = sp.eye(g.number_of_nodes()) - N * A * N

    # Eigenvectors with numpy
    EigVal, EigVec = np.linalg.eig(L.toarray())
    idx = EigVal.argsort() # increasing order
    EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx])
    g.ndata['lap_pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float()

    return g

def dataset_with_lap(dataset, N=500, nhead=nhead):
    for i in range(len(dataset)):
        if i < N:
            g = dataset[i][0]

            result = dataset[i][1]
            g = laplacian_positional_encoding(g, pos_emb_dim)

            number_of_nodes = g.ndata['feat'].shape[0]

            # pad and add CLS

            feat = torch.full(size=[40], fill_value=pad_val, dtype=g.ndata['feat'].dtype)
            feat[0] = cls_val
            feat[0:number_of_nodes] = g.ndata['feat']

            lap_PE = torch.zeros(size=[40, pos_emb_dim], dtype=g.ndata['lap_pos_enc'].dtype)
            lap_PE[0:number_of_nodes, :] = g.ndata['lap_pos_enc']

            mask = torch.full(size=[40], fill_value=-np.inf)
            mask[0:number_of_nodes] = 0
            
            adj = g.adjacency_matrix().to_dense() != 0
            adj_large = torch.zeros(size=[40,40], dtype=np.bool)
            adj_large[0:number_of_nodes, 0:number_of_nodes] = adj
            connectivity_mask = torch.stack([adj_large]*nhead)

            yield (feat, lap_PE, result, mask, connectivity_mask)

In [6]:
train_loader = DataLoader([g for g in tqdm(dataset_with_lap(dataset_train, 500))],
                    batch_size=batch_size, shuffle=True)

val_loader = DataLoader([g for g in tqdm(dataset_with_lap(dataset_val, 500))],
                    batch_size=batch_size, shuffle=False)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


	DGLGraph.adjacency_matrix(transpose, scipy_fmt="csr").






HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




## model

In [7]:
class Model(nn.Module):
    def __init__(self, d, nlayers=12):
        super().__init__()
        self.d = d
        self.node_feature_embedding = nn.Embedding(num_embeddings=num_atom_type+2, embedding_dim=d) # +1 for the cls, which is the last one.
        self.pos_embedding_linear = nn.Linear(in_features=pos_emb_dim, out_features=d)

        tlayer = nn.TransformerEncoderLayer(d_model=d, nhead=nhead, dim_feedforward=d*4, dropout=0.1, activation='relu')
        self.transformer_encoder = nn.TransformerEncoder(tlayer, nlayers)
        self.transformer_encoder_layer = tlayer

        self.out = nn.Linear(in_features=d, out_features=1)



    def forward(self, node_feat, lap_PE, padd_mask, conn_mask):
        """
        node_feat is (batch x nodes x 28)
        lap_PE is (batch x node x 8)
        """
        batch_size = node_feat.shape[0]
        node_feat = self.node_feature_embedding(node_feat)

        # move the laplacien PE to the 'd' dimension
        lap_PE = self.pos_embedding_linear(lap_PE)

#         activations = lap_PE + node_feat
        activations = node_feat

        vprint('---------------------')
        vprint(f'activation size:{activations.shape}')
        vprint(f'mask size:{padd_mask.shape}')
        activations = activations.permute(1,0,2)
        vprint(f'permuted activation size:{activations.shape}')
        activations = self.transformer_encoder(activations, src_key_padding_mask=padd_mask, mask=conn_mask)
        vprint(f'transfrormed activation size:{activations.shape}')
        activations = activations.permute(1,0,2)
        vprint(f'permuted back activation size:{activations.shape}')
        activations = torch.sum(activations, dim=1)
        vprint(f'summed activation size:{activations.shape}')
        result = self.out(activations)
        vprint(f'result size:{result.shape}')
        return result


In [13]:
model = Model(d=32, nlayers=1).to(device)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
optimizer = torch.optim.Adam(model.parameters(), lr=0.000070)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                 factor=0.5,
                                                 patience=10,
                                                 verbose=True)

def train_epoch(model):
    model.train()
    losses = []
    for batch in train_loader:
        node_feat = batch[0].to(device)
#         print(node_feat)
        pos_emb = batch[1].to(device)
        targets = batch[2].to(device)
        mask = batch[3].to(device)
        connecticity_mask = batch[4].to(device).view(-1, 40, 40)
        optimizer.zero_grad()

        preds = model.forward(node_feat, pos_emb, mask, connecticity_mask)
        loss = nn.L1Loss()(preds, targets)
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().cpu().numpy())

    return np.mean(losses)


def eval_epoch(model):
    model.eval()
    losses = []
    for batch in val_loader:
        node_feat = batch[0].to(device)
        pos_emb = batch[1].to(device)
        targets = batch[2].to(device)
        mask = batch[3].to(device)
        connecticity_mask = batch[4].to(device).view(-1, 40, 40)

        with torch.no_grad():
            preds = model.forward(node_feat, pos_emb, mask, connecticity_mask)
            loss = nn.L1Loss()(preds, targets)

        losses.append(loss.detach().cpu().numpy())

    return np.mean(losses)

## train!

In [14]:
for i in range(1000):
    train_loss = train_epoch(model)
    val_loss = eval_epoch(model)
    scheduler.step(val_loss)

#     if i % 20 == 19:
    print(f'epoch {i}, train loss {train_loss}, validation loss {val_loss}')
    


epoch 0, train loss 26.88532257080078, validation loss 26.27253532409668
epoch 1, train loss 25.87335777282715, validation loss 25.223161697387695
epoch 2, train loss 24.815196990966797, validation loss 24.170955657958984
epoch 3, train loss 23.76415252685547, validation loss 23.115629196166992
epoch 4, train loss 22.72788429260254, validation loss 22.057022094726562
epoch 5, train loss 21.67156410217285, validation loss 20.99496078491211
epoch 6, train loss 20.643638610839844, validation loss 19.929031372070312
epoch 7, train loss 19.551233291625977, validation loss 18.85898780822754
epoch 8, train loss 18.53038787841797, validation loss 17.78445053100586
epoch 9, train loss 17.450328826904297, validation loss 16.70499038696289
epoch 10, train loss 16.351680755615234, validation loss 15.62044906616211
epoch 11, train loss 15.304275512695312, validation loss 14.530267715454102
epoch 12, train loss 14.179101943969727, validation loss 13.434110641479492
epoch 13, train loss 13.1153774261

epoch 109, train loss 1.54024338722229, validation loss 1.5347416400909424
epoch 110, train loss 1.5987354516983032, validation loss 1.5307623147964478
epoch 111, train loss 1.5288629531860352, validation loss 1.5276471376419067
epoch 112, train loss 1.5452933311462402, validation loss 1.5256762504577637
epoch 113, train loss 1.505205512046814, validation loss 1.5209732055664062
epoch 114, train loss 1.5039116144180298, validation loss 1.5153065919876099
epoch 115, train loss 1.5313224792480469, validation loss 1.5106122493743896
epoch 116, train loss 1.5493954420089722, validation loss 1.5051977634429932
epoch 117, train loss 1.533901572227478, validation loss 1.5006155967712402
epoch 118, train loss 1.54325532913208, validation loss 1.496657371520996
epoch 119, train loss 1.5305795669555664, validation loss 1.4936954975128174
epoch 120, train loss 1.5037662982940674, validation loss 1.4900810718536377
epoch 121, train loss 1.5173838138580322, validation loss 1.4849679470062256
epoch 

KeyboardInterrupt: 