In [7]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn
from torch_scatter import scatter
import tqdm
import numpy as np
import wandb

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import sys
sys.path = ['/home/chrisw/Documents/projects/2021/graph-transformer/src'] + sys.path

In [9]:
from torch_geometric.nn.models.schnet import GaussianSmearing
from dfs_transformer import EarlyStopping
import dfs_transformer
from dfs_transformer.nn.models.myschnet import MySchNet

In [10]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

In [12]:
wandb.init(project='QM9-GAT', entity='chrisxx')
config = wandb.config
config.hidden_dim = 128
config.nlayers = 3
config.nhead = 1
config.lr = 0.0003
config.n_epochs = 5000
config.patience = 5
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 256
config.valid_patience = 200
config.valid_minimal_improvement=0.005
config.n_train = 110000
config.n_valid = 10000
config.model_dir = '../models/qm9/MySchNet/noDFS/1/'
config.num_workers = 4
config.dfs_codes = None
config.use_pos = False
config.use_dist = True

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.11.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [13]:
dfs_codes = None
if config.dfs_codes is not None:
    import json
    with open(config.dfs_codes, 'r') as f:
        dfs_codes = json.load(f)

In [14]:
def transform(data, dfs_codes = dfs_codes, edge_transform = GaussianSmearing(0, 10, 50), use_dist=config.use_dist):
    features = data.x
    # make atomic number a one hot
    atomic_number = nn.functional.one_hot(features[:, 5].long(), 100)
    # make num_h a one hot
    num_h = nn.functional.one_hot(features[:, -1].long(), 9)
    data.x = torch.cat((features[:, :5], features[:, 6:-1], atomic_number, num_h), axis=1)
    if dfs_codes is not None:
        dfs_indices = nn.functional.one_hot(torch.LongTensor(dfs_codes[data.name]['dfs_indices']), 29)
        data.x = torch.cat((data.x, dfs_indices), axis=1)
    if use_dist:
        row, col = data.edge_index
        edge_weights = (data.pos[row] - data.pos[col]).norm(dim=-1)
        dist_feats = edge_transform(edge_weights)
        data.edge_attr = torch.cat((data.edge_attr, dist_feats), axis=1)
    return data

In [15]:
target_idx = config.target_idx

In [16]:
dataset = QM9('../datasets/qm9_geometric_work/', transform=transform)

In [17]:
dataset = dataset.shuffle()
train_dataset = dataset[:config.n_train]
valid_dataset = dataset[config.n_train:config.n_train+config.n_valid]
test_dataset = dataset[config.n_train+config.n_valid:]
config.n_test = len(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True, num_workers=config.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, pin_memory=True, num_workers=config.num_workers)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=config.num_workers)

In [18]:
import os
os.makedirs(config.model_dir, exist_ok=True)

In [19]:
torch.save(dataset.indices(), config.model_dir+'dataset_indices.pt')

In [20]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

In [27]:
def score(loader, model, device=device):
    model = model.to(device)
    pbar = tqdm.tqdm(enumerate(loader, 0))
    maes = []
    for i, data in pbar:
        data.to(device)
        prediction = model(data.z, data.pos, data.batch)
        mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
        maes += [mae.detach().cpu()]
    maes = torch.cat(maes, dim=0)
    mae = maes.mean().item()
    return mae/chemical_accuracy[config.target_idx]

# Model

In [22]:
target_vec = []

In [23]:
# based on https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_02_qm9.html
# and https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/schnet.html#SchNet
for data in train_loader:
    data = data.to(device)
    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
    target_vec += [(data.y[:, target_idx] - target_modular).detach().cpu().numpy()]
target_vec = np.concatenate(target_vec, axis=0)

In [24]:
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [25]:
data = next(iter(train_loader))
model = MySchNet(readout='attention', atomref=dataset.atomref(target_idx), mean=target_mean, std=target_std)
loss = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=config.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=config.patience, factor=config.factor)
early_stopping = EarlyStopping(patience=config.valid_patience, delta=config.valid_minimal_improvement,
                              path=config.model_dir+'checkpoint.pt')



# Training

In [26]:
model = model.to(device)

