In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import wandb
from sklearn.model_selection import train_test_split

In [None]:
# Constants
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LR_SHAPE = (16, 16)  # Low resolution shape
HR_SHAPE = (128, 128)  # High resolution shape
BATCH_SIZE = 4
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3
INPUT_PATH = "/home/diya/Projects/super_resolution/dataset/"  # Update this path

In [None]:
def setup_wandb():
    """Initialize Weights & Biases tracking"""
    wandb.init(
        project="flow-super-resolution-fno",
        config={
            "learning_rate": LEARNING_RATE,
            "batch_size": BATCH_SIZE,
            "epochs": NUM_EPOCHS,
            "model": "FNO3D",
            "architecture": "modes1=8, modes2=8, modes3=8, width=64"
        }
    )

In [8]:
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.randn(in_channels, out_channels, modes1, modes2, 2))
        self.weights2 = nn.Parameter(self.scale * torch.randn(in_channels, out_channels, modes1, modes2, 2))

    def compl_mul2d(self, input, weights):
        return torch.stack([
            input[..., 0] * weights[..., 0] - input[..., 1] * weights[..., 1],
            input[..., 0] * weights[..., 1] + input[..., 1] * weights[..., 0]
        ], dim=-1)

    def forward(self, x):
        batchsize = x.shape[0]
        size1, size2 = x.shape[1], x.shape[2]

        x_ft = torch.fft.rfft2(x, dim=[1, 2])
        x_ft = torch.stack([x_ft.real, x_ft.imag], dim=-1)

        out_ft = torch.zeros(batchsize, size1, size2//2 + 1, 2, device=x.device)
        out_ft[:, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, -self.modes1:, :self.modes2], self.weights2)

        x = torch.complex(out_ft[..., 0], out_ft[..., 1])
        x = torch.fft.irfft2(x, s=(size1, size2), dim=[1, 2])
        return x

