In [1]:
import argparse
import logging
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

from utils import data
import models, utils

import pandas as pd
from laspy.file import File
from pickle import dump, load

import torch.nn as nn
import torch.optim as optim
import torch.utils.data as udata
from torch.autograd import Variable
from sklearn.preprocessing import MinMaxScaler

%matplotlib inline

In [2]:
class Args(object):
    def __init__(self):
        self.data_path= 'data' # not used
        self.dataset= 'masked_pwc' # move lidar into datasets
        self.batch_size= 32
        self.model= 'lstm_unet1d'
        self.lr= 0.001
        self.num_epochs= 100
        self.n_data = 100000 # not used
        self.min_sep = 5 # not used
        self.valid_interval= 1 
        self.save_interval= 1
        self.seed = 0
        self.output_dir= 'lidar_experiments'
        self.experiment= None
        self.resume_training= False
        self.restore_file= None
        self.no_save= False
        self.step_checkpoints= False
        self.no_log= False
        self.log_interval= 100
        self.no_visual= False
        self.visual_interval= 100
        self.no_progress= False
        self.draft= False
        self.dry_run= False
        self.bias= False 
#         self.in_channels= 1 # maybe 6?
        self.test_num = 0
        # UNET
        self.residual = False
args=Args()

In [3]:
# gpu or cpu
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
utils.setup_experiment(args)
utils.init_logging(args)

[2020-09-15 15:05:57] COMMAND: /home/michael/python-virtual-environments/data/lib/python3.6/site-packages/ipykernel_launcher.py -f /home/michael/.local/share/jupyter/runtime/kernel-f465bba2-0619-4a39-8dfe-97045990b5f5.json
[2020-09-15 15:05:57] Arguments: {'data_path': 'data', 'dataset': 'masked_pwc', 'batch_size': 32, 'model': 'lstm_unet1d', 'lr': 0.001, 'num_epochs': 100, 'n_data': 100000, 'min_sep': 5, 'valid_interval': 1, 'save_interval': 1, 'seed': 0, 'output_dir': 'lidar_experiments', 'experiment': 'lstm-unet1d-Sep-15-15:05:57', 'resume_training': False, 'restore_file': None, 'no_save': False, 'step_checkpoints': False, 'no_log': False, 'log_interval': 100, 'no_visual': False, 'visual_interval': 100, 'no_progress': False, 'draft': False, 'dry_run': False, 'bias': False, 'test_num': 0, 'residual': False, 'experiment_dir': 'lidar_experiments/lstm_unet1d/lstm-unet1d-Sep-15-15:05:57', 'checkpoint_dir': 'lidar_experiments/lstm_unet1d/lstm-unet1d-Sep-15-15:05:57/checkpoints', 'log_dir'

In [4]:
# Saving model
# torch.save(model.state_dict(), MODEL_PATH)
# MODEL_PATH = "models/trained/dncnn1d_partialconv_5kdata_20epoch_08_12_20.pth"
# MODEL_PATH = "models/trained/unet1d_partialconv_10kdata_30epoch_3minsep_08_14_20.pth"

train_new_model = True

# Build data loaders, a model and an optimizer
if train_new_model:
    model = models.build_model(args).to(device)
else:
    model = models.build_model(args)
    model.load_state_dict(torch.load(MODEL_PATH))
    model.to(device)

print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=25, gamma=0.5)
logging.info(f"Built a model consisting of {sum(p.numel() for p in model.parameters()):,} parameters")

if args.resume_training:
    state_dict = utils.load_checkpoint(args, model, optimizer, scheduler)
    global_step = state_dict['last_step']
    start_epoch = int(state_dict['last_step']/(403200/state_dict['args'].batch_size))+1
else:
    global_step = -1
    start_epoch = 0

[2020-09-15 15:05:59] Built a model consisting of 73,120 parameters