In [None]:
loss_hist = []
try:
    # For each epoch
    for epoch in range(config.n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data.to(device)
            target = data.y[:, target_idx]
            prediction = model(data.z, data.pos, data.batch)
            output = loss(prediction.view(-1), target)
            mae = (prediction.view(-1) - target).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('Epoch %d: MAE/CA %2.6f'%(epoch+1, epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
            wandb.log({'MSE': output.item()})
        
        valid_loss = score(valid_loader, model)
        curr_lr = list(optimizer.param_groups)[0]['lr']
        wandb.log({'MAE':epoch_loss, 
                   'MAE/CA':epoch_loss/chemical_accuracy[target_idx],
                   'learning rate':curr_lr,
                   'MAE/CA valid':valid_loss})
        
        lr_scheduler.step(epoch_loss)
        early_stopping(valid_loss, model)
        loss_hist += [epoch_loss] 
        
        if early_stopping.early_stop:
            break
        
        if curr_lr < config.minimal_lr:
            break


except KeyboardInterrupt:
    print('keyboard interrupt caught')
    torch.save(model.state_dict(), config.model_dir+'gat_epoch%d.pt'%(epoch+1))

Epoch 1: MAE/CA 29.934328: : 430it [00:58,  7.37it/s]
40it [00:02, 16.43it/s]
Epoch 2: MAE/CA 12.170259: : 430it [00:46,  9.19it/s]
40it [00:02, 18.86it/s]

EarlyStopping counter: 1 out of 200



Epoch 3: MAE/CA 9.703605: : 430it [00:45,  9.43it/s]
40it [00:02, 18.08it/s]

EarlyStopping counter: 2 out of 200



Epoch 4: MAE/CA 8.937248: : 430it [00:45,  9.39it/s]
40it [00:02, 18.56it/s]
Epoch 5: MAE/CA 8.775295: : 430it [00:46,  9.31it/s]
40it [00:02, 18.56it/s]

EarlyStopping counter: 1 out of 200



Epoch 6: MAE/CA 7.936426: : 430it [00:46,  9.25it/s]
40it [00:02, 17.85it/s]

EarlyStopping counter: 2 out of 200



Epoch 7: MAE/CA 7.614728: : 430it [00:46,  9.29it/s]
40it [00:02, 18.35it/s]
Epoch 8: MAE/CA 7.297351: : 430it [00:46,  9.20it/s]
40it [00:02, 17.64it/s]

EarlyStopping counter: 1 out of 200



Epoch 9: MAE/CA 7.596041: : 430it [00:46,  9.31it/s]
40it [00:02, 17.64it/s]

EarlyStopping counter: 2 out of 200



Epoch 10: MAE/CA 6.701267: : 430it [00:45,  9.36it/s]
40it [00:02, 17.32it/s]

EarlyStopping counter: 3 out of 200



Epoch 11: MAE/CA 6.489968: : 430it [00:45,  9.37it/s]
40it [00:02, 18.27it/s]

EarlyStopping counter: 4 out of 200



Epoch 12: MAE/CA 6.850215: : 430it [00:45,  9.40it/s]
40it [00:02, 17.76it/s]

EarlyStopping counter: 5 out of 200



Epoch 13: MAE/CA 6.011651: : 430it [00:45,  9.41it/s]
40it [00:02, 18.65it/s]
Epoch 14: MAE/CA 6.357209: : 430it [00:45,  9.39it/s]
40it [00:02, 16.89it/s]

EarlyStopping counter: 1 out of 200



Epoch 15: MAE/CA 5.816940: : 430it [00:45,  9.42it/s]
40it [00:02, 15.60it/s]

EarlyStopping counter: 2 out of 200



Epoch 16: MAE/CA 5.323445: : 430it [00:45,  9.47it/s]
40it [00:02, 16.90it/s]

EarlyStopping counter: 3 out of 200



Epoch 17: MAE/CA 6.088711: : 430it [00:45,  9.40it/s]
40it [00:02, 15.59it/s]

EarlyStopping counter: 4 out of 200



Epoch 18: MAE/CA 5.310138: : 430it [00:45,  9.39it/s]
40it [00:02, 17.22it/s]
Epoch 19: MAE/CA 5.422420: : 430it [00:45,  9.42it/s]
40it [00:02, 18.19it/s]

EarlyStopping counter: 1 out of 200



Epoch 20: MAE/CA 5.295166: : 430it [00:46,  9.35it/s]
40it [00:02, 18.11it/s]
Epoch 21: MAE/CA 5.127909: : 430it [00:45,  9.40it/s]
40it [00:02, 18.94it/s]


EarlyStopping counter: 1 out of 200


Epoch 22: MAE/CA 5.105217: : 430it [00:46,  9.31it/s]
40it [00:02, 17.43it/s]

EarlyStopping counter: 2 out of 200



Epoch 23: MAE/CA 5.035912: : 430it [00:45,  9.37it/s]
40it [00:02, 18.40it/s]

EarlyStopping counter: 3 out of 200



Epoch 24: MAE/CA 5.137117: : 430it [00:45,  9.55it/s]
40it [00:02, 17.49it/s]

EarlyStopping counter: 4 out of 200



Epoch 25: MAE/CA 4.786964: : 430it [00:45,  9.55it/s]
40it [00:02, 17.40it/s]

EarlyStopping counter: 5 out of 200



Epoch 26: MAE/CA 4.763390: : 430it [00:45,  9.52it/s]
40it [00:02, 17.20it/s]

EarlyStopping counter: 6 out of 200



Epoch 27: MAE/CA 4.626888: : 430it [00:44,  9.56it/s]
40it [00:02, 16.85it/s]

EarlyStopping counter: 7 out of 200



Epoch 28: MAE/CA 4.543068: : 430it [00:45,  9.53it/s]
40it [00:02, 17.46it/s]

EarlyStopping counter: 8 out of 200



Epoch 29: MAE/CA 4.282541: : 430it [00:44,  9.57it/s]
40it [00:02, 16.75it/s]
Epoch 30: MAE/CA 4.565819: : 430it [00:44,  9.57it/s]
40it [00:02, 15.21it/s]

EarlyStopping counter: 1 out of 200



Epoch 31: MAE/CA 4.114196: : 430it [00:45,  9.52it/s]
40it [00:02, 16.53it/s]

EarlyStopping counter: 2 out of 200



Epoch 32: MAE/CA 4.125172: : 430it [00:44,  9.59it/s]
40it [00:02, 15.49it/s]

EarlyStopping counter: 3 out of 200



Epoch 33: MAE/CA 4.066486: : 430it [00:44,  9.64it/s]
40it [00:02, 18.27it/s]
Epoch 34: MAE/CA 11.312056: : 430it [00:44,  9.58it/s]
40it [00:02, 17.95it/s]

EarlyStopping counter: 1 out of 200



Epoch 35: MAE/CA 9.876550: : 430it [00:45,  9.45it/s] 
40it [00:02, 19.02it/s]

EarlyStopping counter: 2 out of 200



Epoch 36: MAE/CA 4.477003: : 430it [00:45,  9.51it/s]
40it [00:02, 17.22it/s]

EarlyStopping counter: 3 out of 200



Epoch 37: MAE/CA 4.173122: : 430it [00:45,  9.47it/s]
40it [00:02, 17.40it/s]

EarlyStopping counter: 4 out of 200



Epoch 38: MAE/CA 3.901285: : 430it [00:45,  9.45it/s]
40it [00:01, 20.67it/s]

EarlyStopping counter: 5 out of 200



Epoch 39: MAE/CA 4.078614: : 430it [00:45,  9.53it/s]
40it [00:02, 18.12it/s]

EarlyStopping counter: 6 out of 200



Epoch 40: MAE/CA 4.155838: : 430it [00:45,  9.49it/s]
40it [00:01, 20.10it/s]

EarlyStopping counter: 7 out of 200



Epoch 41: MAE/CA 4.390588: : 430it [00:45,  9.47it/s]
40it [00:01, 21.52it/s]

EarlyStopping counter: 8 out of 200



Epoch 42: MAE/CA 4.288082: : 430it [00:44,  9.57it/s]
40it [00:01, 20.02it/s]

EarlyStopping counter: 9 out of 200



Epoch 43: MAE/CA 4.119885: : 430it [00:44,  9.60it/s]
40it [00:02, 19.71it/s]
Epoch 44: MAE/CA 4.624084: : 430it [00:44,  9.56it/s]
40it [00:02, 19.16it/s]

Epoch    44: reducing learning rate of group 0 to 2.8500e-04.
EarlyStopping counter: 1 out of 200



Epoch 45: MAE/CA 3.518423: : 430it [00:44,  9.57it/s]
40it [00:02, 17.71it/s]

EarlyStopping counter: 2 out of 200



Epoch 46: MAE/CA 4.327914: : 430it [00:45,  9.53it/s]
40it [00:02, 18.60it/s]

EarlyStopping counter: 3 out of 200



Epoch 47: MAE/CA 4.020851: : 430it [00:44,  9.56it/s]
40it [00:02, 18.38it/s]

EarlyStopping counter: 4 out of 200



Epoch 48: MAE/CA 4.078987: : 430it [00:44,  9.62it/s]
40it [00:02, 18.61it/s]

EarlyStopping counter: 5 out of 200



Epoch 49: MAE/CA 4.009842: : 430it [00:44,  9.70it/s]
40it [00:01, 20.29it/s]

EarlyStopping counter: 6 out of 200



Epoch 50: MAE/CA 3.855461: : 430it [00:44,  9.64it/s]
40it [00:02, 18.61it/s]

EarlyStopping counter: 7 out of 200



Epoch 51: MAE/CA 3.825460: : 430it [00:44,  9.64it/s]
40it [00:02, 18.20it/s]

Epoch    51: reducing learning rate of group 0 to 2.7075e-04.
EarlyStopping counter: 8 out of 200



Epoch 52: MAE/CA 3.689156: : 430it [00:44,  9.67it/s]
40it [00:02, 18.10it/s]
Epoch 53: MAE/CA 3.918847: : 430it [00:44,  9.60it/s]
40it [00:02, 18.86it/s]
Epoch 54: MAE/CA 3.766064: : 430it [00:44,  9.60it/s]
40it [00:02, 18.37it/s]
Epoch 55: MAE/CA 3.476376: : 430it [00:44,  9.59it/s]
40it [00:02, 19.38it/s]

EarlyStopping counter: 1 out of 200



Epoch 56: MAE/CA 3.978765: : 430it [00:44,  9.59it/s]
40it [00:02, 18.13it/s]

EarlyStopping counter: 2 out of 200



Epoch 57: MAE/CA 3.231224: : 430it [00:44,  9.60it/s]
40it [00:02, 18.08it/s]
Epoch 58: MAE/CA 3.541828: : 430it [00:44,  9.60it/s]
40it [00:02, 18.47it/s]

EarlyStopping counter: 1 out of 200



Epoch 59: MAE/CA 3.562086: : 430it [00:45,  9.55it/s]
40it [00:02, 18.46it/s]

EarlyStopping counter: 2 out of 200



Epoch 60: MAE/CA 3.482410: : 430it [00:45,  9.55it/s]
40it [00:02, 18.91it/s]

EarlyStopping counter: 3 out of 200



Epoch 61: MAE/CA 3.511618: : 430it [00:45,  9.52it/s]
40it [00:02, 18.90it/s]

EarlyStopping counter: 4 out of 200



Epoch 62: MAE/CA 3.510545: : 430it [00:44,  9.60it/s]
40it [00:02, 19.59it/s]

EarlyStopping counter: 5 out of 200



Epoch 63: MAE/CA 3.318553: : 430it [00:44,  9.57it/s]
40it [00:02, 18.70it/s]

Epoch    63: reducing learning rate of group 0 to 2.5721e-04.
EarlyStopping counter: 6 out of 200



Epoch 64: MAE/CA 3.194626: : 430it [00:44,  9.59it/s]
40it [00:02, 19.77it/s]

EarlyStopping counter: 7 out of 200



Epoch 65: MAE/CA 3.168909: : 430it [00:44,  9.60it/s]
40it [00:02, 17.28it/s]

EarlyStopping counter: 8 out of 200



Epoch 66: MAE/CA 3.159691: : 430it [00:44,  9.59it/s]
40it [00:02, 18.51it/s]

EarlyStopping counter: 9 out of 200



Epoch 67: MAE/CA 3.163400: : 430it [00:45,  9.46it/s]
40it [00:02, 18.55it/s]

EarlyStopping counter: 10 out of 200



Epoch 68: MAE/CA 3.178649: : 430it [00:44,  9.61it/s]
40it [00:02, 17.81it/s]
Epoch 69: MAE/CA 2.900480: : 430it [00:44,  9.57it/s]
40it [00:02, 18.65it/s]

EarlyStopping counter: 1 out of 200



Epoch 70: MAE/CA 3.241954: : 430it [00:45,  9.51it/s]
40it [00:02, 18.80it/s]

EarlyStopping counter: 2 out of 200



Epoch 71: MAE/CA 3.032497: : 430it [00:45,  9.51it/s]
40it [00:02, 18.95it/s]

EarlyStopping counter: 3 out of 200



Epoch 72: MAE/CA 3.169503: : 430it [00:44,  9.58it/s]
40it [00:02, 18.96it/s]

EarlyStopping counter: 4 out of 200



Epoch 73: MAE/CA 2.916542: : 430it [00:45,  9.52it/s]
40it [00:02, 18.75it/s]

EarlyStopping counter: 5 out of 200



Epoch 74: MAE/CA 2.824342: : 430it [00:44,  9.56it/s]
40it [00:02, 19.53it/s]

EarlyStopping counter: 6 out of 200



Epoch 75: MAE/CA 3.110747: : 430it [00:44,  9.59it/s]
40it [00:02, 18.77it/s]

EarlyStopping counter: 7 out of 200



Epoch 76: MAE/CA 3.096862: : 430it [00:45,  9.52it/s]
40it [00:02, 18.63it/s]

EarlyStopping counter: 8 out of 200



Epoch 77: MAE/CA 2.676674: : 430it [00:45,  9.53it/s]
40it [00:02, 19.29it/s]

EarlyStopping counter: 9 out of 200



Epoch 78: MAE/CA 2.846228: : 430it [00:44,  9.58it/s]
40it [00:02, 18.91it/s]

EarlyStopping counter: 10 out of 200



Epoch 79: MAE/CA 2.973900: : 430it [00:45,  9.53it/s]
40it [00:02, 19.07it/s]
Epoch 80: MAE/CA 2.702385: : 430it [00:45,  9.54it/s]
40it [00:02, 19.76it/s]


EarlyStopping counter: 1 out of 200


Epoch 81: MAE/CA 2.953188: : 430it [00:44,  9.56it/s]
40it [00:02, 19.78it/s]

EarlyStopping counter: 2 out of 200



Epoch 82: MAE/CA 2.715205: : 430it [00:45,  9.50it/s]
40it [00:01, 21.86it/s]

EarlyStopping counter: 3 out of 200



Epoch 83: MAE/CA 2.555898: : 430it [00:45,  9.47it/s]
40it [00:01, 21.61it/s]

EarlyStopping counter: 4 out of 200



Epoch 84: MAE/CA 2.704806: : 430it [00:45,  9.50it/s]
40it [00:01, 21.69it/s]

EarlyStopping counter: 5 out of 200



Epoch 85: MAE/CA 2.928643: : 430it [00:45,  9.53it/s]
40it [00:02, 18.52it/s]

EarlyStopping counter: 6 out of 200



Epoch 86: MAE/CA 2.680694: : 430it [00:45,  9.50it/s]
40it [00:01, 20.09it/s]

EarlyStopping counter: 7 out of 200



Epoch 87: MAE/CA 2.574741: : 430it [00:44,  9.58it/s]
40it [00:02, 18.75it/s]

EarlyStopping counter: 8 out of 200



Epoch 88: MAE/CA 2.554235: : 430it [00:44,  9.61it/s]
40it [00:02, 18.86it/s]

EarlyStopping counter: 9 out of 200



Epoch 89: MAE/CA 2.507664: : 430it [00:45,  9.51it/s]
40it [00:02, 18.30it/s]

EarlyStopping counter: 10 out of 200



Epoch 90: MAE/CA 6.710966: : 430it [00:44,  9.59it/s] 
40it [00:02, 18.76it/s]

EarlyStopping counter: 11 out of 200



Epoch 91: MAE/CA 2.943941: : 430it [00:44,  9.62it/s]
40it [00:02, 18.99it/s]

EarlyStopping counter: 12 out of 200



Epoch 92: MAE/CA 2.689398: : 430it [00:44,  9.62it/s]
40it [00:02, 19.23it/s]

EarlyStopping counter: 13 out of 200



Epoch 93: MAE/CA 2.888681: : 430it [00:44,  9.61it/s]
40it [00:02, 19.30it/s]

EarlyStopping counter: 14 out of 200



Epoch 94: MAE/CA 2.746367: : 430it [00:44,  9.64it/s]
40it [00:02, 19.11it/s]

EarlyStopping counter: 15 out of 200



Epoch 95: MAE/CA 2.772153: : 430it [00:44,  9.62it/s]
40it [00:02, 18.32it/s]

Epoch    95: reducing learning rate of group 0 to 2.4435e-04.
EarlyStopping counter: 16 out of 200



Epoch 96: MAE/CA 2.602784: : 430it [00:44,  9.63it/s]
40it [00:02, 19.34it/s]

EarlyStopping counter: 17 out of 200



Epoch 97: MAE/CA 2.772373: : 430it [00:44,  9.57it/s]
40it [00:01, 20.17it/s]

EarlyStopping counter: 18 out of 200



Epoch 98: MAE/CA 2.824919: : 430it [00:45,  9.55it/s]
40it [00:02, 18.96it/s]

EarlyStopping counter: 19 out of 200



Epoch 99: MAE/CA 2.690957: : 430it [00:44,  9.65it/s]
40it [00:02, 19.26it/s]

EarlyStopping counter: 20 out of 200



Epoch 100: MAE/CA 2.881671: : 430it [00:44,  9.58it/s]
40it [00:02, 18.31it/s]

EarlyStopping counter: 21 out of 200



Epoch 101: MAE/CA 2.669029: : 430it [00:44,  9.60it/s]
40it [00:02, 19.46it/s]

Epoch   101: reducing learning rate of group 0 to 2.3213e-04.
EarlyStopping counter: 22 out of 200



Epoch 102: MAE/CA 2.338719: : 430it [00:44,  9.56it/s]
40it [00:02, 18.87it/s]
Epoch 103: MAE/CA 2.393407: : 430it [00:45,  9.54it/s]
40it [00:02, 18.60it/s]

EarlyStopping counter: 1 out of 200



Epoch 104: MAE/CA 2.655083: : 430it [00:44,  9.59it/s]
40it [00:02, 19.13it/s]

EarlyStopping counter: 2 out of 200



Epoch 105: MAE/CA 2.715719: : 430it [00:45,  9.55it/s]
40it [00:02, 18.41it/s]
Epoch 106: MAE/CA 2.576238: : 430it [00:44,  9.61it/s]
40it [00:02, 18.80it/s]

EarlyStopping counter: 1 out of 200



Epoch 107: MAE/CA 2.350271: : 430it [00:44,  9.57it/s]
40it [00:02, 18.10it/s]

EarlyStopping counter: 2 out of 200



Epoch 108: MAE/CA 2.532448: : 430it [00:44,  9.60it/s]
40it [00:02, 17.90it/s]

Epoch   108: reducing learning rate of group 0 to 2.2053e-04.
EarlyStopping counter: 3 out of 200



Epoch 109: MAE/CA 2.169450: : 430it [00:45,  9.55it/s]
40it [00:02, 19.57it/s]

EarlyStopping counter: 4 out of 200



Epoch 110: MAE/CA 2.462011: : 430it [00:44,  9.58it/s]
40it [00:02, 18.38it/s]

EarlyStopping counter: 5 out of 200



Epoch 111: MAE/CA 2.393607: : 430it [00:44,  9.63it/s]
40it [00:02, 19.55it/s]

EarlyStopping counter: 6 out of 200



Epoch 112: MAE/CA 2.208981: : 430it [00:44,  9.65it/s]
40it [00:02, 19.87it/s]

EarlyStopping counter: 7 out of 200



Epoch 113: MAE/CA 2.482173: : 430it [00:44,  9.58it/s]
40it [00:02, 19.78it/s]

EarlyStopping counter: 8 out of 200



Epoch 114: MAE/CA 2.161651: : 430it [00:44,  9.60it/s]
40it [00:02, 19.68it/s]

EarlyStopping counter: 9 out of 200



Epoch 115: MAE/CA 2.440566: : 430it [00:44,  9.62it/s]
40it [00:02, 18.19it/s]

EarlyStopping counter: 10 out of 200



Epoch 116: MAE/CA 2.379232: : 430it [00:44,  9.60it/s]
40it [00:02, 17.08it/s]

EarlyStopping counter: 11 out of 200



Epoch 117: MAE/CA 2.190129: : 430it [00:45,  9.51it/s]
40it [00:02, 17.32it/s]

EarlyStopping counter: 12 out of 200



Epoch 118: MAE/CA 2.375621: : 430it [00:44,  9.59it/s]
40it [00:02, 19.60it/s]

EarlyStopping counter: 13 out of 200



Epoch 119: MAE/CA 2.338038: : 430it [00:44,  9.60it/s]
40it [00:02, 19.60it/s]

EarlyStopping counter: 14 out of 200



Epoch 120: MAE/CA 2.148869: : 430it [00:44,  9.60it/s]
40it [00:02, 17.48it/s]

EarlyStopping counter: 15 out of 200



Epoch 121: MAE/CA 2.336509: : 430it [00:45,  9.55it/s]
40it [00:02, 18.06it/s]

EarlyStopping counter: 16 out of 200



Epoch 122: MAE/CA 2.231743: : 430it [00:45,  9.53it/s]
40it [00:02, 17.86it/s]

EarlyStopping counter: 17 out of 200



Epoch 123: MAE/CA 2.303954: : 430it [00:44,  9.56it/s]
40it [00:02, 18.52it/s]
Epoch 124: MAE/CA 1.955291: : 430it [00:44,  9.61it/s]
40it [00:02, 18.54it/s]

EarlyStopping counter: 1 out of 200



Epoch 125: MAE/CA 2.338854: : 430it [00:45,  9.55it/s]
40it [00:02, 18.43it/s]

EarlyStopping counter: 2 out of 200



Epoch 126: MAE/CA 2.146120: : 430it [00:44,  9.59it/s]
40it [00:02, 15.86it/s]

EarlyStopping counter: 3 out of 200



Epoch 127: MAE/CA 2.250287: : 430it [00:45,  9.55it/s]
40it [00:02, 16.94it/s]

EarlyStopping counter: 4 out of 200



Epoch 128: MAE/CA 2.204921: : 430it [00:45,  9.54it/s]
40it [00:02, 17.88it/s]


EarlyStopping counter: 5 out of 200


Epoch 129: MAE/CA 2.053465: : 430it [00:44,  9.60it/s]
40it [00:02, 16.72it/s]

EarlyStopping counter: 6 out of 200



Epoch 130: MAE/CA 2.218784: : 430it [00:45,  9.53it/s]
40it [00:02, 17.49it/s]

Epoch   130: reducing learning rate of group 0 to 2.0950e-04.
EarlyStopping counter: 7 out of 200



Epoch 131: MAE/CA 1.891236: : 430it [00:44,  9.56it/s]
40it [00:02, 17.16it/s]

EarlyStopping counter: 8 out of 200



Epoch 132: MAE/CA 2.388252: : 430it [00:45,  9.52it/s]
40it [00:02, 16.82it/s]

EarlyStopping counter: 9 out of 200



Epoch 133: MAE/CA 1.935534: : 430it [00:45,  9.55it/s]
40it [00:02, 17.39it/s]

EarlyStopping counter: 10 out of 200



Epoch 134: MAE/CA 2.018490: : 430it [00:44,  9.58it/s]
40it [00:02, 18.30it/s]

EarlyStopping counter: 11 out of 200



Epoch 135: MAE/CA 2.093133: : 430it [00:44,  9.57it/s]
40it [00:02, 16.30it/s]

EarlyStopping counter: 12 out of 200



Epoch 136: MAE/CA 1.995772: : 430it [00:44,  9.58it/s]
40it [00:02, 17.15it/s]

EarlyStopping counter: 13 out of 200



Epoch 137: MAE/CA 2.065522: : 430it [00:45,  9.56it/s]
40it [00:02, 17.12it/s]

Epoch   137: reducing learning rate of group 0 to 1.9903e-04.



Epoch 138: MAE/CA 1.723058: : 430it [00:44,  9.60it/s]
40it [00:02, 16.29it/s]
Epoch 139: MAE/CA 1.935906: : 430it [00:44,  9.57it/s]
40it [00:02, 17.24it/s]

EarlyStopping counter: 1 out of 200



Epoch 140: MAE/CA 1.945352: : 430it [00:44,  9.62it/s]
40it [00:02, 17.75it/s]

EarlyStopping counter: 2 out of 200



Epoch 141: MAE/CA 1.879144: : 430it [00:44,  9.57it/s]
40it [00:02, 17.63it/s]

EarlyStopping counter: 3 out of 200



Epoch 142: MAE/CA 1.901473: : 430it [00:44,  9.59it/s]
40it [00:02, 16.18it/s]

EarlyStopping counter: 4 out of 200



Epoch 143: MAE/CA 1.993889: : 430it [00:44,  9.59it/s]
40it [00:02, 17.26it/s]

EarlyStopping counter: 5 out of 200



Epoch 144: MAE/CA 1.813730: : 430it [00:44,  9.61it/s]
40it [00:02, 16.86it/s]

Epoch   144: reducing learning rate of group 0 to 1.8907e-04.
EarlyStopping counter: 6 out of 200



Epoch 145: MAE/CA 1.799694: : 430it [00:44,  9.57it/s]
40it [00:02, 18.34it/s]

EarlyStopping counter: 7 out of 200



Epoch 146: MAE/CA 1.840877: : 430it [00:45,  9.53it/s]
40it [00:02, 17.93it/s]

EarlyStopping counter: 8 out of 200



Epoch 147: MAE/CA 2.007098: : 430it [00:45,  9.51it/s]
40it [00:02, 16.86it/s]

EarlyStopping counter: 9 out of 200



Epoch 148: MAE/CA 1.806773: : 430it [00:45,  9.46it/s]
40it [00:02, 17.22it/s]

EarlyStopping counter: 10 out of 200



Epoch 149: MAE/CA 1.840746: : 430it [00:45,  9.48it/s]
40it [00:02, 18.28it/s]

EarlyStopping counter: 11 out of 200



Epoch 150: MAE/CA 1.895064: : 430it [00:45,  9.52it/s]
40it [00:02, 17.52it/s]

Epoch   150: reducing learning rate of group 0 to 1.7962e-04.
EarlyStopping counter: 12 out of 200



Epoch 151: MAE/CA 1.539122: : 430it [00:45,  9.46it/s]
40it [00:02, 18.22it/s]

EarlyStopping counter: 13 out of 200



Epoch 152: MAE/CA 1.679812: : 430it [00:45,  9.51it/s]
40it [00:02, 17.80it/s]

EarlyStopping counter: 14 out of 200



Epoch 153: MAE/CA 1.789134: : 430it [00:45,  9.50it/s]
40it [00:02, 19.60it/s]

EarlyStopping counter: 15 out of 200



Epoch 154: MAE/CA 1.698622: : 430it [00:45,  9.47it/s]
40it [00:02, 18.96it/s]

EarlyStopping counter: 16 out of 200



Epoch 155: MAE/CA 1.597455: : 430it [00:45,  9.40it/s]
40it [00:02, 17.24it/s]

EarlyStopping counter: 17 out of 200



Epoch 156: MAE/CA 1.799557: : 430it [00:45,  9.44it/s]
40it [00:02, 17.75it/s]

EarlyStopping counter: 18 out of 200



Epoch 157: MAE/CA 1.642445: : 430it [00:45,  9.49it/s]
40it [00:02, 18.64it/s]

Epoch   157: reducing learning rate of group 0 to 1.7064e-04.



Epoch 158: MAE/CA 1.540327: : 430it [00:45,  9.49it/s]
40it [00:02, 19.06it/s]

EarlyStopping counter: 1 out of 200



Epoch 159: MAE/CA 1.735563: : 430it [00:45,  9.50it/s]
40it [00:02, 18.11it/s]

EarlyStopping counter: 2 out of 200



Epoch 160: MAE/CA 1.711694: : 430it [00:45,  9.43it/s]
40it [00:02, 19.40it/s]

EarlyStopping counter: 3 out of 200



Epoch 161: MAE/CA 1.457827: : 430it [00:45,  9.39it/s]
40it [00:02, 18.65it/s]

EarlyStopping counter: 4 out of 200



Epoch 162: MAE/CA 1.568447: : 430it [00:45,  9.51it/s]
40it [00:02, 16.56it/s]

EarlyStopping counter: 5 out of 200



Epoch 163: MAE/CA 1.722109: : 430it [00:45,  9.48it/s]
40it [00:02, 18.00it/s]

EarlyStopping counter: 6 out of 200



Epoch 164: MAE/CA 1.681761: : 430it [00:45,  9.50it/s]
40it [00:02, 17.68it/s]

EarlyStopping counter: 7 out of 200



Epoch 165: MAE/CA 1.562074: : 430it [00:45,  9.49it/s]
40it [00:02, 17.81it/s]

EarlyStopping counter: 8 out of 200



Epoch 166: MAE/CA 1.672449: : 430it [00:45,  9.50it/s]
40it [00:02, 17.38it/s]

EarlyStopping counter: 9 out of 200



Epoch 167: MAE/CA 1.503471: : 430it [00:45,  9.52it/s]
40it [00:02, 16.48it/s]

Epoch   167: reducing learning rate of group 0 to 1.6211e-04.
EarlyStopping counter: 10 out of 200



Epoch 168: MAE/CA 1.594804: : 430it [00:45,  9.51it/s]
40it [00:02, 17.48it/s]

EarlyStopping counter: 11 out of 200



Epoch 169: MAE/CA 1.372677: : 430it [00:45,  9.50it/s]
40it [00:02, 16.74it/s]

EarlyStopping counter: 12 out of 200



Epoch 170: MAE/CA 1.605789: : 430it [00:44,  9.56it/s]
40it [00:02, 17.98it/s]

EarlyStopping counter: 13 out of 200



Epoch 171: MAE/CA 1.609718: : 430it [00:44,  9.66it/s]
40it [00:02, 18.46it/s]

EarlyStopping counter: 14 out of 200



Epoch 172: MAE/CA 1.508801: : 430it [00:45,  9.54it/s]
40it [00:02, 18.95it/s]

EarlyStopping counter: 15 out of 200



Epoch 173: MAE/CA 1.571193: : 430it [00:44,  9.59it/s]
40it [00:02, 18.78it/s]

EarlyStopping counter: 16 out of 200



Epoch 174: MAE/CA 1.388881: : 430it [00:44,  9.56it/s]
40it [00:02, 19.67it/s]
Epoch 175: MAE/CA 1.659428: : 430it [00:44,  9.57it/s]
40it [00:02, 18.26it/s]

Epoch   175: reducing learning rate of group 0 to 1.5400e-04.
EarlyStopping counter: 1 out of 200



Epoch 176: MAE/CA 1.301693: : 430it [00:45,  9.52it/s]
40it [00:02, 19.33it/s]

EarlyStopping counter: 2 out of 200



Epoch 177: MAE/CA 1.505279: : 430it [00:45,  9.52it/s]
40it [00:01, 20.11it/s]

EarlyStopping counter: 3 out of 200



Epoch 178: MAE/CA 1.425706: : 430it [00:45,  9.55it/s]
40it [00:01, 20.69it/s]

EarlyStopping counter: 4 out of 200



Epoch 179: MAE/CA 1.472456: : 430it [00:44,  9.59it/s]
40it [00:02, 17.96it/s]

EarlyStopping counter: 5 out of 200



Epoch 180: MAE/CA 1.434314: : 430it [00:44,  9.59it/s]
40it [00:02, 18.60it/s]

EarlyStopping counter: 6 out of 200



Epoch 181: MAE/CA 1.471472: : 430it [00:44,  9.58it/s]
40it [00:02, 19.87it/s]
Epoch 182: MAE/CA 1.455999: : 430it [00:44,  9.59it/s]
40it [00:02, 18.96it/s]

Epoch   182: reducing learning rate of group 0 to 1.4630e-04.
EarlyStopping counter: 1 out of 200



Epoch 183: MAE/CA 1.248142: : 430it [00:44,  9.64it/s]
40it [00:02, 19.07it/s]
Epoch 184: MAE/CA 1.366376: : 430it [00:44,  9.61it/s]
40it [00:02, 18.91it/s]

EarlyStopping counter: 1 out of 200



Epoch 185: MAE/CA 1.405267: : 430it [00:45,  9.54it/s]
40it [00:02, 18.44it/s]

EarlyStopping counter: 2 out of 200



Epoch 186: MAE/CA 1.334226: : 430it [00:44,  9.57it/s]
40it [00:02, 18.96it/s]

EarlyStopping counter: 3 out of 200



Epoch 187: MAE/CA 1.409210: : 430it [00:44,  9.58it/s]
40it [00:02, 19.12it/s]

EarlyStopping counter: 4 out of 200



Epoch 188: MAE/CA 1.473925: : 430it [00:44,  9.63it/s]
40it [00:02, 18.05it/s]

EarlyStopping counter: 5 out of 200



Epoch 189: MAE/CA 1.339227: : 430it [00:44,  9.61it/s]
40it [00:02, 19.49it/s]

Epoch   189: reducing learning rate of group 0 to 1.3899e-04.
EarlyStopping counter: 6 out of 200



Epoch 190: MAE/CA 1.409608: : 430it [00:44,  9.61it/s]
40it [00:02, 18.82it/s]

EarlyStopping counter: 7 out of 200



Epoch 191: MAE/CA 1.283603: : 430it [00:45,  9.51it/s]
40it [00:02, 19.81it/s]

EarlyStopping counter: 8 out of 200



Epoch 192: MAE/CA 1.208039: : 430it [00:45,  9.51it/s]
40it [00:02, 19.96it/s]

EarlyStopping counter: 9 out of 200



Epoch 193: MAE/CA 1.397278: : 430it [00:44,  9.58it/s]
40it [00:01, 20.11it/s]

EarlyStopping counter: 10 out of 200



Epoch 194: MAE/CA 1.293599: : 430it [00:45,  9.52it/s]
40it [00:02, 19.14it/s]

EarlyStopping counter: 11 out of 200



Epoch 195: MAE/CA 1.285786: : 430it [00:46,  9.28it/s]
40it [00:02, 19.50it/s]

EarlyStopping counter: 12 out of 200



Epoch 196: MAE/CA 1.294328: : 430it [00:46,  9.28it/s]
40it [00:02, 18.79it/s]

EarlyStopping counter: 13 out of 200



Epoch 197: MAE/CA 1.272610: : 430it [00:46,  9.32it/s]
40it [00:02, 18.69it/s]

EarlyStopping counter: 14 out of 200



Epoch 198: MAE/CA 1.355452: : 430it [00:45,  9.38it/s]
40it [00:02, 19.26it/s]

Epoch   198: reducing learning rate of group 0 to 1.3204e-04.
EarlyStopping counter: 15 out of 200



Epoch 199: MAE/CA 1.133347: : 430it [00:45,  9.37it/s]
40it [00:02, 19.10it/s]
Epoch 200: MAE/CA 1.394041: : 430it [00:45,  9.42it/s]
40it [00:02, 19.93it/s]

EarlyStopping counter: 1 out of 200



Epoch 201: MAE/CA 1.256262: : 430it [00:46,  9.32it/s]
40it [00:01, 21.02it/s]

EarlyStopping counter: 2 out of 200



Epoch 202: MAE/CA 1.393874: : 430it [00:46,  9.32it/s]
40it [00:02, 19.53it/s]

EarlyStopping counter: 3 out of 200



Epoch 203: MAE/CA 1.126729: : 430it [00:45,  9.36it/s]
40it [00:02, 18.79it/s]

EarlyStopping counter: 4 out of 200



Epoch 204: MAE/CA 1.399860: : 430it [00:46,  9.35it/s]
40it [00:01, 20.98it/s]

EarlyStopping counter: 5 out of 200



Epoch 205: MAE/CA 1.161317: : 430it [00:46,  9.33it/s]
40it [00:01, 20.36it/s]

EarlyStopping counter: 6 out of 200



Epoch 206: MAE/CA 1.210453: : 430it [00:46,  9.34it/s]
40it [00:02, 19.25it/s]

EarlyStopping counter: 7 out of 200



Epoch 207: MAE/CA 1.240212: : 430it [00:46,  9.29it/s]
40it [00:02, 19.87it/s]
Epoch 208: MAE/CA 1.356903: : 430it [00:45,  9.43it/s]
40it [00:02, 19.47it/s]

EarlyStopping counter: 1 out of 200



Epoch 209: MAE/CA 1.180159: : 430it [00:45,  9.40it/s]
40it [00:02, 19.04it/s]

Epoch   209: reducing learning rate of group 0 to 1.2544e-04.
EarlyStopping counter: 2 out of 200



Epoch 210: MAE/CA 1.058255: : 430it [00:45,  9.40it/s]
40it [00:02, 19.69it/s]

EarlyStopping counter: 3 out of 200



Epoch 211: MAE/CA 1.141737: : 430it [00:45,  9.53it/s]
40it [00:02, 18.86it/s]

EarlyStopping counter: 4 out of 200



Epoch 212: MAE/CA 1.207707: : 430it [00:45,  9.51it/s]
40it [00:02, 19.50it/s]

EarlyStopping counter: 5 out of 200



Epoch 213: MAE/CA 1.166183: : 430it [00:45,  9.53it/s]
40it [00:02, 19.28it/s]


EarlyStopping counter: 6 out of 200


Epoch 214: MAE/CA 1.257228: : 430it [00:44,  9.63it/s]
40it [00:02, 18.04it/s]

EarlyStopping counter: 7 out of 200



Epoch 215: MAE/CA 1.125448: : 430it [00:44,  9.58it/s]
40it [00:02, 19.38it/s]

EarlyStopping counter: 8 out of 200



Epoch 216: MAE/CA 1.151657: : 430it [00:44,  9.57it/s]
40it [00:02, 18.13it/s]

Epoch   216: reducing learning rate of group 0 to 1.1916e-04.



Epoch 217: MAE/CA 1.087619: : 430it [00:45,  9.42it/s]
40it [00:01, 22.26it/s]

EarlyStopping counter: 1 out of 200



Epoch 218: MAE/CA 1.103813: : 430it [00:38, 11.22it/s]
40it [00:01, 23.18it/s]

EarlyStopping counter: 2 out of 200



Epoch 219: MAE/CA 1.185308: : 430it [00:38, 11.30it/s]
40it [00:01, 23.99it/s]

EarlyStopping counter: 3 out of 200



Epoch 220: MAE/CA 1.255666: : 430it [00:37, 11.34it/s]
40it [00:01, 23.60it/s]
Epoch 221: MAE/CA 1.064632: : 430it [00:37, 11.32it/s]
40it [00:01, 24.18it/s]

EarlyStopping counter: 1 out of 200



Epoch 222: MAE/CA 1.133350: : 430it [00:38, 11.22it/s]
40it [00:01, 23.16it/s]

Epoch   222: reducing learning rate of group 0 to 1.1321e-04.
EarlyStopping counter: 2 out of 200



Epoch 223: MAE/CA 0.949696: : 430it [00:38, 11.27it/s]
40it [00:01, 24.61it/s]

EarlyStopping counter: 3 out of 200



Epoch 224: MAE/CA 1.098083: : 430it [00:38, 11.21it/s]
40it [00:01, 24.73it/s]

EarlyStopping counter: 4 out of 200



Epoch 225: MAE/CA 1.091111: : 430it [00:38, 11.22it/s]
40it [00:01, 23.55it/s]

EarlyStopping counter: 5 out of 200



Epoch 226: MAE/CA 1.159926: : 430it [00:39, 10.99it/s]
40it [00:01, 25.66it/s]

EarlyStopping counter: 6 out of 200



Epoch 227: MAE/CA 1.134318: : 430it [00:38, 11.21it/s]
40it [00:01, 23.57it/s]

EarlyStopping counter: 7 out of 200



Epoch 228: MAE/CA 0.976928: : 430it [00:38, 11.26it/s]
40it [00:01, 24.06it/s]

EarlyStopping counter: 8 out of 200



Epoch 229: MAE/CA 1.105172: : 430it [00:38, 11.31it/s]
40it [00:01, 23.61it/s]

Epoch   229: reducing learning rate of group 0 to 1.0755e-04.
EarlyStopping counter: 9 out of 200



Epoch 230: MAE/CA 0.990061: : 430it [00:38, 11.30it/s]
40it [00:01, 24.44it/s]

EarlyStopping counter: 10 out of 200



Epoch 231: MAE/CA 1.076019: : 430it [00:38, 11.24it/s]
40it [00:02, 18.87it/s]

EarlyStopping counter: 11 out of 200



Epoch 232: MAE/CA 1.155375: : 353it [00:32, 11.45it/s]

In [None]:
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data.to(device)
    prediction = model(data.z, data.pos, data.batch)
    mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
    maes += [mae.detach().cpu()]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
wandb.log({'TEST MAE':mae, 'TEST MAE/CA':mae/chemical_accuracy[target_idx]})
