In [None]:
import argparse
import os
import sys
import uuid
from datetime import datetime as dt

import shutil
import logging
import time
import timeit
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data.distributed
import wandb
from tqdm import tqdm

# Distributed training
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import torch.nn.functional as F

import model_io
import models
import utils
from dataloader import DepthDataLoader
# from loss import SILogLoss, BinsChamferLoss
from utils import RunningAverage, colorize
import matplotlib

from datasets import Cityscapes
from loss import CrossEntropy

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def is_distributed():
    return dist.is_initialized()

def get_world_size():
    if not dist.is_initialized():
        return 1
    return dist.get_world_size()

def get_rank():
    if not dist.is_initialized():
        return 0
    return dist.get_rank()

def get_sampler(dataset):
    if is_distributed():
        from torch.utils.data.distributed import DistributedSampler
        return DistributedSampler(dataset)
    else:
        return None

def is_rank_zero(args):
    return args.rank == 0

def reduce_tensor(inp):
    """
    Reduce the loss from all processes so that 
    process with rank 0 has the averaged results.
    """
    world_size = dist.get_world_size()
    if world_size < 2:
        return inp
    with torch.no_grad():
        reduced_inp = inp
        torch.distributed.reduce(reduced_inp, dst=0)
    return reduced_inp / world_size

######################################
seed = 304
gpus = [0,]
local_rank = 0

# Cityscapes
data_root = '/newHDD/datasets/Cityscapes/'
train_list = 'train.lst'
val_list = 'val.lst'
######################################

import argparse

# Define the argument parser (this part remains the same)
parser = argparse.ArgumentParser(description='Training script. Default values of all arguments are recommended for reproducibility', fromfile_prefix_chars='@',
                                 conflict_handler='resolve')
parser.add_argument('--epochs', default=25, type=int, help='number of total epochs to run')
parser.add_argument('--n-bins', '--n_bins', default=80, type=int, help='number of bins/buckets to divide depth range into')
parser.add_argument('--lr', '--learning-rate', default=0.000357, type=float, help='max learning rate')
parser.add_argument('--wd', '--weight-decay', default=0.1, type=float, help='weight decay')
parser.add_argument('--w_chamfer', '--w-chamfer', default=0.1, type=float, help="weight value for chamfer loss")
parser.add_argument('--div-factor', '--div_factor', default=25, type=float, help="Initial div factor for lr")
parser.add_argument('--final-div-factor', '--final_div_factor', default=100, type=float, help="final div factor for lr")
parser.add_argument('--batch_size', default=4, type=int, help='batch size')
parser.add_argument('--validate-every', '--validate_every', default=100, type=int, help='validation period')
parser.add_argument('--gpu', default=None, type=int, help='Which gpu to use')

parser.add_argument("--norm", default="linear", type=str, help="Type of norm/competition for bin-widths", choices=['linear', 'softmax', 'sigmoid'])
parser.add_argument("--same-lr", '--same_lr', default=False, action="store_true", help="Use same LR for all param groups")
parser.add_argument("--distributed", default=False, action="store_true", help="Use DDP if set")

parser.add_argument("--root", default="./experiments", type=str, help="Root folder to save data in")
parser.add_argument("--name", default="UnetAdaptiveBins")
parser.add_argument("--resume", default='', type=str, help="Resume from checkpoint")

