In [1]:
import sys
sys.path.append("../")
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [3]:
#https://github.com/szagoruyko/pytorchviz

In [4]:
import torch
import torchvision
import torchvision.transforms as transforms

In [5]:
import tensorboardX
print("torch:",torch.__version__)
print("tensorboardX:",tensorboardX.__version__)

torch: 1.3.1
tensorboardX: 2.0


In [6]:
import argparse
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

best_acc1 = 0


def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch, lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


In [7]:
model_names

['alexnet',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'googlenet',
 'inception_v3',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet_v2',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext50_32x4d',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'shufflenet_v2_x1_5',
 'shufflenet_v2_x2_0',
 'squeezenet1_0',
 'squeezenet1_1',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'wide_resnet101_2',
 'wide_resnet50_2']

In [8]:
data_path = "~/image_net"

In [9]:
data_path

'~/image_net'

In [10]:
!ls {data_path}

train  val


In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

cuda:0


In [12]:
arch='resnet18'
lr=0.1

In [13]:
global best_acc1
batch_size = 4
# create model
# if args.pretrained:
#     print("=> using pre-trained model '{}'".format(args.arch))
#     model = models.__dict__[args.arch](pretrained=True)
# else:
print("=> creating model '{}'".format(arch))
model = models.__dict__[arch]()

model.to(device)


# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                            momentum=0.9,
                            weight_decay=1e-4)


# Data loading code
traindir = os.path.join(data_path, 'train')
valdir = os.path.join(data_path, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))


train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size,
    num_workers=8)

val_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=2)

# if args.evaluate:
#     validate(val_loader, model, criterion, args)
# else:
    

=> creating model 'resnet18'


In [14]:
device

device(type='cuda', index=0)

In [15]:
from modelinspector.inspector import build_state_images, stats_suffix, DataGroupType_BUFFER, DataGroupType_INPUT, DataGroupType_LABEL, DataGroupType_OUTPUT, DataGroupType_PARAM, DataGroupType_LOSS
from modelinspector.vis_utils import tensor_to_image, tensor_to_dist, fig2img
from IPython.core.display import HTML

import os
import json

In [16]:
from modelinspector.vis_utils import get_concat_v_blank, resize_img, heatmap_legend, get_concat_h_blank

def make_image(v, 
               value_name, 
               use_color_for_3channel_data=False):
    min_width = 400
    min_height = 60
    
    stats_value_name = '{}{}'.format(value_name,stats_suffix)
    
    if value_name not in v or stats_value_name not in v:
        return None
    stats = v[stats_value_name]
    
    img = tensor_to_image(
            val = v[value_name],
            vstats = stats,
            use_color_for_3channel_data = use_color_for_3channel_data)
    
    h = max(img.size[1],min_height) 
    w = max(img.size[0],min_width) 
    img = resize_img(img,w=w,min_h=h)
#     Add HM legend
    if not use_color_for_3channel_data:
        legimg = heatmap_legend(stats['min'],stats['max'])
        img = get_concat_h_blank(img,legimg)
    return img

In [17]:
def build_state_images(
                state,
                state_stats=None):
    
    state_image_data = {}
    for group_name, group_data in state['data'].items():
        image_data = {}
        for value_name in ['tensor','grad','first_delta']:
            
            img = make_image(
                    v = group_data['value'],
                    value_name = value_name,
                    use_color_for_3channel_data=group_data['data_group_type'] == DataGroupType_INPUT)
            if img is not None:
                image_data["{}__image".format(value_name)] = img
        state_image_data[group_name] = image_data
        
    return state_image_data

In [18]:
def build_cyto_graph_file(graph):
    nodes = []
    edges = []
    only_one_input = []
    for nid,node in graph.nodes.items():
        nodes.append({"data":{"id":nid,'label':node['op_type'],'component_ids':node['component_ids']}})

    for eid,edge in graph.edges.items():
        edges.append({'data':{'id':eid,"source":edge['source_id'],'target':edge['target_id']}})
        
    return {'nodes':nodes,'edges':edges}

