In [3]:
import torch

In [8]:
checkpoint_path = 'logs/2024_0722_0109_29_base128_bs4_lmdb/epoch_00015_iteration_000016000_checkpoint.pt'

In [None]:
checkpoint = torch.load(
            checkpoint_path, map_location=lambda storage, loc: storage)

In [7]:
checkpoint.keys

KeyboardInterrupt: 

In [1]:
import argparse
import os
import sys
import random

import torch.autograd.profiler as profiler
import wandb

import imaginaire.config
from imaginaire.config import Config
from imaginaire.utils.cudnn import init_cudnn
from imaginaire.utils.dataset import get_train_and_val_dataloader
from imaginaire.utils.distributed import init_dist, is_master, get_world_size
from imaginaire.utils.distributed import master_only_print as print
# from imaginaire.utils.gpu_affinity import set_affinity
from imaginaire.utils.misc import slice_tensor
from imaginaire.utils.logging import init_logging, make_logging_dir
from imaginaire.utils.trainer import (get_model_optimizer_and_scheduler,
                                      get_trainer, set_random_seed)

sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))

In [2]:
config = "/home/hpc/i9vl/i9vl106h/imaginaire/configs/projects/spade/kitti/base128_bs4_lmdb.yaml"

In [3]:
def parse_args():
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--config',
                        help='Path to the training config file.', required=True)
    parser.add_argument('--logdir', help='Dir for saving logs and models.')
    parser.add_argument('--checkpoint', default='', help='Checkpoint path.')
    parser.add_argument('--seed', type=int, default=2, help='Random seed.')
    parser.add_argument('--randomized_seed', action='store_true', help='Use a random seed between 0-10000.')
    parser.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0))
    parser.add_argument('--single_gpu', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--use_jit', action='store_true')
    parser.add_argument('--profile', action='store_true')
    parser.add_argument('--wandb', action='store_true')
    parser.add_argument('--wandb_name', default='default', type=str)
    parser.add_argument('--wandb_id', type=str)
    parser.add_argument('--resume', type=int)
    parser.add_argument('--num_workers', type=int)
    args = parser.parse_args()
    return args


In [4]:
cfg = Config(config)

In [5]:
train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg, 42)

LMDB file at /home/woody/i9vl/i9vl106h/data/kitti/lmdb/train/images opened.
LMDB file at /home/woody/i9vl/i9vl106h/data/kitti/lmdb/train/label opened.
Num datasets: 1
Num sequences: 1
Max sequence length: 4022
LMDB file at /home/woody/i9vl/i9vl106h/data/kitti/lmdb/val/images opened.
LMDB file at /home/woody/i9vl/i9vl106h/data/kitti/lmdb/val/label opened.
Num datasets: 1
Num sequences: 1
Max sequence length: 1005
Train dataset length: 4022
Val dataset length: 1005


In [6]:
from matplotlib import pyplot as plt

In [13]:
net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=42)

Using random seed 42
SPADE generator initialization.
Concatenate images:
    ext: png
    num_channels: 1
    normalize: False
    is_mask: False
    pre_aug_ops: None
    post_aug_ops: None
    computed_on_the_fly: False for input.
	Num. of channels in the input image: 1
Concatenate images:
    ext: png
    num_channels: 1
    normalize: False
    is_mask: False
    pre_aug_ops: None
    post_aug_ops: None
    computed_on_the_fly: False for input.
Concatenate label:
    ext: png
    num_channels: 3
    normalize: True
    use_dont_care: False
    is_mask: False
    pre_aug_ops: None
    post_aug_ops: None
    computed_on_the_fly: False for input.
	Num. of channels in the input label: 3
dict_keys(['resize_smallest_side', 'rotate', 'random_scale_limit', 'horizontal_flip', 'random_crop_h_w'])
	Crop size: (256, 1000)
	Style code dimensions: 256
	Base filter number: 128
	Convolution kernel size: 3
	Weight norm type: spectral
num_filters: 128
kernel_size: 5
separate_projection: True
activat

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Done with the SPADE generator initialization.
Multi-resolution patch discriminator initialization.
Concatenate images:
    ext: png
    num_channels: 1
    normalize: False
    is_mask: False
    pre_aug_ops: None
    post_aug_ops: None
    computed_on_the_fly: False for input.
	Num. of channels in the input image: 1
Concatenate images:
    ext: png
    num_channels: 1
    normalize: False
    is_mask: False
    pre_aug_ops: None
    post_aug_ops: None
    computed_on_the_fly: False for input.
Concatenate label:
    ext: png
    num_channels: 3
    normalize: True
    use_dont_care: False
    is_mask: False
    pre_aug_ops: None
    post_aug_ops: None
    computed_on_the_fly: False for input.
	Num. of channels in the input label: 3
	Base filter number: 128
	Number of discriminators: 2
	Number of layers in a discriminator: 5
	Weight norm type: spectral
Done with the Multi-resolution patch discriminator initialization.
Initialize net_G and net_D weights using type: xavier gain: 0.02
Usin

In [14]:
trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          train_data_loader, val_data_loader)
    

Setup trainer.
Using automatic mixed precision training.
Augmentation policy: 
GAN mode: hinge




Perceptual loss:
	Mode: vgg19
Loss GAN                  Weight 1.0
Loss Perceptual           Weight 10.0
Loss FeatureMatching      Weight 10.0
Loss GaussianKL           Weight 0.05


In [6]:
batch = next(iter(train_data_loader))

  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(


In [11]:
current_epoch = 0
current_iteration = 0
batch_size = 2

In [15]:
for epoch in range(current_epoch, cfg.max_epoch):
    print('Epoch {} ...'.format(epoch))
    trainer.start_of_epoch(current_epoch)
    for it, data in enumerate(train_data_loader):
        with profiler.profile(enabled=False,
                                use_cuda=True,
                                profile_memory=True,
                                record_shapes=True) as prof:
            data = trainer.start_of_iteration(data, current_iteration)

            for i in range(cfg.trainer.dis_step):
                trainer.dis_update(
                    slice_tensor(data, i * batch_size,
                                    (i + 1) * batch_size))
            for i in range(cfg.trainer.gen_step):
                trainer.gen_update(
                    slice_tensor(data, i * batch_size,
                                    (i + 1) * batch_size))

            current_iteration += 1
            trainer.end_of_iteration(data, current_epoch, current_iteration)

Epoch 0 ...


  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(
  warn(


Generator overflowed!
Generator overflowed!
Generator overflowed!
Generator overflowed!
Generator overflowed!
Generator overflowed!
Generator overflowed!
Generator overflowed!
Discriminator overflowed!
Discriminator overflowed!
Discriminator overflowed!
Discriminator overflowed!
Discriminator overflowed!
Discriminator overflowed!
Discriminator overflowed!
Iteration: 100, average iter time: 0.620528.


Exception: Log writer not set.