# Optimiazion Script using pytorch and tune 
## Import of libraries
It was just added ray respect to previous one. You need to run pip install ray[tune] in your environment for the next cell to work

In [None]:
from __future__ import print_function
from functools import partial
import numpy as np
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
from models.binarized_modules import  BinarizeLinear,BinarizeConv2d
from models.binarized_modules import  Binarize,HingeLoss
import matplotlib.pyplot as plt
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

## Creating the dataloader with a shared directory and the network model

In [None]:
# first lets define again the function for binarizing the image 

class ThresholdTransform(object):
    def __init__(self, thr_255):
        self.thr = thr_255  

    def __call__(self, x):
        return (x >= self.thr).to(x.dtype) 
    
#declare the transform  with a shared dir 
def load_data(data_dir="./data"):
    transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)),ThresholdTransform(thr_255=0)])
    # Get data from torchvision.datasets
    train_data = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
    test_data = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
    return train_data,test_data

# I will now create a definition of my model as the previous one  
class MY_BNN(nn.Module):
    
    def __init__(self, in_features = 28*28, neurons_l1 = 100, neurons_l2 = 100, neurons_l3 = 100, out_features = 10):
        super(MY_BNN, self).__init__()
        self.fc1 = BinarizeLinear(in_features, neurons_l1, bias = False)
        self.htanh1 = nn.Hardtanh()
        self.bn1 = nn.BatchNorm1d(neurons_l1)
        self.fc2 = BinarizeLinear(neurons_l1, neurons_l2, bias = False)
        self.htanh2 = nn.Hardtanh()
        self.bn2 = nn.BatchNorm1d(neurons_l2)
        self.fc3 = BinarizeLinear(neurons_l2, neurons_l3, bias = False)
        self.htanh3 = nn.Hardtanh()
        self.bn3 = nn.BatchNorm1d(neurons_l3)
        self.fc4 = BinarizeLinear(neurons_l3, out_features, bias = False)
        self.drop=nn.Dropout(0.5)
        self.logsoftmax=nn.LogSoftmax()

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)    
        x = self.bn1(x)    
        x = self.htanh1(x) 
        x = self.fc2(x)    
        x = self.bn2(x)    
        x = self.htanh2(x)
        x = self.fc3(x)
        x = self.drop(x)
        x = self.bn3(x)
        x = self.htanh3(x)
        x = self.fc4(x)
        return self.logsoftmax(x)




## Define a training function 

In [None]:
def train_bnn(config, checkpoint_dir=None, data_dir=None):
    
    net = MY_BNN(config["neurons_l1"],config["neurons_l2"],config["neurons_l3"])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
    
    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
        os.path.join(checkpoint_dir, "checkpoint"))
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
        
    trainset, testset = load_data(data_dir)
    
    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs])

    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True)
    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True)
    
    for epoch in range(10):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                running_loss / epoch_steps))
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data

                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((net.state_dict(), optimizer.state_dict()), path)

        tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
    print("Finished Training")

## Define a test function

In [None]:
def test_accuracy(net, device="cpu"):
    trainset, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

## Define configuration fo search space

In [None]:
config = {
    "neurons_l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "neurons_l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "neurons_l3": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 8, 16])
}

In [None]:
def main(num_samples=10, max_num_epochs=10):
    data_dir = os.path.abspath("./data")
    load_data(data_dir)

    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    reporter = CLIReporter(
         #parameter_columns=["neurons_l1", "neurons_l2", "neurons_l1","lr", "batch_size"],
        metric_columns=["loss", "accuracy", "training_iteration"])
    result = tune.run(
        partial(train_bnn, data_dir=data_dir),
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter)

    best_trial = result.get_best_trial("loss", "min", "last")
    print("Best trial config: {}".format(best_trial.config))
    print("Best trial final validation loss: {}".format(
        best_trial.last_result["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_trial.last_result["accuracy"]))

    best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if gpus_per_trial > 1:
            best_trained_model = nn.DataParallel(best_trained_model)
    best_trained_model.to(device)

    best_checkpoint_dir = best_trial.checkpoint.value
    model_state, optimizer_state = torch.load(os.path.join(
        best_checkpoint_dir, "checkpoint"))
    best_trained_model.load_state_dict(model_state)

    test_acc = test_accuracy(best_trained_model, device)
    print("Best trial test set accuracy: {}".format(test_acc))


if __name__ == "__main__":
    # You can change the number of GPUs per trial here:
    main(num_samples=10, max_num_epochs=10)