#### Import libraries

In [2]:
# Standard library imports
import os
import time
import json
import datetime
import argparse
import yaml

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader

from loss_functions import DiceLoss  #, FocalLoss
from models_multiple_GPUs import *

from torch.utils.data import Dataset

In this notebook, we test some CNN architectures for training on the heat plume prediction dataset. We start by defining some helper functions.

In [3]:
def plot_results(unet, savepath, epoch_number, train_dataset, val_dataset):
    """
    This function plots two rows: one visualizing the results for the training image and one visualizing 
    the result for the validation image.
    """
    def plot_subplot(position, image, title='', vmin=None, vmax=None):
        plt.subplot(4, 3, position)
        plt.axis("off")
        plt.imshow(image, cmap="RdBu_r", vmin=vmin, vmax=vmax)

    def process_and_plot(images, masks, start_pos):
        unet.eval()
        with torch.no_grad():
            predictions = unet([img.unsqueeze(0) for img in images]).cpu()
            full_images = unet.concatenate_tensors([img.unsqueeze(0) for img in images]).squeeze().cpu()

        for i in range(3):
            plot_subplot(start_pos + i, full_images[i].cpu())
        
        plot_subplot(start_pos + 3, predictions[0, 0].cpu(), vmin=0, vmax=1)
        plot_subplot(start_pos + 4, masks.cpu()[0])
        plot_subplot(start_pos + 5, torch.abs(masks.cpu()[0] - predictions[0, 0].cpu()))

    plt.figure(figsize=(9, 12))
    
    # Adjust spacing between plots
    plt.subplots_adjust(hspace=0.1, wspace=0.1)

    train_image, train_mask = train_dataset[0]
    process_and_plot(train_image, train_mask, 1)

    val_image, val_mask = val_dataset[0]
    process_and_plot(val_image, val_mask, 7)

    os.makedirs(os.path.join(savepath, "figures"), exist_ok=True)
    plt.savefig(os.path.join(savepath, "figures", f"epoch_{epoch_number}.png"), bbox_inches='tight')
    plt.close()

In [4]:
class DatasetMultipleGPUs(Dataset):
    """
    Dataset to load images and their corresponding masks, apply transformations,
    and handle data augmentation. Supports splitting images into subdomains for 
    multi-GPU training.

    Attributes:
        img_labels (list): List of image filenames.
        img_dir (str): Directory containing images.
        mask_dir (str): Directory containing masks.
        transform (callable, optional): Transformation function for images.
        target_transform (callable, optional): Transformation function for masks.
        data_augmentation (callable, optional): Data augmentation function.
        size (int, optional): Size of the images.
        patch_size (int, optional): Size of the patches to crop from the images.
        subdomains_dist (tuple, optional): Distribution of subdomains (rows, cols).
    """

    def __init__(self, image_labels, image_dir, mask_dir, transform=None, target_transform=None, 
                 data_augmentation=None, size=2560, patch_size=1280, subdomains_dist=(2, 2)):
        self.img_labels = image_labels
        self.img_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.target_transform = target_transform
        self.data_augmentation = data_augmentation
        self.size = size
        self.patch_size = patch_size
        self.subdomains_dist = subdomains_dist

    def __len__(self):
        return len(self.img_labels)

    def __split_image(self, full_image):
        """
        Split the image into subdomains based on subdomains_dist.
        """
        subdomain_tensors = []
        subdomain_height = full_image.shape[1] // self.subdomains_dist[0]
        subdomain_width = full_image.shape[2] // self.subdomains_dist[1]

        for i in range(self.subdomains_dist[0]):
            for j in range(self.subdomains_dist[1]):
                subdomain = full_image[:, 
                                       i * subdomain_height: (i + 1) * subdomain_height,
                                       j * subdomain_width: (j + 1) * subdomain_width]
                subdomain_tensors.append(subdomain)

        return subdomain_tensors        

    def __crop_patch(self, full_image, full_mask):
        """
        Crop a patch from the full image and mask.
        """
        _, height, width = full_image.shape
        patch_height, patch_width = self.patch_size, self.patch_size

        if height < patch_height or width < patch_width:
            raise ValueError("Patch size must be smaller than image size.")
        
        top = random.randint(0, height - patch_height)
        left = random.randint(0, width - patch_width)
        
        image_patch = full_image[:, top:top + patch_height, left:left + patch_width]
        mask_patch = full_mask[:, top:top + patch_height, left:left + patch_width]

        return image_patch, mask_patch

    def __getitem__(self, idx):
        img_name = self.img_labels[idx]
        
        img_path = os.path.join(self.img_dir, f"{img_name}")                
        mask_path = os.path.join(self.mask_dir, f"{img_name}")

        image = torch.load(img_path)
        mask = torch.load(mask_path)

        image, mask = self.__crop_patch(image, mask)

        if self.data_augmentation:
            image, mask = self.data_augmentation(image, mask)

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            mask = self.target_transform(mask)
            
        images = self.__split_image(image)

        return images, mask


