In [None]:
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.bayesopt import BayesOptSearch

import os

def train_cifar(config, checkpoint_dir=None):
    #net = net(config["l1"], config["l2"])
    net = SiameseNetwork_GCN(
        indim=config["l1"], hiddendim=config["l2"], outdim=config['outdim'], dropoutx=config['dropoutx'])

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    net.to(device)

    criterion = ContrastiveLoss()
    optimizer = optim.Adam(net.parameters(), lr=config["lr"])

    # The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint
    # should be restored.
    if checkpoint_dir:
        checkpoint = os.path.join(checkpoint_dir, "checkpoint")
        model_state, optimizer_state = torch.load(checkpoint)
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
    
    
    validation_samples, traindata_list = train_sampler(X_train, y_train, dim=config["l1"])
    trainloader = DataLoader(traindata_list, batch_size=1, follow_batch=[
                      'x1', 'x2'], shuffle=False)
    
    testdata_list = test_sampler(X_test, y_test, validation_samples, dim=config["l1"])
    valloader = DataLoader(testdata_list, batch_size=1,
                         follow_batch=['x1', 'x2'], shuffle=False)

#     trainloader = train_looood
#     valloader = test_looood

    for epoch in range(100):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            x1, x2 = data.x1.to(device), data.x2.to(device)
            x1_index, x2_index = data.x1_index.to(
                device), data.x2_index.to(device)
            x1_batch, x2_batch = data.x1_batch.to(
                device), data.x2_batch.to(device)
            y = data.y.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            output1, output2 = net(x1, x1_index, x1_batch, x2, x2_index, x2_batch)
            loss = criterion(output1, output2, y)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                running_loss / epoch_steps))
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                x1, x2 = data.x1.to(device), data.x2.to(device)
                x1_index, x2_index = data.x1_index.to(device), data.x2_index.to(device)
                x1_batch, x2_batch = data.x1_batch.to(
                    device), data.x2_batch.to(device)
                sim, _1, _2 = data.y.cpu().detach().numpy()

                output1, output2 = net(x1, x1_index, x1_batch, x2, x2_index, x2_batch)
                euclidean_distance = F.pairwise_distance(output1, output2).item()

                predicted = None
                if euclidean_distance < 1:
                    predicted = 0
                else:
                    predicted = 1
                    
                total += 1
                correct += (torch.tensor(predicted) == torch.tensor(sim)).sum().item()

                loss = criterion(output1, output2, sim)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        # Here we save a checkpoint. It is automatically registered with
        # Ray Tune and will potentially be passed as the `checkpoint_dir`
        # parameter in future iterations.
        with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save(
                (net.state_dict(), optimizer.state_dict()), path)

        tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
    print("Finished Training")


In [None]:
def main(num_samples, max_num_epochs, gpus_per_trial):
    config = {
            "l1": 16,#tune.sample_from(lambda _: 2 ** np.random.randint(2, 6)),
            "l2": 256,#tune.sample_from(lambda _: 2 ** np.random.randint(6, 9)),
            "outdim": 3,#tune.sample_from(lambda _: np.random.randint(1, 10)),
            "lr": tune.choice([0.0001]),
            "dropoutx": tune.choice([0.5, 0.6, 0.7, 0.8, 0.9]),
            "batch_size": tune.choice([1]),
            #"epoch": tune.choice([20,30,40,50])
        }
    scheduler = ASHAScheduler(
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    
    result = tune.run(
        tune.with_parameters(train_cifar),
        resources_per_trial={
            "cpu": 2, 
            "gpu": gpus_per_trial},
        config=config,
        metric="accuracy",
        mode="max",
        num_samples=num_samples,
        scheduler=scheduler
    )

    best_trial = result.get_best_trial("loss", "min", "last")
    print("Best trial config: {}".format(best_trial.config))
    print("Best trial final validation loss: {}".format(
        best_trial.last_result["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_trial.last_result["accuracy"]))

#     if ray.util.client.ray.is_connected():
#         # If using Ray Client, we want to make sure checkpoint access
#         # happens on the server. So we wrap `test_best_model` in a Ray task.
#         # We have to make sure it gets executed on the same node that
#         # ``tune.run`` is called on.
#         from ray.tune.utils.util import force_on_current_node
#         remote_fn = force_on_current_node(ray.remote(test_best_model))
#         ray.get(remote_fn.remote(best_trial))
    # else:
    #     test_best_model(best_trial)


In [None]:
main(num_samples=10, max_num_epochs=100, gpus_per_trial=1)