In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os.path as osp
from tqdm.autonotebook import tqdm
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, GRU

import torch_geometric.transforms as T
from torch_geometric.data import DataLoader



In [3]:
import pandas as pd
import numpy as np

In [4]:
from sklearn.model_selection import train_test_split, GroupShuffleSplit

In [5]:
from torch.utils.data import Subset

In [6]:
from kaggle_champs import constants

# Load and preprocessing data

## Load data

In [7]:
train = pd.read_csv('./data/train.csv')

In [8]:
y_mean = train.scalar_coupling_constant.mean()

In [9]:
y_std = train.scalar_coupling_constant.std()

In [10]:
np.log((train.scalar_coupling_constant - train.type.map(train.groupby('type').scalar_coupling_constant.mean())).abs().groupby(train.type).mean())

type
1JHC    2.548219
1JHN    2.275415
2JHC    0.999041
2JHH    0.983063
2JHN    1.086673
3JHC    0.911788
3JHH    1.122420
3JHN   -0.033818
dtype: float64

## Split train valid

In [11]:
molecules = train.molecule_name.drop_duplicates().sort_values()

In [12]:
train_ind, valid_ind = train_test_split(np.arange(len(molecules)),
                                        test_size=5000,
                                        random_state=1234)

In [13]:
assert not set(train_ind).intersection(valid_ind)

In [14]:
len(train_ind), len(valid_ind)

(80003, 5000)

## Create train valid subet

In [15]:
# Check reproducibility
rs = np.random.RandomState(seed=1234)
print(rs.choice(train_ind, 10))
print(rs.choice(valid_ind, 10))

[19669 26783  1698 47278 33476 59113 40999 64242 25723 71229]
[55624 62327 36561 67391 19447 20288 70596 59541 32479 52121]


In [16]:
train_data = train.loc[train.molecule_name.isin(molecules.iloc[train_ind])]
val_data = train.loc[train.molecule_name.isin(molecules.iloc[valid_ind])]

## Create dataset

In [17]:
from kaggle_champs.dataset import ChampsDataset

In [18]:
import os
import numpy as np
import openbabel
import torch

from torch_geometric.data import Data
from torch.utils.data import Dataset
from tqdm.autonotebook import tqdm
from kaggle_champs.dataset import mol_to_data_v2

In [19]:
class MoleculeDataset(Dataset):
    def __init__(self, metadata=None, base_dir=None, transform=None):
        self.molecules = metadata.molecule_name.unique()
        self.metadata = dict([
            (ind, df) for ind, df in tqdm(metadata.groupby('molecule_name'))
        ])
        self.base_dir = base_dir
        self.transform = transform
        self.conversion = openbabel.OBConversion()
        self.conversion.SetInAndOutFormats("xyz", "mdl")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, index):
        mol = openbabel.OBMol()
        mol_name = self.molecules[index]

        xyz_file = os.path.join(self.base_dir, f'{mol_name}.xyz')
        if not os.path.exists(xyz_file):
            raise FileNotFoundError(f'Expecting file {xyz_file} not found')
        self.conversion.ReadFile(mol, xyz_file)

        data = mol_to_data_v2(mol)
        data.mol_ind = torch.tensor([[index]], dtype=torch.long)
        
        data = self._add_targets(data, metadata=self.metadata[mol_name])
        
        data.graph = nx.Graph()
        data.graph.add_edges_from(data.edge_index.transpose(1,0).cpu().numpy())
        
        if self.transform:
            data = self.transform(data)
            
        if hasattr(data, 'graph'):
            del data.graph
        return data
    
    def _add_inverse_couple(self, couples):
        inverse_direction = couples.rename(
            {'atom_index_1': 'atom_index_0', 
             'atom_index_0': 'atom_index_1'}, 
            axis=1)
        
        couples = couples.append(
            inverse_direction,
            sort=False
        )
        couples = couples.sort_values(['atom_index_0',
                                       'atom_index_1'])
        
        return couples
    
    def _add_y(self, data, couples):
        if 'scalar_coupling_constant' in couples.columns:
            data.y = torch.tensor(
                couples['scalar_coupling_constant'].values,
                dtype=torch.float).view(-1,1)
        else:
            data.y = torch.zeros((len(couples), 1), dtype=torch.float)
        return data
    
    def _add_targets(self, data, metadata):
        couples = metadata.copy()        
        couples = self._add_inverse_couple(couples)
        
        
        data.couples_ind = torch.tensor(
            couples[['atom_index_0',
                     'atom_index_1']].values,
            dtype=torch.long)
        
        data = self._add_y(data, couples)
        
        data.type = torch.tensor(
            couples['type'].map(constants.TYPES_DICT).values,
            dtype=torch.long)
        
        data.sample_weight = torch.tensor(
            couples['type'].map(constants.TYPES_WEIGHTS).values,
            dtype=torch.float)
        
        return data

In [20]:
from kaggle_champs.preprocessing import RandomRotation, AddVirtualEdges, AddEdgeDistanceAndDirection, SortTarget

In [21]:
import networkx as nx

