In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from fastai import *
from fastai.vision import *
import torch
from torch import nn
from dsin.ae.data_manager.data_loader import (
    SideinformationImageImageList, ImageSiTuple)
from dsin.ae import config
from dsin.ae.si_ae import SideInformationAutoEncoder
from dsin.ae.si_net import SiNetChannelIn
from dsin.ae.loss_man import LossManager
from dsin.ae.distortions import Distortions, DistTypes
from dsin.ae.kitti_normalizer import ChangeImageStatsToKitti, ChangeState
from dsin.ae import config


In [3]:
config.H_target = 2* 0.3
# config.autoencoder_loss_distortion_to_minimize=DistTypes.MS_SSMIM
config.K_MS_SSIM=500
config.beta

500

In [4]:
import time
import datetime
import logging

logger = logging.getLogger()

def setup_file_logger(log_file='out.log'):
    hdlr = logging.FileHandler(log_file)
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr) 
    logger.setLevel(logging.INFO)
    logger.info("start")

setup_file_logger()

In [5]:
class AverageMetric(Callback):
    "Wrap a `func` in a callback for metrics computation."
    def __init__(self, func,name):
        # If it's a partial, use func.func
        #         name = getattr(func,'func',func).__name__
        self.func, self.name = func, name

    def on_epoch_begin(self, **kwargs):
        "Set the inner value to 0."
        self.val, self.count = 0.,0

    def on_batch_end(self, last_output, last_target, **kwargs):
        "Update metric computation with `last_output` and `last_target`."
        if not is_listy(last_target): last_target=[last_target]
        self.count += last_target[0].size(0) # batch size
        
        if last_output[0] is None:
            X_OUTPUT = 1 # X_DEC
        else:
            X_OUTPUT = 0 # X_HAT_OUT
        val = self.func(last_output[X_OUTPUT], last_target[0])
        self.val += last_target[0].size(0) * val.detach().cpu()
        
    
    
    def on_epoch_end(self, last_metrics, **kwargs):
        "Set the final result in `last_metrics`."
        return add_metrics(last_metrics, self.val/self.count)
    
class ParameterMetricCallback(Callback):
    def __init__(self,loss_man):
        self.loss_man = loss_man
    
    def on_backward_begin(self,*args, **kwargs):
        self.pbar=kwargs["pbar"]
        if hasattr(self.loss_man,'soft_bit_entropy'):
            self.pbar.child.comment += f' soft_bit_entropy: {self.loss_man.soft_bit_entropy:.4f}'

class ParameterRunningAverageMetricCallback(Callback):
    def __init__(self,loss_man,alpha=0.1):
        self.loss_man = loss_man
        self.alpha = alpha
        self.val = None
    
    def on_backward_begin(self,*args, **kwargs):
        self.pbar=kwargs["pbar"]
        self.importance_map=kwargs["last_output"][3].detach()
#         import pdb
#         pdb.set_trace()
        if hasattr(self.loss_man,'soft_bit_entropy'):
            if self.val is None:
                self.val = self.loss_man.soft_bit_entropy.detach()
            else:
                self.val *= 1 - self.alpha
                self.val += self.alpha * self.loss_man.soft_bit_entropy.detach()
                
            self.pbar.child.comment += f' avg_bpp: {self.val / 2 :.4f} imp-mean-var {(torch.mean(self.importance_map),torch.var(self.importance_map))}'
            self.pbar.child.comment += f'autoencoder_loss_value{self.loss_man.autoencoder_loss_value:.1f} '            
            self.pbar.child.comment += f'l2_reg_loss{self.loss_man.l2_reg_loss:.1f} '
            self.pbar.child.comment += f'si_net_loss_value{self.loss_man.si_net_loss_value:.1f} '
            self.pbar.child.comment += f'feat_loss_value{self.loss_man.feat_loss_value:.1f} '
            self.pbar.child.comment += f'bit_cost_loss_value{self.loss_man.bit_cost_loss_value:.1f} '

