In [None]:
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np

from torch.utils import data
from datasets import VOCSegmentation, Cityscapes
from utils import ext_transforms as et
from metrics import StreamSegMetrics

import torch
import torch.nn as nn
from utils.visualizer import Visualizer

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

os.environ['PYTORCH_CUDA_ALLOC_CONF']='max_split_size_mb:4096'

class Params:
    def __init__(self):
        # Datset Options
        self.data_root ='/home/irfan/Desktop/Data/VOCtrainval_11-May-2012'#VOCdevkit/VOC2012/'                  
        self.dataset ='voc'
        self.num_classes =None

        # Deeplab Options
        available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \
                                  not (name.startswith("__") or name.startswith('_')) and callable(
                                  network.modeling.__dict__[name])
                                  )
        self.model = 'deeplabv3plus_mobilenet'
        self.separable_conv = False
        self.output_stride  = 16
        # Train Options
        self.test_only      = False
        self.save_val_results = False
        self.total_itrs     = 30e3
        self.lr             = 0.01
        self.lr_policy      = 'poly'

        self.step_size      = 10000
        self.crop_val       = False
        self.batch_size     = 8
        self.val_batch_size = 2
        self.crop_size      = 256
        self.ckpt           = None
        self.continue_training = False

        self.loss_type      = 'cross_entropy'
        self.gpu_id         = '0'
        self.weight_decay   = 1e-4
        self.random_seed    = 1
        self.print_interval = 10
        self.val_interval   = 100
        self.download       = False
        # PASCAL VOC Options
        self.year           ='2012'

        # Visdom options
        self.enable_vis     = False              
        self.vis_port       ='13570'
        self.vis_env        ='main'
        self.vis_num_samples =8
        