UNet(
  (conv1): PartialConv1d(6, 32, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
  (conv2): PartialConv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
  (conv3): PartialConv1d(32, 64, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)
  (conv4): PartialConv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
  (conv5): PartialConv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,), bias=False)
  (conv6): PartialConv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,), bias=False)
  (conv7): ConvTranspose1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,), bias=False)
  (conv8): PartialConv1d(96, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
  (conv9): PartialConv1d(32, 3, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
)


## Load the data

In [5]:
# Training Data parameters
scan_line_gap_break = 7000 # threshold over which scan_gap indicates a new scan line
min_pt_count = 1700 # in a scan line, otherwise line not used
max_pt_count = 2000 # in a scan line, otherwise line not used
num_scan_lines = 150 # to use as training set
val_split = 0.2
seq_len = 64
mask_pts_per_seq = 2

# LSTM Model parameters
hidden_size = 100 # hidden features
num_layers = 2 # Default is 1, 2 is a stacked LSTM
output_dim = 3 # x,y,z

# Training parameters
num_epochs = 5
learning_rate = 0.01

In [6]:
first_return_df = pd.read_pickle("../../lidar/Data/parking_lot/first_returns_modified_164239.pkl")

In [7]:
# Note: x_scaled, y_scaled, and z_scaled MUST be the first 3 features
feature_list = [
    'x_scaled',
    'y_scaled',
    'z_scaled',
    'scan_line_idx',
    'scan_angle_deg',
    'abs_scan_angle_deg',
    'miss_pts_before'
]

### Mask the data

In [8]:
# miss_pts_before is the count of missing points before the point in question (scan gap / 5 -1)
first_return_df['miss_pts_before'] = round((first_return_df['scan_gap']/-5)-1)
first_return_df['miss_pts_before'] = [max(0,pt) for pt in first_return_df['miss_pts_before']]

# Add 'mask' column, set to one by default
first_return_df['mask'] = [1]*first_return_df.shape[0]

[2020-09-15 15:06:02] NumExpr defaulting to 4 threads.


In [9]:
def add_missing_pts(first_return_df):
    # Create a series with the indices of points after gaps and the number of missing points (max of 5)
    miss_pt_ser = first_return_df[(first_return_df['miss_pts_before']>0)&\
                                      (first_return_df['miss_pts_before']<6)]['miss_pts_before']
    # miss_pts_arr is an array of zeros that is the dimensions [num_missing_pts,cols_in_df]
    miss_pts_arr = np.zeros([int(miss_pt_ser.sum()),first_return_df.shape[1]])
    # Create empty series to collect the indices of the missing points
    indices = np.ones(int(miss_pt_ser.sum()))

    # Fill in the indices, such that they all slot in in order before the index
    i=0
    for index, row in zip(miss_pt_ser.index,miss_pt_ser):
        new_indices = index + np.arange(row)/row-1+.01
        indices[i:i+int(row)] = new_indices
        i+=int(row)
    # Create a Dataframe of the indices and miss_pts_arr
    miss_pts_df = pd.DataFrame(miss_pts_arr,index=indices,columns = first_return_df.columns)
    miss_pts_df['mask'] = [0]*miss_pts_df.shape[0]
    # Fill scan fields with NaN so we can interpolate them
    for col in ['scan_angle','scan_angle_deg']:
        miss_pts_df[col] = [np.NaN]*miss_pts_df.shape[0]
    # Concatenate first_return_df and new df
    full_df = first_return_df.append(miss_pts_df, ignore_index=False)
    # Resort so that the missing points are interspersed, and then reset the index
    full_df = full_df.sort_index().reset_index(drop=True)
    return full_df

In [10]:
# Don't add these in yet
# first_return_df = add_missing_pts(first_return_df)
# first_return_df[['scan_angle','scan_angle_deg']] = first_return_df[['scan_angle','scan_angle_deg']].interpolate()

In [11]:
# Add abs_scan_angle_deg as a feature
first_return_df['abs_scan_angle_deg'] = abs(first_return_df['scan_angle_deg'])

#### Extract tensor of scan lines

In [12]:
# Number of points per scan line
scan_line_pt_count = first_return_df.groupby('scan_line_idx').count()['gps_time']

# Identify the indices for points at end of scan lines
scan_break_idx = first_return_df[(first_return_df['scan_gap']>scan_line_gap_break)].index

In [13]:
# Create Tensor
line_count = ((scan_line_pt_count>min_pt_count)&(scan_line_pt_count<max_pt_count)).sum()
scan_line_tensor = torch.randn([line_count,min_pt_count,len(feature_list)])

# Collect the scan lines longer than min_pt_count
# For each, collect the first min_pt_count points
i=0
for line,count in enumerate(scan_line_pt_count):
    if (count>min_pt_count)&(count<max_pt_count):
        try:
            line_idx = scan_break_idx[line-1]
            scan_line_tensor[i,:,:] = torch.Tensor(first_return_df.iloc\
                                      [line_idx:line_idx+min_pt_count][feature_list].values)
            i+=1
        except RuntimeError:
            print("line: ",line)
            print("line_idx: ",line_idx)

### This e

In [14]:
def min_max_tensor(tensor):
    # Function takes a 3-D tensor, performs minmax scaling to [0,1] along the third dimension.
    # First 2 dimensions are flattened
    a,b,c = tensor.shape
    # Flatten first two dimensions
    flat_tensor = tensor.view(-1,c)
    sc =  MinMaxScaler()
    flat_norm_tensor = sc.fit_transform(flat_tensor)
    # Reshape to original
    output = flat_norm_tensor.reshape([a,b,c])
    return torch.Tensor(output), sc

In [15]:
scan_line_tensor_norm, sc = min_max_tensor(scan_line_tensor)

### Generate the data

In [16]:
def generate_samples(data,min_pt_count,seq_len,num_scan_lines,val_split,starting_line=1000):
    '''
    Function generates training and validation samples for predicting the next point in the sequence.
    Inputs:
        data: 3-Tensor with dimensions: i) the number of viable scan lines in the flight pass, 
                                        ii) the minimum number of points in the scan line,
                                        iii) 3 (xyz, or feature count)
    
    '''
    # Create generic x tensor
    x = torch.ones([(min_pt_count-seq_len)*num_scan_lines,seq_len,len(feature_list)]) 
    i=0
    # Cycle through the number of scan lines requested, starting somewhere in the middle
    for line_idx in range(starting_line,starting_line+num_scan_lines):
        x = sliding_windows(data[line_idx,:,:],seq_len,line_idx-starting_line, x)

    # Remove sequences with missing points
    x = x[x[:,:,6].sum(axis=1)==0.]
    
    # Remove the 'miss_pts_before' column
    x = x[:,:,:-1]
    
    # Train-Val split
    x_train,x_val = train_val_split(x,val_split)
    return x_train.transpose(1,2),x_val.transpose(1,2)

def sliding_windows(data, seq_length, line_num, x):
    for i in range(len(data)-seq_length):
        # Index considers previous lines
        idx = i+line_num*(min_pt_count-seq_length)
        _x = data[i:(i+seq_length)]
        _y = data[i+seq_length,:3] # Assumes xyz are the first 3 features in scan_line_tensor
        x[idx,:,:] = _x

    return x

def train_val_split(x,val_split):   
    # Training/Validation split
    # For now, we'll do the last part of the dataset as validation...shouldn't matter?
    train_val_split_idx = int(x.shape[0]*(1-val_split))
    x_train = x[:train_val_split_idx,:,:]
    x_val = x[train_val_split_idx:,:,:]
    
    return x_train,x_val

def add_mask(tensor,mask_pts_per_seq):
    # Given a 3-D tensor of all ones, returns a mask_tensor of same shape 
    # with random masking determined by mask_pts_per_seq
    mask_tensor = torch.ones(tensor.shape)
    seq_len = mask_tensor.shape[1]
    mask_idx = torch.randint(seq_len,(mask_tensor.shape[0],mask_pts_per_seq))
    for i,m in enumerate(mask_idx):
        mask_tensor[i,m,:] = 0
    return mask_tensor

In [17]:
x_train,x_val = generate_samples(scan_line_tensor_norm,min_pt_count,seq_len,num_scan_lines,val_split)

In [18]:
# Create mask tensor
mask_train = add_mask(x_train,mask_pts_per_seq)
mask_val = add_mask(x_val,mask_pts_per_seq)

In [19]:
# Dataloader class
class LidarLstmDataset(udata.Dataset):
    def __init__(self, x, mask):
        super(LidarLstmDataset, self).__init__()
        self.x = x
        self.mask = mask

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self,index):
        return self.x[index],self.mask[index]

