In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os.path as osp
from tqdm.autonotebook import tqdm
import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, GRU

import torch_geometric.transforms as T
from torch_geometric.data import DataLoader



In [3]:
import pandas as pd
import numpy as np

In [4]:
from sklearn.model_selection import train_test_split, GroupShuffleSplit

In [5]:
from torch.utils.data import Subset

In [6]:
from kaggle_champs import constants

# Load and preprocessing data

## Load data

In [7]:
train = pd.read_csv('../data/train.csv')

In [8]:
y_mean = train.scalar_coupling_constant.mean()

In [9]:
y_std = train.scalar_coupling_constant.std()

In [10]:
np.log((train.scalar_coupling_constant - train.type.map(train.groupby('type').scalar_coupling_constant.mean())).abs().groupby(train.type).mean())

type
1JHC    2.548219
1JHN    2.275415
2JHC    0.999041
2JHH    0.983063
2JHN    1.086673
3JHC    0.911788
3JHH    1.122420
3JHN   -0.033818
dtype: float64

## Split train valid

In [11]:
molecules = train.molecule_name.drop_duplicates().sort_values()

In [12]:
train_ind, valid_ind = train_test_split(np.arange(len(molecules)),
                                        test_size=5000,
                                        random_state=1234)

In [13]:
assert not set(train_ind).intersection(valid_ind)

In [14]:
len(train_ind), len(valid_ind)

(80003, 5000)

## Create train valid subet

In [15]:
# Check reproducibility
rs = np.random.RandomState(seed=1234)
print(rs.choice(train_ind, 10))
print(rs.choice(valid_ind, 10))

[19669 26783  1698 47278 33476 59113 40999 64242 25723 71229]
[55624 62327 36561 67391 19447 20288 70596 59541 32479 52121]


In [16]:
train_data = train.loc[train.molecule_name.isin(molecules.iloc[train_ind])]
val_data = train.loc[train.molecule_name.isin(molecules.iloc[valid_ind])]

## Create dataset

In [17]:
from kaggle_champs.dataset import ChampsDataset

In [18]:
import os
import numpy as np
import openbabel
import torch

from torch_geometric.data import Data
from torch.utils.data import Dataset
from tqdm.autonotebook import tqdm
from kaggle_champs.dataset import mol_to_data_v2

In [19]:
class MoleculeDataset(Dataset):
    def __init__(self, metadata=None, base_dir=None, transform=None):
        self.molecules = metadata.molecule_name.unique()
        self.metadata = dict([
            (ind, df) for ind, df in tqdm(metadata.groupby('molecule_name'))
        ])
        self.base_dir = base_dir
        self.transform = transform
        self.conversion = openbabel.OBConversion()
        self.conversion.SetInAndOutFormats("xyz", "mdl")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, index):
        mol = openbabel.OBMol()
        mol_name = self.molecules[index]

        xyz_file = os.path.join(self.base_dir, f'{mol_name}.xyz')
        if not os.path.exists(xyz_file):
            raise FileNotFoundError(f'Expecting file {xyz_file} not found')
        self.conversion.ReadFile(mol, xyz_file)

        data = mol_to_data_v2(mol)
        data.mol_ind = torch.tensor([[index]], dtype=torch.long)
        
        data = self._add_targets(data, metadata=self.metadata[mol_name])
        
        data.graph = nx.Graph()
        data.graph.add_edges_from(data.edge_index.transpose(1,0).cpu().numpy())
        
        if self.transform:
            data = self.transform(data)
            
        if hasattr(data, 'graph'):
            del data.graph
        return data
    
    def _add_inverse_couple(self, couples):
        inverse_direction = couples.rename(
            {'atom_index_1': 'atom_index_0', 
             'atom_index_0': 'atom_index_1'}, 
            axis=1)
        
        couples = couples.append(
            inverse_direction,
            sort=False
        )
        couples = couples.sort_values(['atom_index_0',
                                       'atom_index_1'])
        
        return couples
    
    def _add_y(self, data, couples):
        if 'scalar_coupling_constant' in couples.columns:
            data.y = torch.tensor(
                couples['scalar_coupling_constant'].values,
                dtype=torch.float).view(-1,1)
        else:
            data.y = torch.zeros((len(couples), 1), dtype=torch.float)
        return data
    
    def _add_targets(self, data, metadata):
        couples = metadata.copy()        
        couples = self._add_inverse_couple(couples)
        
        
        data.couples_ind = torch.tensor(
            couples[['atom_index_0',
                     'atom_index_1']].values,
            dtype=torch.long)
        
        data = self._add_y(data, couples)
        
        data.type = torch.tensor(
            couples['type'].map(constants.TYPES_DICT).values,
            dtype=torch.long)
        
        data.sample_weight = torch.tensor(
            couples['type'].map(constants.TYPES_WEIGHTS).values,
            dtype=torch.float)
        
        return data

In [20]:
from kaggle_champs.preprocessing import RandomRotation, AddVirtualEdges, AddEdgeDistanceAndDirection, SortTarget

In [21]:
import networkx as nx

In [22]:
class AddEdgeDistanceAndDirection:
    def __init__(self, dist_noise=0., gauss_base_max=4, gauss_base_steps=20, keep=True):
        self.dist_noise = dist_noise
        self.gauss_base_max = gauss_base_max
        self.gauss_base_steps = gauss_base_steps
        self.keep = True
        
    def __call__(self, data):
        (row, col), pos, edge_attr = data.edge_index, data.pos, data.edge_attr

        dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)
        
        if self.dist_noise > 0:
            noise = 1 + torch.randn_like(dist, dtype=dist.dtype) * self.dist_noise
            dist = dist * noise

        direction = (pos[col] - pos[row]) / dist
        if self.keep:
            data.dist = dist
            data.direction = direction
        
        base = torch.linspace(self.gauss_base_max/self.gauss_base_steps,
                              self.gauss_base_max, 
                              self.gauss_base_steps, 
                              dtype=torch.float).view(1, -1)    # shape 1xn for broadcasting
        
        dist = torch.exp(-(dist - base) ** 2 / 0.5 ** 2)
        
        edge_attr = edge_attr.view(-1, 1) if edge_attr.dim() == 1 else edge_attr
        data.edge_attr = torch.cat(
                [edge_attr,
                 dist.type_as(edge_attr),
                 direction.type_as(edge_attr)],
                dim=-1)      

        return data