In [19]:
import simplejson as json
import numpy as np
import numpy
import math
class StateEncoder(json.JSONEncoder):
    """
    Source: Modified from stack overflow answer
    """
    
    def default(self, obj):
        try:
            if isinstance(obj, type):
                return str(obj.__name__)
            elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
                np.int16, np.int32, np.int64, np.uint8,
                np.uint16, np.uint32, np.uint64)):
                val = int(obj)
                return val
            elif isinstance(obj, (np.float_, np.float16, np.float32, 
                np.float64)):
                val =  float(obj)
                return val
            elif isinstance(obj,(numpy.ndarray,)) or isinstance(obj,np.ndarray): #### This is the fix
                if(obj.ndim == 1 and obj.size <=200):
                    return obj.tolist()
                else:
                    return "NOT SERIALIZED" #json.JSONEncoder.default(self,{'size':obj.size,'shape':str(obj.shape),'status':"NOT SERIALIZED"})
            else:
                return json.JSONEncoder.default(self, obj)
        except TypeError as te:
            print(type(obj))
            raise te

In [20]:
def save_images(state_image_data,output_path): #TODO: use these):
    filenames = {}
    os.makedirs(output_path,exist_ok=True)
    for group_name, image_data in state_image_data.items():
        filename_list = []
        for img_name,img in image_data.items():
            filename = "{}_{}.jpg".format(group_name,img_name)
            filename_list.append({'filename':filename,'size': img.size})
            img.save(os.path.join(output_path,filename))
        state_image_data[group_name] = image_data
        filenames[group_name] = filename_list
        
    return filenames

In [21]:
def save_session(inspector,session_id,session_root):
    session_data = {
        'session_id': session_id,
        'session_metrics': { v['id']:v for v in inspector.metrics_log},
        'metric_ids':[s['id'] for s in inspector.metrics_log],
        'state_ids':[s['id'] for s in inspector.state_log],
         'component_stats':inspector.global_stats
    }

    os.makedirs(session_root,exist_ok=True);

    with open(os.path.join(session_root, "session.json"),'w') as f:
        json.dump(session_data,f,cls=StateEncoder,ignore_nan=True)

In [22]:
def save_last_state(inspector,session_root):
    state = inspector.state_log[-1]
    print("Saving {}".format(state['id']))

    state_image_data = build_state_images(state)
    graph = inspector.graph_data[state['graph_id']]
    
    graph_out = build_cyto_graph_file(graph)
    state['graph'] = graph_out
    
    # Paths
    state_path= os.path.join(session_root,state['id'])
    image_path = os.path.join(state_path,"images")
    os.makedirs(image_path,exist_ok=True);
    save_images(state_image_data,image_path)

    with open(os.path.join(state_path, "state.json"),'w') as f:
        print(state_path)
        print(state.keys())
        json.dump(state,f,cls=StateEncoder,ignore_nan=True)

In [23]:
# CONFIGURE
data_root = 'graph_web/session_data'
session_id = "test7"
session_root = os.path.join(data_root,session_id)
os.makedirs(session_root,exist_ok=True)

In [24]:
!rm -rf {session_root}

In [25]:
from modelinspector.inspector import Inspector
inspector = Inspector()

In [None]:
state_log_freq = 1000
metric_log_freq = 50

for epoch in range(0, 5000):
        adjust_learning_rate(optimizer, epoch, lr)

        # train for one epoch
