In [1]:
import argparse
import logging
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

from utils import data
import models, utils

In [2]:
class Args(object):
    def __init__(self):
        self.data_path= 'data'
        self.dataset= 'masked_pwc'
        self.batch_size= 32
        self.model= 'unet1d'
        self.lr= 0.001
        self.num_epochs= 100
        self.n_data = 100000
        self.min_sep = 5
        self.valid_interval= 1
        self.save_interval= 1
        self.seed = 0
        self.output_dir= 'experiments'
        self.experiment= None
        self.resume_training= False
        self.restore_file= None
        self.no_save= False
        self.step_checkpoints= False
        self.no_log= False
        self.log_interval= 100
        self.no_visual= False
        self.visual_interval= 100
        self.no_progress= False
        self.draft= False
        self.dry_run= False
        self.in_channels= 1
        self.bias= False
        self.test_num = 0
        # UNET
        self.residual = False
args=Args()

In [3]:
# gpu or cpu
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
utils.setup_experiment(args)
utils.init_logging(args)

[2020-08-20 09:29:41] COMMAND: /home/michael/python-virtual-environments/bfcnn/lib/python3.6/site-packages/ipykernel_launcher.py -f /home/michael/.local/share/jupyter/runtime/kernel-630f7d19-217f-468c-994c-4fd48aa9dcc4.json
[2020-08-20 09:29:41] Arguments: {'data_path': 'data', 'dataset': 'masked_pwc', 'batch_size': 32, 'model': 'unet1d', 'lr': 0.001, 'num_epochs': 100, 'n_data': 100000, 'min_sep': 5, 'valid_interval': 1, 'save_interval': 1, 'seed': 0, 'output_dir': 'experiments', 'experiment': 'unet1d-Aug-20-09:29:41', 'resume_training': False, 'restore_file': None, 'no_save': False, 'step_checkpoints': False, 'no_log': False, 'log_interval': 100, 'no_visual': False, 'visual_interval': 100, 'no_progress': False, 'draft': False, 'dry_run': False, 'in_channels': 1, 'bias': False, 'test_num': 0, 'residual': False, 'experiment_dir': 'experiments/unet1d/unet1d-Aug-20-09:29:41', 'checkpoint_dir': 'experiments/unet1d/unet1d-Aug-20-09:29:41/checkpoints', 'log_dir': 'experiments/unet1d/unet1d-

In [None]:
MODEL_PATH = "models/trained/unet1d_partialconv_10kdata_30epoch_3minsep_08_14_20.pth"
torch.save(model.state_dict(), MODEL_PATH)

In [4]:
# Saving model
# torch.save(model.state_dict(), MODEL_PATH)
# MODEL_PATH = "models/trained/dncnn1d_partialconv_5kdata_20epoch_08_12_20.pth"
MODEL_PATH = "models/trained/unet1d_partialconv_10kdata_30epoch_3minsep_08_14_20.pth"

train_new_model = True



# Build data loaders, a model and an optimizer
if train_new_model:
    model = models.build_model(args).to(device)
else:
    model = models.build_model(args)
    model.load_state_dict(torch.load(MODEL_PATH))
    model.to(device)

print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 60, 70, 80, 90, 100], gamma=0.5)
logging.info(f"Built a model consisting of {sum(p.numel() for p in model.parameters()):,} parameters")

if args.resume_training:
    state_dict = utils.load_checkpoint(args, model, optimizer, scheduler)
    global_step = state_dict['last_step']
    start_epoch = int(state_dict['last_step']/(403200/state_dict['args'].batch_size))+1
else:
    global_step = -1
    start_epoch = 0

[2020-08-20 09:29:59] Built a model consisting of 72,000 parameters


