In [None]:
from hops import hdfs
from torchvision import models

In [None]:
def train_fn(module, hparams, train_set, test_set):
    
    import time
    import os
    import pickle
    
    import torch
    import torch.nn as nn
    import torch.distributed as dist
    from torch.utils.data import DataLoader
    from torch.cuda.amp import GradScaler, autocast
    from torchvision import transforms as T
        
    from hops import hdfs

    
    model = module(**hparams)
    
    n_epochs = 1
    n_exec = 2
    batch_size = 128
    lr_base = 0.1 * n_exec*batch_size/256
    
    def train_transform(image_net_row):
        transform = T.Compose([
            T.ToTensor(),
            T.RandomCrop(224),
            T.RandomHorizontalFlip(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])
        return {"image": transform(image_net_row['image']), "label": image_net_row['label']}
    
    def test_transform(image_net_row):
        transform = T.Compose([
            T.ToTensor(),
            T.CenterCrop(224),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
        ])
        return {"image": transform(image_net_row['image']), "label": image_net_row['label']}

    # Parameters as in https://arxiv.org/pdf/1706.02677.pdf
    optimizer = torch.optim.SGD(model.parameters(), lr=lr_base, momentum=0.9, weight_decay=0.0001, nesterov=True)

    loss_criterion = nn.CrossEntropyLoss()
    
    train_loader = DataLoader(train_set, pin_memory=True, batch_size=batch_size, transform_spec=train_transform)
    test_loader = DataLoader(test_set, pin_memory=True, batch_size=batch_size, transform_spec=test_transform)
                
    def time_to_h_m_s(t_diff):
        minutes, seconds = divmod(t_diff, 60)
        hours, minutes = divmod(minutes, 60)
        return hours, minutes, seconds

    def print_train_time(t_0, batch, n_batches, epoch, n_epochs):
        t_diff = time.time() - t_0
        tr_time = time_to_h_m_s(t_diff)
        t_est = t_diff * (n_epochs*n_batches/(epoch*n_batches + idx+1) - 1)
        est_time = time_to_h_m_s(t_est)
        print("Training time: {:.0f}h {:.0f}m {:.0f}s\nEstimated remaining time: {:.0f}h {:.0f}m {:.0f}s.".format(*tr_time, *est_time))
    
    ### Logging ###
    log_path = hdfs.project_path() + "Experiments/" + config.name + "/training_log_" + os.environ["RANK"] + ".log"
    if hdfs.exists(log_path) and hdfs.isfile(log_path):
        hdfs.delete(log_path)
    hdfs.dump("ep,lr,top1acc,t_load,t_forward,t_backward,t_step,t_ep,mem_f_g,mem_b_g,mem_op_g,checkpoint_time", log_path)
        
    def log_training(ep, top1acc, t_load, t_forward, t_backward, t_step, t_ep, mem_f_g, mem_b_g, mem_op_g, checkpoint_time):
        f = hdfs.open_file(log_path, flags="at")
        try:
            f.write("\n")
            f.write(f"{ep},{lr_schedule(ep)},{top1acc},{t_load},{t_forward},{t_backward},{t_step},{t_ep},{mem_f_g},{mem_b_g},{mem_op_g},{checkpoint_time}")
        finally:
            f.close()
        
    def eval_model(model, test_loader):
        acc = 0
        model.eval()
        img_cnt = 0
        with torch.no_grad():
            with model.join():
                for idx, data in enumerate(test_loader):
                    print("Testing batch {}".format(idx))
                    img, label = data["image"].float(), data["label"].float()  # permute(0,3,1,2).contiguous().
                    prediction = model(img)
                    acc += torch.sum(torch.argmax(prediction, dim=1) == label).detach()
                    img_cnt += len(label.detach())
        acc = acc/float(img_cnt)
        print("Test accuracy: {:.3f}".format(acc))
        print("-"*20)
        return acc
    
    def checkpoint(model, optimizer, epoch):
        print("Saving model...")
        checkpoint_ = {"model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch}
        path = hdfs.project_path() + "Experiments/checkpoint/" + "model_e" + str(epoch)
        try:
            hdfs.dump(pickle.dumps(checkpoint_), path)
            print("Model saved.")
        except:
            print("Save abort.")
    
    def lr_schedule(epoch):
        if epoch in range(5):
            return lr_base * (epoch+1)/5
        elif epoch < 30:
            return lr_base
        elif epoch < 60:
            return lr_base * 0.1
        elif epoch < 80:
            return lr_base * 0.1**2
        return lr_base * 0.1**3

    # Mixed precision training with learning rate scheduling.
    scaler = GradScaler()

    model.train()
    t_0 = time.time()
    for epoch in range(n_epochs):
        print("-"*20 + "\nStarting new epoch\n")
        model.train()
        for g in optimizer.param_groups:
            g['lr'] = lr_schedule(epoch)
        print("Epoch {}, lr: {}".format(epoch, lr_schedule(epoch)))
        t_load = 0
        t_forward = 0
        t_backward = 0
        t_step = 0
        t_ep_0 = time.time()

        with model.join():
            t_load_0 = time.time()

            for idx, data in enumerate(train_loader):
                t_load_1 = time.time()
                print("DataLoader load time: {:.3f}s, Batch: {}".format(t_load_1-t_load_0, idx))
                print("Batch size: {}".format(len(data["image"])))
                if len(data["image"]) != batch_size:
                    print("Batch size mismatch detected")
                    continue
                if len(data["image"]) != batch_size:
                    print("Unreachable flag")
                img, label = data["image"].float(), data["label"].float()  # .permute(0,3,1,2).contiguous()
                with autocast():
                    if idx == 0:
                        mem_f_pre = torch.cuda.max_memory_allocated(0)
                    t_forward_0 = time.time()
                    prediction = model(img)
                    t_forward_1 = time.time()
                    if idx == 0:
                        mem_f_post = mem_b_pre = torch.cuda.max_memory_allocated(0)
                    loss = loss_criterion(prediction, label.long())
                t_backward_0 = time.time()
                scaler.scale(loss).backward()
                t_backward_1 = time.time()
                if idx == 0:
                    mem_b_post = mem_op_pre = torch.cuda.max_memory_allocated(0)
                t_step_0 = time.time()
                scaler.step(optimizer)
                t_step_1 = time.time()
                if idx == 0:
                    mem_op_post = torch.cuda.max_memory_allocated(0)

                scaler.update()
                optimizer.zero_grad()

                if idx%(10) == 0:
                    print(f"Working on batch {idx}")

                t_load += t_load_1 - t_load_0
                t_forward += t_forward_1 - t_forward_0
                t_backward += t_backward_1 - t_backward_0
                t_step += t_step_1 - t_step_0
                mem_f_g = mem_f_post - mem_f_pre
                mem_b_g = mem_b_post - mem_b_pre
                mem_op_g = mem_op_post - mem_op_pre

                t_load_0 = time.time()
                print("Batch computation time: {:.3f}s, Batch: {}".format(t_load_0-t_load_1, idx))

        t_ep_1 = t_check_0 = time.time()
        if os.environ["RANK"] == "0" and epoch%10 == 0:
            checkpoint(model, optimizer, epoch)
        t_check_1 = time.time()
        print("Epoch training took {:.0f}s.\n".format(t_ep_1-t_ep_0))
        acc = eval_model(model, test_loader)
        log_training(epoch, acc, t_load, t_forward, t_backward, t_step, t_ep_1-t_ep_0, mem_f_g, mem_b_g, mem_op_g, t_check_1-t_check_0)
    t_1 = time.time()
    minutes, seconds = divmod(t_1 - t_0, 60)
    hours, minutes = divmod(minutes, 60)
    print("-"*20 + "\nTotal training time: {:.0f}h {:.0f}m {:.0f}s.".format(hours, minutes, seconds))
    return float(acc)

In [None]:
train_ds = hdfs.project_path() + "DataSets/ImageNet/PetastormImageNette/train"
test_ds = hdfs.project_path() + "DataSets/ImageNet/PetastormImageNette/test"
print(hdfs.exists(train_ds), hdfs.exists(test_ds))

In [None]:
from maggy import experiment
from maggy.experiment_config import TorchDistributedConfig

config = TorchDistributedConfig(name='ImageNet_ddp_Z1/2', module=models.resnet50, hparams={"pretrained": False}, train_set=train_ds, test_set=test_ds, backend="ddp", zero_lvl=1)

In [None]:
result = experiment.lagom(train_fn, config)