In [None]:
import argparse
import logging
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
from torch.serialization import default_restore_location

from torch.utils.tensorboard import SummaryWriter

from utils import data
import models, utils

In [None]:
# gpu or cpu
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

### Load the GAN model

In [None]:
def load_GAN_models(restore_file):
    # load state_dict for args
    state_dict = torch.load(restore_file, map_location=lambda s, l: default_restore_location(s, "cpu"))

    # Extract args
    args = state_dict['args']

    # Initiate models
    G,D = models.build_model_gan(args)
    netG = G.to(device)
    netD = D.to(device)

    # Load state_dict
    netG.load_state_dict(state_dict['modelG'][0]) # Remove the [0] for future models
    netD.load_state_dict(state_dict['modelD'][0])
    return netG,netD, args

restore_file_GAN= "experiments/unet1d-Sep-01-23:17:43_GAN_only/checkpoints/checkpoint_best.pt"
restore_file_GANMSE= "experiments/unet1d-Sep-01-23:18:04_MSE/checkpoints/checkpoint_best.pt"

netG_GAN,netD_GAN,_ = load_GAN_models(restore_file_GAN)
netG_GANMSE,netD_GANMSE,args = load_GAN_models(restore_file_GANMSE)

### Load the MSE model

In [None]:
# Load MSE model
mse = models.build_model(args)
netMSE = mse.to(device)

MODEL_PATH = "models/trained/unet1d_partialconv_10kdata_30epoch_3minsep_08_14_20.pth"
netMSE.load_state_dict(torch.load(MODEL_PATH))
netMSE.to(device)

### Testing

In [None]:
# Best PSNR 28.560
def mask_idx_f(mask):
    mask_start = int(np.argmin(mask[0]))
    mask_length = int((1-mask[0]).sum())
    mask_idx = range(mask_start,mask_start+mask_length)
     # No mask indices
    before = np.arange(mask.shape[2])[:mask_start]
    after = np.arange(mask.shape[2])[mask_start+mask_length:]
    no_mask_idx = np.append(before,after)
    return mask_idx,before, after, mask_length, mask_start

def model_outputs(clean,mask,model):
    outputs = model(clean.to(device),mask.to(device)).cpu()
    out = outputs[0] * (1-mask[0]) + clean[0]*mask[0]
    return out

    
def print_one(loader,model_GAN,model_GANMSE,model_MSE):
    np.random.seed()
    clean,mask = next(iter(loader))
    
    mask_idx,before_mask,after_mask,mask_length, mask_start = mask_idx_f(mask)

    out_GAN = model_outputs(clean,mask,model_GAN)
    out_GAN_MSE = model_outputs(clean,mask,model_GANMSE)
    out_MSE = model_outputs(clean,mask,model_MSE)

    print("Mask Length: {}\tMask Start: {}".format(mask_length,mask_start))
    
    plt.figure(figsize=[21,14])
    plt.subplot(5,1,1)
    plt.plot(clean[0,0,:],'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')
    plt.title("True signal")

    plt.subplot(5,1,2)
    masked = clean[0]*mask[0]
    masked_plot = masked[:mask_start,]
    plt.plot(before_mask,masked[0,before_mask],'xb')
    plt.plot(after_mask,masked[0,after_mask],'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')
    plt.title("Masked signal")

    plt.subplot(5,1,3)
    plt.plot(out_GAN[0,:].detach(),'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')
    plt.title("GAN denoised signal")
 
    plt.subplot(5,1,4)
    plt.plot(out_GAN_MSE[0,:].detach(),'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')
    plt.title("GAN+MSE denoised signal")
    
    plt.subplot(5,1,5)
    plt.plot(out_MSE[0,:].detach(),'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')
    plt.title("MSE denoised signal")
#     return out1,out2, clean

### Test the Generator

In [None]:
# Test loader is shuffled and allows test_num to force a certain mask shape
_, _, test_loader = data.build_dataset(args.datasetG,
                                                   batch_size=1,
                                                   fix_datapoints=False,
                                                   min_sep = args.min_sep,
                                                   test_num = 0)

In [None]:
print_one(test_loader,netG_GAN,netG_GANMSE,netMSE)

In [None]:
print_one(test_loader,netG_GAN,netG_GANMSE,netMSE)

In [None]:
print_one(test_loader,netG_GAN,netG_GANMSE,netMSE)

In [None]:
print_one(test_loader,netG_GAN,netG_GANMSE,netMSE)

In [None]:
print_one(test_loader,netG_GAN,netG_GANMSE,netMSE)

In [None]:
print_one(test_loader,netG_GAN,netG_GANMSE,netMSE)

## Test the Discriminator
Pull a few examples, generate fakes, and try both the reals and fakes on the discriminator.

In [None]:
clean,mask = next(iter(test_loader))
# test discriminator on clean
netD_GANMSE(clean.to(device))

In [None]:
criterion = torch.nn.BCELoss()
inputs = clean.to(device)
mask_inputs = mask.to(device)

# only use the mask part of the outputs
raw_outputs = netG_GANMSE(inputs,mask_inputs)
fake = (1-mask_inputs)*raw_outputs + mask_inputs*inputs

label = torch.full((inputs.shape[0],),0,device=device)
# Introducing label noise
#         label = torch.rand((b_size,),device=device)*(fake_label[1]-fake_label[0])+fake_label[0]

# Classify all fake batch with D
output = netD_GANMSE(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
print(errD_fake.max())

In [None]:
netD_GANMSE(fake.detach()).view(-1)[4]

In [None]:
inputs[4,:,:]*mask_inputs[4,:,:]

In [None]:
mask_inputs[4,:,:]

In [None]:
mask[4,:,:]

In [None]:
fake[4,:,:]