### Variables


In [49]:
subdomains_dist = (1,1)
image_dir = os.path.join("data", "Inputs")
mask_dir = os.path.join("data", "Labels")
patch_size = 1280

#### Define datasets and dataloaders

In [50]:
# Define datasets
train_dataset = DatasetMultipleGPUs(image_labels=["RUN_1.pt"], image_dir=image_dir, mask_dir=mask_dir, transform=None,
                                    target_transform=None, data_augmentation=None, patch_size=patch_size, subdomains_dist=subdomains_dist)

val_dataset = DatasetMultipleGPUs(image_labels=["RUN_2.pt"], image_dir=image_dir, mask_dir=mask_dir, transform=None,
                                    target_transform=None, data_augmentation=None, patch_size=patch_size, subdomains_dist=subdomains_dist)

test_dataset = DatasetMultipleGPUs(image_labels=["RUN_4.pt"], image_dir=image_dir, mask_dir=mask_dir, transform=None,
                                    target_transform=None, data_augmentation=None, patch_size=patch_size, subdomains_dist=subdomains_dist)

# Train hyperparams
batch_size = 1
batch_size_test = 1

# Define dataloaders
dataloader_train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) #, num_workers=6)
dataloader_val = DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False)# , num_workers=6)
dataloader_test = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False)# , num_workers=6)

In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.up = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.up(x)
        return x

class FCN(nn.Module):
    def __init__(self, num_blocks, channels):
        super(FCN, self).__init__()
        self.encoder_blocks = nn.ModuleList()
        self.decoder_blocks = nn.ModuleList()
        
        # Encoder
        in_channels = 3  # Assuming input is RGB images
        for out_channels in channels[:num_blocks]:
            self.encoder_blocks.append(EncoderBlock(in_channels, out_channels))
            in_channels = out_channels
        
        # Decoder
        for out_channels in channels[num_blocks:]:
            self.decoder_blocks.append(DecoderBlock(in_channels, out_channels))
            in_channels = out_channels

    def forward(self, x):
        # Encoder
        encoder_outputs = []
        for block in self.encoder_blocks:
            x = block(x)
            encoder_outputs.append(x)

        # Decoder
        for i, block in enumerate(self.decoder_blocks):
            x = block(x)
            if i < len(self.encoder_blocks) - 1:
                x = x + encoder_outputs[-(i + 2)]  # Skip connection
                
        return x


In [52]:
# Example usage
num_blocks = 4
channels = [16, 32, 64, 128, 64, 32, 16, 1]
model = FCN(num_blocks, channels)

# Define your loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# Training loop
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    # Training
    model.train()
    running_loss = 0.0
    for inputs, targets in tqdm(dataloader_train, desc=f'Epoch {epoch+1}/{num_epochs}'):
        inputs, targets = inputs[0].to(device), targets.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        
        # Compute loss
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(dataloader_train.dataset)
    print(f'Training Loss: {epoch_loss:.4f}')
    
    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader_val, desc=f'Validation'):
            inputs, targets = inputs[0].to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            
            # Compute loss
            loss = criterion(outputs, targets)
            
            val_loss += loss.item() * inputs.size(0)
    
    val_loss /= len(dataloader_val.dataset)
    print(f'Validation Loss: {val_loss:.4f}')

# Test
model.eval()
test_loss = 0.0
with torch.no_grad():
    for inputs, targets in tqdm(dataloader_test, desc=f'Testing'):
        inputs, targets = inputs[0].to(device), targets.to(device)
        
        # Forward pass
        outputs = model(inputs)
        
        # Compute loss
        loss = criterion(outputs, targets)
        
        test_loss += loss.item() * inputs.size(0)

