In [1]:
import torch
import numpy as np
import os
from tqdm import tqdm

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed = 47
torch.manual_seed(seed)
np.random.seed(seed)
# data_file = '{}/exp01_dense_test.pt'.format(load_dir)
load_dir = './dat/scratch'
data_file = '{}/exp02_sparse_test.pt'.format(load_dir)
theta, x, y, y_idx = torch.load(data_file, map_location=torch.device('cpu'))
y_dict = {'theta': theta,
          'y': y,
          'y_idx': y_idx}
batch_size = theta.shape[0]
theta.shape, x.shape, y.shape, y_idx.shape

(torch.Size([1024, 4]),
 torch.Size([1024, 151, 2]),
 torch.Size([1024, 151, 2]),
 torch.Size([1024, 151]))

In [3]:
def save_obs_files(y, y_idx, out_dir='.', mode='sparse_test'):
    T = 30
    dt = 0.2
    batch_size = int((T/dt) + 1) #151
    num_batch = y_idx.shape[0] #1024
    N = num_batch * batch_size
    print('num_batch:', num_batch)
    print('N:', N)
    print('batch_size:', batch_size)

    data_all = np.zeros((2,N))
    time_till_all = np.zeros((2,N))
    obs_bins_all = np.ones((2,N))

    for i in tqdm(range(num_batch)):
        obs_bins = np.ones((2, batch_size))
        time_till = np.zeros((2, batch_size)) - 1
        data = y.detach().cpu().numpy()[i,:].T
        obs_bins[data==0] = 0
        obs_locs = np.where(obs_bins[0]==1)[0]
        data[obs_bins==0] = -1
        for j in range(batch_size):
            next_locs = obs_locs[obs_locs>j]
            if len(next_locs) > 0:
                time_till[:,j] = next_locs[0] - j
            else:
                time_till[:,j] = 0 ## setting 0 for no more observations
        time_till *= dt
        data_all[:,i*batch_size:(i+1)*batch_size] = data
        time_till_all[:,i*batch_size:(i+1)*batch_size] = time_till
        obs_bins_all[:,i*batch_size:(i+1)*batch_size] = obs_bins
        
    np.savetxt(os.path.join(out_dir, f'LV_obs_partial_{mode}.txt'), data_all.astype(np.float32))
    np.savetxt(os.path.join(out_dir, f'LV_obs_binary_{mode}.txt'), obs_bins_all.astype(np.float32))
    np.savetxt(os.path.join(out_dir, f'LV_time_till_{mode}.txt'), time_till_all.astype(np.float32))

save_obs_files(y, y_idx)

 31%|███       | 315/1024 [00:00<00:00, 3143.81it/s]

num_batch: 1024
N: 154624
batch_size: 151


100%|██████████| 1024/1024 [00:00<00:00, 3457.24it/s]


In [6]:
# Check files
data = np.loadtxt('LV_obs_partial_theta_test.txt')
obs_bins = np.loadtxt('LV_obs_binary_theta_test.txt')
time_till = np.loadtxt('LV_time_till_theta_test.txt')
data.shape, obs_bins.shape, time_till.shape

((2, 154624), (2, 154624), (2, 154624))