In [1]:
import torch
import torch.nn as nn

In [25]:
import torch
import torch.nn as nn

class UNet10x(nn.Module):
    def __init__(self):
        super(UNet10x, self).__init__()
        
        # Define encoding layers
        self.encoder1 = self.conv_block(1, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = self.conv_block(256, 512)
        
        # Define decoding layers
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(256 + 256, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(128 + 128, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(64 + 64, 64)
        
        # Additional upsampling layers for 10x output resolution
        self.upconv_final1 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)  # 2x
        self.upconv_final2 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)  # 4x
        self.upconv_final3 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)  # 8x
        self.output = nn.Conv2d(64, 1, kernel_size=1)  # Final output layer
        
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)
        p1 = self.pool(e1)
        e2 = self.encoder2(p1)
        p2 = self.pool(e2)
        e3 = self.encoder3(p2)
        p3 = self.pool(e3)
        
        # Bottleneck
        b = self.bottleneck(p3)
        
        # Decoder
        d3 = self.upconv3(b)
        d3 = torch.cat((d3, e3), dim=1)  # Skip connection
        d3 = self.decoder3(d3)
        d2 = self.upconv2(d3)
        d2 = torch.cat((d2, e2), dim=1)  # Skip connection
        d2 = self.decoder2(d2)
        d1 = self.upconv1(d2)
        d1 = torch.cat((d1, e1), dim=1)  # Skip connection
        d1 = self.decoder1(d1)
        
        # Additional upsampling for 10x resolution
        d_final1 = self.upconv_final1(d1)
        d_final2 = self.upconv_final2(d_final1)
        d_final3 = self.upconv_final3(d_final2)
        return self.output(d_final3)

# Initialize the model
model = UNet10x()

In [26]:
input_tensor = torch.randn(1, 1, 64, 64)  # Example input
output_tensor = model(input_tensor)
print("Input shape:", input_tensor.shape)   # (1, 1, 64, 64)
print("Output shape:", output_tensor.shape)

Input shape: torch.Size([1, 1, 64, 64])
Output shape: torch.Size([1, 1, 512, 512])


In [27]:
import os

# Extract the day from the filename
def get_day_from_filename(file_path):
    base_name = os.path.basename(file_path)
    date_str = base_name[4:12]  # Extract YYYYMMDD part
    day = int(date_str[6:8])    # Extract the DD part
    return day

In [28]:
import glob

# Directory containing the files
directory_path = "/Users/fquareng/data/1h_2D_sel_blurred"  # Replace with the path to your directory

# Get all matching files (recursively if needed)
file_paths = sorted(glob.glob(f"{directory_path}/lffd*.nz"))  # Adjust pattern as necessary

# Print the number of files found
print(f"Number of files found: {len(file_paths)}")

Number of files found: 3651


In [29]:
# Partition the files
train_files = [f for f in file_paths if get_day_from_filename(f) <= 19]
val_files = [f for f in file_paths if 26 <= get_day_from_filename(f) <= 30]
test_files = [f for f in file_paths if 20 <= get_day_from_filename(f) <= 25]

print(f"Training files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")
print(f"Test files: {len(test_files)}")

Training files: 2254
Validation files: 599
Test files: 726


In [49]:
import torch
from torch.utils.data import Dataset
from netCDF4 import Dataset as NetCDFDataset
import numpy as np

class MultiVariableDataset(Dataset):
    def __init__(self, input_files, target_files, input_vars, target_vars, transform=None):
        """
        Args:
            input_files (list): List of input NetCDF file paths.
            target_files (list): List of target NetCDF file paths.
            input_vars (list): List of variable names to extract as input.
            target_vars (list): List of variable names to extract as target.
            transform (callable, optional): Optional transform to apply to the data.
        """
        assert len(input_files) == len(target_files), "Input and target file lists must have the same length"
        self.input_files = input_files
        self.target_files = target_files
        self.input_vars = input_vars
        self.target_vars = target_vars
        self.transform = transform

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_file = self.input_files[idx]
        target_file = self.target_files[idx]

        # Load input variables
        with NetCDFDataset(input_file, mode='r') as nc_file:
            inputs = np.stack([nc_file.variables[var][:] for var in self.input_vars], axis=0)

        # Load target variables
        with NetCDFDataset(target_file, mode='r') as nc_file:
            targets = np.stack([nc_file.variables[var][:] for var in self.target_vars], axis=0)

        # Convert to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        targets = torch.tensor(targets, dtype=torch.float32)

        # Apply optional transformations
        if self.transform:
            inputs, targets = self.transform(inputs, targets)

        return inputs, targets

In [50]:
import glob

# Directory containing the files
inputs_path = "/Users/fquareng/data/1h_2D_sel_blurred"  # Replace with the path to your directory
targets_path = "/Users/fquareng/data/1h_2D_sel"  # Replace with the path to your directory

# Get all matching files (recursively if needed)
input_files = sorted(glob.glob(f"{inputs_path}/lffd*.nz"))  # Adjust pattern as necessary
target_files = sorted(glob.glob(f"{targets_path}/lffd*.nz"))  # Adjust pattern as necessary

# Print the number of files found
assert len(input_files) == len(target_files), print("Size source and target not matching.")
print(f"Number of files found: {len(input_files)}")
print(f"Number of files found: {len(target_files)}")

Number of files found: 3651
Number of files found: 3651


In [51]:
from torch.utils.data import DataLoader

# Example file list and variable names
input_vars = ["RELHUM_2M", "T_2M", "PS"]
target_vars = ["RELHUM_2M", "T_2M", "PS"]

# Initialize the dataset
dataset = MultiVariableDataset(input_files, target_files, input_vars, target_vars)

# Create a DataLoader
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Iterate through the DataLoader
for inputs, targets in dataloader:
    print(f"Inputs shape: {inputs.shape}, Targets shape: {targets.shape}")  # Shape will be [batch_size, len(variable_names), H, W]

Inputs shape: torch.Size([16, 3, 1, 193, 193]), Targets shape: torch.Size([16, 3, 1, 1542, 1542])
Inputs shape: torch.Size([16, 3, 1, 193, 193]), Targets shape: torch.Size([16, 3, 1, 1542, 1542])
Inputs shape: torch.Size([16, 3, 1, 193, 193]), Targets shape: torch.Size([16, 3, 1, 1542, 1542])


KeyboardInterrupt: 