In [23]:
class AddBondLinks:
    def __call__(self, data):
        bonds_ind = data.bonds_edge_ind
        
        bonds_from = bonds_ind.view(-1, 1).repeat(1, (len(bonds_ind))).view(-1)
        bonds_to = bonds_ind.view(-1).repeat(1, len(data.bonds_edge_ind)).view(-1)
        bonds_links = torch.stack([bonds_from, bonds_to], dim=1)  # all couples, will filter
        
        filter_correct_common_node = (data.edge_index[:, bonds_from][1] == data.edge_index[:, bonds_to][0])
        filter_remove_self_loop = (data.edge_index[:, bonds_from][0] != data.edge_index[:, bonds_to][1])
        
        data.bonds_links_edge_ind = bonds_links[filter_correct_common_node * filter_remove_self_loop]
        return data

In [24]:
class AddCounts:
    def __call__(self, data):
        data.count_nodes = torch.tensor([[data.num_nodes]], dtype=torch.long)
        data.count_edges = torch.tensor([[data.num_edges]], dtype=torch.long)
        return data

In [25]:
class AddGlobalAttr:
    def __init__(self):
        pass

    def __call__(self, data):
        data.global_attr = torch.zeros((1, 1), dtype=torch.float)
        return data

class AddCouplesInd:
    def __call__(self, data):
        node_from, node_to = data.couples_ind.transpose(1, 0)
        edge_ind = node_from * (data.num_nodes-1) + node_to
        edge_ind[node_from < node_to] = edge_ind[node_from < node_to] - 1
        
        data.couples_edge_ind = edge_ind.view(-1, 1)
        assert torch.equal(data.edge_attr[data.mask], data.edge_attr[data.couples_edge_ind.view(-1)])
        return data

In [54]:
class SortTarget:
    def _get_index(self, data, row, col):
        idx = row * (data.num_nodes-1) + col
        idx[row < col] = idx[row < col] - 1
        return idx
    
    def __call__(self, data):
        target = torch.zeros((data.num_edges, data.y.size()[1]), dtype=torch.float)        
        weights = torch.zeros((data.num_edges), dtype=torch.float)        
        mask = torch.zeros((data.num_edges), dtype=torch.bool)      
        types = torch.zeros((data.num_edges), dtype=torch.long)
        
        row, col = data.couples_ind.transpose(1,0)
        indexes = self._get_index(data, row, col)
        
        mask[indexes] = True
        weights[indexes] = data.sample_weight
        target[indexes] = data.y
        types[indexes] = data.type
        
        #data.mask = mask
        data.y = target[mask]
        data.sample_weight = weights[mask]
        data.type = types[mask]
        
        assert torch.equal(data.couples_ind, data.edge_index[:, mask].transpose(1,0))
        data.couples_edge_ind = torch.arange(data.num_edges, dtype=torch.long)[mask].view(-1,1)
        return data        

In [55]:
class AddBondPath:
    def __call__(self, data):
        # suffix _index to get node index adjustment
        data.paths_index = self.find_paths(data).transpose(1,0)  
        data.paths_edge_ind = torch.cat(
            [self._nodes_to_edge_ind(data, data.paths_index[i], data.paths_index[i+1]) for i in range(3)], 
            dim=1)
        return data
    
    def _nodes_to_edge_ind(self, data, node_from, node_to):
        edge_ind = node_from * (data.num_nodes-1) + node_to
        edge_ind[node_from < node_to] = edge_ind[node_from < node_to] - 1
        return edge_ind.view(-1, 1)
    
    def find_paths(self, data):
        assert hasattr(data, 'couples_ind')
        assert hasattr(data, 'graph')

        all_paths = nx.shortest_path(data.graph)
        paths = []
        for (from_, to_) in data.couples_ind.numpy():
            path = torch.tensor(all_paths[from_][to_], dtype=torch.long).view(-1,1)
            paths.append(path)

        paths = torch.nn.utils.rnn.pad_sequence(paths, batch_first=True).squeeze()
        if paths.size(1) < 4:
            paths = torch.nn.functional.pad(paths, (0, 4 - paths.size(1)))
        return paths

In [56]:
def correct_batch_edge_ind(batch):
    offset_edge_ind = torch.zeros_like(batch.count_edges)
    offset_edge_ind[1:] = batch.count_edges[:-1].cumsum(dim=0)
    for k in ['bonds_edge_ind', 'bonds_links_edge_ind', 'paths_edge_ind', 'couples_edge_ind']:
        if hasattr(batch, k):
            batch[k] = batch[k] + offset_edge_ind[batch[k+'_batch']]
    return batch

In [57]:
train_dataset = MoleculeDataset(metadata=train_data,
                              base_dir=constants.STRUCT_DATA_PATH,
                              transform=T.Compose([
                                  AddBondPath(),
                                  AddVirtualEdges(),
                                  RandomRotation(),
                                  AddEdgeDistanceAndDirection(
                                      dist_noise=0.),
                                  AddGlobalAttr(),
                                  SortTarget(),
                                  AddBondLinks(),
                                  AddCounts(),     
                              ]))

val_dataset = MoleculeDataset(metadata=val_data,
                            base_dir=constants.STRUCT_DATA_PATH,
                            transform=T.Compose([
                                AddBondPath(),
                                AddVirtualEdges(),
                                AddEdgeDistanceAndDirection(
                                      dist_noise=0.),
                                AddGlobalAttr(),
                                SortTarget(),
                                AddBondLinks(),
                                AddCounts(),     
                            ]))

HBox(children=(IntProgress(value=0, max=80003), HTML(value='')))




HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))




In [58]:
data = train_dataset[10]
data

FP16_Data(bonds_edge_ind=[20, 1], bonds_links_edge_ind=[36, 2], count_edges=[1, 1], count_nodes=[1, 1], couples_edge_ind=[86, 1], couples_ind=[86, 2], direction=[110, 3], dist=[110, 1], edge_attr=[110, 34], edge_index=[2, 110], global_attr=[1, 1], mol_ind=[1, 1], paths_edge_ind=[86, 3], paths_index=[4, 86], pos=[11, 3], sample_weight=[86], type=[86], x=[11, 28], y=[86, 1])

