In [None]:
import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

from common import *
from layers.layers_09ZI_distributionnal_loss import MEGNetList
from dataset.dataset_9ZB_117_edge_link import EdgeBasedDataset, DataLoader
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch
from tensorboardX import SummaryWriter
from scheduler_superconvergence_09J import *
from torch_geometric.data import DataListLoader
from torch_scatter import scatter_add
from importancer import get_tags, select_tags
from RAdam import RAdam
import torch.nn.utils

def init_dataset():
    global train_loader
    global train_small_loader
    global valid_loader
    global train_indices
    global valid_indices
    
    global submit_loader
    
    if action == 'train':
        if to_load:
            train_indices = to_load['train_indices']
            valid_indices = to_load['valid_indices']
        else:
            indices = list(range(len(dataset)))
            train_indices, valid_indices = train_test_split(indices, test_size = 5000, random_state = 1234)
            
        train_big_indices, train_small_indices = train_test_split(list(range(len(train_indices))), test_size = 5000, random_state = 1234)

        train = torch.utils.data.Subset(dataset, train_indices)
        train_small = torch.utils.data.Subset(train, train_small_indices)
        valid = torch.utils.data.Subset(dataset, valid_indices)

        if not parallel_gpu:
            train_loader = DataLoader(train, batch_size = batch_size, drop_last = True, shuffle = True, follow_batch=['edge_attr_numeric'], num_workers=num_workers)
            train_small_loader = DataLoader(train_small, batch_size = batch_size * valid_batch_size_factor, drop_last = True, shuffle = True, follow_batch=['edge_attr_numeric'], num_workers=num_workers)
            valid_loader = DataLoader(valid, batch_size = batch_size * valid_batch_size_factor, drop_last = True, shuffle = True, follow_batch=['edge_attr_numeric'], num_workers=num_workers)
        else:
            train_loader = DataListLoader(train, batch_size = batch_size, shuffle = True, num_workers=num_workers)
            valid_loader = DataListLoader(valid, batch_size = batch_size * valid_batch_size_factor, shuffle = True, num_workers=num_workers)

        if False and "benchmark":
            for batch in tqdm.tqdm_notebook(train_loader):
                pass
    else:
        if not parallel_gpu:
            submit_loader = DataLoader(dataset, batch_size = batch_size * valid_batch_size_factor, drop_last = False, shuffle = False, follow_batch=['node_embeddings', 'edge_embeddings', 'edge_triangle_embeddings', 'cycle_embeddings'], num_workers=num_workers)
        else:
            raise ValueError

        if False and "benchmark":
            for batch in tqdm.tqdm_notebook(submit_loader):
                pass

def init_model():
    global model
    global optimizer
    
    model = MEGNetList(
        layer_count,                
        atom_embedding_count, bond_ebedding_count, global_embedding_count, 
        atom_input_size, bond_input_size, global_input_size, 
        hidden, 
        target_means, target_stds)
        
    import adabound
    optimizer = RAdam(model.parameters(), lr = 0.0)
    #optimizer = torch.optim.Adam(model.parameters(), lr = 0.0)
    #optimizer = adabound.AdaBound(model.parameters(), gamma=1e-5)
    #optimizer = torch.optim.SGD(model.parameters(), lr = 0.0, nesterov=True, momentum=0.95)

    if to_load:
        model.load_state_dict(to_load['model'])
        optimizer.load_state_dict(to_load['optimizer'])
        
    if not parallel_gpu:
        model = model.to(device)
    else:
        model = model.to('cuda:0')

loss_fn = nn.L1Loss(reduction = 'none')

def init_experiment():
    global writer
    global step
    
    writer = SummaryWriter(f'runs/{experiment}')
    
    if to_load:
        step = to_load['step']
    else:
        step = 0


