In [8]:
import torch
import torch.nn as nn
from vit_pytorch import ViT
import netCDF4 as nc
import numpy as np
from tqdm import tqdm
import logging
import xarray as xr

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

datadir = '../ERA5_reduced/'

class ViTFeatureExtractor(nn.Module):
    def __init__(self, image_size=(40, 45), patch_size=5, dim=512, depth=6, heads=8, mlp_dim=2048, channels=6):
        super(ViTFeatureExtractor, self).__init__()
        self.vit = ViT(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=dim,  # Output dimension of the ViT
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim,
            channels=channels
        )
    
    def forward(self, x):
        return self.vit(x)

class ViT_BiLSTM_FeatureExtractor(nn.Module):
    def __init__(self, image_size=(40, 45), patch_size=5, hidden_dim=512, num_layers=2):
        super(ViT_BiLSTM_FeatureExtractor, self).__init__()
        self.feature_extractor = ViTFeatureExtractor(image_size=image_size, patch_size=patch_size)
        
        # Bidirectional LSTM to model temporal dependencies
        self.bilstm = nn.LSTM(
            input_size=512,  # This should match the output dimension of ViT
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )
    
    def forward(self, x):
        batch_size, seq_length, num_channels, height, width = x.size()
        
        # Reshape to (batch_size * seq_length, num_channels, height, width)
        x = x.view(batch_size * seq_length, num_channels, height, width)
        
        # Feature extraction using ViT
        features = self.feature_extractor(x)
        
        # Reshape to (batch_size, seq_length, feature_dim)
        features = features.view(batch_size, seq_length, -1)
        
        # BiLSTM for temporal dependencies
        bilstm_out, _ = self.bilstm(features)
        
        return bilstm_out

def load_and_norm_single_variable_netcdf_data(my_var, vname, time=None):
    fname = f'{datadir}/{my_var}.2001-2020.anomaly.nc'
    
    with xr.open_dataset(fname) as ds:
        da = ds[vname]
        if time is not None:
            da = da.sel(time=slice(*time))
        da_max = da.max(['time','latitude','longitude']).data
        da_min = da.max(['time','latitude','longitude']).data
        
        my_vmax = np.maximum(da_max, -1*da_min)
        vmax, vmin = my_vmax, -1*my_vmax
        logging.debug(f'{my_var}: {vmin}, {vmax}')
   
        out = (da - vmin) / (vmax - vmin)
        
    return out

def load_and_norm_netcdf_data(time=None):
    t500_full = load_and_norm_single_variable_netcdf_data('t500', 'T', time=time)
    t850_full = load_and_norm_single_variable_netcdf_data('t850', 'T', time=time)
    z500_full = load_and_norm_single_variable_netcdf_data('z500', 'Z', time=time)
    z850_full = load_and_norm_single_variable_netcdf_data('z850', 'Z', time=time)
    t2_full   = load_and_norm_single_variable_netcdf_data('2t', 'VAR_2T', time=time)
    sp_full   = load_and_norm_single_variable_netcdf_data('sp', 'SP', time=time)

    data_list = [t500_full, t850_full, z500_full, z850_full, t2_full, sp_full]

    time = data_list[0].time
    for v in data_list:
        time = np.intersect1d(time, v.time)

    for v in data_list:
        v = v.sel(time=time)
    
    normalized_data = np.stack(data_list, axis=1)

    logging.debug(f"Data shape after stacking and normalization: {normalized_data.shape}")

    return normalized_data

# Ensure the input image size is compatible with the patch size
def pad_image_to_fit_patch_size(data, patch_size):
    _, _, height, width = data.shape
    pad_height = (patch_size - height % patch_size) % patch_size
    pad_width = (patch_size - width % patch_size) % patch_size
    padded_data = np.pad(data, ((0, 0), (0, 0), (0, pad_height), (0, pad_width)), mode='constant', constant_values=0)
    return padded_data

# Example usage to extract and store hidden states
model = ViT_BiLSTM_FeatureExtractor(image_size=(37, 45), patch_size=(37, 45))

# Load NetCDF data
netcdf_file_path = 'path_to_your_netcdf_file.nc'
input_data = load_and_norm_netcdf_data(time=('2016-06-26 00:00', '2016-06-30 12:00'))

# # Pad image size to be compatible with patch size (e.g., 5x5)
# input_data = pad_image_to_fit_patch_size(input_data, patch_size=5)

# Convert input_data to torch tensor and ensure the correct shape
input_data = torch.tensor(input_data, dtype=torch.float32)

# Ensure input_data has shape (batch_size, num_channels, seq_length, height, width)
# For example, if you have a single sequence of length 3 (hours), batch size 1:
input_data = input_data.unsqueeze(0)  # Add batch dimension


# Check if multiple GPUs are available and use DataParallel if so
if torch.cuda.device_count() > 1:
    logging.info(f"Using {torch.cuda.device_count()} GPUs for model parallelism.")
    model = nn.DataParallel(model)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
input_data = input_data.to(device)

# Process data in batches using tqdm for progress tracking
batch_size = 1
num_batches = input_data.size(0) // batch_size

logging.info(f"Starting feature extraction for {num_batches} batches")
all_hidden_states = []

for i in tqdm(range(num_batches), desc="Processing Batches"):
    batch_data = input_data[i*batch_size:(i+1)*batch_size]
    print(batch_data.shape)
    hidden_states = model(batch_data)
    all_hidden_states.append(hidden_states)

# Concatenate all hidden states
all_hidden_states = torch.cat(all_hidden_states, dim=0)

# Save hidden states for later use
torch.save(all_hidden_states, 'hidden_states.ViT.pt')
logging.info("Hidden states saved to hidden_states.pt")

2024-06-28 17:00:29,312 - DEBUG - t500: -6.6453399658203125, 6.6453399658203125
2024-06-28 17:00:29,328 - DEBUG - t850: -14.48138427734375, 14.48138427734375
2024-06-28 17:00:29,350 - DEBUG - z500: -1520.5, 1520.5
2024-06-28 17:00:29,371 - DEBUG - z850: -927.9365234375, 927.9365234375
2024-06-28 17:00:29,388 - DEBUG - 2t: -17.207305908203125, 17.207305908203125
2024-06-28 17:00:29,409 - DEBUG - sp: -952.28125, 952.28125
2024-06-28 17:00:29,412 - DEBUG - Data shape after stacking and normalization: (109, 6, 37, 45)
2024-06-28 17:00:29,416 - INFO - Starting feature extraction for 1 batches
Processing Batches:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 109, 6, 37, 45])


Processing Batches: 100%|██████████| 1/1 [00:00<00:00,  4.96it/s]
2024-06-28 17:00:29,620 - INFO - Hidden states saved to hidden_states.pt


In [2]:
device

device(type='cpu')

In [7]:
print(all_hidden_states.shape)
all_hidden_states

torch.Size([1, 109, 1024])


tensor([[[-0.0116,  0.0036,  0.0403,  ..., -0.0088, -0.0434, -0.0825],
         [-0.0189,  0.0049,  0.0592,  ..., -0.0127, -0.0551, -0.0879],
         [-0.0248,  0.0059,  0.0673,  ..., -0.0170, -0.0624, -0.0904],
         ...,
         [-0.0505,  0.0172,  0.0843,  ..., -0.0145, -0.0649, -0.0602],
         [-0.0456,  0.0146,  0.0796,  ..., -0.0109, -0.0533, -0.0457],
         [-0.0393,  0.0129,  0.0686,  ..., -0.0055, -0.0329, -0.0240]]],
       grad_fn=<CatBackward0>)