test_loss /= len(dataloader_test.dataset)
print(f'Test Loss: {test_loss:.4f}')


Epoch 1/10:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 16, 640, 640])
torch.Size([1, 32, 320, 320])
torch.Size([1, 64, 160, 160])
torch.Size([1, 128, 80, 80])
torch.Size([1, 64, 160, 160])
torch.Size([1, 32, 320, 320])
torch.Size([1, 16, 640, 640])


Epoch 1/10: 100%|██████████| 1/1 [00:01<00:00,  1.69s/it]


Training Loss: 0.0788


Validation:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 16, 640, 640])
torch.Size([1, 32, 320, 320])
torch.Size([1, 64, 160, 160])
torch.Size([1, 128, 80, 80])
torch.Size([1, 64, 160, 160])
torch.Size([1, 32, 320, 320])
torch.Size([1, 16, 640, 640])


Validation: 100%|██████████| 1/1 [00:00<00:00,  2.41it/s]


Validation Loss: 0.0815


Epoch 2/10:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 16, 640, 640])
torch.Size([1, 32, 320, 320])
torch.Size([1, 64, 160, 160])
torch.Size([1, 128, 80, 80])
torch.Size([1, 64, 160, 160])
torch.Size([1, 32, 320, 320])
torch.Size([1, 16, 640, 640])


Epoch 2/10: 100%|██████████| 1/1 [00:01<00:00,  1.09s/it]


Training Loss: 0.0728


Validation:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 16, 640, 640])
torch.Size([1, 32, 320, 320])
torch.Size([1, 64, 160, 160])
torch.Size([1, 128, 80, 80])
torch.Size([1, 64, 160, 160])
torch.Size([1, 32, 320, 320])


Validation: 100%|██████████| 1/1 [00:00<00:00,  2.78it/s]


torch.Size([1, 16, 640, 640])
Validation Loss: 0.0771


Epoch 3/10:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 16, 640, 640])
torch.Size([1, 32, 320, 320])
torch.Size([1, 64, 160, 160])
torch.Size([1, 128, 80, 80])
torch.Size([1, 64, 160, 160])
torch.Size([1, 32, 320, 320])
torch.Size([1, 16, 640, 640])


Epoch 3/10: 100%|██████████| 1/1 [00:01<00:00,  1.15s/it]


Training Loss: 0.0770


Validation:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 16, 640, 640])
torch.Size([1, 32, 320, 320])
torch.Size([1, 64, 160, 160])
torch.Size([1, 128, 80, 80])
torch.Size([1, 64, 160, 160])
torch.Size([1, 32, 320, 320])


Validation: 100%|██████████| 1/1 [00:00<00:00,  2.66it/s]


torch.Size([1, 16, 640, 640])
Validation Loss: 0.0785


Epoch 4/10:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 16, 640, 640])
torch.Size([1, 32, 320, 320])
torch.Size([1, 64, 160, 160])
torch.Size([1, 128, 80, 80])
torch.Size([1, 64, 160, 160])
torch.Size([1, 32, 320, 320])
torch.Size([1, 16, 640, 640])


Epoch 4/10: 100%|██████████| 1/1 [00:01<00:00,  1.79s/it]


Training Loss: 0.0780


Validation:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 16, 640, 640])
torch.Size([1, 32, 320, 320])
torch.Size([1, 64, 160, 160])
torch.Size([1, 128, 80, 80])
torch.Size([1, 64, 160, 160])
torch.Size([1, 32, 320, 320])


Validation: 100%|██████████| 1/1 [00:00<00:00,  1.24it/s]


torch.Size([1, 16, 640, 640])
Validation Loss: 0.0809


Epoch 5/10:   0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([1, 16, 640, 640])
torch.Size([1, 32, 320, 320])
torch.Size([1, 64, 160, 160])
torch.Size([1, 128, 80, 80])
torch.Size([1, 64, 160, 160])
torch.Size([1, 32, 320, 320])
torch.Size([1, 16, 640, 640])


Epoch 5/10:   0%|          | 0/1 [00:02<?, ?it/s]


KeyboardInterrupt: 