def get_dataset(opts):
    """ Dataset And Augmentation
    """
    if opts.dataset == 'voc':
        train_transform = et.ExtCompose([
            # et.ExtResize(size=opts.crop_size),
            et.ExtRandomScale((0.5, 2.0)),
            et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True),
            et.ExtRandomHorizontalFlip(),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])
        if opts.crop_val:
            val_transform = et.ExtCompose([
                et.ExtResize(opts.crop_size),
                et.ExtCenterCrop(opts.crop_size),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
        else:
            val_transform = et.ExtCompose([
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
        train_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
                                    image_set='train', download=opts.download, transform=train_transform)
        val_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
                                  image_set='val', download=False, transform=val_transform)

    if opts.dataset == 'cityscapes':
        train_transform = et.ExtCompose([
            # et.ExtResize( 512 ),
            et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
            et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
            et.ExtRandomHorizontalFlip(),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])

        val_transform = et.ExtCompose([
            # et.ExtResize( 512 ),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])

        train_dst = Cityscapes(root=opts.data_root,
                               split='train', transform=train_transform)
        val_dst = Cityscapes(root=opts.data_root,
                             split='val', transform=val_transform)
    return train_dst, val_dst


def validate(opts, model, loader, device, metrics, ret_samples_ids=None):
    """Do validation and return specified samples"""
    metrics.reset()
    ret_samples = []
    if opts.save_val_results:
        if not os.path.exists('results'):
            os.mkdir('results')
        denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
        img_id = 0

    with torch.no_grad():
        for i, (images, labels) in tqdm(enumerate(loader)):

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            outputs = model(images)
            preds = outputs.detach().max(dim=1)[1].cpu().numpy()
            targets = labels.cpu().numpy()

            metrics.update(targets, preds)
            if ret_samples_ids is not None and i in ret_samples_ids:  # get vis samples
                ret_samples.append(
                    (images[0].detach().cpu().numpy(), targets[0], preds[0]))

            if opts.save_val_results:
                for i in range(len(images)):
                    image = images[i].detach().cpu().numpy()
                    target = targets[i]
                    pred = preds[i]

                    image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8)
                    target = loader.dataset.decode_target(target).astype(np.uint8)
                    pred = loader.dataset.decode_target(pred).astype(np.uint8)

                    Image.fromarray(image).save('results/%d_image.png' % img_id)
                    Image.fromarray(target).save('results/%d_target.png' % img_id)
                    Image.fromarray(pred).save('results/%d_pred.png' % img_id)

                    fig = plt.figure()
                    plt.imshow(image)
                    plt.axis('off')
                    plt.imshow(pred, alpha=0.7)
                    ax = plt.gca()
                    ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0)
                    plt.close()
                    img_id += 1

        score = metrics.get_results()
    return score, ret_samples


def main(opts):
    #opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    train_dst, val_dst = get_dataset(opts)
    import pdb;pdb.set_trace()
    train_loader = data.DataLoader(
        train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2,
        drop_last=True)  # drop_last=True to ignore single-image batches.
    val_loader = data.DataLoader(
        val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model (all models are 'constructed at network.modeling)
    model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

    # Set up criterion
    # criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

    def save_ckpt(path):
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    # ==========   Train Loop   ==========#
    vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,
                                      np.int32) if opts.enable_vis else None  # sample idxs for visualization
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori images

    if opts.test_only:
        model.eval()
        val_score, ret_samples = validate(
            opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
        print(metrics.to_str(val_score))
        return

    interval_loss = 0
    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss
            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)

            if (cur_itrs) % 10 == 0:
                interval_loss = interval_loss / 10
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

            if (cur_itrs) % opts.val_interval == 0:
                save_ckpt('checkpoints/exp-1/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))
                print("validation...")
                model.eval()
                val_score, ret_samples = validate(
                    opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
                    ret_samples_ids=vis_sample_id)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt('checkpoints/exp-1/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))

                if vis is not None:  # visualize validation score and samples
                    vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
                    vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
                    vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = (denorm(img) * 255).astype(np.uint8)
                        target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
                        lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
                        concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
                        vis.vis_image('Sample %d' % k, concat_img)
                model.train()
            scheduler.step()

            if cur_itrs >= opts.total_itrs:
                return

#torch.cuda.memory._record_memory_history()
params = Params()
main(params)

Device: cuda
/home/irfan/Desktop/Data/VOCtrainval_11-May-2012/VOCdevkit/VOC2012
/home/irfan/Desktop/Data/VOCtrainval_11-May-2012/VOCdevkit/VOC2012
> [0;32m/tmp/ipykernel_3540827/2123625121.py[0m(210)[0;36mmain[0;34m()[0m
[0;32m    208 [0;31m    [0mtrain_dst[0m[0;34m,[0m [0mval_dst[0m [0;34m=[0m [0mget_dataset[0m[0;34m([0m[0mopts[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    209 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 210 [0;31m    train_loader = data.DataLoader(
[0m[0;32m    211 [0;31m        [0mtrain_dst[0m[0;34m,[0m [0mbatch_size[0m[0;34m=[0m[0mopts[0m[0;34m.[0m[0mbatch_size[0m[0;34m,[0m [0mshuffle[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m [0mnum_workers[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    212 [0;31m        drop_last=True)  # drop_last=True to ignore single-image batche

ipdb>  aa = val_dst.__getitem__(0)
ipdb>  aa[0].shape


torch.Size([3, 366, 500])


ipdb>  aa[1].shape


torch.Size([366, 500])


ipdb>  ll


[1;32m    182 [0m[0;32mdef[0m [0mmain[0m[0;34m([0m[0mopts[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;32m    183 [0m    [0;31m#opts = get_argparser().parse_args()[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[1;32m    184 [0m    [0;32mif[0m [0mopts[0m[0;34m.[0m[0mdataset[0m[0;34m.[0m[0mlower[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;34m'voc'[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;32m    185 [0m        [0mopts[0m[0;34m.[0m[0mnum_classes[0m [0;34m=[0m [0;36m21[0m[0;34m[0m[0;34m[0m[0m
[1;32m    186 [0m    [0;32melif[0m [0mopts[0m[0;34m.[0m[0mdataset[0m[0;34m.[0m[0mlower[0m[0;34m([0m[0;34m)[0m [0;34m==[0m [0;34m'cityscapes'[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;32m    187 [0m        [0mopts[0m[0;34m.[0m[0mnum_classes[0m [0;34m=[0m [0;36m19[0m[0;34m[0m[0;34m[0m[0m
[1;32m    188 [0m[0;34m[0m[0m
[1;32m    189 [0m    [0;31m# Setup visualization[0m[0;34m[0m[0;34m[0m[0;34m[0m

ipdb>  val_loader = data.DataLoader(     214         val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)


*** SyntaxError: invalid syntax


ipdb>  val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
ipdb>  for elem in val_loader: break
ipdb>  elem


[tensor([[[[-0.3883, -1.0562, -1.7069,  ..., -1.7412, -1.7412, -1.7240],
          [-0.0116, -0.4739, -0.8507,  ..., -1.5014, -1.6384, -1.6727],
          [ 0.2967,  0.0398,  0.0398,  ..., -1.2274, -1.4672, -1.6727],
          ...,
          [-1.5014, -1.5870, -1.5185,  ...,  0.7077,  0.7248,  0.7419],
          [-1.5870, -1.6213, -1.5185,  ...,  0.6734,  0.6734,  0.6906],
          [-1.6042, -1.6042, -1.5528,  ...,  0.8276,  0.8104,  0.8104]],

         [[-0.2325, -0.9328, -1.6331,  ..., -1.6506, -1.6506, -1.6331],
          [ 0.1001, -0.3725, -0.7752,  ..., -1.4405, -1.5630, -1.6331],
          [ 0.3803,  0.1352,  0.1176,  ..., -1.0903, -1.3529, -1.5805],
          ...,
          [-1.6681, -1.7731, -1.7381,  ...,  0.8179,  0.8354,  0.8529],
          [-1.7381, -1.7731, -1.7031,  ...,  0.7829,  0.7829,  0.8004],
          [-1.7206, -1.7206, -1.7031,  ...,  0.9405,  0.9230,  0.9230]],

         [[-0.0615, -0.7936, -1.4733,  ..., -1.3861, -1.4559, -1.4384],
          [ 0.4614, -0.0441, 

ipdb>  l


[1;32m    353 [0m[0;31m#torch.cuda.memory._record_memory_history()[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[1;32m    354 [0m[0mparams[0m [0;34m=[0m [0mParams[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    355 [0m[0mmain[0m[0;34m([0m[0mparams[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m



In [None]:
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")