In [1]:
from functools import partial
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, TensorDataset
from torch.autograd import Variable
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from resnet import ResNet18

In [2]:
exp_root = "/data/user/ladilova/fed_colearning/FederatedLearningViaCoTraining/cifar10"

In [3]:
def load_data():
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(
        root='/data/user/ladilova/fed_colearning/FederatedLearningViaCoTraining/cifar10/data', train=True, download=True, transform=transform_train)

    testset = torchvision.datasets.CIFAR10(
        root='/data/user/ladilova/fed_colearning/FederatedLearningViaCoTraining/cifar10/data', train=False, download=True, transform=transform_test)

    return trainset, testset

In [4]:
trainset, _ = load_data()
trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=8,
        shuffle=True,
        num_workers=8)
net = ResNet18()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs, labels
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if i > 5:
            print("training works")
            break

Files already downloaded and verified
Files already downloaded and verified
training works


In [5]:
def train_net(config, checkpoint_dir=None, data_dir=None):
    net = ResNet18()

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

    if checkpoint_dir:
        model_state, optimizer_state = torch.load(os.path.join(checkpoint_dir, "checkpoint"))
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    trainset, _ = load_data()

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(trainset, [test_abs, len(trainset) - test_abs])

    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=8)
    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=8)

    for epoch in range(10):  # loop over the dataset multiple times
        print("Training started...")
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps))
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((net.state_dict(), optimizer.state_dict()), path)

        tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
    print("Finished Training")

In [6]:
def test_accuracy(net, device="cpu"):
    _, testset = load_data()

    testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [7]:
data_dir = "network_tune"
config = {
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 16, 64, 128, 512])
}
scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=10, #not sure
    grace_period=1,
    reduction_factor=2)
reporter = CLIReporter(
    parameter_columns=["lr", "batch_size"],
    metric_columns=["loss", "accuracy", "training_iteration"])
result = tune.run(
    partial(train_net, data_dir=data_dir),
    local_dir = './network_tune',
    name='resnet18_cifar10',
    resources_per_trial={"cpu": 8, "gpu": 1},
    config=config,
    num_samples=20, # not sure
    scheduler=scheduler,
    progress_reporter=reporter)

best_trial = result.get_best_trial("loss", "min", "last")
print("Best trial config: {}".format(best_trial.config))
print("Best trial final validation loss: {}".format(best_trial.last_result["loss"]))
print("Best trial final validation accuracy: {}".format(best_trial.last_result["accuracy"]))

2021-05-06 18:44:54,839	INFO services.py:1174 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2021-05-06 18:44:57,343	INFO registry.py:65 -- Detected unknown callable for trainable. Converting to class.


== Status ==
Memory usage on this node: 27.7/754.6 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 8/64 CPUs, 1/2 GPUs, 0.0/542.72 GiB heap, 0.0/128.52 GiB objects (0/1.0 accelerator_type:V100S)
Result logdir: /data/user/ladilova/fed_colearning/FederatedLearningViaCoTraining/cifar10/network_tune/resnet18_cifar10
Number of trials: 1/20 (1 RUNNING)
+---------------------+----------+-------+-------------+--------------+
| Trial name          | status   | loc   |          lr |   batch_size |
|---------------------+----------+-------+-------------+--------------|
| DEFAULT_6429b_00000 | RUNNING  |       | 0.000174246 |          128 |
+---------------------+----------+-------+-------------+--------------+


[2m[36m(pid=58040)[0m Files already downloaded and verified
[2m[36m(pid=58040)[0m Files already downloaded and verified
[2m[36m(pid=58040)[0m Training started...
[2m[36m(pid=58037)[

2021-05-06 19:52:53,459	INFO tune.py:450 -- Total run time: 4079.62 seconds (4075.94 seconds for the tuning loop).


Result for DEFAULT_6429b_00019:
  accuracy: 0.3822
  date: 2021-05-06_19-52-53
  done: true
  experiment_id: 310546b08415423fa4019a43fb400abf
  hostname: dpl20
  iterations_since_restore: 1
  loss: 1.6884864147007466
  node_ip: 10.116.44.44
  pid: 58099
  should_checkpoint: true
  time_since_restore: 444.35638976097107
  time_this_iter_s: 444.35638976097107
  time_total_s: 444.35638976097107
  timestamp: 1620323573
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 6429b_00019
  
== Status ==
Memory usage on this node: 34.1/754.6 GiB
Using AsyncHyperBand: num_stopped=20
Bracket: Iter 8.000: -0.5059078049056137 | Iter 4.000: -0.7324895179724391 | Iter 2.000: -1.0640824724183702 | Iter 1.000: -1.490436634707451
Resources requested: 8/64 CPUs, 1/2 GPUs, 0.0/542.72 GiB heap, 0.0/128.52 GiB objects (0/1.0 accelerator_type:V100S)
Result logdir: /data/user/ladilova/fed_colearning/FederatedLearningViaCoTraining/cifar10/network_tune/resnet18_cifar10
Number of trials: 20/20 (1 RUNN

In [11]:
trainset, _ = load_data()
trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=64,
        shuffle=True,
        num_workers=8)
net = ResNet18()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
for e in range(200):
    total = 0
    correct = 0
    train_loss = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs, labels
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    print('Epoch %d, Loss: %.3f | Acc: %.3f%% (%d/%d)' % (e, train_loss/(i+1), 100.*correct/total, correct, total))

Files already downloaded and verified
Files already downloaded and verified
Epoch 0, Loss: 1.411 | Acc: 48.440% (24220/50000)
Epoch 1, Loss: 0.889 | Acc: 68.674% (34337/50000)
Epoch 2, Loss: 0.690 | Acc: 75.978% (37989/50000)
Epoch 3, Loss: 0.584 | Acc: 79.698% (39849/50000)
Epoch 4, Loss: 0.507 | Acc: 82.342% (41171/50000)
Epoch 5, Loss: 0.454 | Acc: 84.176% (42088/50000)
Epoch 6, Loss: 0.410 | Acc: 85.712% (42856/50000)
Epoch 7, Loss: 0.378 | Acc: 86.834% (43417/50000)
Epoch 8, Loss: 0.349 | Acc: 87.988% (43994/50000)
Epoch 9, Loss: 0.319 | Acc: 88.824% (44412/50000)
Epoch 10, Loss: 0.292 | Acc: 89.950% (44975/50000)
Epoch 11, Loss: 0.276 | Acc: 90.366% (45183/50000)
Epoch 12, Loss: 0.262 | Acc: 90.970% (45485/50000)
Epoch 13, Loss: 0.244 | Acc: 91.412% (45706/50000)
Epoch 14, Loss: 0.224 | Acc: 92.242% (46121/50000)
Epoch 15, Loss: 0.212 | Acc: 92.540% (46270/50000)
Epoch 16, Loss: 0.199 | Acc: 93.138% (46569/50000)
Epoch 17, Loss: 0.185 | Acc: 93.436% (46718/50000)
Epoch 18, Loss: 

KeyboardInterrupt: 