# Model

In [59]:
from kaggle_champs.modelling import MegNetBlock, create_mlp_v2, MegNetBlock_v2, MegNetBlock_v3

In [60]:
from torch import nn

In [61]:
from torch_scatter import scatter_add

In [62]:
def gather_embedding(data, x_out, edge_out, u_out, couple_type):
    n_bonds = int(couple_type[0])
    couple_filter = (data.type == constants.TYPES_DICT[couple_type])
    couples_edge_ind = data.couples_edge_ind.view(-1)
    
    merged = [
        u_out[data.batch[data.edge_index[0][couples_edge_ind][couple_filter]]],
    ]
    if n_bonds > 1:
        merged.append(edge_out[couples_edge_ind][couple_filter])
        
    node_ind = data.paths_index.transpose(1,0)[:, :n_bonds+1][couple_filter] # convert_node_ind(data, 'paths')[:, :n_bonds+1]
    for i in range(n_bonds+1):
        merged.append(x_out[node_ind[:,i]])
        
    for i in range(n_bonds):
        edge_ind = data.paths_edge_ind[:,i] # convert_couple_to_edge_ind(data, data.paths_index[i], data.paths_index[i+1], data.paths_edge_ind_batch)
        merged.append(edge_out[edge_ind[couple_filter]])
    return torch.cat(merged, dim=1)