UNet(
  (conv1): PartialConv1d(1, 32, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
  (conv2): PartialConv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
  (conv3): PartialConv1d(32, 64, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)
  (conv4): PartialConv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
  (conv5): PartialConv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,), bias=False)
  (conv6): PartialConv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,), bias=False)
  (conv7): ConvTranspose1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,), bias=False)
  (conv8): PartialConv1d(96, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
  (conv9): PartialConv1d(32, 1, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
)


In [5]:
# build_dataset is a function in utils/data/__init__.py
train_loader, valid_loader, _ = data.build_dataset(args.dataset,
                                                   args.n_data, 
                                                   batch_size=args.batch_size,
                                                   min_sep = args.min_sep)

In [6]:
# Track moving average of loss values
train_meters = {name: utils.RunningAverageMeter(0.98) for name in (["train_loss", "train_psnr", "train_ssim"])}
valid_meters = {name: utils.AverageMeter() for name in (["valid_psnr", "valid_ssim"])}
writer = SummaryWriter(log_dir=args.experiment_dir) if not args.no_visual else None

In [7]:
# TRAINING
for epoch in range(start_epoch, args.num_epochs):
    if args.resume_training:
        if epoch %10 == 0:
            optimizer.param_groups[0]["lr"] /= 2
            print('learning rate reduced by factor of 2')

    train_bar = utils.ProgressBar(train_loader, epoch)
    for meter in train_meters.values():
        meter.reset()

    for batch_id, (clean, mask) in enumerate(train_bar):
        # dataloader returns [clean, mask] list
        model.train()
        global_step += 1
        inputs = clean.to(device)
        mask_inputs = mask.to(device)
        # only use the mask part of the outputs
        raw_outputs = model(inputs,mask_inputs)
        outputs = (1-mask_inputs)*raw_outputs + mask_inputs*inputs
        
        # TO DO, only run loss on masked part of output
        loss = F.mse_loss(outputs, inputs, reduction="sum") / (inputs.size(0) * 2)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        train_psnr = utils.psnr(outputs, inputs)
        train_ssim = utils.ssim(outputs, inputs)
        train_meters["train_loss"].update(loss.item())
        train_meters["train_psnr"].update(train_psnr.item())
        train_meters["train_ssim"].update(train_ssim.item())
        train_bar.log(dict(**train_meters, lr=optimizer.param_groups[0]["lr"]), verbose=True)

        if writer is not None and global_step % args.log_interval == 0:
            writer.add_scalar("lr", optimizer.param_groups[0]["lr"], global_step)
            writer.add_scalar("loss/train", loss.item(), global_step)
            writer.add_scalar("psnr/train", train_psnr.item(), global_step)
            writer.add_scalar("ssim/train", train_ssim.item(), global_step)
            gradients = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None], dim=0)
            writer.add_histogram("gradients", gradients, global_step)
            sys.stdout.flush()

    if epoch % args.valid_interval == 0:
        model.eval()
        for meter in valid_meters.values():
            meter.reset()

        valid_bar = utils.ProgressBar(valid_loader)
        
        for sample_id, (clean, mask) in enumerate(valid_bar):
            with torch.no_grad():
                inputs = clean.to(device)
                mask_inputs = mask.to(device)
                # only use the mask part of the outputs
                raw_output = model(inputs,mask_inputs)
                output = (1-mask_inputs)*raw_output + mask_inputs*inputs
#                 output = model(inputs)
#                 sample = noisy_clean_sample[1].to(device)
#                 noisy_inputs = noisy_clean_sample[0].to(device);
#                 output = model(noisy_inputs)

                valid_psnr = utils.psnr(inputs, output)
                valid_meters["valid_psnr"].update(valid_psnr.item())
                valid_ssim = utils.ssim(inputs, output)
                valid_meters["valid_ssim"].update(valid_ssim.item())

                ### Uncomment these when finished
                if writer is not None and sample_id < 10:
                    image = torch.cat([inputs, torch.mul(inputs, mask_inputs), output], dim=0)
                    image = torchvision.utils.make_grid(image.clamp(0, 1), nrow=3, normalize=False)
                    writer.add_image(f"valid_samples/{sample_id}", image, global_step)

        if writer is not None:
            writer.add_scalar("psnr/valid", valid_meters['valid_psnr'].avg, global_step)
            writer.add_scalar("ssim/valid", valid_meters['valid_ssim'].avg, global_step)
            sys.stdout.flush()

        logging.info(train_bar.print(dict(**train_meters, **valid_meters, lr=optimizer.param_groups[0]["lr"])))
        utils.save_checkpoint(args, global_step, model, optimizer, score=valid_meters["valid_psnr"].avg, mode="max")
    scheduler.step()

logging.info(f"Done training! Best PSNR {utils.save_checkpoint.best_score:.3f} obtained after step {utils.save_checkpoint.best_step}.")


[2020-08-20 09:37:50] epoch 00 | train_loss 0.151 | train_psnr 28.743 | train_ssim 0.908 | valid_psnr 28.927 | valid_ssim 0.907 | lr 1.0e-03             
[2020-08-20 09:45:27] epoch 01 | train_loss 0.147 | train_psnr 28.989 | train_ssim 0.909 | valid_psnr 29.361 | valid_ssim 0.911 | lr 1.0e-03             
[2020-08-20 09:52:50] epoch 02 | train_loss 0.135 | train_psnr 29.441 | train_ssim 0.911 | valid_psnr 29.770 | valid_ssim 0.914 | lr 1.0e-03             
[2020-08-20 10:00:14] epoch 03 | train_loss 0.138 | train_psnr 29.494 | train_ssim 0.911 | valid_psnr 29.768 | valid_ssim 0.913 | lr 1.0e-03             
[2020-08-20 10:07:37] epoch 04 | train_loss 0.144 | train_psnr 29.409 | train_ssim 0.911 | valid_psnr 29.719 | valid_ssim 0.912 | lr 1.0e-03             
[2020-08-20 10:15:00] epoch 05 | train_loss 0.140 | train_psnr 29.619 | train_ssim 0.911 | valid_psnr 29.845 | valid_ssim 0.913 | lr 1.0e-03             
[2020-08-20 10:22:24] epoch 06 | train_loss 0.144 | train_psnr 29.689 | trai

