In [16]:
import sys
import logging

sys.path.append(".")  

EXP_NAME = "BN"

nblog = open("{}.log".format(EXP_NAME), "a+")
sys.stdout.echo = nblog
sys.stderr.echo = nblog

get_ipython().log.handlers[0].stream = nblog
get_ipython().log.setLevel(logging.INFO)

%autosave 5

Autosaving every 5 seconds


In [10]:
import torch
from torch import nn, optim
import matplotlib.pyplot as plt

from models.gcl import unsorted_segment_sum, E_GCL
from qm9 import utils as qm9_utils
from qm9 import dataset

import json
import pickle
import time
import datetime

In [14]:
class EGNN_BN(nn.Module):
    def __init__(self, num_features, x_dim, momentum=0.9, eps=1e-5, pos=False):
        super().__init__()

        h_shape = (1, num_features)
        x_shape = (1, x_dim)

        self.gamma_h = nn.Parameter(torch.ones(h_shape))
        self.beta_h = nn.Parameter(torch.zeros(h_shape))

        # self.gamma_x = nn.Parameter(torch.ones(x_shape))
        # self.beta_x = nn.Parameter(torch.zeros(x_shape))

        self.pos = pos

        self.momentum = momentum
        self.eps = eps

        self.register_buffer('moving_mean_h', torch.ones(h_shape))
        self.register_buffer('moving_var_h', torch.ones(h_shape))
        # self.register_buffer('moving_mean_x', torch.ones(x_shape))
        # self.register_buffer('moving_var_x', torch.ones(x_shape))

        self.reset_parameters()

    def reset_parameters(self):
        # self.moving_var_x.fill_(1)
        self.moving_var_h.fill_(1)

    def forward(self, h, x):
        if self.training:
            var_h, mean_h = torch.var_mean(h, dim=0, keepdim=True, unbiased=False)
            # var_x, mean_x = torch.var_mean(x, dim=0, keepdim=True, unbiased=False)

            self.moving_mean_h.mul_(self.momentum)
            self.moving_mean_h.add_((1 - self.momentum) * mean_h)
            self.moving_var_h.mul_(self.momentum)
            self.moving_var_h.add_((1 - self.momentum) * var_h)

            # self.moving_mean_x.mul_(self.momentum)
            # self.moving_mean_x.add_((1 - self.momentum) * mean_x)
            # self.moving_var_x.mul_(self.momentum)
            # self.moving_var_x.add_((1 - self.momentum) * var_x)
        else:
            # var_x = self.moving_var_x
            # mean_x = self.moving_mean_x
            var_h = self.moving_var_h
            mean_h = self.moving_mean_h
            

        h = (h - mean_h) * torch.rsqrt(var_h+self.eps)
        # x = (x - mean_x) * torch.rsqrt(var_x+self.eps)

        out_h = h * self.gamma_h + self.beta_h
        # out_x = x * self.gamma_x + self.beta_x

        # return out_h, out_x
        return out_h

class E_GCL_mask(E_GCL):
    """Graph Neural Net with global state and fixed number of nodes per graph.
    Args:
          hidden_dim: Number of hidden units.
          num_nodes: Maximum number of nodes (for self-attentive pooling).
          global_agg: Global aggregation function ('attn' or 'sum').
          temp: Softmax temperature.
    """

    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_attr_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, norm_diff=False, attention=False):
        E_GCL.__init__(self, input_nf, output_nf, hidden_nf, edges_in_d=edges_in_d, nodes_att_dim=nodes_attr_dim, act_fn=act_fn, recurrent=recurrent, coords_weight=coords_weight, norm_diff=norm_diff, attention=attention)

        del self.coord_mlp
        self.act_fn = act_fn

    def coord_model(self, coord, edge_index, coord_diff, edge_feat, edge_mask):
        row, col = edge_index
        trans = coord_diff * self.coord_mlp(edge_feat) * edge_mask
        agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
        coord += agg*self.coords_weight
        return coord

    def forward(self, h, edge_index, coord, node_mask, edge_mask, edge_attr=None, node_attr=None, n_nodes=None):
        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, coord)

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)

        edge_feat = edge_feat * edge_mask

        # TO DO: edge_feat = edge_feat * edge_mask

        #coord = self.coord_model(coord, edge_index, coord_diff, edge_feat, edge_mask)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)

        return h, coord, edge_attr

class EGNN(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, coords_weight=1.0, attention=False, node_attr=1, normalize="None"):
        super(EGNN, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.normalize = normalize

        ### Encoder
        self.embedding = nn.Linear(in_node_nf, hidden_nf)
        self.node_attr = node_attr
        if node_attr:
            n_node_attr = in_node_nf
        else:
            n_node_attr = 0

        if normalize == "egnn" or normalize == "both":
            norm_diff = True
        else:
            norm_diff = False

        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, E_GCL_mask(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, nodes_attr_dim=n_node_attr, act_fn=act_fn, recurrent=True, coords_weight=coords_weight, norm_diff = norm_diff, attention=attention))
            if normalize == "batch" or normalize == "both":
                self.add_module("bn_%d" % i, EGNN_BN(self.hidden_nf, 3, pos=True))

        self.node_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf),
                                      act_fn,
                                      nn.Linear(self.hidden_nf, self.hidden_nf))

        self.graph_dec = nn.Sequential(nn.Linear(self.hidden_nf, self.hidden_nf),
                                       act_fn,
                                       nn.Linear(self.hidden_nf, 1))
        self.to(self.device)

    def forward(self, h0, x, edges, edge_attr, node_mask, edge_mask, n_nodes):
        h = self.embedding(h0)
        for i in range(0, self.n_layers):
            if self.node_attr:
                h, _, _ = self._modules["gcl_%d" % i](h, edges, x, node_mask, edge_mask, edge_attr=edge_attr, node_attr=h0, n_nodes=n_nodes)
                if self.normalize == "batch" or self.normalize == "both":
                    h = self._modules["bn_%d" % i](h, x)
            else:
                h, _, _ = self._modules["gcl_%d" % i](h, edges, x, node_mask, edge_mask, edge_attr=edge_attr,
                                                      node_attr=None, n_nodes=n_nodes)
                if self.normalize == "batch" or self.normalize == "both":
                    h = self._modules["bn_%d" % i](h, x)

        h = self.node_dec(h)
        h = h * node_mask
        h = h.view(-1, n_nodes, self.hidden_nf)
        h = torch.sum(h, dim=1)
        pred = self.graph_dec(h)
        return pred.squeeze(1)