In [20]:
# Create the dataloaders
train_dataset = LidarLstmDataset(x_train,mask_train)
val_dataset = LidarLstmDataset(x_val,mask_val)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, num_workers=4, shuffle=True)
valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=4, shuffle=True)

In [21]:
# Track moving average of loss values
train_meters = {name: utils.RunningAverageMeter(0.98) for name in (["train_loss"])}
valid_meters = {name: utils.AverageMeter() for name in (["valid_loss"])}
writer = SummaryWriter(log_dir=args.experiment_dir) if not args.no_visual else None

In [22]:
x_train.shape

torch.Size([154517, 6, 64])

In [23]:
# TRAINING
for epoch in range(start_epoch, args.num_epochs):
    if args.resume_training:
        if epoch %10 == 0:
            optimizer.param_groups[0]["lr"] /= 2
            print('learning rate reduced by factor of 2')

    train_bar = utils.ProgressBar(train_loader, epoch)
    for meter in train_meters.values():
        meter.reset()

    for batch_id, (clean, mask) in enumerate(train_bar):
        # dataloader returns [clean, mask] list
        model.train()
        global_step += 1
        inputs = clean.to(device)
        mask_inputs = mask.to(device)
        # only use the mask part of the outputs
        raw_outputs = model(inputs,mask_inputs)
        outputs = (1-mask_inputs)*raw_outputs + mask_inputs*inputs
        
        # TO DO, only run loss on masked part of output
        loss = F.mse_loss(outputs, inputs, reduction="sum") / (inputs.size(0) * 2)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        train_meters["train_loss"].update(loss)
        train_bar.log(dict(**train_meters, lr=optimizer.param_groups[0]["lr"]), verbose=True)

        if writer is not None and global_step % args.log_interval == 0:
            writer.add_scalar("lr", optimizer.param_groups[0]["lr"], global_step)
            writer.add_scalar("loss/train", loss.item(), global_step)
            gradients = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None], dim=0)
            writer.add_histogram("gradients", gradients, global_step)
            sys.stdout.flush()

    if epoch % args.valid_interval == 0:
        model.eval()
        for meter in valid_meters.values():
            meter.reset()

        valid_bar = utils.ProgressBar(valid_loader)
        
        for sample_id, (clean, mask) in enumerate(valid_bar):
            with torch.no_grad():
                inputs = clean.to(device)
                mask_inputs = mask.to(device)
                # only use the mask part of the outputs
                raw_output = model(inputs,mask_inputs)
                output = (1-mask_inputs)*raw_output + mask_inputs*inputs

                 # TO DO, only run loss on masked part of output
                val_loss = F.mse_loss(outputs, inputs, reduction="sum") / (inputs.size(0) * 2)

                valid_meters["valid_loss"].update(valid_psnr.item())

                ### Uncomment these when finished
                if writer is not None and sample_id < 10:
                    image = torch.cat([inputs, torch.mul(inputs, mask_inputs), output], dim=0)
                    image = torchvision.utils.make_grid(image.clamp(0, 1), nrow=3, normalize=False)
                    writer.add_image(f"valid_samples/{sample_id}", image, global_step)

        if writer is not None:
            writer.add_scalar("loss/valid", valid_meters['valid_loss'].avg, global_step)
            sys.stdout.flush()

        logging.info(train_bar.print(dict(**train_meters, **valid_meters, lr=optimizer.param_groups[0]["lr"])))
        utils.save_checkpoint(args, global_step, model, optimizer, score=valid_meters["valid_loss"].avg, mode="max")
    scheduler.step()

