In [1]:
import argparse
import os
import shutil
import sys
import time
import warnings
from random import sample

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import metrics
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR

# from cgcnn.data import CIFData
# from cgcnn.data import collate_pool, get_train_val_test_loader
from cgcnn.featureModel import CrystalGraphConvNet
from property_prediction_ofm.model import Net

from dataloader import CIFOFMData
from dataloader import collate_pool, get_train_val_test_loader

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
# Model
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class MixtureNet(nn.Module):
    def __init__(self, input_size=64):
        super(MixtureNet, self).__init__()
        
        self.fc1 = nn.Linear(input_size, 48)
        self.fc2 = nn.Linear(48, 32)
        self.fc3 = nn.Linear(32, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
mixnet = MixtureNet()
mixnet = mixnet.float()
mixnet.to(device)

MixtureNet(
  (fc1): Linear(in_features=64, out_features=48, bias=True)
  (fc2): Linear(in_features=48, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=1, bias=True)
)

In [17]:
data_options = '../project_mlns/dataset/'
modelpathCG = 'trained_nets/formation-energy-per-atom.pth.tar'
modelpathOFM = 'trained_nets/propertyPredictionUsingOFM_net.pth'

workers = 0
epochs = 30
start_epoch = 0
batch_size = 32
lr = 0.01
lr_milestones = [100]
disable_cuda = False

momentum = 0.9
weight_decay = 0
print_freq = 10

val_ratio = 0.1
test_ratio = 0.1
optim_type ='SGD'

cuda = not disable_cuda and torch.cuda.is_available()

best_mae_error = 1e10

In [18]:
# load data
dataset = CIFOFMData(data_options)
collate_fn = collate_pool
train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset=dataset,
    collate_fn=collate_fn,
    pin_memory=cuda,
    batch_size=batch_size,
    num_workers=workers,
    train_size = None,
    test_size = None,
    val_size = None,
    val_ratio=val_ratio,
    test_ratio=test_ratio,
    return_test=True, shuffle=True)



In [19]:
class Normalizer(object):
    """Normalize a Tensor and restore it later. """

    def __init__(self, tensor):
        """tensor is taken as a sample to calculate the mean and std"""
        self.mean = torch.mean(tensor)
        self.std = torch.std(tensor)

    def norm(self, tensor):
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor):
        return normed_tensor * self.std + self.mean

    def state_dict(self):
        return {'mean': self.mean,
                'std': self.std}

    def load_state_dict(self, state_dict):
        self.mean = state_dict['mean']
        self.std = state_dict['std']

In [21]:
if len(dataset) < 500:
    warnings.warn('Dataset has less than 500 data points. '
                  'Lower accuracy is expected. ')
    sample_data_list = [dataset[i] for i in range(len(dataset))]
else:
    sample_data_list = [dataset[i] for i in
                        sample(range(len(dataset)), 500)]
_, sample_target, _, _ = collate_pool(sample_data_list)
normalizer = Normalizer(sample_target)

# build models
structures, _, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

model_checkpoint = torch.load(modelpathCG,
                              map_location=lambda storage, loc: storage)
model_args = argparse.Namespace(**model_checkpoint['args'])
modelCG = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
                            atom_fea_len=model_args.atom_fea_len,
                            n_conv=model_args.n_conv,
                            h_fea_len=model_args.h_fea_len,
                            n_h=model_args.n_h,
                            classification=False)

modelOFM = Net()
modelOFM.load_state_dict(torch.load(modelpathOFM))

if cuda:
    modelOFM.cuda()
    modelCG.cuda()

In [22]:
# define loss func and optimizer
criterion = nn.MSELoss()
if optim_type == 'SGD':
    optimizer = optim.SGD(mixnet.parameters(), lr,
                          momentum=momentum,
                          weight_decay=weight_decay)
elif optim_type == 'Adam':
    optimizer = optim.Adam(mixnet.parameters(), lr,
                           weight_decay=weight_decay)
else:
    raise NameError('Only SGD or Adam is allowed as --optim')

In [23]:
# resume from a checkpoint
checkpointCG = torch.load(modelpathCG)
modelCG.load_state_dict(checkpointCG['state_dict'])
normalizer.load_state_dict(checkpointCG['normalizer'])

checkpointOFM = torch.load(modelpathOFM)
modelOFM.load_state_dict(checkpointOFM)

<All keys matched successfully>

In [24]:
def mae(prediction, target):
    """
    Computes the mean absolute error between prediction and target

    Parameters
    ----------

    prediction: torch.Tensor (N, 1)
    target: torch.Tensor (N, 1)
    """
    return torch.mean(torch.abs(target - prediction))


