In [1]:
import contextlib
import os
import sys
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from model.point_dataset import SCMDataset
from model.point_model import PointNetOneX
# import EarlyStopping
from model.pytorchtools import EarlyStopping
from tqdm import tqdm
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import numpy as np

In [2]:
def build_optimizer(network, optimizer, learning_rate):
    if optimizer == "sgd":
        optimizer = optim.SGD(network.parameters(),
                              lr=learning_rate, momentum=0.9)
    elif optimizer == "adam":
        optimizer = optim.Adam(network.parameters(),
                               lr=learning_rate)# weight_decay=0.001)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    return optimizer, scheduler

In [3]:
def load_data(data_dir, concat): 
    train_set = SCMDataset(
        root=data_dir,
        concat = concat)
    
    val_set = SCMDataset(
        root=data_dir,
        split = 'val',
        concat = concat)

    test_set = SCMDataset(
        root=data_dir,
        split = 'test',
        concat = concat)
    
    return train_set, val_set, test_set 

In [4]:
def train(train_loader, model, optimizer, device, epoch):
    """ Train the model on num_steps batches 
    Args: 
        train_loader: a torch.utils.data.DataLoader object that fetches the data
        model: the neural network 
        optimizer: adams 
    """
    model.train()
    running_loss = 0.0
    num_batch = len(train_loader)

    for i, (points, target, _) in enumerate(train_loader): 
        points = points.transpose(2, 1) 
        points, target = points.to(device), target.to(device)

        # forward + backward + optimize 
        pred = model(points)
        loss = F.mse_loss(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print statistics 
        running_loss += loss.item()
        # if i % 300 == 0: 
        #     print('[%d: %d/%d] train loss: %f lr = %f' % (epoch, i, num_batch, loss.item(), optimizer.param_groups[0]["lr"]))

    return running_loss / num_batch 

In [5]:
def validate(val_dataloader, model, device): 
    model.eval()
    val_running_loss = 0.0
    num_batch = len(val_dataloader)

    with torch.no_grad(): 
        for points, target, _ in val_dataloader:
            points = points.transpose(2, 1)
            points, target = points.to(device), target.to(device)
            # points, target = points.to(device), target.to(device)

            # outputs  = model(points, sys_scalar)
            outputs  = model(points)
            loss = F.mse_loss(outputs, target)

            val_running_loss += loss.item() 

    return val_running_loss / num_batch

In [6]:
def plot_pramas(test_y, test_pred, foldername, filename): 

    # print("R2 of training is: ", r2_score(train_y, train_pred))
    print("R2 of test is: ", r2_score(test_y, test_pred))

    np.savetxt(f'{foldername}/test_predict_{filename}.txt', test_pred)
    np.savetxt(f'{foldername}/test_target_{filename}.txt', test_y)

    test_mse = mean_squared_error(test_y, test_pred)
    test_mae = mean_absolute_error(test_y, test_pred)

    print('Test set results for %i samples:' % test_pred.shape[0])
    print('MSE:', test_mse)
    print('MAE:', test_mae)

In [7]:
def test_accuracy(net, testloader, foldername, filename, device): 
    test_pred = []
    test_y = [] 

    running_loss = 0
    with torch.no_grad():
        for data in testloader:
            points, target, _ = data
            points = points.transpose(2, 1) 
            points, target = points.to(device), target.to(device)
            outputs = net(points)
            loss = F.mse_loss(outputs, target)

            running_loss += loss.item() * points.size(0)

            pred_val_numpy = outputs.data.cpu().numpy()
            target_val_numpy = target.data.cpu().numpy()
            test_pred.append(pred_val_numpy)
            test_y.append(target_val_numpy)

    test_pred = np.concatenate(test_pred, axis=0)
    test_y = np.concatenate(test_y, axis=0)

    plot_pramas(test_y, test_pred, foldername, filename)
    
    print('MSE loss on test set is:', running_loss / len(testloader.dataset))    

In [8]:
def train_model(model, train_loader, val_loader, test_loader, optimizer, lr_scheduler, isSch, device, name, res_dir, patience = 20, n_epochs = 50): 
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 

    blue = lambda x: '\033[94m' + x + '\033[0m'
    checkpoint_dir = os.path.join(res_dir, 'checkpoints')

    with contextlib.suppress(OSError):
        os.makedirs(res_dir)
        os.makedirs(checkpoint_dir)

    checkpoint_path = os.path.join(checkpoint_dir, f'{name}.pt')
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience, verbose=True, path = checkpoint_path)

    for epoch in tqdm(range(1, n_epochs + 1)):
        ###################
        # train the model #
        ###################
        train_epoch_loss = train(train_loader, model, optimizer, device, epoch)
        val_epoch_loss = validate(val_loader, model, device)

        if isSch:
            lr_scheduler.step(val_epoch_loss) # no step lr fixed LR = 0.001 

        # # print loss every epoch 
        # print('[%d] train loss: %f ' % (epoch, train_epoch_loss))
        # print('[%d] %s loss: %f' % (epoch, blue('validate'), val_epoch_loss))

        avg_train_losses.append(train_epoch_loss)
        avg_valid_losses.append(val_epoch_loss)

        # add early stopping 
        early_stopping(val_epoch_loss, model)
        if early_stopping.early_stop: 
            print("Early stopping")
            break 

    # load the last checkpoint with the best model
    model.load_state_dict(torch.load(checkpoint_path))

    np.savetxt(os.path.join(res_dir, f'train_loss_{name}.csv'), avg_train_losses)
    np.savetxt(os.path.join(res_dir, f'val_loss_{name}.csv'), avg_valid_losses)

    test_accuracy(model, test_loader, res_dir, name, device)