[2020-08-20 16:11:04] epoch 53 | train_loss 0.137 | train_psnr 30.705 | train_ssim 0.916 | valid_psnr 30.846 | valid_ssim 0.917 | lr 5.0e-04             
[2020-08-20 16:18:26] epoch 54 | train_loss 0.133 | train_psnr 30.930 | train_ssim 0.919 | valid_psnr 30.459 | valid_ssim 0.917 | lr 5.0e-04             
[2020-08-20 16:25:48] epoch 55 | train_loss 0.134 | train_psnr 30.871 | train_ssim 0.917 | valid_psnr 30.693 | valid_ssim 0.917 | lr 5.0e-04             
[2020-08-20 16:33:28] epoch 56 | train_loss 0.139 | train_psnr 30.579 | train_ssim 0.915 | valid_psnr 30.872 | valid_ssim 0.917 | lr 5.0e-04             
[2020-08-20 16:41:07] epoch 57 | train_loss 0.132 | train_psnr 30.707 | train_ssim 0.917 | valid_psnr 30.901 | valid_ssim 0.918 | lr 5.0e-04             
[2020-08-20 16:48:47] epoch 58 | train_loss 0.128 | train_psnr 30.844 | train_ssim 0.919 | valid_psnr 30.913 | valid_ssim 0.918 | lr 5.0e-04             
[2020-08-20 16:56:17] epoch 59 | train_loss 0.127 | train_psnr 30.813 | trai

### Testing

In [None]:
# model3 = models.build_model(args)
# model3.load_state_dict(torch.load("models/trained/unet1d_partialconv_10kdata_30epoch_3minsep_08_14_20.pth"))
# model3.to(device)

# model5 = models.build_model(args)
# model5.load_state_dict(torch.load("models/trained/unet1d_partialconv_10kdata_30epoch_08_13_20.pth"))
# model5.to(device)

model10 = models.build_model(args)
model10.load_state_dict(torch.load("models/trained/unet1d_partialconv_10kdata_30epoch_10minsep_08_14_20.pth"))
model10.to(device)

## Analysis of first predicted point
Comparison to global mean, receptive field mean, next visible point.

### min_sep = 3

In [None]:
import pandas as pd

def first_pt_stats(model,min_sep):
    _,_,test_loader = data.build_dataset(args.dataset,
                                                   args.n_data, 
                                                   batch_size=args.n_data,
                                                   fix_datapoints=True,            
                                                   min_sep = min_sep,
                                                   test_num = 1)
    print("Min_sep: {}".format(min_sep))
    print("*"*30)
    for batch_id,(clean,mask) in enumerate(test_loader):
        print("Mean of clean signal: {:2.4f}".format(clean.mean()))
        outputs = model(clean.to(device),mask.to(device)).cpu()
        print("Mean first value (min_sep=3): {:2.4f}".format(outputs[:,:,0].mean()))

    # Collect the "means" we're comparing to
    mean_unmasked_sig = []
    mean_rf_sig = []
    first_unmasked = []

    # Collect the diffs with the first value
    mean_unmasked_sig_diff = []
    mean_rf_sig_diff = []
    first_unmasked_diff = []

    mask_length = (64-mask.sum(axis=2))
    for i in range(len(mask_length)):
        # Mean of unmasked signal
        mum = clean[i,0,int(mask_length[i]):].mean()
        mean_unmasked_sig.append(mum)
        # Mean of the unmasked receptive field 
        mrf = clean[i,0,int(mask_length[i]):21].mean()
        mean_rf_sig.append(mrf)
        # First unmasked value
        fu = clean[i,0,int(mask_length[i])]
        first_unmasked.append(fu)

        # The diffs
        mean_unmasked_sig_diff.append(abs(outputs[i,0,0]-mum).detach())
        mean_rf_sig_diff.append(abs(outputs[i,0,0]-mrf).detach())
        first_unmasked_diff.append(abs(outputs[i,0,0]-fu).detach())

    print("Mean of full unmasked signal: {:2.4f}".format(np.mean(mean_unmasked_sig)))
    print("Mean of receptive field signal [0,21]: {:2.4f}".format(np.mean(mean_rf_sig)))
    print("Mean of first visible value after mask: {:2.4f}".format(np.mean(first_unmasked)))

    print("First predicted value mean diff: full unmasked signal: {:2.4f} (SD: {:2.4f})"\
          .format(np.mean(mean_unmasked_sig_diff),np.std(mean_unmasked_sig_diff)))
    print("First predicted value mean diff: receptive field signal [0,21]: {:2.4f} (SD: {:2.4f})"\
          .format(np.mean(mean_rf_sig_diff),np.std(mean_rf_sig_diff)))
    print("First predicted value mean diff: first visible value after mask: {:2.4f} (SD: {:2.4f})"\
          .format(np.mean(first_unmasked_diff),np.std(first_unmasked_diff)))
    
    df_list = [min_sep,np.mean(mean_unmasked_sig),np.mean(mean_rf_sig),np.mean(first_unmasked),
              float(outputs[:,:,0].mean().detach()),
              np.mean(mean_unmasked_sig_diff),np.std(mean_unmasked_sig_diff),
              np.mean(mean_rf_sig_diff),np.std(mean_rf_sig_diff),
              np.mean(first_unmasked_diff),np.std(first_unmasked_diff)
              ]
    # print("Mean absolute diff of first predicted value and first visible after mask: {:2.4f}".format(np.mean(first_pred_unmasked_diff)))
    ### min_sep = 3
    return df_list