In [9]:
class FNO2d(nn.Module):
    def __init__(self, modes1, modes2, width=64, in_channels=4, out_channels=4):
        super(FNO2d, self).__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width

        self.fc0 = nn.Linear(in_channels, self.width)
        
        self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)
        
        # Increased intermediate dimensions for better upscaling
        self.fc1 = nn.Linear(self.width, 256)
        self.fc2 = nn.Linear(256, out_channels)
        
        # Learnable interpolation weight
        self.alpha = nn.Parameter(torch.tensor(0.5))

    def forward(self, x):
        x_original = x
        
        x = self.fc0(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        
        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        
        # Upscale to target resolution before final projection
        x = F.interpolate(x, size=HR_SHAPE, mode='bicubic', align_corners=False)
        
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        x = x.permute(0, 3, 1, 2)
        
        # Bicubic interpolation of original input
        x_bicubic = F.interpolate(x_original, size=HR_SHAPE, mode='bicubic', align_corners=False)
        
        return self.alpha * x + (1 - self.alpha) * x_bicubic

In [2]:
import os
import numpy as np
import pandas as pd
import torch
from common.config import INPUT_PATH, LR_SHAPE, HR_SHAPE

class FileNotFoundOrEmptyError(Exception):
    """Custom exception for file not found or empty file cases."""
    pass

def verify_file_exists(filepath):
    """
    Verify if a file exists and is not empty.
    
    Args:
        filepath (str): Path to the file to verify
        
    Returns:
        bool: True if file exists and is not empty
        
    Raises:
        FileNotFoundOrEmptyError: If file doesn't exist or is empty
    """
    if not os.path.exists(filepath):
        raise FileNotFoundOrEmptyError(f"File not found: {filepath}")
    
    if os.path.getsize(filepath) == 0:
        raise FileNotFoundOrEmptyError(f"File is empty: {filepath}")
    
    return True

def load_binary_file(filepath, shape):
    """
    Load a binary file with error handling.
    
    Args:
        filepath (str): Path to the binary file
        shape (tuple): Expected shape of the data
        
    Returns:
        np.ndarray: Loaded and reshaped data
        
    Raises:
        ValueError: If data cannot be reshaped to expected shape
        FileNotFoundOrEmptyError: If file doesn't exist or is empty
    """
    try:
        verify_file_exists(filepath)
        data = np.fromfile(filepath, dtype="<f4")
        
        expected_size = np.prod(shape)
        if data.size != expected_size:
            raise ValueError(
                f"Data size mismatch. Expected {expected_size} elements "
                f"for shape {shape}, but got {data.size} elements"
            )
        
        return data.reshape(shape)
    
    except Exception as e:
        raise type(e)(
            f"Error loading file {filepath}: {str(e)}"
        ) from e

def load_csv_data(mode='train'):
    """
    Load CSV metadata file.
    
    Args:
        mode (str): Dataset mode ('train', 'val', or 'test')
        
    Returns:
        pd.DataFrame: Loaded CSV data
    """
    csv_path = f'{INPUT_PATH}{mode}.csv'
    try:
        verify_file_exists(csv_path)
        return pd.read_csv(csv_path)
    except Exception as e:
        raise type(e)(
            f"Error loading CSV file {csv_path}: {str(e)}"
        ) from e

def get_xy(idx, csv_file, mode='train'):
    """
    Load LR and HR data for a single sample.
    
    Args:
        idx (int): Sample index
        csv_file (pd.DataFrame): CSV metadata
        mode (str): Dataset mode ('train', 'val', or 'test')
        
    Returns:
        tuple: (LR data, HR data) for train/val, or LR data for test
    """
    try:
        # Validate paths
        LR_path = f"{INPUT_PATH}flowfields/LR/{mode}"
        HR_path = f"{INPUT_PATH}flowfields/HR/{mode}" if mode != 'test' else None
        
        if not os.path.exists(LR_path):
            raise FileNotFoundOrEmptyError(f"LR directory not found: {LR_path}")
        if HR_path and not os.path.exists(HR_path):
            raise FileNotFoundOrEmptyError(f"HR directory not found: {HR_path}")
        
        # Load LR data
        lr_files = {
            'rho': f"{LR_path}/{csv_file['rho_filename'][idx]}",
            'ux': f"{LR_path}/{csv_file['ux_filename'][idx]}",
            'uy': f"{LR_path}/{csv_file['uy_filename'][idx]}",
            'uz': f"{LR_path}/{csv_file['uz_filename'][idx]}"
        }
        
        lr_data = {
            name: load_binary_file(filepath, LR_SHAPE)
            for name, filepath in lr_files.items()
        }
        
        X = torch.stack([
            torch.from_numpy(lr_data[name]).float()
            for name in ['rho', 'ux', 'uy', 'uz']
        ], dim=2)
        
        if mode != 'test':
            # Load HR data
            hr_files = {
                'rho': f"{HR_path}/{csv_file['rho_filename'][idx]}",
                'ux': f"{HR_path}/{csv_file['ux_filename'][idx]}",
                'uy': f"{HR_path}/{csv_file['uy_filename'][idx]}",
                'uz': f"{HR_path}/{csv_file['uz_filename'][idx]}"
            }
            
            hr_data = {
                name: load_binary_file(filepath, HR_SHAPE)
                for name, filepath in hr_files.items()
            }
            
            Y = torch.stack([
                torch.from_numpy(hr_data[name]).float()
                for name in ['rho', 'ux', 'uy', 'uz']
            ], dim=2)
            
            return X, Y
        
        return X
    
    except Exception as e:
        raise type(e)(
            f"Error processing sample {idx} in mode {mode}: {str(e)}"
        ) from e

def load_data(mode='train'):
    """
    Load complete dataset.
    
    Args:
        mode (str): Dataset mode ('train', 'val', or 'test')
        
    Returns:
        tuple: (X, Y) for train/val, or X for test
    """
    try:
        print(f"Loading {mode} dataset...")
        df = load_csv_data(mode)
        print(f"Found {len(df)} samples in {mode} CSV.")
        
        data = []
        for i in range(len(df)):
            try:
                sample_data = get_xy(i, df, mode)
                data.append(sample_data)
                if (i + 1) % 10 == 0:
                    print(f"Processed {i + 1}/{len(df)} samples...")
            except Exception as e:
                print(f"Warning: Error processing sample {i}: {str(e)}")
                continue
        
        if not data:
            raise ValueError(f"No valid samples found in {mode} dataset")
        
        if mode != 'test':
            X, Y = zip(*data)
            return torch.stack(X), torch.stack(Y)
        return torch.stack(data)
    
    except Exception as e:
        raise type(e)(
            f"Error loading {mode} dataset: {str(e)}"
        ) from e



In [16]:
class FlowFieldDataset(Dataset):
    def __init__(self, csv_file, mode='train'):
        self.csv_file = csv_file
        self.mode = mode

    def __len__(self):
        return len(self.csv_file)

    def __getitem__(self, idx):
        if self.mode != 'test':
            X, Y = get_xy(idx, self.csv_file, self.mode)
            return X, Y
        else:
            return get_xy(idx, self.csv_file, self.mode)

In [17]:
def combined_loss(pred, target):
    # MSE Loss
    mse_loss = F.mse_loss(pred, target)
    
    # Structural Similarity Loss (simplified version)
    def ssim(img1, img2, window_size=11):
        C1 = 0.01**2
        C2 = 0.03**2
        
        mu1 = F.avg_pool3d(img1, window_size, stride=1, padding=window_size//2)
        mu2 = F.avg_pool3d(img2, window_size, stride=1, padding=window_size//2)
        
        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2
        
        sigma1_sq = F.avg_pool3d(img1 * img1, window_size, stride=1, padding=window_size//2) - mu1_sq
        sigma2_sq = F.avg_pool3d(img2 * img2, window_size, stride=1, padding=window_size//2) - mu2_sq
        sigma12 = F.avg_pool3d(img1 * img2, window_size, stride=1, padding=window_size//2) - mu1_mu2
        
        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
                   ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
        
        return 1 - ssim_map.mean()
    
    ssim_loss = ssim(pred, target)
    
    return mse_loss + 0.1 * ssim_loss

In [10]:
def train_epoch(model, train_loader, optimizer, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')
    
    for batch_idx, (X, Y) in enumerate(pbar):
        X, Y = X.to(DEVICE), Y.to(DEVICE)
        
        optimizer.zero_grad()
        pred = model(X)
        loss = combined_loss(pred, Y)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
        
        wandb.log({
            "train_batch_loss": loss.item(),
            "batch": batch_idx + epoch * len(train_loader)
        })
    
    return total_loss / len(train_loader)

In [12]:
def main():
    # Setup
    setup_wandb()
    train_data, val_data, test_data = load_and_split_data()
    train_loader, val_loader, test_loader = create_dataloaders(train_data, val_data, test_data)
    
    # Initialize model with 2D architecture
    model = FNO2d(
        modes1=8, modes2=8,
        width=64,
        in_channels=4,
        out_channels=4
    ).to(DEVICE)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Rest of the training loop remains the same
    best_val_loss = float('inf')
    for epoch in range(NUM_EPOCHS):
        train_loss = train_epoch(model, train_loader, optimizer, epoch)
        val_loss = validate(model, val_loader)
        
        wandb.log({
            "train_epoch_loss": train_loss,
            "val_loss": val_loss,
            "epoch": epoch,
            "learning_rate": optimizer.param_groups[0]['lr']
        })
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        print(f"Train Loss: {train_loss:.6f}")
        print(f"Val Loss: {val_loss:.6f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            #save_checkpoint(model, optimizer, epoch, val_loss, 'best_model.pth')
        
        scheduler.step(val_loss)
    
    # Final evaluation
    print("Evaluating on test set...")
    predictions, metrics = evaluate(model, test_loader)
    
    wandb.log({
        "test_mse": np.mean(metrics['mse']) if metrics['mse'] else None,
        "test_psnr": np.mean(metrics['psnr']) if metrics['psnr'] else None,
        "test_ssim": np.mean(metrics['ssim']) if metrics['ssim'] else None
    })
    
    torch.save(predictions, 'test_predictions.pt')
    wandb.finish()



In [3]:
import os
import numpy as np
import pandas as pd
import torch
from common.config import INPUT_PATH, LR_SHAPE, HR_SHAPE

class FileNotFoundOrEmptyError(Exception):
    """Custom exception for file not found or empty file cases."""
    pass

def verify_file_exists(filepath):
    """
    Verify if a file exists and is not empty.
    
    Args:
        filepath (str): Path to the file to verify
        
    Returns:
        bool: True if file exists and is not empty
        
    Raises:
        FileNotFoundOrEmptyError: If file doesn't exist or is empty
    """
    if not os.path.exists(filepath):
        raise FileNotFoundOrEmptyError(f"File not found: {filepath}")
    
    if os.path.getsize(filepath) == 0:
        raise FileNotFoundOrEmptyError(f"File is empty: {filepath}")
    
    return True

def load_binary_file(filepath, shape):
    """
    Load a binary file with error handling.
    
    Args:
        filepath (str): Path to the binary file
        shape (tuple): Expected shape of the data
        
    Returns:
        np.ndarray: Loaded and reshaped data
        
    Raises:
        ValueError: If data cannot be reshaped to expected shape
        FileNotFoundOrEmptyError: If file doesn't exist or is empty
    """
    try:
        verify_file_exists(filepath)
        data = np.fromfile(filepath, dtype="<f4")
        
        expected_size = np.prod(shape)
        if data.size != expected_size:
            raise ValueError(
                f"Data size mismatch. Expected {expected_size} elements "
                f"for shape {shape}, but got {data.size} elements"
            )
        
        return data.reshape(shape)
    
    except Exception as e:
        raise type(e)(
            f"Error loading file {filepath}: {str(e)}"
        ) from e

def load_csv_data(mode='train'):
    """
    Load CSV metadata file.
    
    Args:
        mode (str): Dataset mode ('train', 'val', or 'test')
        
    Returns:
        pd.DataFrame: Loaded CSV data
    """
    csv_path = f'{INPUT_PATH}{mode}.csv'
    try:
        verify_file_exists(csv_path)
        return pd.read_csv(csv_path)
    except Exception as e:
        raise type(e)(
            f"Error loading CSV file {csv_path}: {str(e)}"
        ) from e

def get_xy(idx, csv_file, mode='train'):
    """
    Load LR and HR data for a single sample.
    
    Args:
        idx (int): Sample index
        csv_file (pd.DataFrame): CSV metadata
        mode (str): Dataset mode ('train', 'val', or 'test')
        
    Returns:
        tuple: (LR data, HR data) for train/val, or LR data for test
    """
    try:
        # Validate paths
        LR_path = f"{INPUT_PATH}flowfields/LR/{mode}"
        HR_path = f"{INPUT_PATH}flowfields/HR/{mode}" if mode != 'test' else None
        
        if not os.path.exists(LR_path):
            raise FileNotFoundOrEmptyError(f"LR directory not found: {LR_path}")
        if HR_path and not os.path.exists(HR_path):
            raise FileNotFoundOrEmptyError(f"HR directory not found: {HR_path}")
        
        # Load LR data
        lr_files = {
            'rho': f"{LR_path}/{csv_file['rho_filename'][idx]}",
            'ux': f"{LR_path}/{csv_file['ux_filename'][idx]}",
            'uy': f"{LR_path}/{csv_file['uy_filename'][idx]}",
            'uz': f"{LR_path}/{csv_file['uz_filename'][idx]}"
        }
        
        lr_data = {
            name: load_binary_file(filepath, LR_SHAPE)
            for name, filepath in lr_files.items()
        }
        
        X = torch.stack([
            torch.from_numpy(lr_data[name]).float()
            for name in ['rho', 'ux', 'uy', 'uz']
        ], dim=2)
        
        if mode != 'test':
            # Load HR data
            hr_files = {
                'rho': f"{HR_path}/{csv_file['rho_filename'][idx]}",
                'ux': f"{HR_path}/{csv_file['ux_filename'][idx]}",
                'uy': f"{HR_path}/{csv_file['uy_filename'][idx]}",
                'uz': f"{HR_path}/{csv_file['uz_filename'][idx]}"
            }
            
            hr_data = {
                name: load_binary_file(filepath, HR_SHAPE)
                for name, filepath in hr_files.items()
            }
            
            Y = torch.stack([
                torch.from_numpy(hr_data[name]).float()
                for name in ['rho', 'ux', 'uy', 'uz']
            ], dim=2)
            
            return X, Y
        
        return X
    
    except Exception as e:
        raise type(e)(
            f"Error processing sample {idx} in mode {mode}: {str(e)}"
        ) from e

def load_data(mode='train'):
    """
    Load complete dataset.
    
    Args:
        mode (str): Dataset mode ('train', 'val', or 'test')
        
    Returns:
        tuple: (X, Y) for train/val, or X for test
    """
    try:
        print(f"Loading {mode} dataset...")
        df = load_csv_data(mode)
        print(f"Found {len(df)} samples in {mode} CSV.")
        
        data = []
        for i in range(len(df)):
            try:
                sample_data = get_xy(i, df, mode)
                data.append(sample_data)
                if (i + 1) % 10 == 0:
                    print(f"Processed {i + 1}/{len(df)} samples...")
            except Exception as e:
                print(f"Warning: Error processing sample {i}: {str(e)}")
                continue
        
        if not data:
            raise ValueError(f"No valid samples found in {mode} dataset")
        
        if mode != 'test':
            X, Y = zip(*data)
            return torch.stack(X), torch.stack(Y)
        return torch.stack(data)
    
    except Exception as e:
        raise type(e)(
            f"Error loading {mode} dataset: {str(e)}"
        ) from e

class FlowFieldDataset(torch.utils.data.Dataset):
    """PyTorch Dataset for flow field data."""
    
    def __init__(self, mode='train'):
        self.mode = mode
        self.df = load_csv_data(mode)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        return get_xy(idx, self.df, self.mode)
    

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import os
from tqdm import tqdm
import wandb
import math
from common.config import INPUT_PATH, LR_SHAPE, HR_SHAPE

# Constants
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LR_SHAPE = (16, 16)  # Low resolution shape
HR_SHAPE = (128, 128)  # High resolution shape
BATCH_SIZE = 8
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3
INPUT_PATH = "/home/diya/Projects/super_resolution/dataset/"  # Update this path

class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1
        self.modes2 = modes2
        
        self.scale = (1 / (in_channels * out_channels))
        self.weights = nn.Parameter(self.scale * torch.randn(in_channels, out_channels, modes1, modes2, 2))
        
    def forward(self, x):
        batchsize = x.shape[0]
        size1, size2 = x.shape[-2], x.shape[-1]
        
        # Compute Fourier coefficients
        x_ft = torch.fft.rfft2(x, norm='ortho')
        
        # Initialize output array
        out_ft = torch.zeros(batchsize, self.out_channels, size1, size2//2 + 1, 
                           device=x.device, dtype=torch.cfloat)
        
        # Calculate actual modes based on input size
        actual_modes1 = min(self.modes1, size1)
        actual_modes2 = min(self.modes2, size2//2 + 1)
        
        # Only multiply the lower Fourier modes
        out_ft[:, :, :actual_modes1, :actual_modes2] = (
            torch.complex(
                torch.einsum("bixy,ioxy->boxy",
                           x_ft[:, :, :actual_modes1, :actual_modes2].real,
                           self.weights[:, :, :actual_modes1, :actual_modes2, 0]) -
                torch.einsum("bixy,ioxy->boxy",
                           x_ft[:, :, :actual_modes1, :actual_modes2].imag,
                           self.weights[:, :, :actual_modes1, :actual_modes2, 1]),
                
                torch.einsum("bixy,ioxy->boxy",
                           x_ft[:, :, :actual_modes1, :actual_modes2].real,
                           self.weights[:, :, :actual_modes1, :actual_modes2, 1]) +
                torch.einsum("bixy,ioxy->boxy",
                           x_ft[:, :, :actual_modes1, :actual_modes2].imag,
                           self.weights[:, :, :actual_modes1, :actual_modes2, 0])
            )
        )
        
        # Return to physical space
        x = torch.fft.irfft2(out_ft, s=(size1, size2), norm='ortho')
        
        return x


class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.9):
        super().__init__()
        self.alpha = alpha
        self.ssim_window_size = 11
        
    def ssim(self, img1, img2):
        """Calculate SSIM between two images"""
        C1 = 0.01 ** 2
        C2 = 0.03 ** 2
        
        # Create a 1D Gaussian kernel
        kernel_size = self.ssim_window_size
        sigma = 1.5
        gauss = torch.Tensor([math.exp(-(x - kernel_size//2)**2/float(2*sigma**2)) 
                            for x in range(kernel_size)])
        gauss = gauss/gauss.sum()
        
        # Create 2D kernel by outer product
        kernel = gauss.unsqueeze(0) * gauss.unsqueeze(1)
        kernel = kernel.unsqueeze(0).unsqueeze(0).to(img1.device)
        
        # Compute means
        mu1 = F.conv2d(img1, kernel, padding=kernel_size//2, groups=1)
        mu2 = F.conv2d(img2, kernel, padding=kernel_size//2, groups=1)
        mu1_sq = mu1 ** 2
        mu2_sq = mu2 ** 2
        mu1_mu2 = mu1 * mu2
        
        # Compute variances and covariance
        sigma1_sq = F.conv2d(img1 * img1, kernel, padding=kernel_size//2, groups=1) - mu1_sq
        sigma2_sq = F.conv2d(img2 * img2, kernel, padding=kernel_size//2, groups=1) - mu2_sq
        sigma12 = F.conv2d(img1 * img2, kernel, padding=kernel_size//2, groups=1) - mu1_mu2
        
        # Compute SSIM
        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
                   ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
        return ssim_map.mean()
    
    def forward(self, pred, target):
        """
        Combined MSE and SSIM loss
        pred, target: (batch_size, channels, height, width)
        """
        # Ensure same shape
        assert pred.shape == target.shape, f"Shape mismatch: {pred.shape} vs {target.shape}"
        
        # MSE Loss
        mse = F.mse_loss(pred, target)
        
        # SSIM Loss (1 - SSIM to minimize)
        ssim_value = 0
        for i in range(pred.shape[1]):  # Compute SSIM for each channel
            ssim_value += self.ssim(pred[:,i:i+1], target[:,i:i+1])
        ssim_value /= pred.shape[1]
        ssim_loss = 1 - ssim_value
        
        # Combined loss
        return self.alpha * mse + (1 - self.alpha) * ssim_loss


class FNO2d(nn.Module):
    def __init__(self, modes1, modes2, width, in_channels=16, out_channels=128, input_height=16):
        super(FNO2d, self).__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width
        
        # Input lifting layer
        self.fc0 = nn.Linear(in_channels, self.width)
        
        # Fourier layers
        self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        
        # Spatial convolution layers
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)
        self.w3 = nn.Conv2d(self.width, self.width, 1)
        
        # Upsampling layers for height only
        self.upsample_layers = nn.Sequential(
            nn.Upsample(size=(input_height*2, 4), mode='bilinear', align_corners=True),
            nn.Conv2d(self.width, self.width, 1),
            nn.GELU(),
            nn.Upsample(size=(input_height*4, 4), mode='bilinear', align_corners=True),
            nn.Conv2d(self.width, self.width, 1),
            nn.GELU(),
            nn.Upsample(size=(input_height*8, 4), mode='bilinear', align_corners=True),
            nn.Conv2d(self.width, self.width, 1),
            nn.GELU(),
        )
        
        # Output layers
        self.output_layer = nn.Sequential(
            nn.Conv2d(self.width, 256, 1),
            nn.GELU(),
            nn.Conv2d(256, out_channels, 1)
        )
        
    def forward(self, x):
        # x shape: [batch, in_channels, height, width]
        batch_size = x.shape[0]
        
        # Lift to higher dimension
        x = x.permute(0, 2, 3, 1)
        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        
        # Fourier layers with residual connections
        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2
        
        # Upsample to target resolution
        x = self.upsample_layers(x)
        
        # Project to output space
        x = self.output_layer(x)  # [batch, out_channels, height*8, width]
        
        return x


def train_model():
    # Initialize wandb
    wandb.init(
        project="flow-super-resolution-fno",
        name='fno2d' + ' epochs = ' + str(NUM_EPOCHS) + " batch size = " + str(BATCH_SIZE) + " lr = " + str(LEARNING_RATE) + ' alpha = 0.9',
        config={
            "learning_rate": LEARNING_RATE,
            "batch_size": BATCH_SIZE,
            "epochs": NUM_EPOCHS,
            "architecture": "FNO2D"
        }
    )
    
    # Initialize datasets and dataloaders
    train_dataset = FlowFieldDataset(mode='train')
    val_dataset = FlowFieldDataset(mode='val')
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    # Get a sample batch to determine dimensions
    sample_batch = next(iter(train_loader))
    input_channels = sample_batch[0].shape[1]  # Number of input channels
    output_channels = sample_batch[1].shape[1]  # Number of output channels
    input_height = sample_batch[0].shape[2]
    target_height = sample_batch[1].shape[2]
    upscale_factor = target_height // input_height
    
    print(f"Input shape: {sample_batch[0].shape}")
    print(f"Target shape: {sample_batch[1].shape}")
    print(f"Upscale factor: {upscale_factor}")
    
    # Initialize model with correct dimensions
    model = FNO2d(
        modes1=8, 
        modes2=8, 
        width=64,
        in_channels=input_channels,
        out_channels=output_channels,
    ).to(DEVICE)
    
    # Initialize optimizer, scheduler, and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    criterion = CombinedLoss(alpha=0.5).to(DEVICE)
    
    best_val_loss = float('inf')
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_train_loss = 0
        
        for batch_idx, (X, Y) in enumerate(train_loader):
            X, Y = X.to(DEVICE), Y.to(DEVICE)
            
            optimizer.zero_grad()
            pred = model(X)
            loss = criterion(pred, Y)
            
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            
            wandb.log({
                "train_batch_loss": loss.item(),
                "batch": batch_idx + epoch * len(train_loader)
            })
        
        avg_train_loss = total_train_loss / len(train_loader)
        
        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for X, Y in val_loader:
                X, Y = X.to(DEVICE), Y.to(DEVICE)
                pred = model(X)
                loss = criterion(pred, Y)
                total_val_loss += loss.item()
        
        avg_val_loss = total_val_loss / len(val_loader)
        
        wandb.log({
            "train_epoch_loss": avg_train_loss,
            "val_loss": avg_val_loss,
            "epoch": epoch,
            "learning_rate": optimizer.param_groups[0]['lr']
        })
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        print(f"Train Loss: {avg_train_loss:.6f}")
        print(f"Val Loss: {avg_val_loss:.6f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
            }, 'best_model.pth')
        
        scheduler.step(avg_val_loss)
    
    wandb.finish()
    
if __name__ == '__main__':
    train_model()
    

Input shape: torch.Size([8, 16, 16, 4])
Target shape: torch.Size([8, 128, 128, 4])
Upscale factor: 8




Epoch 1/100
Train Loss: 195.860714
Val Loss: 83.779851
Epoch 2/100
Train Loss: 78.941830
Val Loss: 81.501471
Epoch 3/100
Train Loss: 78.143372
Val Loss: 80.184355
Epoch 4/100
Train Loss: 78.223399
Val Loss: 80.799161
Epoch 5/100
Train Loss: 76.796222
Val Loss: 79.517639
Epoch 6/100
Train Loss: 56.567350
Val Loss: 49.831389
Epoch 7/100
Train Loss: 43.183241
Val Loss: 41.777826
Epoch 8/100
Train Loss: 39.114831
Val Loss: 39.539489
Epoch 9/100
Train Loss: 37.268109
Val Loss: 39.669312
Epoch 10/100
Train Loss: 37.135518
Val Loss: 39.589305
Epoch 11/100
Train Loss: 36.522755
Val Loss: 39.921098
Epoch 12/100
Train Loss: 37.095409
Val Loss: 38.957983
Epoch 13/100
Train Loss: 36.259190
Val Loss: 39.151574
Epoch 14/100
Train Loss: 36.028069
Val Loss: 40.025992
Epoch 15/100
Train Loss: 38.318005
Val Loss: 45.530299
Epoch 16/100
Train Loss: 36.208187
Val Loss: 38.879397
Epoch 17/100
Train Loss: 35.098593
Val Loss: 39.647652
Epoch 18/100
Train Loss: 37.002303
Val Loss: 45.896144
Epoch 19/100
Train

VBox(children=(Label(value='0.013 MB of 0.013 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch,▁▁▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇████
epoch,▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇█████
learning_rate,█████████████▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁▁
train_batch_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_epoch_loss,█▄▄▄▄▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,██▄▃▃▃▄▂▂▂▂▂▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
batch,17299.0
epoch,99.0
learning_rate,3e-05
train_batch_loss,12.95627
train_epoch_loss,9.26121
val_loss,22.04964


In [None]:
if __name__ == "__main__":
    main()