def class_eval(prediction, target):
    prediction = np.exp(prediction.numpy())
    target = target.numpy()
    pred_label = np.argmax(prediction, axis=1)
    target_label = np.squeeze(target)
    if not target_label.shape:
        target_label = np.asarray([target_label])
    if prediction.shape[1] == 2:
        precision, recall, fscore, _ = metrics.precision_recall_fscore_support(
            target_label, pred_label, average='binary')
        auc_score = metrics.roc_auc_score(target_label, prediction[:, 1])
        accuracy = metrics.accuracy_score(target_label, pred_label)
    else:
        raise NotImplementedError
    return accuracy, precision, recall, fscore, auc_score


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, epoch, k):
    """Sets the learning rate to the initial LR decayed by 10 every k epochs"""
    assert type(k) is int
    lr = args.lr * (0.1 ** (epoch // k))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [25]:
def train(train_loader, mixnet, modelCG, modelOFM, criterion, optimizer, epoch, normalizer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    mae_errors = AverageMeter()

    # switch to train mode
    mixnet.train()

    end = time.time()
    for i, (input, target, _, ofmMat) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if cuda:
            input_var = (Variable(input[0].cuda(non_blocking=True)),
                         Variable(input[1].cuda(non_blocking=True)),
                         input[2].cuda(non_blocking=True),
                         [crys_idx.cuda(non_blocking=True) for crys_idx in input[3]])
            inOfmMat = ofmMat.float().cuda(non_blocking=True)
        else:
            input_var = (Variable(input[0]),
                         Variable(input[1]),
                         input[2],
                         input[3])
            inOfmMat = ofmMat.float()
            
        # normalize target
        target_normed = normalizer.norm(target)
        
        if cuda:
            target_var = Variable(target_normed.cuda(non_blocking=True))
        else:
            target_var = Variable(target_normed)

        # compute feature from CG and OFM models
        featureCG = modelCG(*input_var)
        featureOFM = modelOFM(inOfmMat)
        # print(featureCG.size(), featureOFM.size())   
        # final feature after concatenation of features from CG and OFM models
        feature = torch.cat((featureCG, featureOFM), 1)
        # print(feature.size())
        output = mixnet(feature)
        
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        mae_error = mae(normalizer.denorm(output.data.cpu()), target)
        losses.update(loss.data.cpu(), target.size(0))
        mae_errors.update(mae_error, target.size(0))
       

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'MAE {mae_errors.val:.3f} ({mae_errors.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, mae_errors=mae_errors)
            )


In [26]:
scheduler = MultiStepLR(optimizer, milestones=lr_milestones,
                        gamma=0.1)

for epoch in range(start_epoch, epochs):
    # train for one epoch
    train(train_loader, mixnet, modelCG, modelOFM, criterion, optimizer, epoch, normalizer)

    scheduler.step()

Epoch: [0][0/81]	Time 0.543 (0.543)	Data 0.520 (0.520)	Loss 61.2409 (61.2409)	MAE 5.960 (5.960)
Epoch: [0][10/81]	Time 0.558 (0.603)	Data 0.513 (0.575)	Loss 42.9800 (176.1419)	MAE 5.842 (8.076)
Epoch: [0][20/81]	Time 0.440 (0.640)	Data 0.414 (0.608)	Loss 77.0584 (203.3968)	MAE 8.430 (8.841)
Epoch: [0][30/81]	Time 0.735 (0.618)	Data 0.710 (0.587)	Loss 139.5190 (181.5776)	MAE 8.037 (8.578)
Epoch: [0][40/81]	Time 0.654 (0.627)	Data 0.626 (0.597)	Loss 310.3209 (180.6172)	MAE 10.625 (8.446)
Epoch: [0][50/81]	Time 0.509 (0.626)	Data 0.482 (0.597)	Loss 82.6150 (189.8645)	MAE 7.905 (8.439)
Epoch: [0][60/81]	Time 0.424 (0.619)	Data 0.399 (0.590)	Loss 88.0301 (196.7016)	MAE 7.729 (8.512)
Epoch: [0][70/81]	Time 0.547 (0.619)	Data 0.523 (0.590)	Loss 254.7070 (211.8556)	MAE 11.607 (8.611)
Epoch: [0][80/81]	Time 0.110 (0.613)	Data 0.082 (0.585)	Loss 18.6608 (208.8024)	MAE 3.965 (8.573)
Epoch: [1][0/81]	Time 0.041 (0.041)	Data 0.009 (0.009)	Loss 1078.2468 (1078.2468)	MAE 17.152 (17.152)
Epoch: [1][10

Epoch: [9][30/81]	Time 0.024 (0.028)	Data 0.002 (0.003)	Loss 450.8015 (201.3913)	MAE 10.004 (8.251)
Epoch: [9][40/81]	Time 0.029 (0.028)	Data 0.004 (0.003)	Loss 73.9308 (200.6547)	MAE 6.745 (8.236)
Epoch: [9][50/81]	Time 0.029 (0.028)	Data 0.004 (0.003)	Loss 168.8675 (196.9488)	MAE 7.718 (8.334)
Epoch: [9][60/81]	Time 0.028 (0.028)	Data 0.004 (0.003)	Loss 45.4854 (217.1442)	MAE 6.119 (8.570)
Epoch: [9][70/81]	Time 0.029 (0.028)	Data 0.004 (0.003)	Loss 69.1759 (215.2389)	MAE 6.701 (8.581)
Epoch: [9][80/81]	Time 0.015 (0.028)	Data 0.001 (0.003)	Loss 14.6570 (208.9617)	MAE 2.951 (8.572)
Epoch: [10][0/81]	Time 0.032 (0.032)	Data 0.007 (0.007)	Loss 61.1069 (61.1069)	MAE 7.207 (7.207)
Epoch: [10][10/81]	Time 0.028 (0.028)	Data 0.004 (0.004)	Loss 975.6170 (272.8632)	MAE 11.226 (8.098)
Epoch: [10][20/81]	Time 0.027 (0.028)	Data 0.003 (0.003)	Loss 99.6605 (243.1722)	MAE 8.033 (8.383)
Epoch: [10][30/81]	Time 0.027 (0.028)	Data 0.004 (0.003)	Loss 123.0566 (240.4170)	MAE 8.779 (8.710)
Epoch: [10][

Epoch: [18][50/81]	Time 0.046 (0.032)	Data 0.007 (0.004)	Loss 77.5545 (181.7351)	MAE 7.183 (7.850)
Epoch: [18][60/81]	Time 0.046 (0.034)	Data 0.007 (0.004)	Loss 370.9203 (194.4050)	MAE 10.205 (8.089)
Epoch: [18][70/81]	Time 0.044 (0.036)	Data 0.006 (0.004)	Loss 180.5419 (202.3884)	MAE 10.103 (8.348)
Epoch: [18][80/81]	Time 0.021 (0.037)	Data 0.001 (0.005)	Loss 45.7001 (209.0285)	MAE 6.865 (8.512)
Epoch: [19][0/81]	Time 0.043 (0.043)	Data 0.004 (0.004)	Loss 201.5094 (201.5094)	MAE 10.990 (10.990)
Epoch: [19][10/81]	Time 0.043 (0.044)	Data 0.005 (0.006)	Loss 82.0765 (220.7341)	MAE 6.813 (8.951)
Epoch: [19][20/81]	Time 0.028 (0.041)	Data 0.003 (0.005)	Loss 130.5147 (169.7037)	MAE 8.739 (8.381)
Epoch: [19][30/81]	Time 0.028 (0.037)	Data 0.003 (0.005)	Loss 68.5099 (178.9761)	MAE 5.989 (8.157)
Epoch: [19][40/81]	Time 0.028 (0.035)	Data 0.003 (0.004)	Loss 795.3002 (241.6773)	MAE 13.948 (8.678)
Epoch: [19][50/81]	Time 0.028 (0.034)	Data 0.004 (0.004)	Loss 39.9080 (217.1654)	MAE 5.843 (8.626)
E

Epoch: [27][70/81]	Time 0.028 (0.028)	Data 0.005 (0.003)	Loss 105.5180 (200.1400)	MAE 8.583 (8.265)
Epoch: [27][80/81]	Time 0.012 (0.028)	Data 0.001 (0.003)	Loss 80.2305 (209.0289)	MAE 8.768 (8.451)
Epoch: [28][0/81]	Time 0.028 (0.028)	Data 0.003 (0.003)	Loss 45.1594 (45.1594)	MAE 6.211 (6.211)
Epoch: [28][10/81]	Time 0.027 (0.028)	Data 0.003 (0.003)	Loss 64.2323 (202.4523)	MAE 7.341 (8.424)
Epoch: [28][20/81]	Time 0.028 (0.028)	Data 0.004 (0.003)	Loss 45.6930 (165.2651)	MAE 5.133 (7.828)
Epoch: [28][30/81]	Time 0.029 (0.028)	Data 0.004 (0.003)	Loss 59.4269 (214.3228)	MAE 5.373 (8.185)
Epoch: [28][40/81]	Time 0.028 (0.028)	Data 0.004 (0.003)	Loss 64.4654 (222.7620)	MAE 7.330 (8.508)
Epoch: [28][50/81]	Time 0.028 (0.028)	Data 0.004 (0.003)	Loss 68.5428 (223.6215)	MAE 7.604 (8.675)
Epoch: [28][60/81]	Time 0.026 (0.028)	Data 0.002 (0.003)	Loss 42.6693 (222.1105)	MAE 5.877 (8.717)
Epoch: [28][70/81]	Time 0.026 (0.028)	Data 0.002 (0.003)	Loss 113.0743 (211.2635)	MAE 8.278 (8.631)
Epoch: [28