In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt  # Import for visualization

In [2]:
# Dataset class for BSDS500
class BSDS500Dataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.split = split
        
        # Load image file paths, filtering for valid image extensions
        self.image_files = [
            os.path.join(root_dir, 'images', split, f)
            for f in os.listdir(os.path.join(root_dir, 'images', split))
            if f.endswith(('.png', '.jpg', '.jpeg'))  # Ensure valid image extensions
        ]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        
        try:
            image = Image.open(img_name).convert('RGB')  # Convert to RGB
        except (UnidentifiedImageError, OSError) as e:
            print(f"Error loading image {img_name}: {e}")
            return None  # Handle the error gracefully

        if self.transform:
            image = self.transform(image)

        return image

In [3]:
# Define transformations for training and validation datasets
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((33, 33))  # Resize for SRCNN input size
])

# Initialize datasets and dataloaders
root_dir = "C:/Users/IPS/Downloads/Super-Resolution-image-using-CNN-in-PyTorch/BSR/BSDS500/data"
train_dataset = BSDS500Dataset(root_dir=root_dir, split='train', transform=transform)
val_dataset = BSDS500Dataset(root_dir=root_dir, split='val', transform=transform)
test_dataset = BSDS500Dataset(root_dir=root_dir, split='test', transform=transform)

# Check if datasets are not empty
if len(train_dataset) == 0 or len(val_dataset) == 0 or len(test_dataset) == 0:
    raise ValueError("Training, validation or test dataset is empty. Please check your data paths.")

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [4]:
# SRCNN Model Definition
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SRCNN().to(device)

In [5]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
# Function to calculate metrics (as defined previously)
def calculate_metrics(output, target):
    output_np = output.permute(0, 2, 3, 1).cpu().numpy()
    target_np = target.permute(0, 2, 3, 1).cpu().numpy()
    
    ssim_val = np.mean([ssim(o, t, data_range=t.max() - t.min(), channel_axis=-1) for o, t in zip(output_np, target_np)])
    psnr_val = np.mean([psnr(t, o, data_range=t.max() - t.min()) for o, t in zip(output_np, target_np)])
    
    mse_val = np.mean((output_np - target_np) ** 2)
    return ssim_val, psnr_val, mse_val

# Training loop (as defined previously)

In [7]:
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for epoch in range(num_epochs):
    model.train()
    for i, data in enumerate(train_loader):
        inputs = data.to(device)
        inputs_upsampled = F.interpolate(inputs, scale_factor=2, mode='bicubic', align_corners=False)
        outputs = model(inputs_upsampled)
        loss = criterion(outputs, inputs_upsampled)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
    
    # Validation
    model.eval()
    ssim_total, psnr_total, mse_total = 0, 0, 0
    with torch.no_grad():
        for data in val_loader:
            inputs = data.to(device)
            inputs_upsampled = F.interpolate(inputs, scale_factor=2, mode='bicubic', align_corners=False)
            outputs = model(inputs_upsampled)
            ssim_val, psnr_val, mse_val = calculate_metrics(outputs, inputs_upsampled)
            ssim_total += ssim_val
            psnr_total += psnr_val
            mse_total += mse_val

    print(f'Validation - Epoch [{epoch+1}/{num_epochs}], SSIM: {ssim_total/len(val_loader):.4f}, PSNR: {psnr_total/len(val_loader):.4f}, MSE: {mse_total/len(val_loader):.4f}')

Epoch [1/10], Step [10/64], Loss: 0.0050
Epoch [1/10], Step [20/64], Loss: 0.0029
Epoch [1/10], Step [30/64], Loss: 0.0024
Epoch [1/10], Step [40/64], Loss: 0.0024
Epoch [1/10], Step [50/64], Loss: 0.0015
Epoch [1/10], Step [60/64], Loss: 0.0013
Validation - Epoch [1/10], SSIM: 0.7352, PSNR: 18.1953, MSE: 0.0126
Epoch [2/10], Step [10/64], Loss: 0.0008
Epoch [2/10], Step [20/64], Loss: 0.0009
Epoch [2/10], Step [30/64], Loss: 0.0009
Epoch [2/10], Step [40/64], Loss: 0.0004
Epoch [2/10], Step [50/64], Loss: 0.0006
Epoch [2/10], Step [60/64], Loss: 0.0005
Validation - Epoch [2/10], SSIM: 0.8087, PSNR: 20.1826, MSE: 0.0084
Epoch [3/10], Step [10/64], Loss: 0.0005
Epoch [3/10], Step [20/64], Loss: 0.0005
Epoch [3/10], Step [30/64], Loss: 0.0003
Epoch [3/10], Step [40/64], Loss: 0.0003
Epoch [3/10], Step [50/64], Loss: 0.0002
Epoch [3/10], Step [60/64], Loss: 0.0003
Validation - Epoch [3/10], SSIM: 0.8666, PSNR: 21.9513, MSE: 0.0057
Epoch [4/10], Step [10/64], Loss: 0.0002
Epoch [4/10], Ste

