<a href="https://colab.research.google.com/github/FOGuzman/PulseIlluminationVideo/blob/main/Compressive_Video_via_Pulsed_Illumination.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Compressive video via Pulsed Illumination

## Setup

*   Preparing repository
*   Download DAVIS2017 dataset
*   Install extra dependencies

In [1]:
!git clone https://github.com/FOGuzman/PulseIlluminationVideo.git
%cd PulseIlluminationVideo/
!wget https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip
!mkdir -p ./dataset/
!unzip DAVIS-2017-trainval-480p.zip -d ./dataset/
!pip install -r requirements.txt

Cloning into 'PulseIlluminationVideo'...
remote: Enumerating objects: 429, done.[K
remote: Total 429 (delta 0), reused 0 (delta 0), pack-reused 429[K
Receiving objects: 100% (429/429), 77.38 MiB | 25.85 MiB/s, done.
Resolving deltas: 100% (203/203), done.
/content/PulseIlluminationVideo


## Import libraries

In [4]:
from unet import UNet
from inverse import  StandardConv2D
import os
import os.path as osp
import sys 
BASE_DIR = osp.dirname(osp.dirname(osp.abspath("__file__")))
sys.path.append(BASE_DIR)
import utils
from cacti.datasets.builder import build_dataset 
from cacti.models.builder import build_model
from cacti.utils.optim_builder import  build_optimizer
from cacti.utils.loss_builder import build_loss
from torch.utils.data import DataLoader
from cacti.utils.mask import generate_masks
from cacti.utils.config import Config
from cacti.utils.logger import Logger
from cacti.utils.utils import save_image, load_checkpoints, get_device_info
from cacti.utils.eval_coarse import eval_psnr_ssim
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import StepLR
from google.colab import drive
drive.mount('/content/gdrive')
import time
import argparse 
import json 
import einops

[Errno 2] No such file or directory: 'PulseIlluminationVideo/'
/content/PulseIlluminationVideo


## Train initblock (Unet)

In [5]:
parser = argparse.ArgumentParser() 
args = parser.parse_args(args=[])
args.config = './configs/STFormer/stformer_base.py'
args.work_dir = './train_results/3meas_coarse_nm/'
args.dataset_path = './dataset/DAVIS/JPEGImages/480p/'
args.device = "cuda"
args.resolution = [256,256]
args.frames = 16
args.dataset_crop = [128,128]
args.distributed = False
args.resume = None
args.Epochs = 400
args.batch_size = 18
args.learning_rate = 0.0001
args.saveImageEach = 500
args.saveModelEach = 2
args.checkpoints = None
args.local_rank = -1
args.device = "cuda" if torch.cuda.is_available() else "cpu"

if os.path.exists('/content/gdrive'):
    print("GDrive Mounted, saving results on MyDrive/PulsedIlluminationRepository/results/")
    args.gdrivepath = "/content/gdrive/MyDrive/PulsedIlluminationRepository/results/"
    if os.path.exists(args.gdrivepath):
      os.makedirs(args.gdrivepath)
      gdFlag = True
else:
  gdFlag = False  

if __name__ == '__main__':
    cfg = Config.fromfile(args.config)
    cfg.resize_h,cfg.resize_w = args.resolution
    cfg.crop_h,cfg.crop_w = args.dataset_crop
    
    cfg.train_pipeline[4]['resize_h'],cfg.train_pipeline[4]['resize_w'] = args.resolution
    cfg.train_pipeline[1]['crop_h'],cfg.train_pipeline[1]['crop_w'] = args.dataset_crop
    cfg.train_data.mask_shape = (args.resolution[0],args.resolution[1],args.frames)
    

    cfg.save_image_config['interval'] = args.saveImageEach
    cfg.runner['max_epoch'] = args.Epochs
    cfg.optimizer['lr'] = args.learning_rate
    cfg.data['samples_per_gpu'] = args.batch_size
    cfg.train_data['data_root'] = args.dataset_path
    cfg.checkpoints = args.checkpoints
    cfg.checkpoint_config['interval'] = args.saveModelEach
    
    if args.work_dir is None:
        args.work_dir = osp.join('./work_dirs',osp.splitext(osp.basename(args.config))[0])

    if args.resume is not None:
        cfg.resume = args.resume

    log_dir = osp.join(args.work_dir,"log")
    show_dir = osp.join(args.work_dir,"show")
    train_image_save_dir = osp.join(args.work_dir,"train_images")
    checkpoints_dir = osp.join(args.work_dir,"checkpoints")

    if not osp.exists(log_dir):
        os.makedirs(log_dir)
    if not osp.exists(show_dir):
        os.makedirs(show_dir)
    if not osp.exists(train_image_save_dir):
        os.makedirs(train_image_save_dir)
    if not osp.exists(checkpoints_dir):
        os.makedirs(checkpoints_dir)

    logger = Logger(log_dir)
    writer = SummaryWriter(log_dir = show_dir)

    rank = 0 
    if args.distributed:
        local_rank = int(args.local_rank)
        dist.init_process_group(backend="nccl")
        rank = dist.get_rank()

    dash_line = '-' * 80 + '\n'
    device_info = get_device_info()
    env_info = '\n'.join(['{}: {}'.format(k,v) for k, v in device_info.items()])

    device = args.device
    model = UNet(in_channel=16, out_channel=14, instance_norm=False).cuda()


    if rank==0:
        logger.info('GPU info:\n' 
                + dash_line + 
                env_info + '\n' +
                dash_line)
        logger.info('cfg info:\n'
                + dash_line + 
                json.dumps(cfg, indent=4)+'\n'+
                dash_line) 
        logger.info('Model info:\n'
                + dash_line + 
                str(model)+'\n'+
                dash_line)

    mask,mask_s = generate_masks(cfg.train_data.mask_path,cfg.train_data.mask_shape)

    


    train_data = build_dataset(cfg.train_data,{"mask":mask})

    if cfg.eval.flag:
        test_data = build_dataset(cfg.test_data,{"mask":mask})
    if args.distributed:
        dist_sampler = DistributedSampler(train_data,shuffle=True)
        train_data_loader = DataLoader(dataset=train_data, 
                                        batch_size=args.batch_size,
                                        sampler=dist_sampler,
                                        num_workers = cfg.data.workers_per_gpu)
    else:
        train_data_loader = DataLoader(dataset=train_data, 
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers = cfg.data.workers_per_gpu)
    optimizer = build_optimizer(cfg.optimizer,{"params":model.parameters()})
    # optimizer = StepLR(optimizer, step_size=30, gamma=0.8)
    criterion = build_loss(cfg.loss)
    criterion = criterion.to(args.device)
    final_loss_sum = 0.
    start_epoch = 0
    tv_loss_sum = 0
    if rank==0:
        if cfg.checkpoints is not None:
            logger.info("Load pre_train model...")
            resume_dict = torch.load(cfg.checkpoints)
            if "model_state_dict" not in resume_dict.keys():
                model_state_dict = resume_dict
            else:
                model_state_dict = resume_dict["model_state_dict"]
            load_checkpoints(model,model_state_dict)
        else:            
            logger.info("No pre_train model")

        if cfg.resume is not None:
            logger.info("Load resume...")
            resume_dict = torch.load(cfg.resume)
            start_epoch = resume_dict["epoch"]
            model_state_dict = resume_dict["model_state_dict"]
            load_checkpoints(model,model_state_dict)

            optim_state_dict = resume_dict["optim_state_dict"]
            optimizer.load_state_dict(optim_state_dict)
    if args.distributed:
        model = DDP(model,device_ids=[local_rank],output_device=local_rank,find_unused_parameters=True)

    iter_num = len(train_data_loader) 
    for epoch in range(start_epoch,cfg.runner.max_epochs):
        epoch_loss = 0
        model = model.train()
        start_time = time.time()
        for iteration, data in enumerate(train_data_loader):
            gt, meas = data
            gt = gt.float().to(args.device)
            meas = meas.unsqueeze(1).float().to(args.device)
            
            meas_f = torch.cat((gt[:,0:1,:,:],meas,gt[:,-1:,:,:]),1)
            
            batch_size = meas.shape[0]

            Phi = einops.repeat(mask,'cr h w->b cr h w',b=batch_size)
            Phi_s = einops.repeat(mask_s,'h w->b 1 h w',b=batch_size)

            Phi = torch.from_numpy(Phi).to(args.device)
            Phi_s = torch.from_numpy(Phi_s).to(args.device)

            optimizer.zero_grad()

            model_out = model(meas_f,Phi,Phi_s)
            
            model_out_f = torch.cat((gt[:,0:1,:,:],model_out,gt[:,-1:,:,:]),1)
            
            final_loss = utils.weighted_L1loss(model_out_f, gt)
            final_loss_sum += final_loss.item()

            tv_loss = utils.gradx(model_out).abs().mean() + utils.grady(model_out).abs().mean()
            tv_loss_sum += tv_loss.item()
            loss = final_loss + 0.1*tv_loss
            
            if not isinstance(model_out,list):
                model_out = [model_out]
            
            epoch_loss += loss.item()

            loss.backward()
            optimizer.step()
            if rank==0 and (iteration % cfg.log_config.interval) == 0:
                lr = optimizer.state_dict()["param_groups"][0]["lr"]
                iter_len = len(str(iter_num))
                logger.info("epoch: [{}][{:>{}}/{}], lr: {:.6f}, loss: {:.5f}.".format(epoch,iteration,iter_len,iter_num,lr,loss.item()))
                writer.add_scalar("loss",loss.item(),epoch*len(train_data_loader) + iteration)
            if rank==0 and (iteration % cfg.save_image_config.interval) == 0:
                sing_out = model_out_f[0].detach().cpu().numpy()
                sing_gt = gt[0].cpu().numpy()
                sing_mask = mask
                image_name = osp.join(train_image_save_dir,str(epoch)+"_"+str(iteration)+".png")
                save_image(sing_out,sing_gt,sing_mask,image_name)
        end_time = time.time()
        if rank==0:
            logger.info("epoch: {}, avg_loss: {:.5f}, time: {:.2f}s.\n".format(epoch,epoch_loss/(iteration+1),end_time-start_time))

        if rank==0 and (epoch % cfg.checkpoint_config.interval) == 0:
            if args.distributed:
                save_model = model.module
            else:
                save_model = model
            checkpoint_dict = {
                "epoch": epoch, 
                "model_state_dict": save_model.state_dict(), 
                "optim_state_dict": optimizer.state_dict(), 
            }
            torch.save(checkpoint_dict,osp.join(checkpoints_dir,"epoch_"+str(epoch)+".pth")) 

            if gdFlag is True:
              torch.save(checkpoint_dict,osp.join(args.gdrivepath,"epoch_"+str(epoch)+".pth")) 


        if rank==0 and cfg.eval.flag and epoch % cfg.eval.interval==0:
            if args.distributed:
                psnr_dict,ssim_dict = eval_psnr_ssim(model.module,test_data,mask,mask_s,args)
            else:
                psnr_dict,ssim_dict = eval_psnr_ssim(model,test_data,mask,mask_s,args)

            psnr_str = ", ".join([key+": "+"{:.4f}".format(psnr_dict[key]) for key in psnr_dict.keys()])
            ssim_str = ", ".join([key+": "+"{:.4f}".format(ssim_dict[key]) for key in ssim_dict.keys()])
            logger.info("Mean PSNR: \n{}.\n".format(psnr_str))
            logger.info("Mean SSIM: \n{}.\n".format(ssim_str))
        
        

INFO:root:GPU info:
--------------------------------------------------------------------------------
CUDA available: True
GPU numbers: 1
GPU INFO: [{'GPU 0': 'Tesla T4'}]
--------------------------------------------------------------------------------

2023-05-20 03:12:47,707 - <ipython-input-5-1c8d53db53e1> [line: 79] - GPU info:
--------------------------------------------------------------------------------
CUDA available: True
GPU numbers: 1
GPU INFO: [{'GPU 0': 'Tesla T4'}]
--------------------------------------------------------------------------------

INFO:root:cfg info:
--------------------------------------------------------------------------------
{
    "test_data": {
        "type": "SixGraySimData",
        "data_root": "test_datasets/simulation",
        "mask_path": "test_datasets/mask/shutter_mask16.mat",
        "mask_shape": null
    },
    "resize_h": 256,
    "resize_w": 256,
    "train_pipeline": [
        {
            "type": "RandomResize"
        },
        {
 

KeyboardInterrupt: ignored

## Transfer learning on Init model and train ST Transformer:

In [None]:
parser = argparse.ArgumentParser() 
args = parser.parse_args(args=[])
args.config = './configs/STFormer/stformer_base.py'
args.work_dir = './train_results/3meas_coarse_nm/'
args.dataset_path = './dataset/DAVIS/JPEGImages/480p/'
args.initModelPath = './train_results/3meas_coarse_nm/checkpoints/epoch_0.pth'
args.device = "cuda"
args.resolution = [256,256]
args.frames = 16
args.dataset_crop = [128,128]
args.distributed = False
args.resume = None
args.Epochs = 400
args.batch_size = 1
args.learning_rate = 0.0001
args.saveImageEach = 500
args.saveModelEach = 2
args.checkpoints = None
args.local_rank = -1
args.device = "cuda" if torch.cuda.is_available() else "cpu"

if os.path.exists('/content/gdrive'):
    print("GDrive Mounted, saving results on MyDrive/PulsedIlluminationRepository/results/")
    args.gdrivepath = "/content/gdrive/MyDrive/PulsedIlluminationRepository/results/"
    if os.path.exists(args.gdrivepath):
      os.makedirs(args.gdrivepath)
      gdFlag = True
else:
  gdFlag = False  



if __name__ == '__main__':
    cfg = Config.fromfile(args.config)
    cfg.resize_h,cfg.resize_w = args.resolution
    cfg.crop_h,cfg.crop_w = args.dataset_crop
    
    cfg.train_pipeline[4]['resize_h'],cfg.train_pipeline[4]['resize_w'] = args.resolution
    cfg.train_pipeline[1]['crop_h'],cfg.train_pipeline[1]['crop_w'] = args.dataset_crop
    cfg.train_data.mask_shape = (args.resolution[0],args.resolution[1],args.frames)
    

    cfg.save_image_config['interval'] = args.saveImageEach
    cfg.runner['max_epoch'] = args.Epochs
    cfg.train_data['data_root'] = args.dataset_path
    cfg.checkpoints = args.checkpoints
    cfg.checkpoint_config['interval'] = args.saveModelEach
    cfg.optimizer['lr'] = args.learning_rate
    cfg.data['samples_per_gpu'] = args.batch_size
    if args.work_dir is None:
        args.work_dir = osp.join('./work_dirs',osp.splitext(osp.basename(args.config))[0])

    if args.resume is not None:
        cfg.resume = args.resume

    log_dir = osp.join(args.work_dir,"log")
    show_dir = osp.join(args.work_dir,"show")
    train_image_save_dir = osp.join(args.work_dir,"train_images")
    checkpoints_dir = osp.join(args.work_dir,"checkpoints")

    if not osp.exists(log_dir):
        os.makedirs(log_dir)
    if not osp.exists(show_dir):
        os.makedirs(show_dir)
    if not osp.exists(train_image_save_dir):
        os.makedirs(train_image_save_dir)
    if not osp.exists(checkpoints_dir):
        os.makedirs(checkpoints_dir)

    logger = Logger(log_dir)
    writer = SummaryWriter(log_dir = show_dir)

    rank = 0 
    if args.distributed:
        local_rank = int(args.local_rank)
        dist.init_process_group(backend="nccl")
        rank = dist.get_rank()

    dash_line = '-' * 80 + '\n'
    device_info = get_device_info()
    env_info = '\n'.join(['{}: {}'.format(k,v) for k, v in device_info.items()])
    
    device = args.device
    model = build_model(cfg.model).to(device)
    
    DeModel = UNet(in_channel=16, out_channel=14, instance_norm=False).cuda()
    
    if os.path.exists(args.initModelPath):
        resume_dict = torch.load(args.initModelPath)
        model_state_dict = resume_dict["model_state_dict"]
        load_checkpoints(DeModel,model_state_dict)
        print("Init Model checkpoint loaded")
    else:
        # File does not exist
        print("Init Model checkpoint not found. Starting from scrach")


    #for name, para in DeModel.named_parameters(): #Freeze Unet
    #   para.requires_grad = False
    
    if rank==0:
        logger.info('GPU info:\n' 
                + dash_line + 
                env_info + '\n' +
                dash_line)
        logger.info('cfg info:\n'
                + dash_line + 
                json.dumps(cfg, indent=4)+'\n'+
                dash_line) 
        logger.info('Model info:\n'
                + dash_line + 
                str(model)+'\n'+
                dash_line)

    mask,mask_s = generate_masks(cfg.train_data.mask_path,cfg.train_data.mask_shape)
    train_data = build_dataset(cfg.train_data,{"mask":mask})
    if cfg.eval.flag:
        test_data = build_dataset(cfg.test_data,{"mask":mask})
    if args.distributed:
        dist_sampler = DistributedSampler(train_data,shuffle=True)
        train_data_loader = DataLoader(dataset=train_data, 
                                        batch_size=cfg.data.samples_per_gpu,
                                        sampler=dist_sampler,
                                        num_workers = cfg.data.workers_per_gpu)
    else:
        train_data_loader = DataLoader(dataset=train_data, 
                                        batch_size=cfg.data.samples_per_gpu,
                                        shuffle=True,
                                        num_workers = cfg.data.workers_per_gpu)
    optimizer = build_optimizer(cfg.optimizer,{"params":model.parameters()})
    
    criterion = build_loss(cfg.loss)
    criterion = criterion.to(args.device)
    
    start_epoch = 0
    if rank==0:
        if cfg.checkpoints is not None:
            logger.info("Load pre_train model...")
            resume_dict = torch.load(cfg.checkpoints)
            if "model_state_dict" not in resume_dict.keys():
                model_state_dict = resume_dict
            else:
                model_state_dict = resume_dict["model_state_dict"]
            load_checkpoints(model,model_state_dict)
        else:            
            logger.info("No pre_train model")

        if cfg.resume is not None:
            logger.info("Load resume...")
            resume_dict = torch.load(cfg.resume)
            start_epoch = resume_dict["epoch"]
            model_state_dict = resume_dict["model_state_dict"]
            load_checkpoints(model,model_state_dict)
            Demodel_state_dict = resume_dict["Demodel_state_dict"]
            load_checkpoints(DeModel,Demodel_state_dict)

            optim_state_dict = resume_dict["optim_state_dict"]
            optimizer.load_state_dict(optim_state_dict)
    if args.distributed:
        model = DDP(model,device_ids=[local_rank],output_device=local_rank,find_unused_parameters=True)
    
    iter_num = len(train_data_loader) 
    for epoch in range(start_epoch,cfg.runner.max_epochs):
        epoch_loss = 0
        model = model.train()
        start_time = time.time()
        for iteration, data in enumerate(train_data_loader):
            gt, meas = data
            gt = gt.float().to(args.device)
            meas = meas.unsqueeze(1).float().to(args.device)
            batch_size = meas.shape[0]

            Phi = einops.repeat(mask,'cr h w->b cr h w',b=batch_size)
            Phi_s = einops.repeat(mask_s,'h w->b 1 h w',b=batch_size)

            Phi = torch.from_numpy(Phi).to(args.device)
            Phi_s = torch.from_numpy(Phi_s).to(args.device)

            optimizer.zero_grad()
            meas_f = torch.cat((gt[:,0:1,:,:],meas,gt[:,-1:,:,:]),1)
            
            de_meas = DeModel(meas_f,Phi,Phi_s)
            model_out = model(de_meas,Phi,Phi_s)
            model_out_f = torch.cat((gt[:,0:1,:,:],model_out[0],gt[:,-1:,:,:]),1)
            



            if not isinstance(model_out,list):
                model_out = [model_out_f]
            loss = torch.sqrt(criterion(model_out_f, gt))
            epoch_loss += loss.item()

            loss.backward()
            optimizer.step()
            if rank==0 and (iteration % cfg.log_config.interval) == 0:
                lr = optimizer.state_dict()["param_groups"][0]["lr"]
                iter_len = len(str(iter_num))
                logger.info("epoch: [{}][{:>{}}/{}], lr: {:.6f}, loss: {:.5f}.".format(epoch,iteration,iter_len,iter_num,lr,loss.item()))
                writer.add_scalar("loss",loss.item(),epoch*len(train_data_loader) + iteration)
            if rank==0 and (iteration % cfg.save_image_config.interval) == 0:
                sing_out = model_out_f[0].detach().cpu().numpy()
                sing_gt = gt[0].cpu().numpy()
                #print(sing_gt.shape)
                #print(mask.shape)
                sing_mask = mask
                image_name = osp.join(train_image_save_dir,str(epoch)+"_"+str(iteration)+".png")
                save_image(sing_out,sing_gt,sing_mask,image_name)
        end_time = time.time()
        if rank==0:
            logger.info("epoch: {}, avg_loss: {:.5f}, time: {:.2f}s.\n".format(epoch,epoch_loss/(iteration+1),end_time-start_time))

        if rank==0 and (epoch % cfg.checkpoint_config.interval) == 0:
            if args.distributed:
                save_model = model.module
                save_De = DeModel.module
            else:
                save_model = model
                save_De = DeModel
            checkpoint_dict = {
                "epoch": epoch, 
                "model_state_dict": save_model.state_dict(), 
                "Demodel_state_dict": save_De.state_dict(),
                "optim_state_dict": optimizer.state_dict(), 
            }
            torch.save(checkpoint_dict,osp.join(checkpoints_dir,"epoch_"+str(epoch)+".pth")) 

            if gdFlag is True:
              torch.save(checkpoint_dict,osp.join(args.gdrivepath,"epoch_"+str(epoch)+".pth")) 

        if rank==0 and cfg.eval.flag and epoch % cfg.eval.interval==0:
            if args.distributed:
                psnr_dict,ssim_dict = eval_psnr_ssim(model.module,test_data,mask,mask_s,args)
            else:
                psnr_dict,ssim_dict = eval_psnr_ssim(model,DeModel,test_data,mask,mask_s,args)
                #psnr_dict,ssim_dict = 0,0

    
            psnr_str = ", ".join([key+": "+"{:.4f}".format(psnr_dict[key]) for key in psnr_dict.keys()])           
            ssim_str = ", ".join([key+": "+"{:.4f}".format(ssim_dict[key]) for key in ssim_dict.keys()])
            logger.info("Mean PSNR: \n{}.\n".format(psnr_str))
            logger.info("Mean SSIM: \n{}.\n".format(ssim_str))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
INFO:root:GPU info:
--------------------------------------------------------------------------------
CUDA available: True
GPU numbers: 1
GPU INFO: [{'GPU 0': 'Tesla T4'}]
--------------------------------------------------------------------------------

2023-05-20 03:54:56,568 - <ipython-input-6-b68ba15ca55d> [line: 93] - GPU info:
--------------------------------------------------------------------------------
CUDA available: True
GPU numbers: 1
GPU INFO: [{'GPU 0': 'Tesla T4'}]
--------------------------------------------------------------------------------

2023-05-20 03:54:56,568 - <ipython-input-6-b68ba15ca55d> [line: 93] - GPU info:
--------------------------------------------------------------------------------
CUDA available: True
GPU numbers: 1
GPU INFO: [{'GPU 0': 'Tesla T4'}]
--------------------------------------------------------------------------------

INFO:root:cfg info:
------------------------------

Init Model checkpoint loaded


INFO:root:No pre_train model
2023-05-20 03:54:56,827 - <ipython-input-6-b68ba15ca55d> [line: 137] - No pre_train model
2023-05-20 03:54:56,827 - <ipython-input-6-b68ba15ca55d> [line: 137] - No pre_train model
INFO:root:epoch: [0][    0/11481], lr: 0.000100, loss: 0.55150.
2023-05-20 03:55:00,569 - <ipython-input-6-b68ba15ca55d> [line: 190] - epoch: [0][    0/11481], lr: 0.000100, loss: 0.55150.
2023-05-20 03:55:00,569 - <ipython-input-6-b68ba15ca55d> [line: 190] - epoch: [0][    0/11481], lr: 0.000100, loss: 0.55150.