parser.add_argument("--notes", default='', type=str, help="Wandb notes")
parser.add_argument("--tags", default='sweep', type=str, help="Wandb tags")
parser.add_argument("--workers", default=11, type=int, help="Number of workers for data loading")
parser.add_argument("--dataset", default='nyu', type=str, help="Dataset to train on")
parser.add_argument("--data_path", default='../dataset/nyu/sync/', type=str, help="path to dataset")
parser.add_argument("--gt_path", default='../dataset/nyu/sync/', type=str, help="path to dataset")
parser.add_argument('--filenames_file', default="./train_test_inputs/nyudepthv2_train_files_with_gt.txt", type=str, help='path to the filenames text file')
parser.add_argument('--input_height', type=int, help='input height', default=416)
parser.add_argument('--input_width', type=int, help='input width', default=544)
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
parser.add_argument('--min_depth', type=float, help='minimum depth in estimation', default=1e-3)
parser.add_argument('--do_random_rotate', default=True, help='if set, will perform random rotation for augmentation', action='store_true')
parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
parser.add_argument('--data_path_eval', default="../dataset/nyu/official_splits/test/", type=str, help='path to the data for online evaluation')
parser.add_argument('--gt_path_eval', default="../dataset/nyu/official_splits/test/", type=str, help='path to the groundtruth data for online evaluation')
parser.add_argument('--filenames_file_eval', default="./train_test_inputs/nyudepthv2_test_files_with_gt.txt", type=str, help='path to the filenames text file for online evaluation')
parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=10)
parser.add_argument('--eigen_crop', default=True, help='if set, crops according to Eigen NIPS14', action='store_true')
parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')

# Segmentation
parser.add_argument('--n_semantic_classes', help='Number of semantic classes', default=19, type=int)
parser.add_argument('--img_width', help='Width of input image', default=1024, type=int) 
parser.add_argument('--img_height', help='Height of input image', default=512, type=int)
parser.add_argument('--base_size', help='Base size of input image', default=2048, type=int) # Cityscapes original size: 2048x1024
parser.add_argument('--ignore_label', help='Label to ignore', default=255, type=int)

arg_list = [
    "--batch_size", "4",
]
args = parser.parse_args(arg_list)