In [63]:
class OutputLayer_new(torch.nn.Module):
    def __init__(self, rep_dim, dim, y_mean, y_std, couple_type):
        super(OutputLayer_new, self).__init__()
        self.scaling = torch.nn.Linear(1, 1)
        self.scaling.bias = torch.nn.Parameter(torch.tensor(y_mean,
                                                            dtype=torch.float),
                                               requires_grad=False)
        self.scaling.weight = torch.nn.Parameter(torch.tensor(
            [[y_std]], dtype=torch.float),
                                                 requires_grad=False)
        self.couple_type = couple_type
        n_bonds = int(couple_type[0])
        
        if n_bonds == 1:
            input_dim = dim * (n_bonds + (n_bonds + 1) + 1)  # edges + nodes + u
        else:
            input_dim = dim * (n_bonds + (n_bonds + 1) + 2)  # edges + nodes + u + direct edge
        
        self.mlp = create_mlp_v2(
            input_dim=input_dim,
            output_dim=1,
            hidden_dims=[input_dim//2, input_dim//4, input_dim//4],
            normalization_cls=torch.nn.LayerNorm,
            activation_cls=torch.nn.Softplus,
            dropout_cls=torch.nn.Dropout,
            dropout_prob=0.
        )

    def forward(self, data, x_out, edge_out, u_out):
        in_ = gather_embedding(data, x_out, edge_out, u_out, self.couple_type)
        out = self.mlp(in_)
        out = self.scaling(out)
        return out

In [64]:
class EdgeConv(torch.nn.Module):
    def __init__(self, dim=32, update_bonds_only=True):
        super(EdgeConv, self).__init__()
        self.update_bonds_only = update_bonds_only
        self.msg_mlp = create_mlp_v2(
            input_dim=dim * 2 + 1,
            output_dim=dim,
            hidden_dims=[dim, dim],
            normalization_cls=torch.nn.LayerNorm,
            activation_cls=torch.nn.Softplus,
            dropout_cls=torch.nn.Dropout,
            dropout_prob=0.
        )
        
        self.gate_mlp = create_mlp_v2(
            input_dim=dim * 2 + 1,
            output_dim=dim,
            hidden_dims=[dim, dim],
            normalization_cls=torch.nn.LayerNorm,
            activation_cls=torch.nn.Softplus,
            dropout_cls=torch.nn.Dropout,
            dropout_prob=0.
        )
        
        self.update_mlp = create_mlp_v2(
            input_dim=dim * 3,
            output_dim=dim,
            hidden_dims=[dim, dim],
            normalization_cls=torch.nn.LayerNorm,
            activation_cls=torch.nn.Softplus,
            dropout_cls=torch.nn.Dropout,
            dropout_prob=0.
        )        
        
    def msg(self, *args):
        in_ = torch.cat(args, dim=1)
        out = self.msg_mlp(in_) * torch.tanh(self.gate_mlp(in_))
        return out
    
    def aggregate_msg(self, msg, groupby, n_edges):
        res = torch.zeros((n_edges, msg.size(1)), dtype=msg.dtype, device='cuda')
        
        unique_groups = torch.unique(groupby, sorted=True)
        
        #index = torch.zeros_like(groupby, dtype=torch.long, , device='cuda')       
        
        scatter_mat = (unique_groups.view(-1,1) == groupby.view(1,-1)).float()
        # assert scatter_mat.size(1) == groupby.size(0)
        # assert scatter_mat.size(0) == unique_groups.size(0)
        res[unique_groups] = torch.mm(scatter_mat, msg)
        
        return res
    
    def forward(self, edge_out, bonds_links_edge_ind, bonds_edge_ind, direction):
        row, col = bonds_links_edge_ind[:, 0], bonds_links_edge_ind[:, 1]
        angle_feats = (direction[row] * direction[col]).sum(dim=1).view(-1,1)
        
        msg = self.msg(edge_out[row], edge_out[col], angle_feats)
        msg_row = self.aggregate_msg(msg, row, edge_out.size(0))
        msg_col = self.aggregate_msg(msg, col, edge_out.size(0))
        if self.update_bonds_only:
            select = bonds_edge_ind.view(-1)
        else:
            select = torch.arange(edge_out.size(0), dtype=torch.long, device=edge_out.device)
        out = self.update_mlp(torch.cat([edge_out[select], msg_row[select], msg_col[select]], dim=1))
        assert out.size(0) == select.size(0)
        result = edge_out        
        result[select] = result[select] + out
        return result

In [65]:
class MegNetModel_new(torch.nn.Module):
    def __init__(self,
                 edge_dim,
                 x_dim,
                 u_dim,
                 dim=32,
                 head_dim=32,
                 n_megnet_blocks=3,
                 y_mean=0,
                 y_std=1,
                 layer_norm=False):
        super(MegNetModel_new, self).__init__()
        self.dim = dim
        self.n_megnet_blocks = n_megnet_blocks
        self.megnet_blocks = torch.nn.ModuleList([
            MegNetBlock_v3(edge_dim,
                           x_dim,
                           u_dim,
                           dim,
                           normalization_cls=torch.nn.LayerNorm,
                           activation_cls=torch.nn.Softplus,
                           dropout_cls=torch.nn.Dropout,
                           dropout_prob=0.,
                           residual=True)
        ] + [
            MegNetBlock_v3(dim,
                           dim,
                           dim,
                           dim,
                           normalization_cls=torch.nn.LayerNorm,
                           activation_cls=torch.nn.Softplus,
                           dropout_cls=torch.nn.Dropout,
                           dropout_prob=0.,
                           residual=True) for i in range(n_megnet_blocks - 1)
        ])

        self.out_mlp = torch.nn.ModuleList([
            OutputLayer_new(
                dim,
                head_dim,
                y_mean=y_mean[i],
                y_std=y_std[i],
                couple_type=type_,
            ) for i, type_ in enumerate(constants.TYPES_LIST)
        ])


    def forward(self, data):
        data = correct_batch_edge_ind(data)

        if not hasattr(data, 'global_attr'):
            data.global_attr = torch.zeros((data.num_graphs, 1),
                                           dtype=torch.float,
                                           device=data.x.device)
        x_out, edge_out, u_out = data.x, data.edge_attr, data.global_attr
        for i in range(self.n_megnet_blocks):
            x_out, edge_out, u_out = self.megnet_blocks[i](x_out,
                                           data.edge_index,
                                           edge_out,
                                           u_out,
                                           data.batch,
                                           first_block=(i==0))
           
        pred = torch.zeros_like(data.type,
                                dtype=torch.float,
                                device=x_out.device)
        for type_ in range(8):
            if (data.type == type_).any():
                pred[data.type == type_] = self.out_mlp[type_](data, x_out,
                                                               edge_out,
                                                               u_out).view(-1)
        return pred

# Training

In [66]:
from kaggle_champs.metrics import MeanLogGroupMAE, AverageMetric

In [67]:
from kaggle_champs.training import train_epoch

In [68]:
import torch
from tqdm.autonotebook import tqdm


def train_epoch(global_iteration, epoch, model, device, optimizer, train_loader, tb_logger, gradient_accumulation_steps=1):
    model.train()

    avg_loss = AverageMetric()
    log_mae = MeanLogGroupMAE()
    
    pbar = tqdm(train_loader)
    for step, data in enumerate(pbar):
        data = data.to(device)
        
        pred = model(data)

        loss = torch.nn.L1Loss(reduction='mean')(pred.view(-1),
                                                     data.y.view(-1))
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            
            global_iteration += 1

        tb_logger.add_scalar('loss', loss.item(), global_iteration)

        avg_loss.update(loss.item() * data.num_graphs, data.num_graphs)
        log_mae.update(pred.view(-1), data.y.view(-1), data.type)

        pbar.set_postfix_str(f'loss: {avg_loss.compute():.4f}')
    return avg_loss.compute(), log_mae, global_iteration

In [69]:
def test_model(model, loader):
    model.eval()
    log_mae = MeanLogGroupMAE()
    avg_loss = AverageMetric()
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            pred = model(data)
            
            loss = torch.nn.L1Loss(reduction='mean')(pred.view(-1),
                                                     data.y.view(-1))
            avg_loss.update(loss.item() * data.num_graphs, data.num_graphs)
            
            log_mae.update(pred.view(-1), data.y.view(-1), data.type.view(-1))
            
        return avg_loss.compute(), log_mae


def make_log(epoch, lr, loss, tr_logmae, val_logmae):
    results = {
        'epoch': epoch,
        'lr': lr,
        'loss': loss,
        'tr_logmae': tr_logmae.compute(),
        'val_logmae': val_logmae.compute(),
    }
    for k, v in tr_logmae.compute_individuals().items():
        results.update({'tr_' + k: v})
    for k, v in val_logmae.compute_individuals().items():
        results.update({'val_' + k: v})
    return results


def save_checkpoint(dir_path, model, optimizer, scheduler, epoch):
    torch.save(model.state_dict(), dir_path + f'model_epoch_{epoch}.pth')
    torch.save(optimizer.state_dict(),
               dir_path + f'optimizer_epoch_{epoch}.pth')
    torch.save(scheduler.state_dict(),
               dir_path + f'scheduler_epoch_{epoch}.pth')

# Run

In [70]:
from tensorboardX import SummaryWriter

In [71]:
import shutil

In [72]:
OUTPUT_DIR = './models/megnet_256x10_unweightedMAE_softplus/'

!mkdir -p {OUTPUT_DIR}
!rm -rf {OUTPUT_DIR}
!mkdir -p {OUTPUT_DIR}

In [73]:
tb_logger = SummaryWriter(OUTPUT_DIR+'tb_log/')
global_iteration = 0

In [74]:
SAVE_INTERVAL = 10

In [75]:
MAX_EPOCH = 150

In [76]:
val_loader = DataLoader(val_dataset,
                        batch_size=64,
                        shuffle=False,
                        num_workers=8, 
                        follow_batch=['bonds_edge_ind', 'bonds_links_edge_ind', 'paths_edge_ind', 'couples_edge_ind']
                       )
train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          num_workers=8,
                          shuffle=True,
                          follow_batch=['bonds_edge_ind', 'bonds_links_edge_ind', 'paths_edge_ind', 'couples_edge_ind']
                         )

In [77]:
batch = next(iter(val_loader))

In [78]:
batch

Batch(batch=[895], bonds_edge_ind=[1758, 1], bonds_edge_ind_batch=[1758], bonds_links_edge_ind=[3076, 2], bonds_links_edge_ind_batch=[3076], count_edges=[64, 1], count_nodes=[64, 1], couples_edge_ind=[5146, 1], couples_edge_ind_batch=[5146], couples_ind=[5146, 2], direction=[12162, 3], dist=[12162, 1], edge_attr=[12162, 34], edge_index=[2, 12162], global_attr=[64, 1], mol_ind=[64, 1], paths_edge_ind=[5146, 3], paths_edge_ind_batch=[5146], paths_index=[4, 5146], pos=[895, 3], sample_weight=[5146], type=[5146], x=[895, 28], y=[5146, 1])

