In [6]:
import sys  
sys.path.insert(0, '/home/mila/o/oussama.boussif/pde_oned/datamodule')
sys.path.insert(0, '/home/mila/o/oussama.boussif/pde_oned/models')
sys.path.insert(0, '/home/mila/o/oussama.boussif/pde_oned')
import pickle
from tqdm.notebook import tqdm
import time
import h5py
import numpy as np

#import moviepy.editor as mpe

from models.factory import FACTORY
import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from dataset import *

import matplotlib.pyplot as plt
import matplotlib.cm as cm

device = torch.device('cuda')

In [75]:
class HDF5DatasetImplicitGNN(Dataset):
    
    def __init__(self, 
                 path,
                 nt,
                 nx,
                 sampling='uniform',
                 mode='train', 
                 load_all=False,
                 samples = 256):
        
        assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']"
        
        f = h5py.File(path, 'r')
        self.mode = mode
        self.data = f[self.mode]
        self.dataset = f'pde_{nt}-{nx}'
        self.samples = samples
        self.sampling = sampling

        if load_all:
            data = {self.dataset: self.data[self.dataset][:]}
            f.close()
            self.data = data

    def __len__(self):
        return self.data[self.dataset].shape[0]

    def __getitem__(self, idx):
        
        x = self.data['x'][idx]
        t = self.data['t'][idx]
        u_hr = torch.from_numpy(self.data[self.dataset][idx]).unsqueeze(1) # T, 1, L
        T, _, L = u_hr.shape
        u_lr = u_hr[:,:,::2] # T, 1, L//2
        lr_coord = x[::2]

        if self.mode in ['train']:
            indices_left = np.setdiff1d(np.arange(0,L), np.arange(0,L)[::2])
            sample_lst = torch.tensor(sorted(np.random.choice(indices_left, self.samples, replace=False)))
            hr_coord = x[sample_lst]

            hr_points = u_hr[:,:,sample_lst].permute(0,2,1)

            return_tensors = {
            't': t,
            'sample_idx': sample_lst,
            'lr_frames': u_lr,
            'hr_frames': u_hr,
            'hr_points': hr_points, 
            'coords_hr': hr_coord,
            'coords_lr': lr_coord
            }
        else:
            hr_coord = x

            hr_points = u_hr.permute(0,2,1)

            return_tensors = {
            't': t,
            'lr_frames': u_lr,
            'hr_frames': u_hr,
            'hr_points': hr_points, 
            'coords_hr': hr_coord,
            'coords_lr': lr_coord 
        }

        return return_tensors

In [12]:
test_path = f'/home/mila/o/oussama.boussif/scratch/pdeone/data/CE_valid_E1_50.h5'
data = h5py.File(test_path, 'r')
portion = 0.2
idx = np.setdiff1d(np.arange(0,50), np.random.choice(50, int(portion*50), replace=False))

In [25]:
data['valid']['x'][0]

array([ 0.        ,  0.32653061,  0.65306122,  0.97959184,  1.30612245,
        1.63265306,  1.95918367,  2.28571429,  2.6122449 ,  2.93877551,
        3.26530612,  3.59183673,  3.91836735,  4.24489796,  4.57142857,
        4.89795918,  5.2244898 ,  5.55102041,  5.87755102,  6.20408163,
        6.53061224,  6.85714286,  7.18367347,  7.51020408,  7.83673469,
        8.16326531,  8.48979592,  8.81632653,  9.14285714,  9.46938776,
        9.79591837, 10.12244898, 10.44897959, 10.7755102 , 11.10204082,
       11.42857143, 11.75510204, 12.08163265, 12.40816327, 12.73469388,
       13.06122449, 13.3877551 , 13.71428571, 14.04081633, 14.36734694,
       14.69387755, 15.02040816, 15.34693878, 15.67346939, 16.        ])

In [26]:
data['valid']['dx'][0]

0.32

In [76]:
dataset_ = HDF5DatasetImplicitGNN(
                '/home/mila/o/oussama.boussif/scratch/pdeone/data/CE_train_E1.h5',
                250,
                50,
                sampling='uniform',
                mode='train', 
                samples = 12)

In [82]:
dataset_[0]['coords_hr']

array([ 0.32653061,  0.97959184,  1.63265306,  2.28571429,  3.59183673,
        6.20408163,  8.16326531,  9.46938776, 10.12244898, 10.7755102 ,
       11.42857143, 14.69387755])

In [81]:
dataset_[0]['coords_lr']

array([ 0.        ,  0.65306122,  1.30612245,  1.95918367,  2.6122449 ,
        3.26530612,  3.91836735,  4.57142857,  5.2244898 ,  5.87755102,
        6.53061224,  7.18367347,  7.83673469,  8.48979592,  9.14285714,
        9.79591837, 10.44897959, 11.10204082, 11.75510204, 12.40816327,
       13.06122449, 13.71428571, 14.36734694, 15.02040816, 15.67346939])

In [108]:
nums = [-2,1,-3,4,-1,2,1,-5,4]
answer = nums[0]

In [109]:
for i in range(len(nums)):
    s=nums[i]
    for j in range(i+1, len(nums)):
        s += nums[j]
        if s>answer:
            answer=s
            print(i,j)

0 1
0 3
0 5
0 6
1 5
1 6
3 5
3 6