def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    
    crop_size = (args.img_height, args.img_width)
    train_dataset = Cityscapes(root=data_root,
                            list_path=train_list,
                            num_samples=None,
                            num_classes=args.n_semantic_classes,
                            multi_scale=True,
                            flip=True,
                            ignore_label=args.ignore_label,
                            base_size=args.base_size,
                            crop_size=crop_size,
                            downsample_rate=1,
                            scale_factor=16)

    train_sampler = get_sampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
            drop_last=True,
            sampler=train_sampler)
    
    val_dataset = Cityscapes(root=data_root,
                            list_path=val_list,
                            num_samples=None,
                            num_classes=args.n_semantic_classes,
                            multi_scale=False,
                            flip=False,
                            ignore_label=args.ignore_label,
                            base_size=args.base_size,
                            crop_size=crop_size,
                            downsample_rate=1)

    val_sampler = get_sampler(val_dataset)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        sampler=val_sampler)

    criterion_entropy = CrossEntropy(ignore_label=args.ignore_label,
                            weight=train_dataset.class_weights)

    print("Load model")
    model = models.UnetAdaptiveSegmentation.build(n_classes=args.n_semantic_classes)

    if args.gpu is not None:  # If a gpu is set by user: NO PARALLELISM!!
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    args.multigpu = False
    if args.distributed:
        # Use DDP
        args.multigpu = True
        args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        # args.batch_size = 8
        args.workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node)
        print(args.gpu, args.rank, args.batch_size, args.workers)
        torch.cuda.set_device(args.gpu)
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = model.cuda(args.gpu)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], output_device=args.gpu,
                                                            find_unused_parameters=True)
    elif args.gpu is None:
        # Use DP
        args.multigpu = True
        model = model.cuda()
        model = torch.nn.DataParallel(model)
            
    args.epoch = 0
    args.last_epoch = -1

    # train(model, args, epochs=args.epochs, lr=args.lr, device=args.gpu, root=args.root,
    #       experiment_name=args.name, optimizer_state_dict=None)

    should_write = ((not args.distributed) or args.rank == 0)
    
    if args.gpu is None:
        device = torch.device('cuda')
    else:
        device = args.gpu

    model.train()
    
    if args.same_lr:
        print("Using same LR")
        params = model.parameters()
    else:
        print("Using diff LR")
        m = model.module if args.multigpu else model
        params = [{"params": m.get_1x_lr_params(), "lr": args.lr / 10},
                    {"params": m.get_10x_lr_params(), "lr": args.lr}]

    optimizer = optim.AdamW(params, weight_decay=args.wd, lr=args.lr)

    iters = len(train_loader)
    step = args.epoch * iters
    # best_loss = np.inf
    best_mIoU = 0

    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                            args.lr, 
                                            epochs=args.epochs, 
                                            steps_per_epoch=len(train_loader),
                                            cycle_momentum=True,
                                            base_momentum=0.85, max_momentum=0.95, last_epoch=args.last_epoch,
                                            div_factor=args.div_factor, final_div_factor=args.final_div_factor)
    
    print("Start training")
    for cur_epoch in range(args.epoch, args.epochs):
        for i, batch in tqdm(enumerate(train_loader), desc=f"Epoch: {cur_epoch + 1}/{args.epochs}. Loop: Train",
                                total=len(train_loader)) if is_rank_zero(args) else enumerate(train_loader):
            optimizer.zero_grad()
            
            images, labels, _, _ = batch
            images = images.cuda()
            labels = labels.long().cuda()
            
            outputs = model(images)
            
            loss = criterion_entropy(outputs, labels)
            
            # model.zero_grad()
            loss.backward()
            optimizer.step()
            
            # if step % 10 == 0 and dist.get_rank() == 0:
            if step % 10 == 0:
                msg = 'Epoch: [{}/{}] Step:[{}], lr: {}, Loss: {:.6f}'.format(cur_epoch, args.epochs, step, [x['lr'] for x in optimizer.param_groups], loss)
            
            step += 1
            scheduler.step()
            
            if should_write and step % args.validate_every == 0:
                
                model.eval()
                valid_loss, mean_IoU, IoU_array = validate(args, val_loader, model, criterion_entropy, cur_epoch, args.epochs, device)

                model_io.save_checkpoint(model, optimizer, cur_epoch, f"{args.name}_latest.pt",
                                            root=os.path.join(args.root, "checkpoints"))

                if mean_IoU > best_mIoU:
                    model_io.save_checkpoint(model, optimizer, cur_epoch, f"{args.name}_best.pt",
                                            root=os.path.join(args.root, "checkpoints"))
                    best_mIoU = mean_IoU

                model.train()

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None

    def initialize(self, val, weight):
        self.val = val
        self.avg = val
        self.sum = val * weight
        self.count = weight
        self.initialized = True

    def update(self, val, weight=1):
        if not self.initialized:
            self.initialize(val, weight)
        else:
            self.add(val, weight)

    def add(self, val, weight):
        self.val = val
        self.sum += val * weight
        self.count += weight
        self.avg = self.sum / self.count

    def value(self):
        return self.val

    def average(self):
        return self.avg
    
def get_confusion_matrix(label, pred, size, num_class, ignore=-1):
    """
    Calcute the confusion matrix by given label and pred
    """
    output = pred.cpu().numpy().transpose(0, 2, 3, 1)
    seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8)
    seg_gt = np.asarray(
    label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=int)

    ignore_index = seg_gt != ignore
    seg_gt = seg_gt[ignore_index]
    seg_pred = seg_pred[ignore_index]

    index = (seg_gt * num_class + seg_pred).astype('int32')
    label_count = np.bincount(index)
    confusion_matrix = np.zeros((num_class, num_class))

    for i_label in range(num_class):
        for i_pred in range(num_class):
            cur_index = i_label * num_class + i_pred
            if cur_index < len(label_count):
                confusion_matrix[i_label,
                                 i_pred] = label_count[cur_index]
    return confusion_matrix

