In [1]:
# Imports
from einops import rearrange, einsum
import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

import sys, os

# sys.path.append("/pscratch/sd/j/jwl50/icl_odes_dev/src/")

In [22]:
import numpy as np
import torch
import h5py
import os
import torch.nn.functional as F
            
class NSTK(torch.utils.data.Dataset):
    def __init__(self,
                 factor,
                 num_pred_steps=1,
                 patch_size=256,
                 stride = 128,
                 train=True,
                 scratch_dir='./'):
        super(NSTK, self).__init__()
        
        self.paths = [os.path.join(scratch_dir,'1000_2048_2048_seed_2150.h5'),
                      os.path.join(scratch_dir,'8000_2048_2048_seed_2150.h5'),
                      os.path.join(scratch_dir,'16000_2048_2048_seed_2150.h5'),
                      ]

        self.RN = [1000,8000,32000]
        
        self.factor = factor
        self.num_pred_steps = num_pred_steps
        self.train = train
        self.patch_size = patch_size
        self.stride = stride
        
        with h5py.File(self.paths[0], 'r') as f:
            self.data_shape = f['w'].shape

        self.max_row = (self.data_shape[1] - self.patch_size) // self.stride + 1
        self.max_col = (self.data_shape[2] - self.patch_size) // self.stride + 1    

    
    def open_hdf5(self):
        self.datasets = [h5py.File(path, 'r')['w'] for path in self.paths]
        print([dataset.shape for dataset in self.datasets])

    def __getitem__(self, index):
        if not hasattr(self, 'dataset'):
            self.open_hdf5()
 
        shift = np.random.randint(1, self.num_pred_steps, 1)[0]
                        
        # Select a time index 
        index = index // 75  
        
        if self.train:    
            index = index * 2
        else:
            index = index * 2 + 1            
            
            
        # Randomly select a patch from the image

        patch_row = np.random.randint(0, self.max_row) * self.stride
        patch_col = np.random.randint(0, self.max_col) * self.stride
        
        #Select one of the training files
        random_dataset = np.random.randint(0, len(self.paths))
        
        Reynolds_number = self.RN[random_dataset]
        dataset = self.datasets[random_dataset]
        dataset_slice = dataset[:, patch_row:(patch_row + self.patch_size), patch_col:(patch_col + self.patch_size)]
        # return dataset_slice, Reynolds_number/40_000.
        return dataset
            
        # patch = torch.from_numpy(dataset[index, patch_row:(patch_row + self.patch_size), patch_col:(patch_col + self.patch_size)]).float().unsqueeze(0) 
        # future_patch = torch.from_numpy(dataset[index + shift, patch_row:(patch_row + self.patch_size), patch_col:(patch_col + self.patch_size)]).float().unsqueeze(0)            
        # lowres_patch = patch[:, ::self.factor, ::self.factor]
        # return lowres_patch, patch, future_patch,  F.one_hot(torch.tensor(shift),self.num_pred_steps), torch.tensor(Reynolds_number/40_000.)

    def __len__(self):
        return  45000 #30000 #self.length      

In [23]:
factor = 4
num_pred_steps = 3
scratch_dir = "/global/cfs/cdirs/m4633/foundationmodel/nskt_tensor"

In [24]:
def load_train_objs(factor, num_pred_steps, scratch_dir):
    train_set = NSTK(factor=factor, num_pred_steps=num_pred_steps,
                     scratch_dir=scratch_dir)
    val_set = NSTK(factor=factor, num_pred_steps=num_pred_steps, train=False,
                   scratch_dir=scratch_dir)
    return train_set, val_set

In [25]:
train_set, val_set = load_train_objs(factor, num_pred_steps, scratch_dir)

In [26]:
a = train_set.__getitem__(0)

[(1501, 2048, 2048), (1501, 2048, 2048), (1501, 2048, 2048)]


In [29]:
a[2]

array([[-4.2934012, -4.302176 , -4.3105164, ..., -4.256658 , -4.271743 ,
        -4.2835813],
       [-4.2780557, -4.2881775, -4.298044 , ..., -4.2388167, -4.254458 ,
        -4.26713  ],
       [-4.262793 , -4.274328 , -4.285776 , ..., -4.220864 , -4.237121 ,
        -4.250693 ],
       ...,
       [-4.3399734, -4.3451343, -4.3492737, ..., -4.309483 , -4.3232775,
        -4.3330336],
       [-4.3243566, -4.3306475, -4.33612  , ..., -4.291994 , -4.306154 ,
        -4.316532 ],
       [-4.3088336, -4.31633  , -4.3232045, ..., -4.2743845, -4.2889752,
        -4.3000484]], dtype=float32)

In [21]:
a[0].shape

(1501, 256, 256)

In [9]:
a = train_set.__getitem__(1)

(1501, 2048, 2048)


In [16]:
a[0].shape

torch.Size([1, 64, 64])

In [18]:
a[1].shape

torch.Size([1, 256, 256])

In [20]:
a[2].shape

torch.Size([1, 256, 256])

In [21]:
a[3]

tensor([0, 0, 1])

In [22]:
a[4]

tensor(0.2000)

In [4]:
def load_train_objs(args):
    train_set = NSTK(factor=args.factor, num_pred_steps=args.num_pred_steps,
                     scratch_dir=args.scratch_dir)
    val_set = NSTK(factor=args.factor, num_pred_steps=args.num_pred_steps,train=False,
                   scratch_dir=args.scratch_dir)
    
    unet_model,lowres_head,future_head = UNet(image_size=256, in_channels=1, out_channels=1, 
                                            base_width=args.base_width,
                                            num_pred_steps=args.num_pred_steps,
                                            Reynolds_number=True)
    
    
    model = GaussianDiffusionModel(base_model=unet_model.cuda(),
                                   lowres_model = lowres_head.cuda(),
                                   forecast_model = future_head.cuda(),
                                   betas=(1e-4, 0.02),
                                   n_T=args.time_steps, 
                                   prediction_type = args.prediction_type, 
                                   sampler = args.sampler)
    
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    return train_set,val_set, model, optimizer