NameError: name 'UnidentifiedImageError' is not defined

In [None]:
def evaluate(model, dataloader):
    model.eval()
    ssim_total, psnr_total, mse_total = 0, 0, 0
    with torch.no_grad():
        for data in dataloader:
            inputs = data.to(device)
            inputs_upsampled = F.interpolate(inputs, scale_factor=2, mode='bicubic', align_corners=False)
            outputs = model(inputs_upsampled)
            ssim_val, psnr_val, mse_val = calculate_metrics(outputs, inputs_upsampled)
            ssim_total += ssim_val
            psnr_total += psnr_val
            mse_total += mse_val

    print(f'Test - SSIM: {ssim_total/len(dataloader):.4f}, PSNR: {psnr_total/len(dataloader):.4f}, MSE: {mse_total/len(dataloader):.4f}')

evaluate(model, test_loader)

In [None]:
# Visualization function to visualize results on test set images
def visualize_results(model, dataloader):
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            if data is None:  # Skip if data is None due to loading error
                continue
            
            inputs = data.to(device)
            inputs_upsampled = F.interpolate(inputs.clone(), scale_factor=2, mode='bicubic', align_corners=False)
            outputs = model(inputs_upsampled)

            # Display the input (original), upsampled image and output images from SRCNN.
            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            
            axs[0].imshow(inputs[1].cpu().permute(1, 2, 0))
            axs[0].set_title('Original Image')
            axs[0].axis('off')  # Hide axes
            
            axs[1].imshow(inputs_upsampled[1].cpu().permute(1, 2, 0))
            axs[1].set_title('Upsampled Image')
            axs[1].axis('off')  
            
            axs[2].imshow(outputs[1].cpu().permute(1, 2, 0))
            axs[2].set_title('SRCNN Output')
            axs[2].axis('off')  
            
            plt.show()

            # Only visualize one batch for simplicity; break after first batch.
            break

# Visualize results on the test set after evaluation.
visualize_results(model,test_loader)


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128))  # Set to a higher resolution
])

