In [None]:
import torch
import torch.nn as nn

class UNet8x(nn.Module):
    def __init__(self):
        super(UNet8x, self).__init__()
        
        # Define encoding layers, now with 2 input channels (e.g. temeprature and elevation data)
        self.encoder1 = self.conv_block(2, 64)  # Change input channels from 1 to 2
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = self.conv_block(256, 512)
        
        # Define decoding layers
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(256 + 256, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(128 + 128, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(64 + 64, 64)
        
        # Additional upsampling layers for 10x output resolution
        self.upconv_final1 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)  # 2x
        self.upconv_final2 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)  # 4x
        self.upconv_final3 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)  # 8x
        self.output = nn.Conv2d(64, 1, kernel_size=1)  # Final output layer
        
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x1, x2):  # Accepting two inputs now
        # Encoder
        e1 = self.encoder1(torch.cat((x1, x2), dim=1))  # Concatenate the two images along the channel dimension
        p1 = self.pool(e1)
        e2 = self.encoder2(p1)
        p2 = self.pool(e2)
        e3 = self.encoder3(p2)
        p3 = self.pool(e3)
        
        # Bottleneck
        b = self.bottleneck(p3)
        
        # Decoder
        d3 = self.upconv3(b)
        d3 = torch.cat((d3, e3), dim=1)  # Skip connection
        d3 = self.decoder3(d3)
        d2 = self.upconv2(d3)
        d2 = torch.cat((d2, e2), dim=1)  # Skip connection
        d2 = self.decoder2(d2)
        d1 = self.upconv1(d2)
        d1 = torch.cat((d1, e1), dim=1)  # Skip connection
        d1 = self.decoder1(d1)
        
        # Additional upsampling for 10x resolution
        d_final1 = self.upconv_final1(d1)
        d_final2 = self.upconv_final2(d_final1)
        d_final3 = self.upconv_final3(d_final2)
        return self.output(d_final3)

# Initialize the model
model = UNet8x()

: 

In [4]:
input_tensor = torch.randn(1, 1, 16, 16)  # Example input
output_tensor = model(input_tensor, input_tensor)
print("Input shape:", input_tensor.shape)
print("Output shape:", output_tensor.shape)

Input shape: torch.Size([1, 1, 16, 16])
Output shape: torch.Size([1, 1, 128, 128])


In [5]:
import os

# Extract the day from the filename
def get_day_from_filename(file_path):
    base_name = os.path.basename(file_path)
    date_str = base_name[8:16]
    day = int(date_str[6:8])
    return day

In [6]:
import glob

# Directory containing the source files
directory_path = "/Users/fquareng/data/1h_2D_sel_cropped_blurred_x8_clustered"  # Replace with the path to your directory
directory_path = "/Users/fquareng/data/1h_2D_sel_cropped_clustered"  # Replace with the path to your directory

