In [1]:
import torch
torch.__version__

'1.9.0+cu111'

In [2]:
import os, sys
del sys.path[0]; del sys.path[0] # remove local utils.py
sys.path.append(os.environ['HOME']+'/mlcomhpc/deepcam/src/deepCam')
os.environ['PMI_NO_PREINITIALIZE']='1'

In [3]:
import random
import numpy as np
from attrdict import AttrDict
from utils import parser
from data import get_dataloaders, get_datashapes

seed=42
# fix random seed
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [4]:
pargs = parser.parse_arguments('')
pargs.max_inter_threads=4
pargs

Namespace(batchnorm_group_size=1, channels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], checkpoint=None, data_dir_prefix='/', gradient_accumulation_frequency=1, local_batch_size=1, logging_frequency=100, lr_schedule=None, lr_warmup_factor=1.0, lr_warmup_steps=0, max_epochs=30, max_inter_threads=4, model_prefix='model', optimizer='Adam', optimizer_betas=[0.9, 0.999], output_dir=None, resume_logging=False, run_tag=None, save_frequency=0, seed=333, start_lr=0.001, target_iou=0.82, valid_batch_size=2, wandb_certdir='/opt/certs', weight_decay=1e-06, wireup_method='nccl-openmpi')

In [5]:
root_dir="/scratch/snx3000/dealmeih/ds/mlperf/deepcam/All-Hist/"
device='cuda'
comm_size=128
comm_rank=0
pargs.valid_batch_size=2
train_loader, train_size, validation_loader, validation_size = get_dataloaders(pargs, root_dir, device, seed, comm_size, comm_rank)
validation_size

Initialized dataset with  121216  samples.
Initialized dataset with  15158  samples.


15158

In [6]:
%%time
for i, b in enumerate(validation_loader):
    pass
i+1, b[0].shape

CPU times: user 3.15 s, sys: 2.06 s, total: 5.21 s
Wall time: 8.7 s


(60, torch.Size([1, 16, 768, 1152]))

In [7]:
from architecture import deeplab_xception
from utils import losses
from driver import train_step, validate
loss_pow = -0.125
class_weights = [0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow, 0.01327431072255291**loss_pow]

class MockLogger:
    def log_start(self, *args, **kvargs): pass
    def log_end(self, *args, **kvargs): pass
    def log_event(self, *args, **kvargs): print(*args, kvargs)

criterion = losses.CELoss(class_weights).to(device)

net = deeplab_xception.DeepLabv3_plus(n_input = 16,
                                      n_classes = 3, 
                                      os=16, pretrained=False, 
                                      rank = 0,
                                      process_group = None).to(device)

Constructing DeepLabv3+ model...
Number of output channels: 3
Output stride: 16
Number of Input Channels: 16


In [8]:
%%time
validate(pargs, comm_size, comm_rank,
         device, 0, 0,
         net, criterion, validation_loader,
         MockLogger())

{'key': 'eval_accuracy', 'value': 0.13863393639316077, 'metadata': {'epoch_num': 1, 'step_num': 0}}
{'key': 'eval_loss', 'value': 30375641.277310923, 'metadata': {'epoch_num': 1, 'step_num': 0}}
CPU times: user 10.4 s, sys: 5.42 s, total: 15.8 s
Wall time: 14.5 s


False

# Prefetch

In [9]:
validation_iter = iter(validation_loader)
import time
time.sleep(3)

In [10]:
%%time
validate(pargs, comm_size, comm_rank,
         device, 0, 0,
         net, criterion, validation_iter,
         MockLogger())

{'key': 'eval_accuracy', 'value': 0.13863393639316077, 'metadata': {'epoch_num': 1, 'step_num': 0}}
{'key': 'eval_loss', 'value': 30375641.277310923, 'metadata': {'epoch_num': 1, 'step_num': 0}}
CPU times: user 10.5 s, sys: 5.11 s, total: 15.6 s
Wall time: 14 s


False

# Test original 'validate' function

In [11]:
# base stuff
import os

# torch
import torch
import torch.distributed as dist

# custom stuff
from utils import metric


def orig_validate(pargs, comm_rank, comm_size,
             device, step, epoch, 
             net, criterion, validation_loader, 
             logger):
    
    logger.log_start(key = "eval_start", metadata = {'epoch_num': epoch+1})

    #eval
    net.eval()

    count_sum_val = torch.zeros((1), dtype=torch.float32, device=device)
    loss_sum_val = torch.zeros((1), dtype=torch.float32, device=device)
    iou_sum_val = torch.zeros((1), dtype=torch.float32, device=device)

    # disable gradients
    with torch.no_grad():

        # iterate over validation sample
        step_val = 0
        # only print once per eval at most
        for inputs_val, label_val, filename_val in validation_loader:

            #send to device
            inputs_val = inputs_val.to(device)
            label_val = label_val.to(device)
            
            # forward pass
            outputs_val = net.forward(inputs_val)
            loss_val = criterion(outputs_val, label_val)

            # accumulate loss
            loss_sum_val += loss_val
        
            #increase counter
            count_sum_val += 1.
        
            # Compute score
            predictions_val = torch.argmax(torch.softmax(outputs_val, 1), 1)
            iou_val = metric.compute_score(predictions_val, label_val, num_classes=3)
            iou_sum_val += iou_val
        
            #increase eval step counter
            step_val += 1
                
        # average the validation loss
        if dist.is_initialized():
            dist.all_reduce(count_sum_val, op=dist.ReduceOp.SUM, async_op=False)
            dist.reduce(loss_sum_val, dst=0, op=dist.ReduceOp.SUM)
            dist.all_reduce(iou_sum_val, op=dist.ReduceOp.SUM, async_op=False)
        loss_avg_val = loss_sum_val.item() / count_sum_val.item()
        iou_avg_val = iou_sum_val.item() / count_sum_val.item()

    # print results
    logger.log_event(key = "eval_accuracy", value = iou_avg_val, metadata = {'epoch_num': epoch+1, 'step_num': step})
    logger.log_event(key = "eval_loss", value = loss_avg_val, metadata = {'epoch_num': epoch+1, 'step_num': step})

    stop_training = False
    if (iou_avg_val >= pargs.target_iou):
        logger.log_event(key = "target_accuracy_reached", value = pargs.target_iou, metadata = {'epoch_num': epoch+1, 'step_num': step})
        stop_training = True

    # set to train
    net.train()

    logger.log_end(key = "eval_stop", metadata = {'epoch_num': epoch+1})
    
    return stop_training

In [12]:
pargs.valid_batch_size=1
train_loader, train_size, validation_loader, validation_size = get_dataloaders(pargs, root_dir, device, seed, comm_size, comm_rank)

Initialized dataset with  121216  samples.
Initialized dataset with  15158  samples.


In [13]:
%%time
orig_validate(pargs, comm_size, comm_rank,
         device, 0, 0,
         net, criterion, validation_loader,
         MockLogger())



{'key': 'eval_accuracy', 'value': 0.13863393639316077, 'metadata': {'epoch_num': 1, 'step_num': 0}}
{'key': 'eval_loss', 'value': 30375630.521008402, 'metadata': {'epoch_num': 1, 'step_num': 0}}
CPU times: user 13 s, sys: 4.43 s, total: 17.5 s
Wall time: 15.2 s


False