In [9]:
def train_main(config): 
    data_dir = 'datasets/train'

    # get dataset 
    train_set, test_set, val_set = load_data(data_dir, config['addGlobal'])

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=0)

    val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=int(config["batch_size"]),
            shuffle=True,
            num_workers=0)
    
    test_loader = torch.utils.data.DataLoader(
            test_set, 
            batch_size=int(config["batch_size"]), 
            shuffle=False, num_workers=0)
    
    print("Creating model")
    model = PointNetOneX(config["l1"], config["l2"], config["l3"], config["l4"], config["l5"], 
                  config["dropout"], config["isBN"], config['isLN'],config['constraint'], config['pooling'])
    print(model)

    # device = torch.device('cuda')
    # if torch.cuda.device_count() > 1:
    #     model = nn.DataParallel(model)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer, lr_scheduler = build_optimizer(model, config['optimizer'], config['lr'])
    res_dir = 'complexDNN/res/'
    
    name = f"{config['model']}{config['l1']}{config['l2']}{config['l3']}{config['l4']}{config['l5']}lr{config['lr']}BS{config['batch_size']}isB{config['isBN']}ln{config['isLN']}cons{config['constraint']}Opt{config['optimizer']}sch{config['lr_scheduler']}pool{config['pooling']}addG{config['addGlobal']}"
    train_model(model, train_loader, val_loader, test_loader, optimizer, lr_scheduler, config['lr_scheduler'], device, name, res_dir, 40, 1000)

In [10]:
config = {"l1": 512,"l2": 512, "l3": 512,"l4": 256,"l5": 128,"dropout" : 0, "lr": 0.01,  "batch_size": 128, "model" : "PointNetOneX", "isBN": False,
        "isLN": True, "lr_scheduler": True, "addGlobal" : False, "pooling" : 'min', "constraint" : False,"optimizer": 'adam'}

In [11]:
import time 
start_time = time.perf_counter()
train_main(config) 
end_time = time.perf_counter() 
print('time used to train model with 40/1000 patience is: ', (end_time - start_time)/60, 'mins')