def batch_train():
    global step
    global batch
    global clip_gradient_norm

    model.train()
    if finetune_type is not None:
        batch.y_mask[batch.y_types != finetune_type] = 0.0
    
    if batch.y_mask.sum() > 0:

        # BATCH
        batch = batch.to(device)

        # OPTIM
        optimizer.zero_grad()

        out = model.forward(
            [batch.x_numeric],
            batch.x_embeddings,

            [batch.edge_attr_numeric], 
            batch.edge_attr_embeddings,

            [batch.u_numeric],
            batch.u_embeddings,

            batch.edge_index, 

            batch.batch, 
            batch.edge_attr_numeric_batch, 

            batch.y_types, 

            batch.cycles_edge_index,
            batch.cycles_id,

            batch.edges_connectivity_ids,
            batch.edges_connectivity_features,
        )

        out = out * centers
        
        """
        out = out - batch.y
        
        loss_matrix = out * out * batch.y_mask
        
        loss_matrix = loss_matrix.mean(dim = 1)
        """
        
        out = out.sum(dim = 1).unsqueeze(1)
        loss_matrix = loss_fn(batch.y, out) * batch.y_mask
        
        loss = loss_matrix.sum() / batch.y_mask.sum()

        loss.backward()
        norm = torch.nn.utils.clip_grad_norm(model.parameters(), clip_gradient_norm)
        optimizer.step()

        # LOG
        """
        if step % 1000 == 0:
            for name, param in model.named_parameters():
                writer.add_histogram(name, param.clone().cpu().data.numpy(), step)
        """
        writer.add_scalar('030-other/lr', optimizer.param_groups[0]['lr'], step)
        writer.add_scalar('000-train/mae', loss, step)
        
        writer.add_scalar('040-norm/norm', norm, step)
        writer.add_scalar('040-norm/max-norm', clip_gradient_norm, step)

        if log_detail:
            losses = pd.DataFrame(
                np.concatenate(
                    [
                        loss_matrix.detach().cpu().numpy(), 
                        batch.y_mask.detach().cpu().numpy(), 
                        batch.y_types.detach().cpu().numpy()
                    ], axis = 1),
                columns = ['loss', 'y_mask', 'y_types']
            )
            losses['y_mask'] = losses['y_mask'].astype(np.int32)
            losses['y_types'] = losses['y_types'].astype(np.int32)
            losses = losses.loc[losses['y_mask'] == 1]
            losses = losses.groupby('y_types')['loss'].mean()

            losses_detail = {}
            for i in losses.index:
                losses_detail[f"type-{i}"] = losses.loc[i]

            if losses_detail:
                writer.add_scalars('train/mae-detail', losses_detail, step)

    # LR SCHEDULING
    if lr_scheduler is not None:
        lr = lr_scheduler.get(step)
        if lr is not None:
            for pg in optimizer.param_groups:
                pg['lr'] = lr

    # MOMENTUM SCHEDULING
    if momentum_scheduler is not None:
        momentum = momentum_scheduler.get(step)
        if momentum is not None:
            for pg in optimizer.param_groups:
                pg['momentum'] = momentum

    # SAVE
    if step != 0 and step % len(train_loader) == 0:
        save()

    step += 1

    # VALID
    if step != 0 and valid_each is not None and step % valid_each == 0:
        valid()
        model.train()

def epoch():
    model.train()
    global step
    global batch
    
    for batch in tqdm.tqdm_notebook(train_loader):
        try:
            batch(batch)
                
        except KeyboardInterrupt:
            print("Escaping")
            return "escape"

def train(until_step):
    model.train()
    global step
    global batch
    
    counter = iter(tqdm.tqdm_notebook(range(step, until_step)))
    
    while True:
        for batch in train_loader:
            if step > until_step:
                return
            else:
                try:
                    next(counter)
                    batch_train()

                except KeyboardInterrupt:
                    print("Escaping")
                    return "escape"
                
        valid()