In [22]:
class AddEdgeDistanceAndDirection:
    def __init__(self, dist_noise=0., gauss_base_max=4, gauss_base_steps=20, keep=True):
        self.dist_noise = dist_noise
        self.gauss_base_max = gauss_base_max
        self.gauss_base_steps = gauss_base_steps
        self.keep = True
        
    def __call__(self, data):
        (row, col), pos, edge_attr = data.edge_index, data.pos, data.edge_attr

        dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)
        
        if self.dist_noise > 0:
            noise = 1 + torch.randn_like(dist, dtype=dist.dtype) * self.dist_noise
            dist = dist * noise

        direction = (pos[col] - pos[row]) / dist
        if self.keep:
            data.dist = dist
            data.direction = direction
        
        base = torch.linspace(self.gauss_base_max/self.gauss_base_steps,
                              self.gauss_base_max, 
                              self.gauss_base_steps, 
                              dtype=torch.float).view(1, -1)    # shape 1xn for broadcasting
        
        dist = torch.exp(-(dist - base) ** 2 / 0.5 ** 2)
        
        edge_attr = edge_attr.view(-1, 1) if edge_attr.dim() == 1 else edge_attr
        data.edge_attr = torch.cat(
                [edge_attr,
                 dist.type_as(edge_attr),
                 direction.type_as(edge_attr)],
                dim=-1)      

        return data

In [23]:
class AddBondLinks:
    def __call__(self, data):
        bonds_ind = data.bonds_edge_ind
        
        bonds_from = bonds_ind.view(-1, 1).repeat(1, (len(bonds_ind))).view(-1)
        bonds_to = bonds_ind.view(-1).repeat(1, len(data.bonds_edge_ind)).view(-1)
        bonds_links = torch.stack([bonds_from, bonds_to], dim=1)  # all couples, will filter
        
        filter_correct_common_node = (data.edge_index[:, bonds_from][1] == data.edge_index[:, bonds_to][0])
        filter_remove_self_loop = (data.edge_index[:, bonds_from][0] != data.edge_index[:, bonds_to][1])
        
        data.bonds_links_edge_ind = bonds_links[filter_correct_common_node * filter_remove_self_loop]
        return data

In [24]:
class AddCounts:
    def __call__(self, data):
        data.count_nodes = torch.tensor([[data.num_nodes]], dtype=torch.long)
        data.count_edges = torch.tensor([[data.num_edges]], dtype=torch.long)
        return data

In [25]:
class AddGlobalAttr:
    def __init__(self):
        pass

    def __call__(self, data):
        data.global_attr = torch.zeros((1, 1), dtype=torch.float)
        return data

class AddCouplesInd:
    def __call__(self, data):
        node_from, node_to = data.couples_ind.transpose(1, 0)
        edge_ind = node_from * (data.num_nodes-1) + node_to
        edge_ind[node_from < node_to] = edge_ind[node_from < node_to] - 1
        
        data.couples_edge_ind = edge_ind.view(-1, 1)
        assert torch.equal(data.edge_attr[data.mask], data.edge_attr[data.couples_edge_ind.view(-1)])
        return data

In [26]:
class SortTarget:
    def _get_index(self, data, row, col):
        idx = row * (data.num_nodes-1) + col
        idx[row < col] = idx[row < col] - 1
        return idx
    
    def __call__(self, data):
        target = torch.zeros((data.num_edges, data.y.size()[1]), dtype=torch.float)        
        weights = torch.zeros((data.num_edges), dtype=torch.float)        
        mask = torch.zeros((data.num_edges), dtype=torch.bool)      
        types = torch.zeros((data.num_edges), dtype=torch.long)
        
        row, col = data.couples_ind.transpose(1,0)
        indexes = self._get_index(data, row, col)
        
        mask[indexes] = True
        weights[indexes] = data.sample_weight
        target[indexes] = data.y
        types[indexes] = data.type
        
        data.mask = mask
        data.y = target[mask]
        data.sample_weight = weights[mask]
        data.type = types[mask]
        
        assert torch.equal(data.couples_ind, data.edge_index[:, mask].transpose(1,0))
        data.couples_edge_ind = torch.arange(data.num_edges, dtype=torch.long)[mask].view(-1,1)
        return data        

In [27]:
class AddBondPath:
    def __call__(self, data):
        # suffix _index to get node index adjustment
        data.paths_index = self.find_paths(data).transpose(1,0)  
        data.paths_edge_ind = torch.cat(
            [self._nodes_to_edge_ind(data, data.paths_index[i], data.paths_index[i+1]) for i in range(3)], 
            dim=1)
        return data
    
    def _nodes_to_edge_ind(self, data, node_from, node_to):
        edge_ind = node_from * (data.num_nodes-1) + node_to
        edge_ind[node_from < node_to] = edge_ind[node_from < node_to] - 1
        return edge_ind.view(-1, 1)
    
    def find_paths(self, data):
        assert hasattr(data, 'couples_ind')
        assert hasattr(data, 'graph')

        all_paths = nx.shortest_path(data.graph)
        paths = []
        for (from_, to_) in data.couples_ind.numpy():
            path = torch.tensor(all_paths[from_][to_], dtype=torch.long).view(-1,1)
            paths.append(path)

        paths = torch.nn.utils.rnn.pad_sequence(paths, batch_first=True).squeeze()
        if paths.size(1) < 4:
            paths = torch.nn.functional.pad(paths, (0, 4 - paths.size(1)))
        return paths

In [28]:
def correct_batch_edge_ind(batch):
    offset_edge_ind = torch.zeros_like(batch.count_edges)
    offset_edge_ind[1:] = batch.count_edges[:-1].cumsum(dim=0)
    for k in ['bonds_edge_ind', 'bonds_links_edge_ind', 'paths_edge_ind', 'couples_edge_ind']:
        if hasattr(batch, k):
            batch[k] = batch[k] + offset_edge_ind[batch[k+'_batch']]
    return batch