# Print the number of files found
print(f"Number of files found: {len(sorted(glob.glob(f"{directory_path}/0/*.nz")))}")
print(f"Number of files found: {len(sorted(glob.glob(f"{directory_path}/1/*.nz")))}")
print(f"Number of files found: {len(sorted(glob.glob(f"{directory_path}/2/*.nz")))}")

Number of files found: 9914
Number of files found: 3322
Number of files found: 1210


In [50]:
import torch
from torch.utils.data import Dataset
from netCDF4 import Dataset as NetCDFDataset
import numpy as np

class MultiVariableDataset(Dataset):
    def __init__(self, input_files, target_files, input_vars, target_vars, transform=None):
        """
        Args:
            input_files (list): List of input NetCDF file paths.
            target_files (list): List of target NetCDF file paths.
            input_vars (list): List of variable names to extract as input.
            target_vars (list): List of variable names to extract as target.
            transform (callable, optional): Optional transform to apply to the data.
        """
        # assert len(input_files) == len(target_files), "Input and target file lists must have the same length"
        ## assert # Check that the input and target variables are from the same square at the same time from the name of the file
        self.input_files = input_files
        self.target_files = target_files
        self.input_vars = input_vars
        self.target_vars = target_vars
        self.transform = transform

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_file = self.input_files[idx]
        target_file = self.target_files[idx]

        # final check that all variables are correct (e.g. same time, same square) from the actual lat lon data and the time in the dataset

        # Load input variables
        with NetCDFDataset(input_file, mode='r') as nc_file:
            inputs = np.stack([nc_file.variables[var][:] for var in self.input_vars], axis=0)

        # Load target variables
        with NetCDFDataset(target_file, mode='r') as nc_file:
            targets = np.stack([nc_file.variables[var][:] for var in self.target_vars], axis=0)

        # Load DEM
        base_name = os.path.basename(input_file)
        square_id = base_name.split('_')[:2]
        dem_file = f"/Users/fquareng/data/dem_squares_blurred_x8/dem_{square_id[0]}_{square_id[1]}_blurred_x8.nc"
        with NetCDFDataset(dem_file, mode='r') as nc_file:
            dem = nc_file.variables["HSURF"][:]

        # Convert to PyTorch tensors
        inputs = torch.tensor(inputs, dtype=torch.float32)
        dem = torch.tensor(dem, dtype=torch.float32)
        targets = torch.tensor(targets, dtype=torch.float32)

        # Apply optional transformations
        if self.transform:
            inputs, targets = self.transform(inputs, targets)

        return inputs, dem, targets


In [56]:
source_path = "/Users/fquareng/data/1h_2D_sel_cropped_blurred_x8_clustered"  
target_path = "/Users/fquareng/data/1h_2D_sel_cropped_clustered"  

# Get source files
source_file_paths = sorted(glob.glob(f"{source_path}/0/*.nz"))
# file_paths = sorted(glob.glob(f"{source_path}/1/*.nz"))
# file_paths = sorted(glob.glob(f"{source_path}/2/*.nz"))

# Get target files
target_file_paths = sorted(glob.glob(f"{target_path}/0/*.nz"))
# file_paths = sorted(glob.glob(f"{target_path}/1/*.nz"))
# file_paths = sorted(glob.glob(f"{target_path}/2/*.nz"))

In [57]:
# Partition the files
train_source_files = [f for f in source_file_paths if get_day_from_filename(f) <= 19]
val_source_files = [f for f in source_file_paths if get_day_from_filename(f) <= 19]
test_source_files = [f for f in source_file_paths if 26 <= get_day_from_filename(f) <= 30]

train_target_files = [f for f in target_file_paths if 26 <= get_day_from_filename(f) <= 30]
val_target_files = [f for f in target_file_paths if 20 <= get_day_from_filename(f) <= 25]
test_target_files = [f for f in target_file_paths if 20 <= get_day_from_filename(f) <= 25]

# check that training and target files are correct

print(f"Training files: {len(train_source_files)}")
print(f"Validation files: {len(val_source_files)}")
print(f"Test files: {len(test_source_files)}")

Training files: 62038
Validation files: 62038
Test files: 11232


In [55]:
from torch.utils.data import DataLoader

# Example file list and variable names
input_vars = ["T_2M"]
target_vars = ["T_2M"]

# Create datasets
train_dataset = MultiVariableDataset(train_source_files, train_target_files, input_vars, target_vars)
val_dataset = MultiVariableDataset(val_source_files, val_target_files, input_vars, target_vars)
test_dataset = MultiVariableDataset(test_source_files, test_target_files, input_vars, target_vars)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Store loaders in a dictionary
dataloaders = {"train": train_loader, "val": val_loader, "test": test_loader}

print("Data loaders created successfully")
print("Number of training batches:", len(train_loader))
print("Number of validation batches:", len(val_loader))
print("Number of test batches:", len(test_loader))
print("Number of training samples:", len(train_dataset))
print("Number of validation samples:", len(val_dataset))
print("Number of test samples:", len(test_dataset))

Data loaders created successfully
Number of training batches: 3878
Number of validation batches: 3878
Number of test batches: 702
Number of training samples: 62038
Number of validation samples: 62038
Number of test samples: 11232


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Define the training function
def train_model(model, dataloaders, criterion, optimizer, num_epochs, device):
    model.to(device)
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("-" * 10)

        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            for inputs, dem, targets in dataloaders[phase]:
                inputs = inputs.to(device)
                dem = dem.to(device).unsqueeze(1)  # Add channel dimension to DEM
                targets = targets.to(device).unsqueeze(1)  # Add channel dimension to targets

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs, dem)
                    loss = criterion(outputs, targets)

                    # Backward pass and optimization in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # Accumulate loss
                running_loss += loss.item() * inputs.size(0)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            print(f"{phase} Loss: {epoch_loss:.4f}")

    print("Training complete.")
    return model

# Main script
if __name__ == "__main__":
    # Set random seed for reproducibility
    torch.manual_seed(0)

    source_path = "/Users/fquareng/data/1h_2D_sel_cropped_blurred_x8_clustered"  
    target_path = "/Users/fquareng/data/1h_2D_sel_cropped_clustered"  

    # Get source files
    source_file_paths = sorted(glob.glob(f"{source_path}/0/*.nz"))
    # source_file_paths = sorted(glob.glob(f"{source_path}/1/*.nz"))
    # source_file_paths = sorted(glob.glob(f"{source_path}/2/*.nz"))

    # Get target files
    target_file_paths = sorted(glob.glob(f"{target_path}/0/*.nz"))
    # target_file_paths = sorted(glob.glob(f"{target_path}/1/*.nz"))
    # target_file_paths = sorted(glob.glob(f"{target_path}/2/*.nz"))
    
    input_vars = ["T_2M"]
    target_vars = ["T_2M"]

    # Partition the files
    train_source_files = [f for f in source_file_paths if get_day_from_filename(f) <= 19]
    val_source_files = [f for f in source_file_paths if get_day_from_filename(f) <= 19]
    test_source_files = [f for f in source_file_paths if 26 <= get_day_from_filename(f) <= 30]

    train_target_files = [f for f in target_file_paths if 26 <= get_day_from_filename(f) <= 30]
    val_target_files = [f for f in target_file_paths if 20 <= get_day_from_filename(f) <= 25]
    test_target_files = [f for f in target_file_paths if 20 <= get_day_from_filename(f) <= 25]

    # Create datasets
    train_dataset = MultiVariableDataset(train_source_files, train_target_files, input_vars, target_vars)
    val_dataset = MultiVariableDataset(val_source_files, val_target_files, input_vars, target_vars)
    test_dataset = MultiVariableDataset(test_source_files, test_target_files, input_vars, target_vars)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    # Store loaders in a dictionary
    dataloaders = {"train": train_loader, "val": val_loader, "test": test_loader}

    # Initialize model, loss function, and optimizer
    model = UNet8x()
    criterion = nn.MSELoss()  # Mean Squared Error Loss
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Train the model
    model = train_model(model, dataloaders, criterion, optimizer, num_epochs=25, device=device)

    # Save the trained model
    torch.save(model.state_dict(), "unet8x_model.pth")
    print("Model saved as 'unet8x_model.pth'")