#         train(train_loader, model, criterion, optimizer, epoch, args)
#         for it in sdf:
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        progress = ProgressMeter(
            len(train_loader),
            [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        # switch to train mode
        model.train()

        end = time.time()
        for i, (images, target) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)
            images = images.to(device)
            target = target.to(device)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            try:

                if i % state_log_freq == 0:
                    inspector.log_state(epoch=epoch,
                              itr=i, 
                              model=model,
                              input_dict={'input.1':images},
                              output_dict={'output.1':output},
                              loss_dict={'loss':loss},
                              label_dict={'class_label':target})
                    inspector.compute_stats()
                    save_last_state(inspector,session_root)
                    progress.display(i)


                if i % metric_log_freq == 0 or i % state_log_freq == 0:
                    inspector.log_metrics(
                        epoch=epoch,
                        itr=i, 
                        metrics={
                            'loss':loss.item(),
                            'acc1':acc1[0].item(),
                            'acc5':acc5[0].item()})

                    save_session(inspector,session_id,session_root)
            except Exception as e:
                print(e)
                        
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            
        # evaluate on validation set
        #acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        #is_best = acc1 > best_acc1
        #best_acc1 = max(acc1, best_acc1)
        
# ON CPU
# Computing graph (0, 0)..
# Computing graph (0, 0)..
# Epoch: [0][   0/4852]	Time  1.260 ( 1.260)	Data  0.072 ( 0.072)	Loss 6.6802e+00 (6.6802e+00)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)
# Epoch: [0][ 200/4852]	Time  0.593 ( 0.675)	Data  0.002 ( 0.002)	Loss 7.8017e+00 (9.1996e+00)	Acc@1   0.00 (  5.10)	Acc@5   0.00 ( 19.78)
# Epoch: [1][   0/4852]	Time  0.864 ( 0.864)	Data  0.155 ( 0.155)	Loss 9.2710e+00 (9.2710e+00)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)
# Epoch: [1][ 200/4852]	Time  0.706 ( 0.599)	Data  0.002 ( 0.002)	Loss 6.3036e+00 (5.9345e+00)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)


Computing graph (0, 0)..


  **kwargs)
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


Saving 0_0
graph_web/session_data/test7/0_0
dict_keys(['id', 'meta', 'step_info', 'additional_info', 'graph_id', 'data', 'graph'])
Epoch: [0][   0/4852]	Time  0.000 ( 0.000)	Data  0.351 ( 0.351)	Loss 6.9990e+00 (6.9990e+00)	Acc@1   0.00 (  0.00)	Acc@5   0.00 (  0.00)
Saving 0_1000
graph_web/session_data/test7/0_1000
dict_keys(['id', 'meta', 'step_info', 'additional_info', 'graph_id', 'data', 'graph'])
Epoch: [0][1000/4852]	Time  0.019 ( 0.027)	Data  0.002 ( 0.002)	Loss 7.9070e+00 (7.5817e+00)	Acc@1   0.00 (  1.00)	Acc@5   0.00 (  3.97)
Saving 0_2000
graph_web/session_data/test7/0_2000
dict_keys(['id', 'meta', 'step_info', 'additional_info', 'graph_id', 'data', 'graph'])
Epoch: [0][2000/4852]	Time  0.017 ( 0.027)	Data  0.001 ( 0.002)	Loss 7.9069e+00 (7.6247e+00)	Acc@1   0.00 (  0.50)	Acc@5   0.00 (  1.99)
Saving 0_3000
graph_web/session_data/test7/0_3000
dict_keys(['id', 'meta', 'step_info', 'additional_info', 'graph_id', 'data', 'graph'])
Epoch: [0][3000/4852]	Time  0.017 ( 0.027)	Data

In [None]:
!df -h

In [None]:
save_session(inspector,session_id,session_root)


In [None]:
print('hi')

# OLD CODE

In [None]:
# from modelinspector.vis_utils import get_concat_v_blank, resize_img, heatmap_legend, get_concat_h_blank

In [None]:
# %matplotlib agg
# import matplotlib.pyplot as plt
# #datasets = [('name',{'xmin':-10,'xmax':30},dd) for x in range(0,5)]
# def build_figs(datasets,n_bins=40):
#     fig, axs = plt.subplots(len(datasets), 1, tight_layout=True,figsize=(8,8));

#     for i,(name,stats,data) in enumerate(datasets):
#         if len(datasets)==1:
#             ax = axs
#         else:
#             ax = axs[i]