In [79]:
y_mean = train.groupby(train.type.map(
    constants.TYPES_DICT)).scalar_coupling_constant.mean().sort_index().values
y_std = train.groupby(train.type.map(
    constants.TYPES_DICT)).scalar_coupling_constant.std().sort_index().values

In [80]:
device = torch.device('cuda')
model = MegNetModel_new(edge_dim=data.edge_attr.size()[1],
                    x_dim=data.x.size()[1],
                    u_dim=1,
                    dim=256,
                    head_dim=256,
                    n_megnet_blocks=10,
                    y_mean=y_mean,
                    y_std=y_std,
                    layer_norm=False).to(device)

In [81]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=1.)

In [55]:
# train loop
logs = []
for epoch in range(1, MAX_EPOCH + 1):
    lr = scheduler.optimizer.param_groups[0]['lr']
    tr_loss, tr_logmae, global_iteration = train_epoch(global_iteration,
                                     epoch,
                                     model,
                                     device,
                                     optimizer,
                                     train_loader,
                                     tb_logger,
                                     gradient_accumulation_steps=1)
    scheduler.step()

    val_loss, val_logmae = test_model(model, val_loader)

    epoch_log = make_log(epoch, lr, tr_loss, tr_logmae, val_logmae)
    logs.append(epoch_log)
    pd.DataFrame(logs).to_csv(OUTPUT_DIR + 'log.csv')
    print('Epoch: {epoch:03d}, LR: {lr:7f}, Loss: {loss:.7f}, \
         Train LogMAE: {tr_logmae:.7f}, Val LogMAE: {val_logmae:.7f}'.format(
        **epoch_log))

    if epoch % SAVE_INTERVAL == 0:
        save_checkpoint(OUTPUT_DIR, model, optimizer, scheduler, epoch)

    tb_logger.add_scalar('lr', lr, global_iteration)
    tb_logger.add_scalar('val_loss', val_loss, global_iteration)
    tb_logger.add_scalars('global_logmae', {
        'tr_logmae': epoch_log['tr_logmae'],
        'val_logmae': epoch_log['val_logmae']
    }, global_iteration)

    for type_ in constants.TYPES_LIST:
        tb_logger.add_scalars(
            type_, {
                'tr_' + type_: epoch_log['tr_' + type_],
                'val_' + type_: epoch_log['val_' + type_]
            }, global_iteration)

HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 001, LR: 0.000100, Loss: 1.3435959,          Train LogMAE: 0.2111909, Val LogMAE: -0.3088027


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 002, LR: 0.000100, Loss: 0.6723769,          Train LogMAE: -0.5152027, Val LogMAE: -0.5917197


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 003, LR: 0.000100, Loss: 0.5278573,          Train LogMAE: -0.7751854, Val LogMAE: -0.7835144


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 004, LR: 0.000100, Loss: 0.4445446,          Train LogMAE: -0.9543021, Val LogMAE: -1.0167685


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 005, LR: 0.000100, Loss: 0.3872927,          Train LogMAE: -1.0986850, Val LogMAE: -1.0699083


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 006, LR: 0.000100, Loss: 0.3461067,          Train LogMAE: -1.2149288, Val LogMAE: -1.1778488


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 007, LR: 0.000100, Loss: 0.3151478,          Train LogMAE: -1.3132377, Val LogMAE: -1.2939807


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 008, LR: 0.000100, Loss: 0.2904124,          Train LogMAE: -1.3992750, Val LogMAE: -1.3745773


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 009, LR: 0.000100, Loss: 0.2691236,          Train LogMAE: -1.4781317, Val LogMAE: -1.4764200


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 010, LR: 0.000100, Loss: 0.2520460,          Train LogMAE: -1.5465860, Val LogMAE: -1.4749096


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 011, LR: 0.000100, Loss: 0.2372075,          Train LogMAE: -1.6110781, Val LogMAE: -1.5394271


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 012, LR: 0.000100, Loss: 0.2244418,          Train LogMAE: -1.6668765, Val LogMAE: -1.5608039


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 013, LR: 0.000100, Loss: 0.2133004,          Train LogMAE: -1.7174368, Val LogMAE: -1.6569852


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 014, LR: 0.000100, Loss: 0.2040652,          Train LogMAE: -1.7670703, Val LogMAE: -1.6678236


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 015, LR: 0.000100, Loss: 0.1942851,          Train LogMAE: -1.8149763, Val LogMAE: -1.6873447


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 016, LR: 0.000100, Loss: 0.1861230,          Train LogMAE: -1.8610811, Val LogMAE: -1.7307230


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 017, LR: 0.000100, Loss: 0.1794937,          Train LogMAE: -1.8961263, Val LogMAE: -1.7096535


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 018, LR: 0.000100, Loss: 0.1726336,          Train LogMAE: -1.9345976, Val LogMAE: -1.7762152


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 019, LR: 0.000100, Loss: 0.1666096,          Train LogMAE: -1.9740012, Val LogMAE: -1.8340820


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 020, LR: 0.000100, Loss: 0.1611310,          Train LogMAE: -2.0106545, Val LogMAE: -1.8599246


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 021, LR: 0.000100, Loss: 0.1559629,          Train LogMAE: -2.0429619, Val LogMAE: -1.8333535


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 022, LR: 0.000100, Loss: 0.1513927,          Train LogMAE: -2.0736567, Val LogMAE: -1.8751195


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 023, LR: 0.000100, Loss: 0.1470898,          Train LogMAE: -2.1016942, Val LogMAE: -1.9168091


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 024, LR: 0.000100, Loss: 0.1432336,          Train LogMAE: -2.1315325, Val LogMAE: -1.9097480


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 025, LR: 0.000100, Loss: 0.1392821,          Train LogMAE: -2.1632042, Val LogMAE: -2.0079301


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 026, LR: 0.000100, Loss: 0.1353705,          Train LogMAE: -2.1907566, Val LogMAE: -1.9867036


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 027, LR: 0.000100, Loss: 0.1323499,          Train LogMAE: -2.2137023, Val LogMAE: -1.9321187


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 028, LR: 0.000100, Loss: 0.1287467,          Train LogMAE: -2.2404443, Val LogMAE: -2.0395623


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 029, LR: 0.000100, Loss: 0.1261419,          Train LogMAE: -2.2653775, Val LogMAE: -1.9788178


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 030, LR: 0.000100, Loss: 0.1231468,          Train LogMAE: -2.2873921, Val LogMAE: -2.0632467


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 031, LR: 0.000100, Loss: 0.1203224,          Train LogMAE: -2.3118772, Val LogMAE: -2.0864137


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 032, LR: 0.000100, Loss: 0.1179538,          Train LogMAE: -2.3339268, Val LogMAE: -2.1001422


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 033, LR: 0.000100, Loss: 0.1156523,          Train LogMAE: -2.3537453, Val LogMAE: -2.1113081


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 034, LR: 0.000100, Loss: 0.1132490,          Train LogMAE: -2.3738982, Val LogMAE: -2.0994170


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 035, LR: 0.000100, Loss: 0.1108633,          Train LogMAE: -2.3960181, Val LogMAE: -2.1397242


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 036, LR: 0.000100, Loss: 0.1089906,          Train LogMAE: -2.4173691, Val LogMAE: -2.1373894


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 037, LR: 0.000100, Loss: 0.1070795,          Train LogMAE: -2.4336989, Val LogMAE: -2.1382804


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 038, LR: 0.000100, Loss: 0.1050562,          Train LogMAE: -2.4537480, Val LogMAE: -2.1771618


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 039, LR: 0.000100, Loss: 0.1031882,          Train LogMAE: -2.4709058, Val LogMAE: -2.1596505


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 040, LR: 0.000100, Loss: 0.1017899,          Train LogMAE: -2.4865844, Val LogMAE: -2.1497889


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 041, LR: 0.000100, Loss: 0.0998834,          Train LogMAE: -2.5051056, Val LogMAE: -2.1667906


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 042, LR: 0.000100, Loss: 0.0983320,          Train LogMAE: -2.5244133, Val LogMAE: -2.1512037


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 043, LR: 0.000100, Loss: 0.0967665,          Train LogMAE: -2.5393165, Val LogMAE: -2.1846755


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 044, LR: 0.000100, Loss: 0.0952208,          Train LogMAE: -2.5569538, Val LogMAE: -2.1934559


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 045, LR: 0.000100, Loss: 0.0939571,          Train LogMAE: -2.5720371, Val LogMAE: -2.2302932


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 046, LR: 0.000100, Loss: 0.0924166,          Train LogMAE: -2.5888565, Val LogMAE: -2.2510141


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 047, LR: 0.000100, Loss: 0.0913434,          Train LogMAE: -2.5997072, Val LogMAE: -2.2241631


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 048, LR: 0.000100, Loss: 0.0899852,          Train LogMAE: -2.6159273, Val LogMAE: -2.2676876


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 049, LR: 0.000100, Loss: 0.0885067,          Train LogMAE: -2.6317925, Val LogMAE: -2.2527345


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 050, LR: 0.000100, Loss: 0.0875231,          Train LogMAE: -2.6444398, Val LogMAE: -2.2743505


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 051, LR: 0.000100, Loss: 0.0861785,          Train LogMAE: -2.6597673, Val LogMAE: -2.2903322


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 052, LR: 0.000100, Loss: 0.0851059,          Train LogMAE: -2.6756986, Val LogMAE: -2.2871558


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 053, LR: 0.000100, Loss: 0.0840291,          Train LogMAE: -2.6878089, Val LogMAE: -2.2835560


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 054, LR: 0.000100, Loss: 0.0831178,          Train LogMAE: -2.6969873, Val LogMAE: -2.2204554


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 055, LR: 0.000100, Loss: 0.0825296,          Train LogMAE: -2.7090053, Val LogMAE: -2.2909268


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 056, LR: 0.000100, Loss: 0.0811594,          Train LogMAE: -2.7235030, Val LogMAE: -2.2875274


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 057, LR: 0.000100, Loss: 0.0799975,          Train LogMAE: -2.7375959, Val LogMAE: -2.3016799


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 058, LR: 0.000100, Loss: 0.0792555,          Train LogMAE: -2.7463877, Val LogMAE: -2.3258586


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 059, LR: 0.000100, Loss: 0.0782799,          Train LogMAE: -2.7640814, Val LogMAE: -2.3274752


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 060, LR: 0.000100, Loss: 0.0774060,          Train LogMAE: -2.7713895, Val LogMAE: -2.3515956


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 061, LR: 0.000100, Loss: 0.0764204,          Train LogMAE: -2.7861903, Val LogMAE: -2.3310510


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 062, LR: 0.000100, Loss: 0.0757485,          Train LogMAE: -2.7969345, Val LogMAE: -2.3012933


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 063, LR: 0.000100, Loss: 0.0749888,          Train LogMAE: -2.8097506, Val LogMAE: -2.3309731


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 064, LR: 0.000100, Loss: 0.0742326,          Train LogMAE: -2.8170393, Val LogMAE: -2.3247403


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 065, LR: 0.000100, Loss: 0.0733537,          Train LogMAE: -2.8290316, Val LogMAE: -2.3338998


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 066, LR: 0.000100, Loss: 0.0727475,          Train LogMAE: -2.8392880, Val LogMAE: -2.3784732


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 067, LR: 0.000100, Loss: 0.0718518,          Train LogMAE: -2.8527484, Val LogMAE: -2.3489957


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 068, LR: 0.000100, Loss: 0.0711912,          Train LogMAE: -2.8607021, Val LogMAE: -2.3642396


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 069, LR: 0.000100, Loss: 0.0706307,          Train LogMAE: -2.8722728, Val LogMAE: -2.3712369


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 070, LR: 0.000100, Loss: 0.0697872,          Train LogMAE: -2.8825578, Val LogMAE: -2.3611532


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 071, LR: 0.000100, Loss: 0.0692617,          Train LogMAE: -2.8924028, Val LogMAE: -2.3853811


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 072, LR: 0.000100, Loss: 0.0684062,          Train LogMAE: -2.9064273, Val LogMAE: -2.3807053


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 073, LR: 0.000100, Loss: 0.0679373,          Train LogMAE: -2.9121835, Val LogMAE: -2.3552043


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 074, LR: 0.000100, Loss: 0.0674116,          Train LogMAE: -2.9196074, Val LogMAE: -2.3950762


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 075, LR: 0.000100, Loss: 0.0667989,          Train LogMAE: -2.9314310, Val LogMAE: -2.3909916


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 076, LR: 0.000100, Loss: 0.0661321,          Train LogMAE: -2.9413580, Val LogMAE: -2.3906081


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 077, LR: 0.000100, Loss: 0.0655262,          Train LogMAE: -2.9503082, Val LogMAE: -2.3727006


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 078, LR: 0.000100, Loss: 0.0650874,          Train LogMAE: -2.9576051, Val LogMAE: -2.3654092


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 079, LR: 0.000100, Loss: 0.0644871,          Train LogMAE: -2.9667808, Val LogMAE: -2.3727336


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 080, LR: 0.000100, Loss: 0.0640414,          Train LogMAE: -2.9739906, Val LogMAE: -2.4069457


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 081, LR: 0.000100, Loss: 0.0635479,          Train LogMAE: -2.9820401, Val LogMAE: -2.4012260


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 082, LR: 0.000100, Loss: 0.0629695,          Train LogMAE: -2.9950545, Val LogMAE: -2.4347384


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 083, LR: 0.000100, Loss: 0.0625287,          Train LogMAE: -2.9979529, Val LogMAE: -2.4013929


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 084, LR: 0.000100, Loss: 0.0619903,          Train LogMAE: -3.0108624, Val LogMAE: -2.4065595


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 085, LR: 0.000100, Loss: 0.0616356,          Train LogMAE: -3.0174936, Val LogMAE: -2.3935418


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 086, LR: 0.000100, Loss: 0.0609967,          Train LogMAE: -3.0250512, Val LogMAE: -2.4167749


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 087, LR: 0.000100, Loss: 0.0604888,          Train LogMAE: -3.0367627, Val LogMAE: -2.3633959


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 088, LR: 0.000100, Loss: 0.0605314,          Train LogMAE: -3.0294948, Val LogMAE: -2.4210152


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 089, LR: 0.000100, Loss: 0.0596197,          Train LogMAE: -3.0526098, Val LogMAE: -2.4355787


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 090, LR: 0.000100, Loss: 0.0592085,          Train LogMAE: -3.0591710, Val LogMAE: -2.4114372


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 091, LR: 0.000100, Loss: 0.0587621,          Train LogMAE: -3.0650263, Val LogMAE: -2.4305834


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 092, LR: 0.000100, Loss: 0.0585062,          Train LogMAE: -3.0761507, Val LogMAE: -2.4246831


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 093, LR: 0.000100, Loss: 0.0579438,          Train LogMAE: -3.0781994, Val LogMAE: -2.4401571


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 094, LR: 0.000100, Loss: 0.0576027,          Train LogMAE: -3.0871726, Val LogMAE: -2.4475933


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 095, LR: 0.000100, Loss: 0.0572166,          Train LogMAE: -3.0975987, Val LogMAE: -2.4343575


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 096, LR: 0.000100, Loss: 0.0568428,          Train LogMAE: -3.1018044, Val LogMAE: -2.4554055


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 097, LR: 0.000100, Loss: 0.0563270,          Train LogMAE: -3.1117552, Val LogMAE: -2.4402157


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 098, LR: 0.000100, Loss: 0.0561052,          Train LogMAE: -3.1178187, Val LogMAE: -2.4354225


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 099, LR: 0.000100, Loss: 0.0558678,          Train LogMAE: -3.1216998, Val LogMAE: -2.4618804


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))