In [29]:
train_dataset = MoleculeDataset(metadata=train_data,
                              base_dir=constants.STRUCT_DATA_PATH,
                              transform=T.Compose([
                                  AddBondPath(),
                                  AddVirtualEdges(),
                                  RandomRotation(),
                                  AddEdgeDistanceAndDirection(
                                      dist_noise=0.),
                                  AddGlobalAttr(),
                                  SortTarget(),
                                  AddBondLinks(),
                                  AddCounts(),     
                              ]))

val_dataset = MoleculeDataset(metadata=val_data,
                            base_dir=constants.STRUCT_DATA_PATH,
                            transform=T.Compose([
                                AddBondPath(),
                                AddVirtualEdges(),
                                AddEdgeDistanceAndDirection(
                                      dist_noise=0.),
                                AddGlobalAttr(),
                                SortTarget(),
                                AddBondLinks(),
                                AddCounts(),     
                            ]))

HBox(children=(IntProgress(value=0, max=80003), HTML(value='')))




HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))




In [30]:
data = train_dataset[10]
data

FP16_Data(bonds_edge_ind=[20, 1], bonds_links_edge_ind=[36, 2], count_edges=[1, 1], count_nodes=[1, 1], couples_edge_ind=[86, 1], couples_ind=[86, 2], direction=[110, 3], dist=[110, 1], edge_attr=[110, 34], edge_index=[2, 110], global_attr=[1, 1], mask=[110], mol_ind=[1, 1], paths_edge_ind=[86, 3], paths_index=[4, 86], pos=[11, 3], sample_weight=[86], type=[86], x=[11, 28], y=[86, 1])

# Model

In [31]:
from kaggle_champs.modelling import MegNetBlock, create_mlp_v2, MegNetBlock_v2, MegNetBlock_v3

In [32]:
from torch import nn

In [33]:
from torch_scatter import scatter_add

In [34]:
def gather_embedding(data, x_out, edge_out, u_out, couple_type):
    n_bonds = int(couple_type[0])
    couple_filter = (data.type == constants.TYPES_DICT[couple_type])
    couples_edge_ind = data.couples_edge_ind.view(-1)
    
    merged = [
        u_out[data.batch[data.edge_index[0][couples_edge_ind][couple_filter]]],
    ]
    if n_bonds > 1:
        merged.append(edge_out[couples_edge_ind][couple_filter])
        
    node_ind = data.paths_index.transpose(1,0)[:, :n_bonds+1][couple_filter] # convert_node_ind(data, 'paths')[:, :n_bonds+1]
    for i in range(n_bonds+1):
        merged.append(x_out[node_ind[:,i]])
        
    for i in range(n_bonds):
        edge_ind = data.paths_edge_ind[:,i] # convert_couple_to_edge_ind(data, data.paths_index[i], data.paths_index[i+1], data.paths_edge_ind_batch)
        merged.append(edge_out[edge_ind[couple_filter]])
    return torch.cat(merged, dim=1)