root_dir = "BSR/BSDS500/data"
train_dataset = BSDS500Dataset(root_dir=root_dir, split='train', transform=transform)
val_dataset = BSDS500Dataset(root_dir=root_dir, split='val', transform=transform)
test_dataset = BSDS500Dataset(root_dir=root_dir, split='test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) #  load data in batches 
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
class SRCNN(nn.Module): # neural network module 
    def __init__(self):
        super(SRCNN, self).__init__() # calls the constructor of the parent class (nn.Module) # it ensures that the nn.Module part of the SRCNN object is properly set up.
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

model = SRCNN().to('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
def calculate_metrics(output, target):
    output_np = output.permute(0, 2, 3, 1).cpu().numpy() #  Changes the order of dimensions of the tensor. If the original shape is (batch_size, channels, height, width), it changes it to (batch_size, height, width, channels).
    target_np = target.permute(0, 2, 3, 1).cpu().numpy()
    ssim_val = np.mean([ssim(o, t, data_range=t.max() - t.min(), channel_axis=-1, win_size=5) for o, t in zip(output_np, target_np)])
    psnr_val = np.mean([psnr(t, o, data_range=t.max() - t.min()) for o, t in zip(output_np, target_np)])
    mse_val = np.mean((output_np - target_np) ** 2)
    return ssim_val, psnr_val, mse_val

In [None]:
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def visualize_progress(model, dataloader):
    model.eval()
    with torch.no_grad():  # No gradients: During evaluation or inference, we don't need to update the model, so calculating gradients is unnecessary and takes extra memory and computation.
        for i, data in enumerate(dataloader):
            inputs = data.to(device)
            inputs_low_res = F.interpolate(inputs, scale_factor=0.5, mode='bicubic', align_corners=False)
            outputs = model(inputs_low_res)
            outputs_upsampled = F.interpolate(outputs, size=inputs.shape[2:], mode='bicubic', align_corners=False) # inputs (batch_size, channels, height, width) make the size of the target image as the input

            # Display the input, low-res, and output images
            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            axs[0].imshow(inputs[0].cpu().permute(1, 2, 0))
            axs[0].set_title('Original Image')
            axs[1].imshow(inputs_low_res[0].cpu().permute(1, 2, 0))
            axs[1].set_title('Low-Resolution Image')
            axs[2].imshow(outputs_upsampled[0].cpu().permute(1, 2, 0))
            axs[2].set_title('SRCNN Output')
            plt.show()

            # Only visualize one batch for simplicity
            break

for epoch in range(num_epochs):
    model.train()
    for i, data in enumerate(train_loader):
        inputs = data.to(device)
        inputs_low_res = F.interpolate(inputs, scale_factor=0.5, mode='bicubic', align_corners=False)
        outputs = model(inputs_low_res)
        outputs_upsampled = F.interpolate(outputs, size=inputs.shape[2:], mode='bicubic', align_corners=False)
        loss = criterion(outputs_upsampled, inputs)

        optimizer.zero_grad() # resetting the gradients of all the model's parameters to zero. # Resetting: We need to reset (clear) the gradients to zero before computing the new gradients for the next batch.
        loss.backward() # Calculates the gradients of the loss with respect to the model's parameters # understanding how much each parameter contributed to the error.
        optimizer.step() # Updates the model's parameters using the gradients calculated 

        if (i+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}') # 13 batches of data.
    
    # Validation
    model.eval()
    ssim_total, psnr_total, mse_total = 0, 0, 0
    with torch.no_grad():
        for data in val_loader:
            inputs = data.to(device)
            inputs_low_res = F.interpolate(inputs, scale_factor=0.5, mode='bicubic', align_corners=False)
            outputs = model(inputs_low_res)
            outputs_upsampled = F.interpolate(outputs, size=inputs.shape[2:], mode='bicubic', align_corners=False)
            ssim_val, psnr_val, mse_val = calculate_metrics(outputs_upsampled, inputs)
            ssim_total += ssim_val
            psnr_total += psnr_val
            mse_total += mse_val

    print(f'Validation - Epoch [{epoch+1}/{num_epochs}], SSIM: {ssim_total/len(val_loader):.4f}, PSNR: {psnr_total/len(val_loader):.4f}, MSE: {mse_total/len(val_loader):.4f}')
    
    # Visualize progress every 10 epochs
    if (epoch + 1) % 10 == 0:
        visualize_progress(model, test_loader)


In [None]:
def visualize_one_random_image(model, dataloader):
    model.eval()
    with torch.no_grad():
        data_iter = iter(dataloader)
        data = next(data_iter)
        inputs = data.to(device)
        inputs_low_res = F.interpolate(inputs, scale_factor=0.5, mode='bicubic', align_corners=False)
        outputs = model(inputs_low_res)
        outputs_upsampled = F.interpolate(outputs, size=inputs.shape[2:], mode='bicubic', align_corners=False)

        # Calculate metrics
        ssim_val, psnr_val, mse_val = calculate_metrics(outputs_upsampled, inputs)

        # Display the input, low-res, and output images
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs[0].imshow(inputs[0].cpu().permute(1, 2, 0))
        axs[0].set_title('Original Image')
        axs[1].imshow(inputs_low_res[0].cpu().permute(1, 2, 0))
        axs[1].set_title('Low-Resolution Image')
        axs[2].imshow(outputs_upsampled[0].cpu().permute(1, 2, 0))
        axs[2].set_title(f'SRCNN Output\nPSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}, MSE: {mse_val:.6f}')
        plt.show()

visualize_one_random_image(model, DataLoader(test_dataset, batch_size=1, shuffle=True))

In [14]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

# Define a Residual Swin Transformer Block
class ResidualSwinTransformerBlock(tf.keras.Model):
    def __init__(self, embed_dim, num_heads):
        super(ResidualSwinTransformerBlock, self).__init__()
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.ffn = models.Sequential([
            layers.Dense(embed_dim * 4, activation='relu'),
            layers.Dense(embed_dim)
        ])

    def call(self, x):
        attn_output = self.attention(x, x)
        x = self.norm1(x + attn_output)  # Residual connection
        ffn_output = self.ffn(x)
        return self.norm2(x + ffn_output)  # Residual connection

# Define the SwinIR Model
class SwinIR(tf.keras.Model):
    def __init__(self, upscale_factor):
        super(SwinIR, self).__init__()
        self.conv1 = layers.Conv2D(64, kernel_size=9, padding='same', activation='relu')
        
        # Residual Swin Transformer Blocks
        self.rstb1 = ResidualSwinTransformerBlock(embed_dim=64, num_heads=4)
        self.rstb2 = ResidualSwinTransformerBlock(embed_dim=64, num_heads=4)

        # Final Convolution Layer
        self.conv2 = layers.Conv2D(3 * (upscale_factor ** 2), kernel_size=5, padding='same')
        
        # Pixel Shuffle for upscaling
        self.pixel_shuffle = layers.Lambda(lambda x: tf.nn.depth_to_space(x, upscale_factor))

    def call(self, x):
        x = self.conv1(x)
        
        # Reshape for attention blocks (batch_size, height*width, channels)
        batch_size = tf.shape(x)[0]
        height = tf.shape(x)[1]
        width = tf.shape(x)[2]
        x = tf.reshape(x, (batch_size, height * width, 64))  # Reshape for attention blocks
        
        x = self.rstb1(x)
        x = self.rstb2(x)

        # Reshape back to spatial dimensions
        x = tf.reshape(x, (batch_size, height, width, 64)) 
        
        x = self.conv2(x)
        
        return self.pixel_shuffle(x)

# Example usage of the model
if __name__ == "__main__":
    # Sample input image (batch size of 1 and low resolution)
    input_image = np.random.rand(1, 66, 66, 3).astype(np.float32)  # Low-resolution input

    # Create the model
    upscale_factor = 2
    model = SwinIR(upscale_factor)

    # Forward pass through the model
    output_image = model(tf.convert_to_tensor(input_image))

    print("Output shape:", output_image.shape)  # Should be (1, 132, 132, 3)

    # Compile the model for training (example only)
    model.compile(optimizer='adam', loss='mean_squared_error')

    # Example training data (dummy data for demonstration)
    train_images = np.random.rand(100, 66, 66, 3).astype(np.float32)  # Low-resolution images
    target_images = np.random.rand(100, 132, 132, 3).astype(np.float32)  # High-resolution images

    # Reduce batch size to address potential OOM error
    batch_size = 4  # Adjust this value based on your GPU memory limitations

    # Train the model (example only)
    model.fit(train_images, target_images, epochs=10, batch_size=batch_size)

Output shape: (1, 132, 132, 3)
Epoch 1/10


ResourceExhaustedError: Graph execution error:

Detected at node 'swin_ir_6/residual_swin_transformer_block_13/multi_head_attention_13/softmax_13/Softmax' defined at (most recent call last):
    File "C:\Users\IPS\.conda\envs\tf\lib\runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\IPS\.conda\envs\tf\lib\runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
      app.start()
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\ipykernel\kernelapp.py", line 701, in start
      self.io_loop.start()
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\tornado\platform\asyncio.py", line 195, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\IPS\.conda\envs\tf\lib\asyncio\windows_events.py", line 321, in run_forever
      super().run_forever()
    File "C:\Users\IPS\.conda\envs\tf\lib\asyncio\base_events.py", line 601, in run_forever
      self._run_once()
    File "C:\Users\IPS\.conda\envs\tf\lib\asyncio\base_events.py", line 1905, in _run_once
      handle._run()
    File "C:\Users\IPS\.conda\envs\tf\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\ipykernel\kernelbase.py", line 534, in dispatch_queue
      await self.process_one()
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\ipykernel\kernelbase.py", line 523, in process_one
      await dispatch(*args)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\ipykernel\kernelbase.py", line 429, in dispatch_shell
      await result
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\ipykernel\kernelbase.py", line 767, in execute_request
      reply_content = await reply_content
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\ipykernel\ipkernel.py", line 429, in do_execute
      res = shell.run_cell(
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\IPython\core\interactiveshell.py", line 3024, in run_cell
      result = self._run_cell(
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\IPython\core\interactiveshell.py", line 3079, in _run_cell
      result = runner(coro)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\IPython\core\interactiveshell.py", line 3284, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\IPython\core\interactiveshell.py", line 3466, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\IPython\core\interactiveshell.py", line 3526, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\IPS\AppData\Local\Temp\ipykernel_824\1762359009.py", line 83, in <module>
      model.fit(train_images, target_images, epochs=10, batch_size=batch_size)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\AppData\Local\Temp\ipykernel_824\1762359009.py", line 49, in call
      x = self.rstb2(x)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\AppData\Local\Temp\ipykernel_824\1022104731.py", line 18, in call
      attn_output = self.attention(x, x)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\layers\attention\multi_head_attention.py", line 596, in call
      attention_output, attention_scores = self._compute_attention(
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\layers\attention\multi_head_attention.py", line 527, in _compute_attention
      attention_scores = self._masked_softmax(
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\layers\attention\multi_head_attention.py", line 493, in _masked_softmax
      return self._softmax(attention_scores, attention_mask)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\layers\activation\softmax.py", line 103, in call
      return backend.softmax(inputs, axis=self.axis[0])
    File "C:\Users\IPS\.conda\envs\tf\lib\site-packages\keras\backend.py", line 5413, in softmax
      return tf.nn.softmax(x, axis=axis)
Node: 'swin_ir_6/residual_swin_transformer_block_13/multi_head_attention_13/softmax_13/Softmax'
OOM when allocating tensor with shape[4,4,4356,4356] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node swin_ir_6/residual_swin_transformer_block_13/multi_head_attention_13/softmax_13/Softmax}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_15107]