Epoch: 100, LR: 0.000100, Loss: 0.0553395,          Train LogMAE: -3.1342966, Val LogMAE: -2.4613397


HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))

KeyboardInterrupt: 

In [82]:
model.load_state_dict(torch.load('./models/megnet_256x10_unweightedMAE_softplus/model_epoch_100.pth'))
optimizer.load_state_dict(torch.load('./models/megnet_256x10_unweightedMAE_softplus/optimizer_epoch_100.pth'))
global_iteration =  len(train_loader) * 100

In [83]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)

In [None]:
# train loop
logs = []
for epoch in range(101, 121):
    lr = scheduler.optimizer.param_groups[0]['lr']
    tr_loss, tr_logmae, global_iteration = train_epoch(global_iteration,
                                     epoch,
                                     model,
                                     device,
                                     optimizer,
                                     train_loader,
                                     tb_logger,
                                     gradient_accumulation_steps=1)
    scheduler.step()

    val_loss, val_logmae = test_model(model, val_loader)

    epoch_log = make_log(epoch, lr, tr_loss, tr_logmae, val_logmae)
    logs.append(epoch_log)
    pd.DataFrame(logs).to_csv(OUTPUT_DIR + 'log.csv')
    print('Epoch: {epoch:03d}, LR: {lr:7f}, Loss: {loss:.7f}, \
         Train LogMAE: {tr_logmae:.7f}, Val LogMAE: {val_logmae:.7f}'.format(
        **epoch_log))

    if epoch % SAVE_INTERVAL == 0:
        save_checkpoint(OUTPUT_DIR, model, optimizer, scheduler, epoch)

    tb_logger.add_scalar('lr', lr, global_iteration)
    tb_logger.add_scalar('val_loss', val_loss, global_iteration)
    tb_logger.add_scalars('global_logmae', {
        'tr_logmae': epoch_log['tr_logmae'],
        'val_logmae': epoch_log['val_logmae']
    }, global_iteration)

    for type_ in constants.TYPES_LIST:
        tb_logger.add_scalars(
            type_, {
                'tr_' + type_: epoch_log['tr_' + type_],
                'val_' + type_: epoch_log['val_' + type_]
            }, global_iteration)

HBox(children=(IntProgress(value=0, max=2501), HTML(value='')))

In [85]:
epoch

120