def batch_valid():
    global batch
    global losses_detail
    
    with torch.no_grad():
        # BATCH
        if finetune_type is not None:
            batch.y_mask[batch.y_types != finetune_type] = 0.0

        if batch.y_mask.sum() > 0:

            batch = batch.to(device)

            # PREDICT
            out = model.forward(
                [batch.x_numeric],
                batch.x_embeddings,

                [batch.edge_attr_numeric], 
                batch.edge_attr_embeddings,

                [batch.u_numeric],
                batch.u_embeddings,

                batch.edge_index, 

                batch.batch, 
                batch.edge_attr_numeric_batch, 

                batch.y_types, 

                batch.cycles_edge_index,
                batch.cycles_id,

                batch.edges_connectivity_ids,
                batch.edges_connectivity_features,
            )
            
            out = out * centers
            
            out = out.sum(dim = 1).unsqueeze(1)

            loss_matrix = loss_fn(batch.y, out) * batch.y_mask

            # COMPUTE
            losses = pd.DataFrame(
                np.concatenate(
                    [
                        loss_matrix.detach().cpu().numpy(), 
                        batch.y_mask.detach().cpu().numpy(), 
                        batch.y_types.detach().cpu().numpy()
                    ], axis = 1),
                columns = ['loss', 'y_mask', 'y_types']
            )
            losses['y_mask'] = losses['y_mask'].astype(np.int32)
            losses['y_types'] = losses['y_types'].astype(np.int32)
            losses = losses.loc[losses['y_mask'] == 1]
            losses_agg = losses.groupby('y_types')['loss'].sum().to_frame()
            losses_agg['count'] = losses.groupby('y_types')['y_mask'].sum()

            for i in losses_agg.index:
                losses_detail[f"type-{i}"] += losses_agg.loc[i, 'loss']
                losses_detail[f"count-{i}"] += losses_agg.loc[i, 'count']

def valid():
    global batch
    global losses_detail
    model.eval()

    losses_detail = {}
    for i in range(8):
        losses_detail[f'type-{i}'] = 0
        losses_detail[f'count-{i}'] = 0
    
    for batch in tqdm.tqdm_notebook(valid_loader):
        try:
            batch_valid()
            
        except KeyboardInterrupt:
            print("Escaping")
            return "escape"

    losses = {}
    total = 0
    total_count = 0
    for i in range(8):
        if losses_detail[f'count-{i}'] != 0:
            total += losses_detail[f'type-{i}']
            total_count += losses_detail[f'count-{i}']
            losses[f'type-{i}'] = losses_detail[f'type-{i}'] / losses_detail[f'count-{i}']
    
    if log_detail:
        writer.add_scalars('valid/mae-detail', losses, step)
    
    loss = np.log(np.array(list(losses.values()))).mean()
        
    writer.add_scalar('010-valid/loss', loss, step)
    writer.add_scalar('020-valid/mae', total / total_count, step)

    return loss, losses

def valid_on_train():
    global batch
    global losses_detail
    model.eval()

    losses_detail = {}
    for i in range(8):
        losses_detail[f'type-{i}'] = 0
        losses_detail[f'count-{i}'] = 0
    
    for batch in tqdm.tqdm_notebook(train_small_loader):
        try:
            batch_valid()
            
        except KeyboardInterrupt:
            print("Escaping")
            return "escape"

    losses = {}
    total = 0
    total_count = 0
    for i in range(8):
        if losses_detail[f'count-{i}'] != 0:
            total += losses_detail[f'type-{i}']
            total_count += losses_detail[f'count-{i}']
            losses[f'type-{i}'] = losses_detail[f'type-{i}'] / losses_detail[f'count-{i}']
    
    if log_detail:
        writer.add_scalars('valid/mae-detail', losses, step)
    
    loss = np.log(np.array(list(losses.values()))).mean()
        
    writer.add_scalar('010-valid/loss', loss, step)
    writer.add_scalar('020-valid/mae', total / total_count, step)

    return loss, losses


def save():
    to_save = {
        'model' : model.state_dict(),
        'optimizer' : optimizer.state_dict(),
        'train_indices' : train_indices,
        'valid_indices' : valid_indices,
        'step' : step,
        'writer' : f'runs/{experiment}',
        'experiment' : experiment,
    }
    torch.save(to_save, f'model_data/model.{experiment}.{step}.bin')

In [None]:
action = 'train'
submit_dataset_name = 'test'

In [None]:
if action == 'train':
    dataset = EdgeBasedDataset(name = 'train')
else:
    dataset = EdgeBasedDataset(name = submit_dataset_name)