def validate(args, test_loader, model, criterion, epoch, epochs, device='cuda'):
    """
    Validation function to compute loss, confusion matrix, and IoU metrics.
    Args:
        args: Command-line arguments including hyperparameters and paths.
        test_loader: DataLoader for validation data.
        model: Model to be evaluated.
        criterion: Loss function (e.g., CrossEntropy).
        epoch: Current epoch for logging.
        epochs: Total number of epochs.
        device: Device to use for computation ('cuda' or 'cpu').
    """
    model.eval()
    ave_loss = AverageMeter()
    nums = args.n_semantic_classes
    confusion_matrix = np.zeros((nums, nums))

    with torch.no_grad():
        for idx, batch in enumerate(test_loader):
            images, labels, _, _ = batch
            size = labels.size()
            images = images.to(device)
            labels = labels.long().to(device)

            # Forward pass
            outputs = model(images)

            # Resize the output to match the target size
            if isinstance(outputs, (list, tuple)):
                outputs = [F.interpolate(output, size=size[-2:], mode='bilinear', align_corners=True) for output in outputs]
            else:
                outputs = F.interpolate(outputs, size=size[-2:], mode='bilinear', align_corners=True)

            # Calculate loss
            loss = criterion(outputs, labels)
            reduced_loss = reduce_tensor(loss) if args.distributed else loss
            ave_loss.update(reduced_loss.item())

            # Compute confusion matrix
            if isinstance(outputs, (list, tuple)):
                for output in outputs:
                    confusion_matrix += get_confusion_matrix(
                        labels, output, size, args.n_semantic_classes, args.ignore_label
                    )
            else:
                confusion_matrix += get_confusion_matrix(
                    labels, outputs, size, args.n_semantic_classes, args.ignore_label
                )

            if idx % 10 == 0 and is_rank_zero(args):
                print(f"Validation: Iteration [{idx}/{len(test_loader)}], Loss: {reduced_loss.item()}")

    # Reduce confusion matrix across GPUs
    if args.distributed:
        confusion_matrix = torch.from_numpy(confusion_matrix).to(device)
        reduced_confusion_matrix = reduce_tensor(confusion_matrix)
        confusion_matrix = reduced_confusion_matrix.cpu().numpy()

    # Compute mean IoU
    pos = confusion_matrix.sum(1)
    res = confusion_matrix.sum(0)
    tp = np.diag(confusion_matrix)
    IoU_array = tp / np.maximum(1.0, pos + res - tp)
    mean_IoU = IoU_array.mean()

    if is_rank_zero(args):
        print(f"Epoch: {epoch}/{epochs}, Validation Loss: {ave_loss.average():.6f}, Mean IoU: {mean_IoU:.6f}")
        logging.info(f"Epoch: {epoch}/{epochs}, IoU per class: {IoU_array}, Mean IoU: {mean_IoU:.6f}")

    return ave_loss.average(), mean_IoU, IoU_array





if seed > 0:
    import random
    random.seed(seed)
    torch.manual_seed(seed)
    
args.num_threads = args.workers
args.mode = 'train'
    
# try:
#     node_str = os.environ['SLURM_JOB_NODELIST'].replace('[', '').replace(']', '')
#     nodes = node_str.split(',')
#     args.world_size = len(nodes)
#     args.rank = int(os.environ['SLURM_PROCID'])
# except KeyError as e:
#     # We are NOT using SLURM
#     args.world_size = 1
#     args.rank = 0
#     nodes = ["127.0.0.1"]
args.world_size = 1
args.rank = 0
nodes = ["127.0.0.1"]

# if args.distributed:
#     mp.set_start_method('forkserver')
#     print(args.rank)
#     port = np.random.randint(15000, 15025)
#     args.dist_url = 'tcp://{}:{}'.format(nodes[0], port)
#     print(args.dist_url)
#     args.dist_backend = 'nccl'
#     args.gpu = None

ngpus_per_node = torch.cuda.device_count()
args.num_workers = args.workers
args.ngpus_per_node = ngpus_per_node

if args.distributed:
    args.world_size = ngpus_per_node * args.world_size
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
    if ngpus_per_node == 1:
        args.gpu = 0
    main_worker(args.gpu, ngpus_per_node, args)