In [87]:
!ls {OUTPUT_DIR}

log.csv		     optimizer_epoch_100.pth  scheduler_epoch_10.pth
model_epoch_100.pth  optimizer_epoch_10.pth   scheduler_epoch_110.pth
model_epoch_10.pth   optimizer_epoch_110.pth  scheduler_epoch_120.pth
model_epoch_110.pth  optimizer_epoch_120.pth  scheduler_epoch_20.pth
model_epoch_120.pth  optimizer_epoch_20.pth   scheduler_epoch_30.pth
model_epoch_20.pth   optimizer_epoch_30.pth   scheduler_epoch_40.pth
model_epoch_30.pth   optimizer_epoch_40.pth   scheduler_epoch_50.pth
model_epoch_40.pth   optimizer_epoch_50.pth   scheduler_epoch_60.pth
model_epoch_50.pth   optimizer_epoch_60.pth   scheduler_epoch_70.pth
model_epoch_60.pth   optimizer_epoch_70.pth   scheduler_epoch_80.pth
model_epoch_70.pth   optimizer_epoch_80.pth   scheduler_epoch_90.pth
model_epoch_80.pth   optimizer_epoch_90.pth   tb_log
model_epoch_90.pth   scheduler_epoch_100.pth


# Make sub

In [88]:
def merge_direction(df):
    inverse_direction = df.rename(
        {
            'atom_index_1': 'atom_index_0',
            'atom_index_0': 'atom_index_1'
        },
        axis=1)
    merged = pd.merge(df,
                      inverse_direction,
                      on=['molecule_name', 'atom_index_0', 'atom_index_1'],
                      suffixes=('', '_bis'))
    merged['scalar_coupling_constant'] = (merged['scalar_coupling_constant'] + merged['scalar_coupling_constant_bis']) / 2
    return merged.drop('scalar_coupling_constant_bis', axis=1)

In [91]:
batch.couples_edge_ind_batch

tensor([ 0,  0,  0,  ..., 63, 63, 63])

In [92]:
def predict(model, input_data, checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path))
    pred_dataset = MoleculeDataset(
        metadata=input_data,
        base_dir=constants.STRUCT_DATA_PATH,
                            transform=T.Compose([
                                AddBondPath(),
                                AddVirtualEdges(),
                                AddEdgeDistanceAndDirection(
                                      dist_noise=0.),
                                AddGlobalAttr(),
                                SortTarget(),
                                AddBondLinks(),
                                AddCounts(),     
                            ]))
    pred_loader = DataLoader(pred_dataset,
                             batch_size=64,
                             shuffle=False,
                             num_workers=8,
                             follow_batch=[
                                 'bonds_edge_ind', 'bonds_links_edge_ind',
                                 'paths_edge_ind', 'couples_edge_ind'
                             ])
    model.eval()
    preds = []
    inds = []
    couples = []
    for data in tqdm(pred_loader):
        with torch.no_grad():
            data = data.to(device)
            pred = model(data).detach().cpu().numpy()
            ind = data.mol_ind[data.couples_edge_ind_batch].detach().cpu().numpy()

            couple_ind = data.couples_ind.cpu().numpy()
            df = pd.DataFrame({
                'molecule_name' : pred_dataset.molecules[ind].ravel(),
                'molecule_ind': ind.ravel(),
                'atom_index_0': couple_ind[:,0].ravel(), 
                'atom_index_1': couple_ind[:,1].ravel(),
            })
            df.sort_values(['molecule_ind', 'atom_index_0', 'atom_index_1'], ascending=True, inplace=True)
            np.testing.assert_array_equal(df.molecule_ind, ind.ravel())
            df['scalar_coupling_constant'] = pred
            preds.append(df.drop('molecule_ind', axis=1))
            
    pred = pd.concat(preds)
    pred = merge_direction(pred)
    merged = pd.merge(input_data,
                  pred,
                  on=['molecule_name', 'atom_index_0', 'atom_index_1'],
                  how='left', suffixes=('_truth', ''))
    assert merged.dropna().shape[0] == input_data.shape[0]
    return merged.loc[:, ['id', 'scalar_coupling_constant']].set_index('id'), pred

In [93]:
pred_val, p = predict(model, val_data, f'{OUTPUT_DIR}/model_epoch_120.pth')

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=79), HTML(value='')))




In [59]:
pred_val.head()

Unnamed: 0_level_0,scalar_coupling_constant
id,Unnamed: 1_level_1
1582,-1.053646
1583,0.496505
1584,0.496774
1585,14.346548
1586,92.781601


In [94]:
def score(pred, ref_data):
    merged = pd.merge(ref_data, pred, how='left', left_on='id', right_index=True, suffixes=('', '_pred'))
    merged['abs_error'] = (merged['scalar_coupling_constant'] - merged['scalar_coupling_constant_pred']).abs()
    result = merged.groupby('type')['abs_error'].mean()
    result.iloc[:] = np.log(np.maximum(result.values, 1e-9))
    return result.mean(), result.to_dict()    

In [95]:
score(pred_val, val_data)

(-2.669540019699434,
 {'1JHC': -1.7277696485261647,
  '1JHN': -1.688054814453369,
  '2JHC': -2.654279952396486,
  '2JHH': -3.205307811357667,
  '2JHN': -2.9765596043212295,
  '3JHC': -2.6248218157210634,
  '3JHH': -3.218794817699955,
  '3JHN': -3.2607316931195385})

In [97]:
test = pd.read_csv('../data/test.csv')

In [98]:
sub, _ = predict(model, test, f'{OUTPUT_DIR}/model_epoch_120.pth')

HBox(children=(IntProgress(value=0, max=45772), HTML(value='')))




HBox(children=(IntProgress(value=0, max=716), HTML(value='')))




In [99]:
sub.head()

Unnamed: 0_level_0,scalar_coupling_constant
id,Unnamed: 1_level_1
4658147,15.185948
4658148,118.996857
4658149,11.215303
4658150,119.213646
4658151,15.185041


In [100]:
!mkdir -p subs/lam_02_v1/

In [101]:
sub.to_csv('./subs/lam_02_v1/sub.csv', index=True)
pred_val.to_csv('./subs/lam_02_v1/pred_val.csv', index=True)