In [35]:
class OutputLayer_new(torch.nn.Module):
    def __init__(self, rep_dim, dim, y_mean, y_std, couple_type):
        super(OutputLayer_new, self).__init__()
        self.scaling = torch.nn.Linear(1, 1)
        self.scaling.bias = torch.nn.Parameter(torch.tensor(y_mean,
                                                            dtype=torch.float),
                                               requires_grad=False)
        self.scaling.weight = torch.nn.Parameter(torch.tensor(
            [[y_std]], dtype=torch.float),
                                                 requires_grad=False)
        self.couple_type = couple_type
        n_bonds = int(couple_type[0])
        
        if n_bonds == 1:
            input_dim = dim * (n_bonds + (n_bonds + 1) + 1)  # edges + nodes + u
        else:
            input_dim = dim * (n_bonds + (n_bonds + 1) + 2)  # edges + nodes + u + direct edge
        
        self.mlp = create_mlp_v2(
            input_dim=input_dim,
            output_dim=1,
            hidden_dims=[input_dim//2, input_dim//4, input_dim//4],
            normalization_cls=torch.nn.LayerNorm,
            activation_cls=torch.nn.Softplus,
            dropout_cls=torch.nn.Dropout,
            dropout_prob=0.
        )

    def forward(self, data, x_out, edge_out, u_out):
        in_ = gather_embedding(data, x_out, edge_out, u_out, self.couple_type)
        out = self.mlp(in_)
        out = self.scaling(out)
        return out

In [36]:
from torch_scatter import scatter_mean

In [37]:
class EdgeAgg(torch.nn.Module):
    def __init__(self, dim=32):
        super(EdgeAgg, self).__init__()
        self.body_mlp = nn.Sequential(
            create_mlp_v2(
                input_dim=dim,
                output_dim=dim*2,
                hidden_dims=[dim*2],
                normalization_cls=torch.nn.LayerNorm,
                activation_cls=torch.nn.Softplus,
                dropout_cls=torch.nn.Dropout,
                dropout_prob=0.),
            nn.LayerNorm(dim*2)
        )
        
        self.value_out = nn.Linear(dim*2, dim)
        
        self.gating = nn.Sequential(
            nn.Linear(dim*2, 1),
            nn.Tanh()
        )
    
    def forward(self, edge_out, edges_ind):
        out = self.body_mlp(edge_out)              
        out = self.value_out(out) * self.gating(out)
        result = scatter_add(out, edges_ind, dim=0)        
        return result

In [38]:
class MegNetBlock(torch.nn.Module):
    def __init__(self, edge_dim, x_dim, u_dim, dim=32, layer_norm=False,
                 normalization_cls=None, activation_cls=nn.ReLU,
                 dropout_cls=nn.Dropout, dropout_prob=0., residual=True, pooling='mean'):
        super(MegNetBlock, self).__init__()
        self.dim = dim
        self.residual = residual
        self.pooling = pooling

        if layer_norm:
            normalization_cls = nn.LayerNorm
        kwargs = dict(
            normalization_cls=normalization_cls,
            activation_cls=activation_cls,
            dropout_cls=dropout_cls,
            dropout_prob=dropout_prob)
        
        self.edge_dense = create_mlp_v2(
            input_dim=edge_dim, output_dim=dim, hidden_dims=[dim * 2], **kwargs)
        
        self.edge_agg = EdgeAgg(dim=dim)
        
        self.node_dense = create_mlp_v2(
            input_dim=x_dim, output_dim=dim, hidden_dims=[dim * 2], **kwargs)
        self.global_dense = create_mlp_v2(
            input_dim=u_dim, output_dim=dim, hidden_dims=[dim * 2], **kwargs)

        self.edge_msg = create_mlp_v2(
            input_dim=dim * 4, output_dim=dim, hidden_dims=[dim*2, dim*2], **kwargs)
        self.node_msg = create_mlp_v2(
            input_dim=dim * 3, output_dim=dim, hidden_dims=[dim*2, dim*2], **kwargs)
        self.global_msg = create_mlp_v2(
            input_dim=dim * 3, output_dim=dim, hidden_dims=[dim*2, dim*2], **kwargs)
        

    def edge_model(self, src, dest, edge_attr, u, batch):
        # source, target: [E, F_x], where E is the number of edges.
        # edge_attr: [E, F_e]
        # u: [B, F_u], where B is the number of graphs.
        # batch: [E] with max entry B - 1.
        out = torch.cat([src, dest, edge_attr, u[batch]], 1)
        out = self.edge_msg(out)
        return out

    def node_model(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, _ = edge_index
        out = self.edge_agg(edge_attr, row)
        out = torch.cat([out, x, u[batch]], dim=1)
        out = self.node_msg(out)
        return out

    def global_model(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, _ = edge_index
        edge_mean = scatter_mean(edge_attr, batch[row], dim=0)
        out = torch.cat(
            [u, scatter_mean(x, batch, dim=0), edge_mean], dim=1)
        out = self.global_msg(out)
        return out
    
    def forward(self, x, edge_index, edge_attr, u, batch, first_block=False):

        # first block
        edge_out = self.edge_dense(edge_attr)
        x_out = self.node_dense(x)
        u_out = self.global_dense(u)

        x_res_base = x_out if first_block else x
        edge_res_base = edge_out if first_block else edge_attr
        u_res_base = u_out if first_block else u

        row, col = edge_index        

        edge_out = self.edge_model(x_out[row], x_out[col], edge_out, u_out,
                                   batch[row])
        if self.residual:
            edge_out = edge_res_base + edge_out

        x_out = self.node_model(x_out, edge_index, edge_out, u_out, batch)
        if self.residual:
            x_out = x_res_base + x_out

        u_out = self.global_model(x_out, edge_index, edge_out, u_out, batch)
        if self.residual:
            u_out = u_res_base + u_out

        return x_out, edge_out, u_out

In [39]:
class MegNetModel_new(torch.nn.Module):
    def __init__(self,
                 edge_dim,
                 x_dim,
                 u_dim,
                 dim=32,
                 head_dim=32,
                 n_megnet_blocks=3,
                 y_mean=0,
                 y_std=1,
                 layer_norm=False):
        super(MegNetModel_new, self).__init__()
        self.dim = dim
        self.n_megnet_blocks = n_megnet_blocks
        self.megnet_blocks = torch.nn.ModuleList([
            MegNetBlock(edge_dim,
                        x_dim,
                        u_dim,
                        dim,
                        normalization_cls=torch.nn.LayerNorm,
                        activation_cls=torch.nn.Softplus,
                        dropout_cls=torch.nn.Dropout,
                        dropout_prob=0.,
                        residual=True)
        ] + [
            MegNetBlock(dim,
                        dim,
                        dim,
                        dim,
                        normalization_cls=torch.nn.LayerNorm,
                        activation_cls=torch.nn.Softplus,
                        dropout_cls=torch.nn.Dropout,
                        dropout_prob=0.,
                        residual=True) for i in range(n_megnet_blocks - 1)
        ])

        self.out_mlp = torch.nn.ModuleList([
            OutputLayer_new(
                dim,
                head_dim,
                y_mean=y_mean[i],
                y_std=y_std[i],
                couple_type=type_,
            ) for i, type_ in enumerate(constants.TYPES_LIST)
        ])

        self.edge_dense = create_mlp_v2(
            input_dim=edge_dim,
            output_dim=dim,
            hidden_dims=[dim * 2],
            normalization_cls=None,
            activation_cls=None,
            dropout_cls=None,
            dropout_prob=0.0,
        )
        self.node_dense = create_mlp_v2(
            input_dim=x_dim,
            output_dim=dim,
            hidden_dims=[dim * 2],
            normalization_cls=None,
            activation_cls=None,
            dropout_cls=None,
            dropout_prob=0.0,
        )
        self.global_dense = create_mlp_v2(
            input_dim=u_dim,
            output_dim=dim,
            hidden_dims=[dim * 2],
            normalization_cls=None,
            activation_cls=None,
            dropout_cls=None,
            dropout_prob=0.0,
        )

    def forward(self, data):
        data = correct_batch_edge_ind(data)

        if not hasattr(data, 'global_attr'):
            data.global_attr = torch.zeros((data.num_graphs, 1),
                                           dtype=torch.float,
                                           device=data.x.device)
        x_out, edge_out, u_out = data.x, data.edge_attr, data.global_attr
        
        #x_out = self.node_dense(x_out)
        #edge_out = self.edge_dense(edge_out)
        #u_out = self.global_dense(u_out)        
        
        for i in range(self.n_megnet_blocks):
            x_out, edge_out, u_out = self.megnet_blocks[i](
                x_out,
                data.edge_index,
                edge_out,
                u_out,
                data.batch,
                first_block=(i==0))

        pred = torch.zeros_like(data.type,
                                dtype=torch.float,
                                device=x_out.device)
        for type_ in range(8):
            if (data.type == type_).any():
                pred[data.type == type_] = self.out_mlp[type_](data, x_out,
                                                               edge_out,
                                                               u_out).view(-1)
        return pred

# Training

In [40]:
from kaggle_champs.metrics import MeanLogGroupMAE, AverageMetric

In [41]:
from kaggle_champs.training import train_epoch

In [42]:
import torch
from tqdm.autonotebook import tqdm


def train_epoch(global_iteration, epoch, model, device, optimizer, 
                train_loader, tb_logger, gradient_accumulation_steps=1, swa=False):
    model.train()
    avg_loss = AverageMetric()
    log_mae = MeanLogGroupMAE()
    
    pbar = tqdm(train_loader)
    for step, data in enumerate(pbar):
        data = data.to(device)
        
        pred = model(data)

        loss = torch.nn.L1Loss(reduction='mean')(pred.view(-1),
                                                     data.y.view(-1))
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            
            global_iteration += 1
            if swa:
                optimizer.update_swa()

        tb_logger.add_scalar('loss', loss.item(), global_iteration)

        avg_loss.update(loss.item() * data.num_graphs, data.num_graphs)
        log_mae.update(pred.view(-1), data.y.view(-1), data.type)

        pbar.set_postfix_str(f'loss: {avg_loss.compute():.4f}')
    return avg_loss.compute(), log_mae, global_iteration

In [43]:
def test_model(model, loader):
    model.eval()
    log_mae = MeanLogGroupMAE()
    avg_loss = AverageMetric()
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            pred = model(data)
            
            loss = torch.nn.L1Loss(reduction='mean')(pred.view(-1),
                                                     data.y.view(-1))
            avg_loss.update(loss.item() * data.num_graphs, data.num_graphs)
            
            log_mae.update(pred.view(-1), data.y.view(-1), data.type.view(-1))
            
        return avg_loss.compute(), log_mae


def make_log(epoch, lr, loss, tr_logmae, val_logmae):
    results = {
        'epoch': epoch,
        'lr': lr,
        'loss': loss,
        'tr_logmae': tr_logmae.compute(),
        'val_logmae': val_logmae.compute(),
    }
    for k, v in tr_logmae.compute_individuals().items():
        results.update({'tr_' + k: v})
    for k, v in val_logmae.compute_individuals().items():
        results.update({'val_' + k: v})
    return results


def save_checkpoint(dir_path, model, optimizer, scheduler, epoch):
    torch.save(model.state_dict(), dir_path + f'model_epoch_{epoch}.pth')
    torch.save(optimizer.state_dict(),
               dir_path + f'optimizer_epoch_{epoch}.pth')
    torch.save(scheduler.state_dict(),
               dir_path + f'scheduler_epoch_{epoch}.pth')

# Run

In [44]:
from tensorboardX import SummaryWriter

In [45]:
import shutil

In [46]:
OUTPUT_DIR = './models/megnet_256x10_softplus_edgeagg/finetune/'

!mkdir -p {OUTPUT_DIR}
!rm -rf {OUTPUT_DIR}
!mkdir -p {OUTPUT_DIR}

In [47]:
tb_logger = SummaryWriter(OUTPUT_DIR+'tb_log/')
global_iteration = 0

In [48]:
SAVE_INTERVAL = 10

In [49]:
MAX_EPOCH = 150

In [50]:
val_loader = DataLoader(val_dataset,
                        batch_size=64,
                        shuffle=False,
                        num_workers=8, 
                        follow_batch=['bonds_edge_ind', 'bonds_links_edge_ind', 'paths_edge_ind', 'couples_edge_ind']
                       )
train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          num_workers=8,
                          shuffle=True,
                          follow_batch=['bonds_edge_ind', 'bonds_links_edge_ind', 'paths_edge_ind', 'couples_edge_ind']
                         )

In [51]:
batch = next(iter(val_loader))

In [52]:
batch

Batch(batch=[895], bonds_edge_ind=[1758, 1], bonds_edge_ind_batch=[1758], bonds_links_edge_ind=[3076, 2], bonds_links_edge_ind_batch=[3076], count_edges=[64, 1], count_nodes=[64, 1], couples_edge_ind=[5146, 1], couples_edge_ind_batch=[5146], couples_ind=[5146, 2], direction=[12162, 3], dist=[12162, 1], edge_attr=[12162, 34], edge_index=[2, 12162], global_attr=[64, 1], mask=[12162], mol_ind=[64, 1], paths_edge_ind=[5146, 3], paths_edge_ind_batch=[5146], paths_index=[4, 5146], pos=[895, 3], sample_weight=[5146], type=[5146], x=[895, 28], y=[5146, 1])

In [53]:
y_mean = train.groupby(train.type.map(
    constants.TYPES_DICT)).scalar_coupling_constant.mean().sort_index().values
y_std = train.groupby(train.type.map(
    constants.TYPES_DICT)).scalar_coupling_constant.std().sort_index().values

In [54]:
device = torch.device('cuda')
model = MegNetModel_new(edge_dim=data.edge_attr.size()[1],
                    x_dim=data.x.size()[1],
                    u_dim=1,
                    dim=256,
                    head_dim=256,
                    n_megnet_blocks=10,
                    y_mean=y_mean,
                    y_std=y_std,
                    layer_norm=False).to(device)

In [55]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=1.)

In [59]:
# train loop
logs = []
for epoch in range(1, 105):
    lr = scheduler.optimizer.param_groups[0]['lr']
    tr_loss, tr_logmae, global_iteration = train_epoch(global_iteration,
                                     epoch,
                                     model,
                                     device,
                                     optimizer,
                                     train_loader,
                                     tb_logger,
                                     gradient_accumulation_steps=1, swa=True)
    scheduler.step()
    if epoch % 10 == 0:
        optimizer.param_groups[0]['lr'] = 1e-4

    val_loss, val_logmae = test_model(model, val_loader)
    
    optimizer.swap_swa_sgd()
    val_loss_swa, val_logmae_swa = test_model(model, val_loader)
    optimizer.swap_swa_sgd()

    epoch_log = make_log(epoch, lr, tr_loss, tr_logmae, val_logmae)
    logs.append(epoch_log)
    pd.DataFrame(logs).to_csv(OUTPUT_DIR + 'log.csv')
    print('Epoch: {epoch:03d}, LR: {lr:7f}, Loss: {loss:.7f}, \
         Train LogMAE: {tr_logmae:.7f}, Val LogMAE: {val_logmae:.7f}'.format(
        **epoch_log))
    print(f'Val LogMAE SWA: {val_logmae_swa.compute():.7f}')

    if epoch % SAVE_INTERVAL == 0:
        save_checkpoint(OUTPUT_DIR, model, optimizer, scheduler, epoch)

    tb_logger.add_scalar('lr', lr, global_iteration)
    tb_logger.add_scalar('val_loss', val_loss, global_iteration)
    tb_logger.add_scalars('global_logmae', {
        'tr_logmae': epoch_log['tr_logmae'],
        'val_logmae': epoch_log['val_logmae']
    }, global_iteration)

    for type_ in constants.TYPES_LIST:
        tb_logger.add_scalars(
            type_, {
                'tr_' + type_: epoch_log['tr_' + type_],
                'val_' + type_: epoch_log['val_' + type_]
            }, global_iteration)

HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 001, LR: 0.000100, Loss: 0.0525858,          Train LogMAE: -3.1921357, Val LogMAE: -2.5308201
Val LogMAE SWA: -2.6923070


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 002, LR: 0.000100, Loss: 0.0523206,          Train LogMAE: -3.1975689, Val LogMAE: -2.4789515
Val LogMAE SWA: -2.7023101


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 003, LR: 0.000050, Loss: 0.0380843,          Train LogMAE: -3.5607711, Val LogMAE: -2.6362811
Val LogMAE SWA: -2.7081334


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 004, LR: 0.000050, Loss: 0.0345512,          Train LogMAE: -3.6766727, Val LogMAE: -2.6497833
Val LogMAE SWA: -2.7140730


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 005, LR: 0.000025, Loss: 0.0273012,          Train LogMAE: -3.9652012, Val LogMAE: -2.6927343
Val LogMAE SWA: -2.7166343


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 006, LR: 0.000025, Loss: 0.0255557,          Train LogMAE: -4.0559917, Val LogMAE: -2.6996353
Val LogMAE SWA: -2.7188434


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 007, LR: 0.000013, Loss: 0.0218847,          Train LogMAE: -4.3050418, Val LogMAE: -2.7178966
Val LogMAE SWA: -2.7212128


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 008, LR: 0.000013, Loss: 0.0209527,          Train LogMAE: -4.3997866, Val LogMAE: -2.7188505
Val LogMAE SWA: -2.7220674


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 009, LR: 0.000006, Loss: 0.0191141,          Train LogMAE: -4.6080639, Val LogMAE: -2.7258790
Val LogMAE SWA: -2.7236790


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 010, LR: 0.000006, Loss: 0.0186579,          Train LogMAE: -4.6773264, Val LogMAE: -2.7273596
Val LogMAE SWA: -2.7239565


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 011, LR: 0.000100, Loss: 0.0490417,          Train LogMAE: -3.2757647, Val LogMAE: -2.5500277
Val LogMAE SWA: -2.7240088


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 012, LR: 0.000100, Loss: 0.0505134,          Train LogMAE: -3.2350820, Val LogMAE: -2.5113836
Val LogMAE SWA: -2.7243656


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 013, LR: 0.000050, Loss: 0.0361824,          Train LogMAE: -3.6187571, Val LogMAE: -2.6531555
Val LogMAE SWA: -2.7248700


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 014, LR: 0.000050, Loss: 0.0328798,          Train LogMAE: -3.7325533, Val LogMAE: -2.6625951
Val LogMAE SWA: -2.7251684


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 015, LR: 0.000025, Loss: 0.0255863,          Train LogMAE: -4.0466991, Val LogMAE: -2.7037813
Val LogMAE SWA: -2.7260801


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 016, LR: 0.000025, Loss: 0.0237965,          Train LogMAE: -4.1552782, Val LogMAE: -2.7075988
Val LogMAE SWA: -2.7267387


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 017, LR: 0.000013, Loss: 0.0202306,          Train LogMAE: -4.4478234, Val LogMAE: -2.7231860
Val LogMAE SWA: -2.7271498


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 018, LR: 0.000013, Loss: 0.0193260,          Train LogMAE: -4.5690665, Val LogMAE: -2.7250853
Val LogMAE SWA: -2.7274944


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 019, LR: 0.000006, Loss: 0.0175233,          Train LogMAE: -4.8302164, Val LogMAE: -2.7301308
Val LogMAE SWA: -2.7279378


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 020, LR: 0.000006, Loss: 0.0170651,          Train LogMAE: -4.9110804, Val LogMAE: -2.7303108
Val LogMAE SWA: -2.7289280


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))

