In [None]:
def predict(dataloader, device='mps', k_iwae=2): # how will we decide to set target_x, or choose not to mask so you use all the points in the interpolation
    pred_mean, pred_std = [], []
    masks = []
    targets = []
    tp =[]
    np.random.seed(0)
    with torch.no_grad():
        for batch in dataloader:
            batch_len = batch.shape[0]
            batch = batch.to(device)
            subsampled_mask = batch[:,:,:,4]
            recon_mask = batch[:,:,:,5]
            context_y = torch.cat((batch[:,:,:,1] * subsampled_mask, subsampled_mask), 1).transpose(2,1)
            px, qz = net.get_reconstruction(batch[:, 0, :,0], context_y, batch[:, 0, :,0], num_samples=k_iwae)
            pred_mean.append(px.mean.cpu().numpy())
            pred_std.append(torch.exp(0.5 * px.logvar).cpu().numpy())
            targets.append((batch[:, :, :,1]).cpu().numpy())
            masks.append(subsampled_mask.cpu().numpy())
            tp.append(batch[:, 0, :,0].cpu().numpy())
      
    pred_mean = np.concatenate(pred_mean, axis=1)
    pred_std = np.concatenate(pred_std, axis=1)
    targets = np.concatenate(targets, axis=0)
    masks = np.concatenate(masks, axis=0)
    tp = np.concatenate(tp, axis=0)
    print(pred_mean.shape, pred_std.shape, targets.shape, masks.shape, tp.shape)
    inputs = np.ma.masked_where(masks < 1., targets)
    inputs = np.transpose(inputs, [0,2,1])
    # reparam trick
    preds = np.random.randn(k_iwae//2, k_iwae, pred_mean.shape[1], pred_mean.shape[2], pred_mean.shape[3]) * pred_std + pred_mean
    preds = preds.reshape(-1, pred_mean.shape[1], pred_mean.shape[2], pred_mean.shape[3])
    return preds

def get_latent_dist(dataloader, device='mps', k_iwae=2):
    qz_mean, qz_std = [], []
    np.random.seed(0)
    with torch.no_grad():
        for batch in dataloader:
            batch_len = batch.shape[0]
            batch = batch.to(device)
            subsampled_mask = batch[:,:,:,4]
            recon_mask = batch[:,:,:,5]
            context_y = torch.cat((batch[:,:,:,1] * subsampled_mask, subsampled_mask), 1).transpose(2,1)
            # context_x = train_batch[:,0,:,0], where we'd set a target x
            #print(batch[:, 0, :,0].shape)
            px, qz = net.get_reconstruction(batch[:, 0, :,0], context_y, batch[:, 0, :,0], num_samples=k_iwae)
            qz_mean.append(qz.mean.cpu().numpy())
            qz_std.append(torch.exp(0.5 * qz.logvar).cpu().numpy())
   
    qz_mean = np.concatenate(qz_mean, axis=0)
    qz_std = np.concatenate(qz_std, axis=0)
    print(qz_mean.shape, qz_std.shape)
    preds = np.random.randn(k_iwae//2, k_iwae, qz_mean.shape[0], qz_mean.shape[1], qz_mean.shape[2]) * qz_std + qz_mean
    preds = preds.reshape(-1, qz_mean.shape[0], qz_mean.shape[1], qz_mean.shape[2])
    median = preds.mean(0)
    print(median.shape)
    return median

In [None]:
def train(net, optimizer,epoch, train_loader, args, device="cuda", frac=0.5, errors=False, beta=0.1):
      
    train_loss = 0.
    train_n = 0.
    avg_loglik, avg_kl, mse, mae = 0., 0., 0., 0.
    for i, train_batch in enumerate(train_loader):
        batch_len = train_batch.shape[0] 
        #train_batch[:,:,2] = torch.ones((train_batch[:,:,3].shape))
        recon_mask, subsampled_mask = make_masks(train_batch, frac=0.5)
        
        train_batch = torch.cat((train_batch, torch.unsqueeze(subsampled_mask, 2), torch.unsqueeze(recon_mask, 2)), axis=-1)
        # print(torch.unsqueeze(subsampled_mask, 2).shape)
        # print(train_batch.shape)
        # train_batch[:,:,4:5] = torch.unsqueeze(recon_mask[:,:], 2)
        # train_batch[:,:,3:4] = torch.unsqueeze(subsampled_mask[:,:], 2)
        train_batch = train_batch.to(device)
        x = train_batch[:,:,0]
        y = train_batch[:,:,1:2]
        subsampled_mask = train_batch[:,:,3:4]
        recon_mask = train_batch[:,:,4:5]
        if errors:
            sample_weight = train_batch[:,:,2:3]
        else:
            sample_weight = 1.
        # weights for loss in analogy to standard weighted least squares error 

        seqlen = train_batch.size(1) 
        # subsampled flux values and their corresponding masks....
        context_y = torch.cat((
            y * subsampled_mask, subsampled_mask
        ), -1) 
        recon_context_y = torch.cat((            # flux values with only recon_mask values showing
                y * recon_mask, recon_mask
            ), -1) 
# format: compute_unsupervised_loss(self, context_x, context_y, target_x, target_y, num_samples=1, beta=1):
        loss_info = net.compute_unsupervised_loss(
            x,
            context_y,  
            x,  # can pick the points we want to project to
            recon_context_y,
            num_samples=args.k_iwae, # 1? 
            beta=beta, # beta i s a 
            # optional, will be zero if not set
            sample_weight = sample_weight,

        )
        optimizer.zero_grad()
        loss_info.composite_loss.backward()
        optimizer.step()
        #scheduler.step()
        train_loss += loss_info.composite_loss.item() * batch_len
        avg_loglik += loss_info.loglik * batch_len
        avg_kl += loss_info.kl * batch_len
        mse += loss_info.mse * batch_len
        mae += loss_info.mae * batch_len
        train_n += batch_len
        
    if epoch % 100 == 0:
        print(
            'Iter: {}, train loss: {:.4f}, avg nll: {:.4f}, avg kl: {:.4f}, '
            'mse: {:.6f}, mae: {:.6f}'.format(
                epoch,
                train_loss / train_n,
                -avg_loglik / train_n,
                avg_kl / train_n,
                mse / train_n,
                mae / train_n
            )
        )
        
    return -avg_loglik / train_n, mse / train_n, y, recon_mask, subsampled_mask