In [6]:
class BitEntropy(Callback):
    "Wrap a `func` in a callback for metrics computation."
    def __init__(self,loss_man,alpha=0.1, logger=logger):
        # If it's a partial, use func.func
        #         name = getattr(func,'func',func).__name__
        self.loss_man = loss_man
        self.alpha = alpha
        self.logger = logger
        
    def on_epoch_begin(self, **kwargs):
        "Set the inner value to 0."
        self.val = 0.0
        self.iter = 0
        
    def on_batch_end(self, last_output, last_target, **kwargs):
        "Update metric computation with `last_output` and `last_target`."
      
        self.val *= 1 - self.alpha
        self.val += self.alpha * self.loss_man.soft_bit_entropy.detach()

        if self.iter % 500 == 0 :
            importance_map  = last_output[3].detach()
            msg = f"iter {self.iter}: bpp = {self.val / 2:.3f}, impmap- mean {torch.mean(importance_map):.4f} var {torch.var(importance_map):.4f} "
            msg += f" total loss{self.loss_man.total_loss:.1f}  l2reg_loss={self.loss_man.l2_reg_loss:.1f}"
            msg += f"autoencoder_loss_value={ self.loss_man.autoencoder_loss_value:.1f}"
            msg += f"si_loss={self.loss_man.si_net_loss_value}"
            self.logger.info(msg)
            print(msg)
        self.iter += 1

    def on_epoch_end(self, last_metrics, **kwargs):
        "Set the final result in `last_metrics`."
        return add_metrics(last_metrics, self.val)
    

In [7]:
config.use_si_flag = SiNetChannelIn.WithSideInformation
config.use_si_flag

<SiNetChannelIn.WithSideInformation: 6>

In [8]:
torch.__version__

'1.5.1'

In [9]:
si_autoencoder = SideInformationAutoEncoder(config.use_si_flag)
path = "../src/dsin/data"
pct= 1 #0.0005 #0.25

valid_image_list = SideinformationImageImageList.from_csv(
    path=path, csv_names=["KITTI_stereo_val.txt"],pct=pct)
train_image_list = SideinformationImageImageList.from_csv(
    path=path, csv_names=["KITTI_stereo_train.txt"],pct=pct)

image_lists = ItemLists(
    path=path, train=train_image_list, valid=valid_image_list)

# ll = image_lists.label_from_func(lambda x: x)

tfms =  get_transforms(do_flip=True, flip_vert=False, max_rotate=None, max_zoom=1., max_lighting=None, max_warp=None, p_affine=0.0, p_lighting=0.0)
 #get_transforms(do_flip=True, max_rotate=0.0) # None #
batchsize = 1
# [flip_lr(p=0.5),[]]
data = (image_lists
        .label_from_func(lambda x: x)
        .transform(None, size=(336, 1224), resize_method=ResizeMethod.CROP, tfm_y=True)
        .databunch(bs=batchsize))
learn = Learner(data=data,
                     model=si_autoencoder,
                     opt_func=torch.optim.Adam,
                     loss_func=LossManager(SiNetChannelIn.WithSideInformation),
                     metrics=[AverageMetric(Distortions._calc_dist,"MS_SSIM")])

# learn.metrics.append(BitEntropy(loss_man=learn.loss_func))

> /home/ubuntu/tDSIN/src/dsin/ae/si_net.py(89)_weight_init()
-> for layer in self.layers:
(Pdb) C
*** NameError: name 'C' is not defined
(Pdb) C
*** NameError: name 'C' is not defined
(Pdb) c


In [10]:
learn.load('try2_200807MAE-l2reg-baseline-2')

model_fname = '200814MAE-l2reg-si'
for i in range(1,5):
    if i != 1 :
        learn.load(f'{model_fname}-{i - 1}',with_opt=True)
    
    config.si_loss_weight_alpha = 0.7
    learn.model.true_tuple_loss_false_just_out = True
    learn.model.use_side_infomation = SiNetChannelIn.WithSideInformation
    learn.loss_func.use_side_infomation = SiNetChannelIn.WithSideInformation
    learn.fit(1, lr=0.0001,wd=0,callbacks=[ParameterRunningAverageMetricCallback(learn.loss_func),BitEntropy(loss_man=learn.loss_func)])
    learn.save(f'{model_fname}-{i}')
    !aws s3 cp ~/tDSIN/src/dsin/data/models/{model_fname}-{i}.pth  s3://dsin-us/models/
    !aws s3 cp ~/tDSIN/tutorials/out.log s3://dsin-us/models/