KeyboardInterrupt: 

In [57]:
from torchcontrib.optim import SWA

In [58]:
optimizer = SWA(optimizer, swa_start=None, swa_freq=None, swa_lr=1e-4)

  "Some of swa_start, swa_freq is None, ignoring swa_lr")


In [59]:
# train loop
logs = []
for epoch in range(106, 127):
    lr = scheduler.optimizer.param_groups[0]['lr']
    tr_loss, tr_logmae, global_iteration = train_epoch(global_iteration,
                                     epoch,
                                     model,
                                     device,
                                     optimizer,
                                     train_loader,
                                     tb_logger,
                                     gradient_accumulation_steps=1, swa=True)
    scheduler.step()
    if epoch % 10 == 0:
        optimizer.param_groups[0]['lr'] = 1e-4

    val_loss, val_logmae = test_model(model, val_loader)
    
    optimizer.swap_swa_sgd()
    val_loss_swa, val_logmae_swa = test_model(model, val_loader)
    optimizer.swap_swa_sgd()

    epoch_log = make_log(epoch, lr, tr_loss, tr_logmae, val_logmae)
    logs.append(epoch_log)
    pd.DataFrame(logs).to_csv(OUTPUT_DIR + 'log.csv')
    print('Epoch: {epoch:03d}, LR: {lr:7f}, Loss: {loss:.7f}, \
         Train LogMAE: {tr_logmae:.7f}, Val LogMAE: {val_logmae:.7f}'.format(
        **epoch_log))
    print(f'Val LogMAE SWA: {val_logmae_swa.compute():.7f}')

    if epoch % SAVE_INTERVAL == 0:
        save_checkpoint(OUTPUT_DIR, model, optimizer, scheduler, epoch)

    tb_logger.add_scalar('lr', lr, global_iteration)
    tb_logger.add_scalar('val_loss', val_loss, global_iteration)
    tb_logger.add_scalars('global_logmae', {
        'tr_logmae': epoch_log['tr_logmae'],
        'val_logmae': epoch_log['val_logmae']
    }, global_iteration)

    for type_ in constants.TYPES_LIST:
        tb_logger.add_scalars(
            type_, {
                'tr_' + type_: epoch_log['tr_' + type_],
                'val_' + type_: epoch_log['val_' + type_]
            }, global_iteration)

HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 001, LR: 0.000100, Loss: 0.0525858,          Train LogMAE: -3.1921357, Val LogMAE: -2.5308201
Val LogMAE SWA: -2.6923070


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 002, LR: 0.000100, Loss: 0.0523206,          Train LogMAE: -3.1975689, Val LogMAE: -2.4789515
Val LogMAE SWA: -2.7023101


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 003, LR: 0.000050, Loss: 0.0380843,          Train LogMAE: -3.5607711, Val LogMAE: -2.6362811
Val LogMAE SWA: -2.7081334


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 004, LR: 0.000050, Loss: 0.0345512,          Train LogMAE: -3.6766727, Val LogMAE: -2.6497833
Val LogMAE SWA: -2.7140730


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 005, LR: 0.000025, Loss: 0.0273012,          Train LogMAE: -3.9652012, Val LogMAE: -2.6927343
Val LogMAE SWA: -2.7166343


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 006, LR: 0.000025, Loss: 0.0255557,          Train LogMAE: -4.0559917, Val LogMAE: -2.6996353
Val LogMAE SWA: -2.7188434


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 007, LR: 0.000013, Loss: 0.0218847,          Train LogMAE: -4.3050418, Val LogMAE: -2.7178966
Val LogMAE SWA: -2.7212128


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 008, LR: 0.000013, Loss: 0.0209527,          Train LogMAE: -4.3997866, Val LogMAE: -2.7188505
Val LogMAE SWA: -2.7220674


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 009, LR: 0.000006, Loss: 0.0191141,          Train LogMAE: -4.6080639, Val LogMAE: -2.7258790
Val LogMAE SWA: -2.7236790


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 010, LR: 0.000006, Loss: 0.0186579,          Train LogMAE: -4.6773264, Val LogMAE: -2.7273596
Val LogMAE SWA: -2.7239565


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 011, LR: 0.000100, Loss: 0.0490417,          Train LogMAE: -3.2757647, Val LogMAE: -2.5500277
Val LogMAE SWA: -2.7240088


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 012, LR: 0.000100, Loss: 0.0505134,          Train LogMAE: -3.2350820, Val LogMAE: -2.5113836
Val LogMAE SWA: -2.7243656


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 013, LR: 0.000050, Loss: 0.0361824,          Train LogMAE: -3.6187571, Val LogMAE: -2.6531555
Val LogMAE SWA: -2.7248700


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 014, LR: 0.000050, Loss: 0.0328798,          Train LogMAE: -3.7325533, Val LogMAE: -2.6625951
Val LogMAE SWA: -2.7251684


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 015, LR: 0.000025, Loss: 0.0255863,          Train LogMAE: -4.0466991, Val LogMAE: -2.7037813
Val LogMAE SWA: -2.7260801


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 016, LR: 0.000025, Loss: 0.0237965,          Train LogMAE: -4.1552782, Val LogMAE: -2.7075988
Val LogMAE SWA: -2.7267387


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 017, LR: 0.000013, Loss: 0.0202306,          Train LogMAE: -4.4478234, Val LogMAE: -2.7231860
Val LogMAE SWA: -2.7271498


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 018, LR: 0.000013, Loss: 0.0193260,          Train LogMAE: -4.5690665, Val LogMAE: -2.7250853
Val LogMAE SWA: -2.7274944


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 019, LR: 0.000006, Loss: 0.0175233,          Train LogMAE: -4.8302164, Val LogMAE: -2.7301308
Val LogMAE SWA: -2.7279378


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 020, LR: 0.000006, Loss: 0.0170651,          Train LogMAE: -4.9110804, Val LogMAE: -2.7303108
Val LogMAE SWA: -2.7289280


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))