def train(model, optimizer, lr_scheduler, epoch, loader, Property="homo", charge_power=2, partition="train", dtype=torch.float32, log_interval=100):
    lr_scheduler.step()

    res = {'loss': 0, 'counter': 0, 'loss_arr':[]}

    for i, data in enumerate(loader):
        if partition == 'train':
            model.train()
            optimizer.zero_grad()

        else:
            model.eval()

        batch_size, n_nodes, _ = data['positions'].size()
        atom_positions = data['positions'].view(batch_size * n_nodes, -1).to(device, dtype)
        atom_mask = data['atom_mask'].view(batch_size * n_nodes, -1).to(device, dtype)
        edge_mask = data['edge_mask'].to(device, dtype)
        one_hot = data['one_hot'].to(device, dtype)
        charges = data['charges'].to(device, dtype)
        nodes = qm9_utils.preprocess_input(one_hot, charges, charge_power, charge_scale, device)

        nodes = nodes.view(batch_size * n_nodes, -1)
        # nodes = torch.cat([one_hot, charges], dim=1)
        edges = qm9_utils.get_adj_matrix(n_nodes, batch_size, device)
        label = data[Property].to(device, dtype)

        pred = model(h0=nodes, x=atom_positions, edges=edges, edge_attr=None, node_mask=atom_mask, edge_mask=edge_mask,
                     n_nodes=n_nodes)

        if partition == 'train':
            loss = loss_l1(pred, (label - meann) / mad)
            loss.backward()
            optimizer.step()
        else:
            loss = loss_l1(mad * pred + meann, label)

        res['loss'] += loss.item() * batch_size
        res['counter'] += batch_size
        res['loss_arr'].append(loss.item())

#         if i % log_interval == 0:
#             print("Epoch %d \t Iteration %d \t loss %.4f" % (epoch, i, sum(res['loss_arr'][-10:])/len(res['loss_arr'][-10:])))
    return res['loss'] / res['counter']


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss_l1 = nn.L1Loss()

# Arguments

batch_size = 72
epochs = 500
lr = 1e-3
nf = 80 # hidden node features
attention=1
# n_layers = 7
Property = 'homo'
charge_power = 2
dataset_paper = "cormorant"
node_attr = 0
weight_decay = 1e-16
test_interval = 50

# Data Loading
dataloaders, charge_scale = dataset.retrieve_dataloaders(batch_size, 2)
meann, mad = qm9_utils.compute_mean_mad(dataloaders, Property)

# Train
train_losses = {'egnn': [], 'batch': [], 'none': [], 'both': []}
test_losses = {'egnn': [], 'batch': [], 'none': [], 'both': []}

if not os.path.isdir("results"):
    os.mkdir("results")

for norm in ['batch']:
    for n_layers in [3, 7, 12, 16]:
        model = EGNN(in_node_nf=15, in_edge_nf=0, hidden_nf=nf, device=device, n_layers=n_layers, coords_weight=1.0,
                 attention=attention, node_attr=node_attr, normalize=norm)

        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

        training_start = datetime.datetime.now()

        for epoch in range(0, epochs):
            epoch_start = datetime.datetime.now()

            train_loss = train(model, optimizer, scheduler, epoch, dataloaders['train'], Property, charge_power, partition="train")
            test_loss = train(model, optimizer, scheduler, epoch, dataloaders['valid'], Property, charge_power, partition="test")

            train_losses[norm].append(train_loss)
            test_losses[norm].append(test_loss)

            total_time = datetime.datetime.now() - training_start
            total_time = total_time - datetime.timedelta(microseconds=total_time.microseconds)

            epoch_time = datetime.datetime.now() - epoch_start 
            epoch_time = epoch_time - datetime.timedelta(microseconds=epoch_time.microseconds)

            print ("{}: epoch {} \t avg test loss: {} \t epoch time: {} \t total time: {}".format(
                EXP_NAME, epoch, test_loss, str(epoch_time), str(total_time))
                  )

            if epoch % test_interval == 0:
                torch.save(model.state_dict(), 'results/egnn{}_depth{}.pth'.format(norm, n_layers))
                with open('results/egnn{}_depth{}_trainloss.pickle'.format(norm, n_layers), 'wb') as handle:
                    pickle.dump(train_losses, handle, protocol=pickle.HIGHEST_PROTOCOL)
                with open('results/egnn{}_depth{}_testloss.pickle'.format(norm, n_layers), 'wb') as handle:
                    pickle.dump(test_losses, handle, protocol=pickle.HIGHEST_PROTOCOL)

            