In [1]:
from torch.autograd import Variable
import shutil

In [None]:
class Normalizer(object):
    """Normalize a Tensor and restore it later. """
    def __init__(self, tensor):
        """tensor is taken as a sample to calculate the mean and std"""
        self.mean = torch.mean(tensor)
        self.std = torch.std(tensor)

    def norm(self, tensor):
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor):
        return normed_tensor * self.std + self.mean

    def state_dict(self):
        return {'mean': self.mean,
                'std': self.std}

    def load_state_dict(self, state_dict):
        self.mean = state_dict['mean']
        self.std = state_dict['std']

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def mae(prediction, target):
    """
    Computes the mean absolute error between prediction and target
    Parameters
    ----------
    prediction: torch.Tensor (N, 1)
    target: torch.Tensor (N, 1)
    """
    return torch.mean(torch.abs(target - prediction))

In [None]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [None]:
def train(train_loader, model, criterion, optimizer, epoch, cuda = True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    mae_errors = AverageMeter()
    
    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        
        if cuda == True:
            input_var = (input[0].cuda(async=True),
                         input[1].cuda(async=True),
                         input[2].cuda(async=True),
                         [crys_idx.cuda(async=True) for crys_idx in input[3]])
        else:
            input_var = (input[0],
                         input[1],
                         input[2],
                         [crys_idx for crys_idx in input[3]])
        # normalize target
        target_normed = target
        if cuda == True:
            target_var = target_normed.cuda(async=True)
        else:
            target_var = target_normed

        # compute output
        output = model(*input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        mae_error = mae(output.data.cpu(), target)
        losses.update(loss.data.cpu().item(), target.size(0))
        mae_errors.update(mae_error, target.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        return losses.avg, mae_errors.avg

In [1]:
def validate(val_loader, model, criterion, test=False, cuda = True):
    batch_time = AverageMeter()
    losses = AverageMeter()
    mae_errors = AverageMeter()
    if test:
        test_targets = []
        test_preds = []
        test_cif_ids = []

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (input, target, batch_cif_ids) in enumerate(val_loader):
        if cuda == True:
            input_var = (input[0].cuda(async=True),
                         input[1].cuda(async=True),
                         input[2].cuda(async=True),
                         [crys_idx.cuda(async=True) for crys_idx in input[3]])
        else:
             input_var = (input[0],
                         input[1],
                         input[2],
                         [crys_idx for crys_idx in input[3]])
        target_normed = target
        if cuda == True:
            target_var = target_normed.cuda(async=True)
        else:
             target_var = target_normed

        # compute output
        output = model(*input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        mae_error = mae(output.data.cpu(), target)
        losses.update(loss.data.cpu().item(), target.size(0))
        mae_errors.update(mae_error, target.size(0))
        if test:
            test_pred = normalizer.denorm(output.data.cpu())
            test_target = target
            test_preds += test_pred.view(-1).tolist()
            test_targets += test_target.view(-1).tolist()
            test_cif_ids += batch_cif_ids

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    return losses.avg, mae_errors.avg

In [None]:
def perform_graph_training(model, train_loader, val_loader, n_epochs = 50, lr=0.01, cuda = True):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=0)
    
    best_mae_error = 1e10
    best_mae_epoch = 0
    mae_history = []
    text_history = []
    train_mae = []
    val_mae = []
    best_model = None

    for epoch in range(n_epochs):
            # train for one epoch
            loss_error, mae_error = train(train_loader, model, criterion, optimizer, epoch, cuda)
            train_mae.append(mae_error)

            # evaluate on validation set
            loss_error, mae_error = validate(val_loader, model, criterion, False, cuda)
            val_mae.append(mae_error)

            if mae_error != mae_error:
                print('Exit due to NaN')
                quit()
                
            if mae_error < best_mae_error:
                best_mae_error = mae_error
                best_mae_epoch = epoch
                best_model = copy.deepcopy(model)

            #visualize
            display.clear_output(wait=True)
            plt.plot(np.arange(len(train_mae)) + 1, train_mae, label = 'MAE on train dataset')
            plt.plot(np.arange(len(val_mae)) + 1, val_mae, label = 'MAE on val dataset')
            print("MAE on train dataset now: ", float(train_mae[-1]))
            print("MAE on val dataset now: ", float(val_mae[-1]))
            print("best MAE on val dataset up to now: ", float(best_mae_error))
            print("best MAE on val dataset was in epoch number: ", int(best_mae_epoch))
            plt.legend()
            plt.show()
    return best_model