target_stats = dataset.bond_descriptors.loc[(dataset.bond_descriptors['type'] != 'VOID') & dataset.bond_descriptors.index.isin(dataset.dataset.loc[dataset.dataset['dataset'] == 'train', 'molecule_id'])].groupby('type_id')['scalar_coupling_constant'].agg(['std', 'median'])

target_means = target_stats['median'].values
target_stds = target_stats['std'].values
target_stats

In [None]:
from sklearn.preprocessing import KBinsDiscretizer

In [None]:
bin_count = 260 * 4
centers = np.linspace(-40, 220 - 1 / 4, bin_count)
delta = (centers[1] - centers[0]) / 2
centers += delta
delta

In [None]:
num_workers = 7
device = 'cuda'
parallel_gpu = False
valid_each = 1400
log_detail = False
finetune_type = None

# Config

hidden = 300
layer_count = 6
batch_size = 20
valid_batch_size_factor = 5
clip_gradient_norm = 100

n = 'distributionnal-loss'
experiment = f'9ZI-001-{n}'
to_load = None

centers = torch.tensor(centers.reshape(1, -1), dtype = torch.float32).to(device)


name = f'9ZI-*-distributionnal-loss'
name_to_load = name


import glob
candidates = glob.glob(f'model_data/model.{name}.*.bin')
next_iter = [int(e.split('.')[1].split('-')[1]) + 1 for e in candidates]
next_iter = '{:0>3}'.format(max(next_iter) if next_iter else 0)
candidates = glob.glob(f'model_data/model.{name_to_load}.*.bin')
last_checkpoint = [(e, int(e.split('.')[2])) for e in candidates]
last_checkpoint = sorted(last_checkpoint, key = lambda x : x[1], reverse = True)

to_load = torch.load(last_checkpoint[0][0], map_location = 'cpu') if last_checkpoint else None
experiment = name.replace('*', next_iter)

print(experiment)
print(last_checkpoint)

## TEMPLATE
#to_load = torch.load(f'model_data/model.9ZE-010-megnet-like-edge-triangle.40000.bin')


# Inputs

sample = dataset[0]
print(sample)

global_embedding_count = dataset.global_embedding_count
atom_embedding_count = dataset.atom_embedding_count
bond_ebedding_count = dataset.bond_ebedding_count

global_numeric_count = sample.u_numeric.size(1)
bond_numeric_count = sample.edge_attr_numeric.size(1)
atom_numeric_count = sample.x_numeric.size(1)

atom_input_size = [(atom_numeric_count, hidden)]
bond_input_size = [(bond_numeric_count, hidden)]
global_input_size = [(global_numeric_count, hidden)]

# Load

init_model()
init_dataset()
if action == 'train':
    init_experiment()

    print(f'train_indices count : {len(train_indices)}')
    print(f'valid_indices count : {len(valid_indices)}')

    # Optimizer

    OPTION = 'static'

    if OPTION == 'static':
        base_lr = 7.5e-5 
        lr_scheduler = LinearScheduler(0, 1000e3, base_lr, base_lr)
        momentum_scheduler = LinearScheduler(0, 1000e3, 0.9, 0.9)

    # Train 150 epoch
    train(150 * 4000)

    # Then drop lr by 2 each 3 epoch for about 15 epoch
    OPTION = 'droplr'

    if OPTION == 'droplr':
        base_step = 150 * 4000
        base_lr = 7.5e-5
        drop_factor = 2
        drop_after = 4000 * 3
    
        curent_step = base_step
        curent_lr = base_lr / drop_factor
    
        lr_schedulers = []
        for drop_i in range(10):
            scheduler = LinearScheduler(curent_step, curent_step + drop_after, curent_lr, curent_lr)
            lr_schedulers.append(scheduler)
        
            curent_step += drop_after
            curent_lr /= drop_factor
    
        lr_scheduler = MixedScheduler(lr_schedulers)
        momentum_scheduler = LinearScheduler(0, 1000e3, 0.9, 0.9)
    
    train(150 * 4000 + 15 * 4000)

    # Save model
    save()