#         if data.ndim >1:
#             data = data.ravel()

#         bins = min(int(data.size/10)+2, n_bins)

#         ax.hist(data, bins=bins,range=(stats['min'],stats['max']))
#         ax.set_title(name)

#     figimg = fig2img(fig).convert("RGB")
#     plt.close()
#     return figimg

In [None]:
# ## HEATMAP
# from PIL import Image, ImageOps, ImageDraw
# image_border = (10,15,10,10)
# border_color = 'black'

# def build_heatmaps(state_image_data,graph,state,data_group_types):
#     data_group_images = {}
#     min_width= 400
#     min_height = 40
#     img_created_counter = 0
#     for data_group_type,value_name in data_group_types:
#         node_images ={}
#         for nid,node in graph.nodes.items():
#             img = None
#             for cid in node['component_ids']:
#                 component = graph.components[cid]
#                 if component['data_group_type'] != data_group_type:
#                     continue
#                 image_type = "{}__image".format(value_name)
#                 component_state = state['data'][cid]

#                 #We get our component image here, however we could generate it here instead (TODO)
#                 cimg = state_image_data.get(cid,{}).get(image_type,None)
#                 if cimg is not None:
#                     h = max(cimg.size[1],min_height) 
#                     w = max(cimg.size[0],min_width) 
#                     cimg = resize_img(cimg,w=w,min_h=h)

#                     cimg = ImageOps.expand(cimg, border=image_border,fill=border_color)
#                     draw = ImageDraw.Draw(cimg)
#                     draw.text((image_border[0], 0), "{},{},{}, shape={}".format(cid, value_name,image_type,component_state['shape']),(255, 255, 255))  # ,font=font))
#     # HEATMAP LEGEND, TODO              
#     #                 cstats = component_state['value']["{}__stats".format(value_name)]
#     #                 legimg = heatmap_legend(cstats['min'],cstats['max'])
#     #                 cimg = get_concat_h_blank(cimg,legimg)
#                     if img is None:
#                         img = cimg
#                     else:
#                         img = get_concat_v_blank(img,cimg)
#             if img is not None:
#                 node_images[nid] = img
#                 img_created_counter +=1
#         data_group_images[(data_group_type,value_name)] = node_images
#     print("{} images created".format(img_created_counter))
#     return data_group_images

In [None]:
## DIST FIGURES
# def build_dist_figures(graph,state):
#     data_group_dists = {}
#     img_created_counter = 0
#     for data_group_type,value_name in data_group_types:
#         node_images ={}
#         for nid,node in graph.nodes.items():
#             datasets=[]
#             for cid in node['component_ids']:
#                 component = graph.components[cid]
#                 if component['data_group_type'] != data_group_type:
#                     continue
#                 component_state = state['data'][cid]
#                 value = component_state['value'][value_name]
#                 stats = component_state['value']['{}__stats'.format(value_name)]
#                 datasets.append(("{},{}".format(cid,value_name),stats,value))
#             if len(datasets) >=1:
#                 img = build_figs(datasets)
#                 node_images[nid] = img
#                 img_created_counter+=1
#         data_group_dists[(data_group_type,value_name)] = node_images
#     print("{} figure created".format(img_created_counter))
#     return data_group_dists

In [None]:
graph_out['nodes']

In [None]:
!ls -lstrh {state_path}

In [None]:
# import os
# def save_node_images(group_images,image_type,output_path): #TODO: use these):
#     filenames = {}
#     os.makedirs(output_path,exist_ok=True)
#     for gid,node_images in group_images.items():
#         nfilenames = {}
#         for nid, image_data in node_images.items():
#             filename = "{}_{}_{}__{}.jpg".format(nid,gid[0],gid[1],image_type)
#             image_data.save(os.path.join(output_path,filename))
#             nfilenames[nid] = filename
#         filenames[gid] = nfilenames
#     return filenames

In [None]:
state['graph']