KeyboardInterrupt: 

In [62]:
torch.save(model.state_dict(), OUTPUT_DIR + f'model_epoch_126.pth')

# Make submission

In [56]:
def merge_direction(df):
    inverse_direction = df.rename(
        {
            'atom_index_1': 'atom_index_0',
            'atom_index_0': 'atom_index_1'
        },
        axis=1)
    merged = pd.merge(df,
                      inverse_direction,
                      on=['molecule_name', 'atom_index_0', 'atom_index_1'],
                      suffixes=('', '_bis'))
    merged['scalar_coupling_constant'] = (merged['scalar_coupling_constant'] + merged['scalar_coupling_constant_bis']) / 2
    return merged.drop('scalar_coupling_constant_bis', axis=1)

In [57]:
def predict(model, input_data, checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path))
    pred_dataset = MoleculeDataset(
        metadata=input_data,
        base_dir=constants.STRUCT_DATA_PATH,
        transform=T.Compose([
            AddBondPath(),
            AddVirtualEdges(),
            AddEdgeDistanceAndDirection(dist_noise=0.),
            AddGlobalAttr(),
            SortTarget(),
            AddBondLinks(),
            AddCounts(),
        ]))
    pred_loader = DataLoader(pred_dataset,
                             batch_size=64,
                             shuffle=False,
                             num_workers=8,
                             follow_batch=[
                                 'bonds_edge_ind', 'bonds_links_edge_ind',
                                 'paths_edge_ind', 'couples_edge_ind'
                             ])
    model.eval()
    preds = []
    inds = []
    couples = []
    for data in tqdm(pred_loader):
        with torch.no_grad():
            data = data.to(device)
            pred = model(data).detach().cpu().numpy()
            ind = data.mol_ind[data.batch[data.edge_index[0][
                data.mask]]].detach().cpu().numpy()

            couple_ind = data.couples_ind.cpu().numpy()
            df = pd.DataFrame({
                'molecule_name' : pred_dataset.molecules[ind].ravel(),
                'molecule_ind': ind.ravel(),
                'atom_index_0': couple_ind[:,0].ravel(), 
                'atom_index_1': couple_ind[:,1].ravel(),
            })
            df.sort_values(['molecule_ind', 'atom_index_0', 'atom_index_1'], ascending=True, inplace=True)
            np.testing.assert_array_equal(df.molecule_ind, ind.ravel())
            df['scalar_coupling_constant'] = pred
            preds.append(df.drop('molecule_ind', axis=1))
            
    pred = pd.concat(preds)
    pred = merge_direction(pred)
    merged = pd.merge(input_data,
                  pred,
                  on=['molecule_name', 'atom_index_0', 'atom_index_1'],
                  how='left', suffixes=('_truth', ''))
    assert merged.dropna().shape[0] == input_data.shape[0]
    return merged.loc[:, ['id', 'scalar_coupling_constant']].set_index('id'), pred