logging.info(f"Done training! Best PSNR {utils.save_checkpoint.best_score:.3f} obtained after step {utils.save_checkpoint.best_step}.")


epoch 00:   0%|          | 0/4829 [00:00<?, ?it/s]

RuntimeError: Given groups=1, weight of size [1, 1, 5], expected input[32, 6, 64] to have 1 channels, but got 6 channels instead

### Testing

In [None]:
# model3 = models.build_model(args)
# model3.load_state_dict(torch.load("models/trained/unet1d_partialconv_10kdata_30epoch_3minsep_08_14_20.pth"))
# model3.to(device)

# model5 = models.build_model(args)
# model5.load_state_dict(torch.load("models/trained/unet1d_partialconv_10kdata_30epoch_08_13_20.pth"))
# model5.to(device)

model10 = models.build_model(args)
model10.load_state_dict(torch.load("models/trained/unet1d_partialconv_10kdata_30epoch_10minsep_08_14_20.pth"))
model10.to(device)

## Analysis of first predicted point
Comparison to global mean, receptive field mean, next visible point.

### min_sep = 3

In [None]:
import pandas as pd

def first_pt_stats(model,min_sep):
    _,_,test_loader = data.build_dataset(args.dataset,
                                                   args.n_data, 
                                                   batch_size=args.n_data,
                                                   fix_datapoints=True,            
                                                   min_sep = min_sep,
                                                   test_num = 1)
    print("Min_sep: {}".format(min_sep))
    print("*"*30)
    for batch_id,(clean,mask) in enumerate(test_loader):
        print("Mean of clean signal: {:2.4f}".format(clean.mean()))
        outputs = model(clean.to(device),mask.to(device)).cpu()
        print("Mean first value (min_sep=3): {:2.4f}".format(outputs[:,:,0].mean()))

    # Collect the "means" we're comparing to
    mean_unmasked_sig = []
    mean_rf_sig = []
    first_unmasked = []

    # Collect the diffs with the first value
    mean_unmasked_sig_diff = []
    mean_rf_sig_diff = []
    first_unmasked_diff = []

    mask_length = (64-mask.sum(axis=2))
    for i in range(len(mask_length)):
        # Mean of unmasked signal
        mum = clean[i,0,int(mask_length[i]):].mean()
        mean_unmasked_sig.append(mum)
        # Mean of the unmasked receptive field 
        mrf = clean[i,0,int(mask_length[i]):21].mean()
        mean_rf_sig.append(mrf)
        # First unmasked value
        fu = clean[i,0,int(mask_length[i])]
        first_unmasked.append(fu)

        # The diffs
        mean_unmasked_sig_diff.append(abs(outputs[i,0,0]-mum).detach())
        mean_rf_sig_diff.append(abs(outputs[i,0,0]-mrf).detach())
        first_unmasked_diff.append(abs(outputs[i,0,0]-fu).detach())

    print("Mean of full unmasked signal: {:2.4f}".format(np.mean(mean_unmasked_sig)))
    print("Mean of receptive field signal [0,21]: {:2.4f}".format(np.mean(mean_rf_sig)))
    print("Mean of first visible value after mask: {:2.4f}".format(np.mean(first_unmasked)))

    print("First predicted value mean diff: full unmasked signal: {:2.4f} (SD: {:2.4f})"\
          .format(np.mean(mean_unmasked_sig_diff),np.std(mean_unmasked_sig_diff)))
    print("First predicted value mean diff: receptive field signal [0,21]: {:2.4f} (SD: {:2.4f})"\
          .format(np.mean(mean_rf_sig_diff),np.std(mean_rf_sig_diff)))
    print("First predicted value mean diff: first visible value after mask: {:2.4f} (SD: {:2.4f})"\
          .format(np.mean(first_unmasked_diff),np.std(first_unmasked_diff)))
    
    df_list = [min_sep,np.mean(mean_unmasked_sig),np.mean(mean_rf_sig),np.mean(first_unmasked),
              float(outputs[:,:,0].mean().detach()),
              np.mean(mean_unmasked_sig_diff),np.std(mean_unmasked_sig_diff),
              np.mean(mean_rf_sig_diff),np.std(mean_rf_sig_diff),
              np.mean(first_unmasked_diff),np.std(first_unmasked_diff)
              ]
    # print("Mean absolute diff of first predicted value and first visible after mask: {:2.4f}".format(np.mean(first_pred_unmasked_diff)))
    ### min_sep = 3
    return df_list