In [None]:
df_list3 = first_pt_stats(model3,3)

In [None]:
df_list5 = first_pt_stats(model5,5)

In [None]:
df_list10 = first_pt_stats(model10,10)

In [None]:
pd.DataFrame([df_list3,df_list5,df_list10],columns = ['min_sep',\
                                      'clean_sig_mean','receptive_field_mean',\
                                      'first_visible_mean','first_pred_mean',\
                                      'full_unmasked_diff_mean','full_unmasked_diff_sd',\
                                      'receptive_field_diff_mean','receptive_field_diff_sd',\
                                      'first_visible_diff_mean','first_visible_diff_sd'
                                     ]).T

## Examples

In [None]:
# Best PSNR 28.560
def mask_idx_f(mask):
    mask_start = int(np.argmin(mask[0]))
    mask_length = int((1-mask[0]).sum())
    mask_idx = range(mask_start,mask_start+mask_length)
     # No mask indices
    before = np.arange(mask.shape[2])[:mask_start]
    after = np.arange(mask.shape[2])[mask_start+mask_length:]
    no_mask_idx = np.append(before,after)
    return mask_idx,before, after, mask_length, mask_start

def print_one(loader,model):
    np.random.seed()
    clean,mask = next(iter(loader))
    outputs = model(clean.to(device),mask.to(device)).cpu()
    
    mask_idx,before_mask,after_mask,mask_length, mask_start = mask_idx_f(mask)

    outputs[0] * (1-mask[0]) + clean[0]*mask[0]    

    out = outputs[0] * (1-mask[0]) + clean[0]*mask[0]
    print("Mask Length: {}\tMask Start: {}".format(mask_length,mask_start))
    
    plt.figure(figsize=[15,10])
    plt.subplot(3,1,1)
    plt.plot(clean[0,0,:],'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')
    plt.title("True signal")

    plt.subplot(3,1,2)
    masked = clean[0]*mask[0]
    masked_plot = masked[:mask_start,]
    plt.plot(before_mask,masked[0,before_mask],'xb')
    plt.plot(after_mask,masked[0,after_mask],'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')

    plt.title("Masked signal")
    plt.subplot(3,1,3)
    plt.plot(out[0,:].detach(),'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')

    plt.title("Denoised signal")
    
    # Mean of the visible signal
    sig_mean = clean[0,0,mask_length:21].mean()
    print("First mask value: {:2.4f}\nMean of full signal: {:2.4f}\nMean of visible signal: {:2.4f}"\
          .format(out[0,0],clean[0,0,:21].mean(),sig_mean))

In [None]:
# Test loader is shuffled and allows test_num to force a certain mask shape
_,_,test_loader = data.build_dataset(args.dataset,
                                                   args.n_data, 
                                                   batch_size=args.n_data,
                                                   fix_datapoints=True,            
                                                   min_sep = 10,
                                                   test_num = 0)

In [None]:
print_one(test_loader,model10)

In [None]:
print_one(test_loader,model10)

In [None]:
print_one(test_loader,model10)

In [None]:
print_one(test_loader,model10)

In [None]:
print_one(test_loader,model10)

In [None]:
torch.Tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
              1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,1.,0.,0.,0.,0.,0.,0.,0.,0., 
              0., 1., 1., 1.,1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]).shape

In [None]:
c,m = next(iter(test_loader))
m.shape

In [8]:
torch.save(model.state_dict(), MODEL_PATH)
MODEL_PATH = "models/trained/unet1d_partialconv_100kdata_100epoch_08_21_20.pth"