RuntimeError: Error(s) in loading state_dict for SideInformationAutoEncoder:
	Missing key(s) in state_dict: "si_net.model.2.running_mean", "si_net.model.2.running_var", "si_net.model.3.model.0.weight", "si_net.model.3.model.0.bias", "si_net.model.3.model.2.weight", "si_net.model.3.model.2.bias", "si_net.model.3.model.2.running_mean", "si_net.model.3.model.2.running_var", "si_net.model.4.model.0.weight", "si_net.model.4.model.0.bias", "si_net.model.4.model.2.weight", "si_net.model.4.model.2.bias", "si_net.model.4.model.2.running_mean", "si_net.model.4.model.2.running_var", "si_net.model.5.model.0.weight", "si_net.model.5.model.0.bias", "si_net.model.5.model.2.weight", "si_net.model.5.model.2.bias", "si_net.model.5.model.2.running_mean", "si_net.model.5.model.2.running_var", "si_net.model.6.model.0.weight", "si_net.model.6.model.0.bias", "si_net.model.6.model.2.weight", "si_net.model.6.model.2.bias", "si_net.model.6.model.2.running_mean", "si_net.model.6.model.2.running_var", "si_net.model.7.model.0.weight", "si_net.model.7.model.0.bias", "si_net.model.7.model.2.weight", "si_net.model.7.model.2.bias", "si_net.model.7.model.2.running_mean", "si_net.model.7.model.2.running_var", "si_net.model.8.model.0.weight", "si_net.model.8.model.0.bias", "si_net.model.8.model.2.weight", "si_net.model.8.model.2.bias", "si_net.model.8.model.2.running_mean", "si_net.model.8.model.2.running_var", "si_net.model.9.model.0.weight", "si_net.model.9.model.0.bias", "si_net.model.9.model.2.weight", "si_net.model.9.model.2.bias", "si_net.model.9.model.2.running_mean", "si_net.model.9.model.2.running_var", "si_net.model.11.mean", "si_net.model.11.var". 
	Unexpected key(s) in state_dict: "si_net.model.12.weight", "si_net.model.12.bias", "si_net.model.14.weight", "si_net.model.14.bias", "si_net.model.16.weight", "si_net.model.16.bias", "si_net.model.17.mean", "si_net.model.17.var", "si_net.model.4.weight", "si_net.model.4.bias", "si_net.model.6.weight", "si_net.model.6.bias", "si_net.model.8.weight", "si_net.model.8.bias". 
	size mismatch for si_net.model.2.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for si_net.model.10.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 32, 1, 1]).
	size mismatch for si_net.model.10.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([3]).

In [None]:
learn.model.my_tuple[5]


In [None]:
learn.data.train_ds[1][0].img

In [None]:
learn.data.train_ds[1][0].si_img

In [None]:
learn.loss_func.soft_bit_entropy

In [None]:
config.si_loss_weight_alpha

In [None]:
learn.loss_func.autoencoder_loss_value

In [None]:
learn.load('try2_200807MAE-l2reg-baseline-2')


In [None]:
mt = learn.model.my_tuple[-2].squeeze_().detach()
mx, mn =torch.max(mt), torch.min(mt)
diff = mx-mn
show_image(Image((mt - mn)/(diff)),figsize=(30,30))
# # print(mn)

In [None]:

# # show_image(Image(learn.model.my_tuple[0][:,:50,:50]/255))
# # Image(learn.model.my_tuple[2].squeeze_().detach()/255.0)
mt = learn.model.my_tuple[2].squeeze_().detach()
mx, mn =torch.max(mt), torch.min(mt)
diff = mx-mn
show_image(Image((mt - mn)/(diff)),figsize=(30,30))
# # print(mn)

# torch.min(learn.model.my_tuple[2].squeeze_().detach())

In [None]:
aaa = (mt - mn)/(diff)

flip_lr(Image(aaa.cpu()))

In [None]:
# learn.model.true_tuple_loss_false_just_out = False

# learn.show_results(figsize=(30,30))

In [None]:
learn.model.my_tuple[-2].data