In [None]:
df_list3 = first_pt_stats(model3,3)

In [None]:
df_list5 = first_pt_stats(model5,5)

In [None]:
df_list10 = first_pt_stats(model10,10)

In [None]:
pd.DataFrame([df_list3,df_list5,df_list10],columns = ['min_sep',\
                                      'clean_sig_mean','receptive_field_mean',\
                                      'first_visible_mean','first_pred_mean',\
                                      'full_unmasked_diff_mean','full_unmasked_diff_sd',\
                                      'receptive_field_diff_mean','receptive_field_diff_sd',\
                                      'first_visible_diff_mean','first_visible_diff_sd'
                                     ]).T

## Examples

In [None]:
# Best PSNR 28.560
def mask_idx_f(mask):
    mask_start = int(np.argmin(mask[0]))
    mask_length = int((1-mask[0]).sum())
    mask_idx = range(mask_start,mask_start+mask_length)
     # No mask indices
    before = np.arange(mask.shape[2])[:mask_start]
    after = np.arange(mask.shape[2])[mask_start+mask_length:]
    no_mask_idx = np.append(before,after)
    return mask_idx,before, after, mask_length, mask_start

def print_one(loader,model):
    np.random.seed()
    clean,mask = next(iter(loader))
    outputs = model(clean.to(device),mask.to(device)).cpu()
    
    mask_idx,before_mask,after_mask,mask_length, mask_start = mask_idx_f(mask)

    outputs[0] * (1-mask[0]) + clean[0]*mask[0]    

    out = outputs[0] * (1-mask[0]) + clean[0]*mask[0]
    print("Mask Length: {}\tMask Start: {}".format(mask_length,mask_start))
    
    plt.figure(figsize=[15,10])
    plt.subplot(3,1,1)
    plt.plot(clean[0,0,:],'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')
    plt.title("True signal")

    plt.subplot(3,1,2)
    masked = clean[0]*mask[0]
    masked_plot = masked[:mask_start,]
    plt.plot(before_mask,masked[0,before_mask],'xb')
    plt.plot(after_mask,masked[0,after_mask],'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')

    plt.title("Masked signal")
    plt.subplot(3,1,3)
    plt.plot(out[0,:].detach(),'xb')
    plt.plot(mask_idx,np.zeros(len(mask_idx)),'--k')
    plt.plot(mask_idx,np.ones(len(mask_idx)),'--k')

    plt.title("Denoised signal")
    
    # Mean of the visible signal
    sig_mean = clean[0,0,mask_length:21].mean()
    print("First mask value: {:2.4f}\nMean of full signal: {:2.4f}\nMean of visible signal: {:2.4f}"\
          .format(out[0,0],clean[0,0,:21].mean(),sig_mean))

In [None]:
# Test loader is shuffled and allows test_num to force a certain mask shape
_,_,test_loader = data.build_dataset(args.dataset,
                                                   args.n_data, 
                                                   batch_size=args.n_data,
                                                   fix_datapoints=True,            
                                                   min_sep = 10,
                                                   test_num = 0)

In [None]:
print_one(test_loader,model10)

In [None]:
print_one(test_loader,model10)

In [None]:
print_one(test_loader,model10)

In [None]:
print_one(test_loader,model10)

In [None]:
print_one(test_loader,model10)

In [None]:
torch.Tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
              1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,1.,0.,0.,0.,0.,0.,0.,0.,0., 
              0., 1., 1., 1.,1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]).shape

In [None]:
c,m = next(iter(test_loader))
m.shape

In [None]:
torch.save(model.state_dict(), MODEL_PATH)
MODEL_PATH = "models/trained/unet1d_partialconv_100kdata_100epoch_08_21_20.pth"