In [58]:
pred_val, p = predict(model, val_data, './models/megnet_256x10_softplus_edgeagg/model_epoch_126.pth')

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=79), HTML(value='')))




In [59]:
pred_val.head()

Unnamed: 0_level_0,scalar_coupling_constant
id,Unnamed: 1_level_1
1582,-1.053646
1583,0.496505
1584,0.496774
1585,14.346548
1586,92.781601


In [60]:
def score(pred, ref_data):
    merged = pd.merge(ref_data, pred, how='left', left_on='id', right_index=True, suffixes=('', '_pred'))
    merged['abs_error'] = (merged['scalar_coupling_constant'] - merged['scalar_coupling_constant_pred']).abs()
    result = merged.groupby('type')['abs_error'].mean()
    result.iloc[:] = np.log(np.maximum(result.values, 1e-9))
    return result.mean(), result.to_dict()    

In [61]:
score(pred_val, val_data)

(-2.7415171948044534,
 {'1JHC': -1.797125641902922,
  '1JHN': -1.775391270999278,
  '2JHC': -2.7413817666351967,
  '2JHH': -3.269456397283846,
  '2JHN': -3.045777238183105,
  '3JHC': -2.7072087982777666,
  '3JHH': -3.2822666043082807,
  '3JHN': -3.3135298408452325})

In [None]:
test = pd.read_csv('./data/test.csv')

In [78]:
sub, _ = predict(model, test, './models/megnet_256x10_softplus_edgeagg/model_epoch_126.pth')

HBox(children=(IntProgress(value=0, max=45772), HTML(value='')))




HBox(children=(IntProgress(value=0, max=716), HTML(value='')))




In [81]:
sub.head()

Unnamed: 0_level_0,scalar_coupling_constant
id,Unnamed: 1_level_1
4658147,18.438942
4658148,194.546509
4658149,10.211948
4658150,194.50351
4658151,18.439617


In [83]:
!mkdir -p subs/lam_01_v1/

In [84]:
sub.to_csv('./subs/lam_01_v1/sub.csv', index=True)
pred_val.to_csv('./subs/lam_01_v1/pred_val.csv', index=True)