Creating model
PointNetOneX(
  (conv1): Conv1d(6, 512, kernel_size=(1,), stride=(1,))
  (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
  (conv3): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (ln1): LayerNorm((512, 56), eps=1e-05, elementwise_affine=True)
  (ln2): LayerNorm((512, 56), eps=1e-05, elementwise_affine=True)
  (ln3): LayerNorm((512, 56), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=5, bias=True)
  (dropout): Dropout(p=0, inplace=False)
  (bn4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn5): Batc

  0%|▏                                                                                                                                 | 1/1000 [00:05<1:29:20,  5.37s/it]

Validation loss decreased (inf --> 0.030042).  Saving model ...


  0%|▎                                                                                                                                 | 2/1000 [00:09<1:19:09,  4.76s/it]

Validation loss decreased (0.030042 --> 0.017884).  Saving model ...


  0%|▍                                                                                                                                 | 3/1000 [00:13<1:15:28,  4.54s/it]

Validation loss decreased (0.017884 --> 0.014469).  Saving model ...


  0%|▌                                                                                                                                 | 4/1000 [00:18<1:13:36,  4.43s/it]

Validation loss decreased (0.014469 --> 0.012662).  Saving model ...


  0%|▋                                                                                                                                 | 5/1000 [00:22<1:13:00,  4.40s/it]

Validation loss decreased (0.012662 --> 0.010050).  Saving model ...


  1%|▊                                                                                                                                 | 6/1000 [00:26<1:12:01,  4.35s/it]

Validation loss decreased (0.010050 --> 0.009627).  Saving model ...


  1%|▉                                                                                                                                 | 7/1000 [00:31<1:12:03,  4.35s/it]

Validation loss decreased (0.009627 --> 0.008740).  Saving model ...


  1%|█                                                                                                                                 | 8/1000 [00:35<1:11:52,  4.35s/it]

EarlyStopping counter: 1 out of 40


  1%|█▏                                                                                                                                | 9/1000 [00:39<1:11:52,  4.35s/it]

Validation loss decreased (0.008740 --> 0.008287).  Saving model ...


  1%|█▎                                                                                                                               | 10/1000 [00:44<1:12:03,  4.37s/it]

EarlyStopping counter: 1 out of 40


  1%|█▍                                                                                                                               | 11/1000 [00:48<1:12:02,  4.37s/it]

EarlyStopping counter: 2 out of 40


  1%|█▌                                                                                                                               | 12/1000 [00:53<1:11:55,  4.37s/it]

EarlyStopping counter: 3 out of 40


  1%|█▋                                                                                                                               | 13/1000 [00:57<1:11:42,  4.36s/it]

EarlyStopping counter: 4 out of 40


  1%|█▊                                                                                                                               | 14/1000 [01:01<1:11:46,  4.37s/it]

Validation loss decreased (0.008287 --> 0.008090).  Saving model ...


  2%|█▉                                                                                                                               | 15/1000 [01:06<1:11:16,  4.34s/it]

EarlyStopping counter: 1 out of 40


  2%|██                                                                                                                               | 16/1000 [01:10<1:10:55,  4.32s/it]

Validation loss decreased (0.008090 --> 0.007943).  Saving model ...


  2%|██▏                                                                                                                              | 17/1000 [01:14<1:11:02,  4.34s/it]

Validation loss decreased (0.007943 --> 0.007715).  Saving model ...


  2%|██▎                                                                                                                              | 18/1000 [01:18<1:10:37,  4.31s/it]

Validation loss decreased (0.007715 --> 0.007653).  Saving model ...


  2%|██▍                                                                                                                              | 19/1000 [01:23<1:10:26,  4.31s/it]

EarlyStopping counter: 1 out of 40


  2%|██▌                                                                                                                              | 20/1000 [01:27<1:10:46,  4.33s/it]

Validation loss decreased (0.007653 --> 0.007593).  Saving model ...


  2%|██▋                                                                                                                              | 21/1000 [01:32<1:11:03,  4.35s/it]

EarlyStopping counter: 1 out of 40


  2%|██▊                                                                                                                              | 22/1000 [01:36<1:11:17,  4.37s/it]

Validation loss decreased (0.007593 --> 0.007416).  Saving model ...


  2%|██▉                                                                                                                              | 23/1000 [01:40<1:11:09,  4.37s/it]

Validation loss decreased (0.007416 --> 0.007302).  Saving model ...


  2%|███                                                                                                                              | 24/1000 [01:45<1:11:33,  4.40s/it]

Validation loss decreased (0.007302 --> 0.007186).  Saving model ...


  2%|███▏                                                                                                                             | 25/1000 [01:49<1:11:55,  4.43s/it]

EarlyStopping counter: 1 out of 40


  3%|███▎                                                                                                                             | 26/1000 [01:54<1:11:58,  4.43s/it]

EarlyStopping counter: 2 out of 40


  3%|███▍                                                                                                                             | 27/1000 [01:58<1:11:40,  4.42s/it]

EarlyStopping counter: 3 out of 40


  3%|███▌                                                                                                                             | 28/1000 [02:03<1:11:30,  4.41s/it]

EarlyStopping counter: 4 out of 40


  3%|███▋                                                                                                                             | 29/1000 [02:07<1:11:21,  4.41s/it]

EarlyStopping counter: 5 out of 40


  3%|███▊                                                                                                                             | 30/1000 [02:11<1:11:14,  4.41s/it]

Validation loss decreased (0.007186 --> 0.006907).  Saving model ...


  3%|███▉                                                                                                                             | 31/1000 [02:16<1:11:13,  4.41s/it]

EarlyStopping counter: 1 out of 40


  3%|████▏                                                                                                                            | 32/1000 [02:20<1:11:39,  4.44s/it]

Validation loss decreased (0.006907 --> 0.006880).  Saving model ...


  3%|████▎                                                                                                                            | 33/1000 [02:25<1:11:52,  4.46s/it]

Validation loss decreased (0.006880 --> 0.006646).  Saving model ...


  3%|████▍                                                                                                                            | 34/1000 [02:29<1:11:36,  4.45s/it]

Validation loss decreased (0.006646 --> 0.006431).  Saving model ...


  4%|████▌                                                                                                                            | 35/1000 [02:34<1:11:05,  4.42s/it]

EarlyStopping counter: 1 out of 40


  4%|████▋                                                                                                                            | 36/1000 [02:38<1:11:07,  4.43s/it]

EarlyStopping counter: 2 out of 40


  4%|████▊                                                                                                                            | 37/1000 [02:42<1:11:11,  4.44s/it]

EarlyStopping counter: 3 out of 40


  4%|████▉                                                                                                                            | 38/1000 [02:47<1:11:07,  4.44s/it]

Validation loss decreased (0.006431 --> 0.006108).  Saving model ...


  4%|█████                                                                                                                            | 39/1000 [02:51<1:11:23,  4.46s/it]

EarlyStopping counter: 1 out of 40


  4%|█████▏                                                                                                                           | 40/1000 [02:56<1:11:11,  4.45s/it]

EarlyStopping counter: 2 out of 40


  4%|█████▎                                                                                                                           | 41/1000 [03:00<1:11:36,  4.48s/it]

EarlyStopping counter: 3 out of 40


  4%|█████▍                                                                                                                           | 42/1000 [03:05<1:11:24,  4.47s/it]

EarlyStopping counter: 4 out of 40


  4%|█████▌                                                                                                                           | 43/1000 [03:09<1:11:14,  4.47s/it]

EarlyStopping counter: 5 out of 40


  4%|█████▋                                                                                                                           | 44/1000 [03:14<1:10:49,  4.45s/it]

EarlyStopping counter: 6 out of 40


  4%|█████▊                                                                                                                           | 45/1000 [03:18<1:10:40,  4.44s/it]

EarlyStopping counter: 7 out of 40


  5%|█████▉                                                                                                                           | 46/1000 [03:22<1:10:10,  4.41s/it]

EarlyStopping counter: 8 out of 40


  5%|██████                                                                                                                           | 47/1000 [03:27<1:09:43,  4.39s/it]

EarlyStopping counter: 9 out of 40


  5%|██████▏                                                                                                                          | 48/1000 [03:31<1:09:37,  4.39s/it]

EarlyStopping counter: 10 out of 40


  5%|██████▎                                                                                                                          | 49/1000 [03:36<1:09:55,  4.41s/it]

EarlyStopping counter: 11 out of 40


  5%|██████▍                                                                                                                          | 50/1000 [03:40<1:10:33,  4.46s/it]

Validation loss decreased (0.006108 --> 0.004784).  Saving model ...


  5%|██████▌                                                                                                                          | 51/1000 [03:45<1:10:24,  4.45s/it]

Validation loss decreased (0.004784 --> 0.004775).  Saving model ...


  5%|██████▋                                                                                                                          | 52/1000 [03:49<1:10:29,  4.46s/it]

Validation loss decreased (0.004775 --> 0.004594).  Saving model ...


  5%|██████▊                                                                                                                          | 53/1000 [03:54<1:10:27,  4.46s/it]

EarlyStopping counter: 1 out of 40


  5%|██████▉                                                                                                                          | 54/1000 [03:58<1:10:11,  4.45s/it]

Validation loss decreased (0.004594 --> 0.004476).  Saving model ...


  6%|███████                                                                                                                          | 55/1000 [04:02<1:10:13,  4.46s/it]

EarlyStopping counter: 1 out of 40


  6%|███████▏                                                                                                                         | 56/1000 [04:07<1:10:06,  4.46s/it]

EarlyStopping counter: 2 out of 40


  6%|███████▎                                                                                                                         | 57/1000 [04:11<1:09:31,  4.42s/it]

EarlyStopping counter: 3 out of 40


  6%|███████▍                                                                                                                         | 58/1000 [04:16<1:09:09,  4.40s/it]

Validation loss decreased (0.004476 --> 0.004437).  Saving model ...


  6%|███████▌                                                                                                                         | 59/1000 [04:20<1:09:13,  4.41s/it]

Validation loss decreased (0.004437 --> 0.004380).  Saving model ...


  6%|███████▋                                                                                                                         | 60/1000 [04:25<1:09:18,  4.42s/it]

EarlyStopping counter: 1 out of 40


  6%|███████▊                                                                                                                         | 61/1000 [04:29<1:09:53,  4.47s/it]

EarlyStopping counter: 2 out of 40


  6%|███████▉                                                                                                                         | 62/1000 [04:34<1:10:17,  4.50s/it]

EarlyStopping counter: 3 out of 40


  6%|████████▏                                                                                                                        | 63/1000 [04:38<1:10:08,  4.49s/it]

EarlyStopping counter: 4 out of 40


  6%|████████▎                                                                                                                        | 64/1000 [04:43<1:10:00,  4.49s/it]

EarlyStopping counter: 5 out of 40


  6%|████████▍                                                                                                                        | 65/1000 [04:47<1:09:29,  4.46s/it]

Validation loss decreased (0.004380 --> 0.004364).  Saving model ...


  7%|████████▌                                                                                                                        | 66/1000 [04:51<1:09:07,  4.44s/it]

EarlyStopping counter: 1 out of 40


  7%|████████▋                                                                                                                        | 67/1000 [04:56<1:09:13,  4.45s/it]

EarlyStopping counter: 2 out of 40


  7%|████████▊                                                                                                                        | 68/1000 [05:00<1:09:12,  4.46s/it]

EarlyStopping counter: 3 out of 40


  7%|████████▉                                                                                                                        | 69/1000 [05:05<1:08:46,  4.43s/it]

EarlyStopping counter: 4 out of 40


  7%|█████████                                                                                                                        | 70/1000 [05:09<1:08:39,  4.43s/it]

Validation loss decreased (0.004364 --> 0.004255).  Saving model ...


  7%|█████████▏                                                                                                                       | 71/1000 [05:14<1:08:43,  4.44s/it]

EarlyStopping counter: 1 out of 40


  7%|█████████▎                                                                                                                       | 72/1000 [05:18<1:08:41,  4.44s/it]

EarlyStopping counter: 2 out of 40


  7%|█████████▍                                                                                                                       | 73/1000 [05:23<1:09:22,  4.49s/it]

EarlyStopping counter: 3 out of 40


  7%|█████████▌                                                                                                                       | 74/1000 [05:27<1:08:57,  4.47s/it]

EarlyStopping counter: 4 out of 40


  8%|█████████▋                                                                                                                       | 75/1000 [05:31<1:08:31,  4.45s/it]

EarlyStopping counter: 5 out of 40


  8%|█████████▊                                                                                                                       | 76/1000 [05:36<1:07:58,  4.41s/it]

EarlyStopping counter: 6 out of 40


  8%|█████████▉                                                                                                                       | 77/1000 [05:40<1:07:49,  4.41s/it]

EarlyStopping counter: 7 out of 40


  8%|██████████                                                                                                                       | 78/1000 [05:45<1:07:50,  4.41s/it]

EarlyStopping counter: 8 out of 40


  8%|██████████▏                                                                                                                      | 79/1000 [05:49<1:07:44,  4.41s/it]

Validation loss decreased (0.004255 --> 0.004201).  Saving model ...


  8%|██████████▎                                                                                                                      | 80/1000 [05:53<1:07:42,  4.42s/it]

EarlyStopping counter: 1 out of 40


  8%|██████████▍                                                                                                                      | 81/1000 [05:58<1:07:37,  4.42s/it]

EarlyStopping counter: 2 out of 40


  8%|██████████▌                                                                                                                      | 82/1000 [06:02<1:07:47,  4.43s/it]

EarlyStopping counter: 3 out of 40


  8%|██████████▋                                                                                                                      | 83/1000 [06:07<1:07:43,  4.43s/it]

Validation loss decreased (0.004201 --> 0.004183).  Saving model ...


  8%|██████████▊                                                                                                                      | 84/1000 [06:11<1:07:23,  4.41s/it]

EarlyStopping counter: 1 out of 40


  8%|██████████▉                                                                                                                      | 85/1000 [06:16<1:07:04,  4.40s/it]

EarlyStopping counter: 2 out of 40


  9%|███████████                                                                                                                      | 86/1000 [06:20<1:07:02,  4.40s/it]

EarlyStopping counter: 3 out of 40


  9%|███████████▏                                                                                                                     | 87/1000 [06:24<1:07:01,  4.40s/it]

Validation loss decreased (0.004183 --> 0.004180).  Saving model ...


  9%|███████████▎                                                                                                                     | 88/1000 [06:29<1:06:53,  4.40s/it]

EarlyStopping counter: 1 out of 40


  9%|███████████▍                                                                                                                     | 89/1000 [06:33<1:07:31,  4.45s/it]

EarlyStopping counter: 2 out of 40


  9%|███████████▌                                                                                                                     | 90/1000 [06:38<1:07:34,  4.46s/it]

EarlyStopping counter: 3 out of 40


  9%|███████████▋                                                                                                                     | 91/1000 [06:42<1:07:46,  4.47s/it]

EarlyStopping counter: 4 out of 40


  9%|███████████▊                                                                                                                     | 92/1000 [06:47<1:07:37,  4.47s/it]

EarlyStopping counter: 5 out of 40


  9%|███████████▉                                                                                                                     | 93/1000 [06:51<1:07:26,  4.46s/it]

EarlyStopping counter: 6 out of 40


  9%|████████████▏                                                                                                                    | 94/1000 [06:56<1:07:14,  4.45s/it]

EarlyStopping counter: 7 out of 40


 10%|████████████▎                                                                                                                    | 95/1000 [07:00<1:07:13,  4.46s/it]

EarlyStopping counter: 8 out of 40


 10%|████████████▍                                                                                                                    | 96/1000 [07:05<1:07:05,  4.45s/it]

EarlyStopping counter: 9 out of 40


 10%|████████████▌                                                                                                                    | 97/1000 [07:09<1:06:48,  4.44s/it]

EarlyStopping counter: 10 out of 40


 10%|████████████▋                                                                                                                    | 98/1000 [07:13<1:06:41,  4.44s/it]

EarlyStopping counter: 11 out of 40


 10%|████████████▊                                                                                                                    | 99/1000 [07:18<1:05:58,  4.39s/it]

Validation loss decreased (0.004180 --> 0.004007).  Saving model ...


 10%|████████████▊                                                                                                                   | 100/1000 [07:22<1:06:11,  4.41s/it]

EarlyStopping counter: 1 out of 40


 10%|████████████▉                                                                                                                   | 101/1000 [07:27<1:06:20,  4.43s/it]

EarlyStopping counter: 2 out of 40


 10%|█████████████                                                                                                                   | 102/1000 [07:31<1:06:30,  4.44s/it]

EarlyStopping counter: 3 out of 40


 10%|█████████████▏                                                                                                                  | 103/1000 [07:36<1:06:28,  4.45s/it]

Validation loss decreased (0.004007 --> 0.003974).  Saving model ...


 10%|█████████████▎                                                                                                                  | 104/1000 [07:40<1:06:16,  4.44s/it]

EarlyStopping counter: 1 out of 40


 10%|█████████████▍                                                                                                                  | 105/1000 [07:44<1:06:02,  4.43s/it]

EarlyStopping counter: 2 out of 40


 11%|█████████████▌                                                                                                                  | 106/1000 [07:49<1:06:20,  4.45s/it]

EarlyStopping counter: 3 out of 40


 11%|█████████████▋                                                                                                                  | 107/1000 [07:53<1:06:01,  4.44s/it]

EarlyStopping counter: 4 out of 40


 11%|█████████████▊                                                                                                                  | 108/1000 [07:58<1:05:48,  4.43s/it]

EarlyStopping counter: 5 out of 40


 11%|█████████████▉                                                                                                                  | 109/1000 [08:02<1:06:03,  4.45s/it]

EarlyStopping counter: 6 out of 40


 11%|██████████████                                                                                                                  | 110/1000 [08:07<1:06:09,  4.46s/it]

EarlyStopping counter: 7 out of 40


 11%|██████████████▏                                                                                                                 | 111/1000 [08:11<1:05:52,  4.45s/it]

EarlyStopping counter: 8 out of 40


 11%|██████████████▎                                                                                                                 | 112/1000 [08:15<1:05:15,  4.41s/it]

EarlyStopping counter: 9 out of 40


 11%|██████████████▍                                                                                                                 | 113/1000 [08:20<1:05:26,  4.43s/it]

EarlyStopping counter: 10 out of 40


 11%|██████████████▌                                                                                                                 | 114/1000 [08:24<1:05:27,  4.43s/it]

Validation loss decreased (0.003974 --> 0.003972).  Saving model ...


 12%|██████████████▋                                                                                                                 | 115/1000 [08:29<1:05:23,  4.43s/it]

EarlyStopping counter: 1 out of 40


 12%|██████████████▊                                                                                                                 | 116/1000 [08:33<1:05:14,  4.43s/it]

EarlyStopping counter: 2 out of 40


 12%|██████████████▉                                                                                                                 | 117/1000 [08:38<1:05:02,  4.42s/it]

EarlyStopping counter: 3 out of 40


 12%|███████████████                                                                                                                 | 118/1000 [08:42<1:04:57,  4.42s/it]

Validation loss decreased (0.003972 --> 0.003929).  Saving model ...


 12%|███████████████▏                                                                                                                | 119/1000 [08:46<1:04:19,  4.38s/it]

EarlyStopping counter: 1 out of 40


 12%|███████████████▎                                                                                                                | 120/1000 [08:51<1:04:33,  4.40s/it]

EarlyStopping counter: 2 out of 40


 12%|███████████████▍                                                                                                                | 121/1000 [08:55<1:04:45,  4.42s/it]

EarlyStopping counter: 3 out of 40


 12%|███████████████▌                                                                                                                | 122/1000 [09:00<1:05:03,  4.45s/it]

EarlyStopping counter: 4 out of 40


 12%|███████████████▋                                                                                                                | 123/1000 [09:04<1:04:48,  4.43s/it]

EarlyStopping counter: 5 out of 40


 12%|███████████████▊                                                                                                                | 124/1000 [09:09<1:04:50,  4.44s/it]

EarlyStopping counter: 6 out of 40


 12%|████████████████                                                                                                                | 125/1000 [09:13<1:04:52,  4.45s/it]

EarlyStopping counter: 7 out of 40


 13%|████████████████▏                                                                                                               | 126/1000 [09:17<1:04:10,  4.41s/it]

EarlyStopping counter: 8 out of 40


 13%|████████████████▎                                                                                                               | 127/1000 [09:22<1:04:15,  4.42s/it]

EarlyStopping counter: 9 out of 40


 13%|████████████████▍                                                                                                               | 128/1000 [09:26<1:04:09,  4.41s/it]

EarlyStopping counter: 10 out of 40


 13%|████████████████▌                                                                                                               | 129/1000 [09:31<1:04:15,  4.43s/it]

EarlyStopping counter: 11 out of 40


 13%|████████████████▋                                                                                                               | 130/1000 [09:35<1:04:04,  4.42s/it]

EarlyStopping counter: 12 out of 40


 13%|████████████████▊                                                                                                               | 131/1000 [09:39<1:03:47,  4.40s/it]

EarlyStopping counter: 13 out of 40


 13%|████████████████▉                                                                                                               | 132/1000 [09:44<1:03:52,  4.42s/it]

EarlyStopping counter: 14 out of 40


 13%|█████████████████                                                                                                               | 133/1000 [09:48<1:03:48,  4.42s/it]

EarlyStopping counter: 15 out of 40


 13%|█████████████████▏                                                                                                              | 134/1000 [09:53<1:03:17,  4.38s/it]

EarlyStopping counter: 16 out of 40


 14%|█████████████████▎                                                                                                              | 135/1000 [09:57<1:04:13,  4.45s/it]

EarlyStopping counter: 17 out of 40


 14%|█████████████████▍                                                                                                              | 136/1000 [10:02<1:04:10,  4.46s/it]

EarlyStopping counter: 18 out of 40


 14%|█████████████████▌                                                                                                              | 137/1000 [10:06<1:04:34,  4.49s/it]

EarlyStopping counter: 19 out of 40


 14%|█████████████████▋                                                                                                              | 138/1000 [10:11<1:04:30,  4.49s/it]

EarlyStopping counter: 20 out of 40


 14%|█████████████████▊                                                                                                              | 139/1000 [10:15<1:04:16,  4.48s/it]

EarlyStopping counter: 21 out of 40


 14%|█████████████████▉                                                                                                              | 140/1000 [10:20<1:04:11,  4.48s/it]

EarlyStopping counter: 22 out of 40


 14%|██████████████████                                                                                                              | 141/1000 [10:24<1:04:05,  4.48s/it]

EarlyStopping counter: 23 out of 40


 14%|██████████████████▏                                                                                                             | 142/1000 [10:28<1:03:45,  4.46s/it]

EarlyStopping counter: 24 out of 40


 14%|██████████████████▎                                                                                                             | 143/1000 [10:33<1:04:34,  4.52s/it]

EarlyStopping counter: 25 out of 40


 14%|██████████████████▍                                                                                                             | 144/1000 [10:38<1:04:10,  4.50s/it]

EarlyStopping counter: 26 out of 40


 14%|██████████████████▌                                                                                                             | 145/1000 [10:42<1:03:43,  4.47s/it]

Validation loss decreased (0.003929 --> 0.003927).  Saving model ...


 15%|██████████████████▋                                                                                                             | 146/1000 [10:47<1:03:54,  4.49s/it]

EarlyStopping counter: 1 out of 40


 15%|██████████████████▊                                                                                                             | 147/1000 [10:51<1:03:35,  4.47s/it]

EarlyStopping counter: 2 out of 40


 15%|██████████████████▉                                                                                                             | 148/1000 [10:55<1:03:22,  4.46s/it]

EarlyStopping counter: 3 out of 40


 15%|███████████████████                                                                                                             | 149/1000 [11:00<1:03:08,  4.45s/it]

EarlyStopping counter: 4 out of 40


 15%|███████████████████▏                                                                                                            | 150/1000 [11:04<1:03:04,  4.45s/it]

EarlyStopping counter: 5 out of 40


 15%|███████████████████▎                                                                                                            | 151/1000 [11:09<1:04:17,  4.54s/it]

EarlyStopping counter: 6 out of 40


 15%|███████████████████▍                                                                                                            | 152/1000 [11:14<1:03:59,  4.53s/it]

EarlyStopping counter: 7 out of 40


 15%|███████████████████▌                                                                                                            | 153/1000 [11:18<1:03:38,  4.51s/it]

EarlyStopping counter: 8 out of 40


 15%|███████████████████▋                                                                                                            | 154/1000 [11:22<1:03:25,  4.50s/it]

EarlyStopping counter: 9 out of 40


 16%|███████████████████▊                                                                                                            | 155/1000 [11:27<1:02:48,  4.46s/it]

EarlyStopping counter: 10 out of 40


 16%|███████████████████▉                                                                                                            | 156/1000 [11:31<1:02:34,  4.45s/it]

EarlyStopping counter: 11 out of 40


 16%|████████████████████                                                                                                            | 157/1000 [11:36<1:02:42,  4.46s/it]

EarlyStopping counter: 12 out of 40


 16%|████████████████████▏                                                                                                           | 158/1000 [11:40<1:02:25,  4.45s/it]

EarlyStopping counter: 13 out of 40


 16%|████████████████████▎                                                                                                           | 159/1000 [11:45<1:02:30,  4.46s/it]

EarlyStopping counter: 14 out of 40


 16%|████████████████████▍                                                                                                           | 160/1000 [11:49<1:02:37,  4.47s/it]

EarlyStopping counter: 15 out of 40


 16%|████████████████████▌                                                                                                           | 161/1000 [11:54<1:02:21,  4.46s/it]

EarlyStopping counter: 16 out of 40


 16%|████████████████████▋                                                                                                           | 162/1000 [11:58<1:02:33,  4.48s/it]

EarlyStopping counter: 17 out of 40


 16%|████████████████████▊                                                                                                           | 163/1000 [12:03<1:02:56,  4.51s/it]

EarlyStopping counter: 18 out of 40


 16%|████████████████████▉                                                                                                           | 164/1000 [12:07<1:02:47,  4.51s/it]

EarlyStopping counter: 19 out of 40


 16%|█████████████████████                                                                                                           | 165/1000 [12:12<1:02:18,  4.48s/it]

EarlyStopping counter: 20 out of 40


 17%|█████████████████████▏                                                                                                          | 166/1000 [12:16<1:01:55,  4.45s/it]

EarlyStopping counter: 21 out of 40


 17%|█████████████████████▍                                                                                                          | 167/1000 [12:20<1:01:26,  4.43s/it]

EarlyStopping counter: 22 out of 40


 17%|█████████████████████▌                                                                                                          | 168/1000 [12:25<1:01:26,  4.43s/it]

EarlyStopping counter: 23 out of 40


 17%|█████████████████████▋                                                                                                          | 169/1000 [12:29<1:01:23,  4.43s/it]

EarlyStopping counter: 24 out of 40


 17%|█████████████████████▊                                                                                                          | 170/1000 [12:34<1:01:18,  4.43s/it]

EarlyStopping counter: 25 out of 40


 17%|█████████████████████▉                                                                                                          | 171/1000 [12:38<1:01:25,  4.45s/it]

EarlyStopping counter: 26 out of 40


 17%|██████████████████████                                                                                                          | 172/1000 [12:43<1:02:40,  4.54s/it]

EarlyStopping counter: 27 out of 40


 17%|██████████████████████▏                                                                                                         | 173/1000 [12:47<1:02:11,  4.51s/it]

Validation loss decreased (0.003927 --> 0.003925).  Saving model ...


 17%|██████████████████████▎                                                                                                         | 174/1000 [12:52<1:01:44,  4.48s/it]

EarlyStopping counter: 1 out of 40


 18%|██████████████████████▍                                                                                                         | 175/1000 [12:56<1:01:27,  4.47s/it]

EarlyStopping counter: 2 out of 40


 18%|██████████████████████▌                                                                                                         | 176/1000 [13:01<1:02:06,  4.52s/it]

EarlyStopping counter: 3 out of 40


 18%|██████████████████████▋                                                                                                         | 177/1000 [13:05<1:01:20,  4.47s/it]

EarlyStopping counter: 4 out of 40


 18%|██████████████████████▊                                                                                                         | 178/1000 [13:10<1:01:01,  4.45s/it]

EarlyStopping counter: 5 out of 40


 18%|██████████████████████▉                                                                                                         | 179/1000 [13:14<1:01:00,  4.46s/it]

EarlyStopping counter: 6 out of 40


 18%|███████████████████████                                                                                                         | 180/1000 [13:19<1:01:08,  4.47s/it]

EarlyStopping counter: 7 out of 40


 18%|███████████████████████▏                                                                                                        | 181/1000 [13:23<1:01:06,  4.48s/it]

EarlyStopping counter: 8 out of 40


 18%|███████████████████████▎                                                                                                        | 182/1000 [13:28<1:00:50,  4.46s/it]

Validation loss decreased (0.003925 --> 0.003920).  Saving model ...


 18%|███████████████████████▍                                                                                                        | 183/1000 [13:32<1:00:42,  4.46s/it]

EarlyStopping counter: 1 out of 40


 18%|███████████████████████▌                                                                                                        | 184/1000 [13:36<1:00:03,  4.42s/it]

EarlyStopping counter: 2 out of 40


 18%|████████████████████████                                                                                                          | 185/1000 [13:41<59:49,  4.40s/it]

EarlyStopping counter: 3 out of 40


 19%|███████████████████████▊                                                                                                        | 186/1000 [13:45<1:00:14,  4.44s/it]

EarlyStopping counter: 4 out of 40


 19%|███████████████████████▉                                                                                                        | 187/1000 [13:50<1:00:21,  4.45s/it]

EarlyStopping counter: 5 out of 40


 19%|████████████████████████                                                                                                        | 188/1000 [13:54<1:00:08,  4.44s/it]

EarlyStopping counter: 6 out of 40


 19%|████████████████████████▌                                                                                                         | 189/1000 [13:58<59:26,  4.40s/it]

EarlyStopping counter: 7 out of 40


 19%|████████████████████████▋                                                                                                         | 190/1000 [14:03<59:35,  4.41s/it]

EarlyStopping counter: 8 out of 40


 19%|████████████████████████▊                                                                                                         | 191/1000 [14:07<59:29,  4.41s/it]

EarlyStopping counter: 9 out of 40


 19%|████████████████████████▉                                                                                                         | 192/1000 [14:12<59:28,  4.42s/it]

EarlyStopping counter: 10 out of 40


 19%|█████████████████████████                                                                                                         | 193/1000 [14:16<59:33,  4.43s/it]

EarlyStopping counter: 11 out of 40


 19%|█████████████████████████▏                                                                                                        | 194/1000 [14:21<59:30,  4.43s/it]

EarlyStopping counter: 12 out of 40


 20%|█████████████████████████▎                                                                                                        | 195/1000 [14:25<59:37,  4.44s/it]

EarlyStopping counter: 13 out of 40


 20%|█████████████████████████▍                                                                                                        | 196/1000 [14:30<59:40,  4.45s/it]

EarlyStopping counter: 14 out of 40


 20%|█████████████████████████▌                                                                                                        | 197/1000 [14:34<59:03,  4.41s/it]

EarlyStopping counter: 15 out of 40


 20%|█████████████████████████▋                                                                                                        | 198/1000 [14:38<58:47,  4.40s/it]

EarlyStopping counter: 16 out of 40


 20%|█████████████████████████▊                                                                                                        | 199/1000 [14:43<58:58,  4.42s/it]

EarlyStopping counter: 17 out of 40


 20%|██████████████████████████                                                                                                        | 200/1000 [14:47<58:50,  4.41s/it]

EarlyStopping counter: 18 out of 40


 20%|██████████████████████████▏                                                                                                       | 201/1000 [14:51<58:38,  4.40s/it]

EarlyStopping counter: 19 out of 40


 20%|██████████████████████████▎                                                                                                       | 202/1000 [14:56<58:34,  4.40s/it]

EarlyStopping counter: 20 out of 40


 20%|██████████████████████████▍                                                                                                       | 203/1000 [15:00<58:30,  4.41s/it]

EarlyStopping counter: 21 out of 40


 20%|██████████████████████████▌                                                                                                       | 204/1000 [15:05<58:26,  4.40s/it]

EarlyStopping counter: 22 out of 40


 20%|██████████████████████████▋                                                                                                       | 205/1000 [15:09<58:44,  4.43s/it]

EarlyStopping counter: 23 out of 40


 21%|██████████████████████████▊                                                                                                       | 206/1000 [15:14<58:20,  4.41s/it]

EarlyStopping counter: 24 out of 40


 21%|██████████████████████████▉                                                                                                       | 207/1000 [15:18<58:32,  4.43s/it]

EarlyStopping counter: 25 out of 40


 21%|███████████████████████████                                                                                                       | 208/1000 [15:22<58:24,  4.42s/it]

EarlyStopping counter: 26 out of 40


 21%|███████████████████████████▏                                                                                                      | 209/1000 [15:27<58:12,  4.42s/it]

EarlyStopping counter: 27 out of 40


 21%|███████████████████████████▎                                                                                                      | 210/1000 [15:31<57:59,  4.40s/it]

EarlyStopping counter: 28 out of 40


 21%|███████████████████████████▍                                                                                                      | 211/1000 [15:36<58:24,  4.44s/it]

EarlyStopping counter: 29 out of 40


 21%|███████████████████████████▌                                                                                                      | 212/1000 [15:40<58:11,  4.43s/it]

EarlyStopping counter: 30 out of 40


 21%|███████████████████████████▋                                                                                                      | 213/1000 [15:45<58:01,  4.42s/it]

EarlyStopping counter: 31 out of 40


 21%|███████████████████████████▊                                                                                                      | 214/1000 [15:49<57:49,  4.41s/it]

EarlyStopping counter: 32 out of 40


 22%|███████████████████████████▉                                                                                                      | 215/1000 [15:53<58:11,  4.45s/it]

EarlyStopping counter: 33 out of 40


 22%|████████████████████████████                                                                                                      | 216/1000 [15:58<58:00,  4.44s/it]

EarlyStopping counter: 34 out of 40


 22%|████████████████████████████▏                                                                                                     | 217/1000 [16:02<58:05,  4.45s/it]

EarlyStopping counter: 35 out of 40


 22%|████████████████████████████▎                                                                                                     | 218/1000 [16:07<58:21,  4.48s/it]

EarlyStopping counter: 36 out of 40


 22%|████████████████████████████▍                                                                                                     | 219/1000 [16:11<58:00,  4.46s/it]

EarlyStopping counter: 37 out of 40


 22%|████████████████████████████▌                                                                                                     | 220/1000 [16:16<57:36,  4.43s/it]

EarlyStopping counter: 38 out of 40


 22%|████████████████████████████▋                                                                                                     | 221/1000 [16:20<57:04,  4.40s/it]

EarlyStopping counter: 39 out of 40


 22%|████████████████████████████▋                                                                                                     | 221/1000 [16:24<57:51,  4.46s/it]

EarlyStopping counter: 40 out of 40
Early stopping





R2 of test is:  0.9375733529381579
Test set results for 11309 samples:
MSE: 0.0038909372
MAE: 0.03546797
MSE loss on test set is: 0.003890940064495336
time used to train model with 40/1000 patience is:  16.440725128368165 mins


## prediction on training or experiment data

In [14]:
def predict_on_test(config): 
    
    data_dir = 'datasets/train'
    exp_cscl_dir = 'datasets/exp/rutitle_cscl'
    exp_licl_dir = 'datasets/exp/rutitle_LiCl'

    # get training dataset 
    train_set, _, _ = load_data(exp_cscl_dir, config['addGlobal'])
    # train_set, _, _ = load_data(exp_licl_dir, config['addGlobal'])

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=int(config["batch_size"]),
        shuffle=True)
            
    print("Creating model")
    model = PointNetOneX(config["l1"], config["l2"], config["l3"], config["l4"], config["l5"], 
                  config["dropout"], config["isBN"], config['isLN'],config['constraint'], config['pooling'])
    print('created model is: ', model)
    
    name = f"{config['model']}{config['l1']}{config['l2']}{config['l3']}{config['l4']}{config['l5']}lr{config['lr']}BS{config['batch_size']}isB{config['isBN']}ln{config['isLN']}cons{config['constraint']}Opt{config['optimizer']}sch{config['lr_scheduler']}pool{config['pooling']}addG{config['addGlobal']}"
    name1 = f"cscl{config['model']}{config['l1']}{config['l2']}{config['l3']}{config['l4']}{config['l5']}lr{config['lr']}BS{config['batch_size']}isB{config['isBN']}ln{config['isLN']}cons{config['constraint']}Opt{config['optimizer']}sch{config['lr_scheduler']}pool{config['pooling']}addG{config['addGlobal']}"
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cuda")
    # if torch.cuda.device_count() > 1:
    #     model = nn.DataParallel(model)
    model.to(device)

    optimizer, lr_scheduler = build_optimizer(model, config['optimizer'], config['lr'])
    model.to(device)

    res_dir = 'complexDNN/res/'
    checkpoint_dir = os.path.join(res_dir, 'checkpoints')
    checkpoint_path = os.path.join(checkpoint_dir, f'{name}.pt')

    model.load_state_dict(torch.load(checkpoint_path))

    # test on test set 
    test_accuracy(model, train_loader, res_dir, name1, device)   

In [15]:
import time 
start_time = time.perf_counter()
predict_on_test(config)
end_time = time.perf_counter() 
print('time used to train model with 20/250 patience is: ', (end_time - start_time), 'secs')

Creating model
created model is:  PointNetOneX(
  (conv1): Conv1d(6, 512, kernel_size=(1,), stride=(1,))
  (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
  (conv3): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (ln1): LayerNorm((512, 56), eps=1e-05, elementwise_affine=True)
  (ln2): LayerNorm((512, 56), eps=1e-05, elementwise_affine=True)
  (ln3): LayerNorm((512, 56), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=5, bias=True)
  (dropout): Dropout(p=0, inplace=False)
  (bn4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

