In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
from sklearn.model_selection import train_test_split

# Paths to directories
sar_dir = "/kaggle/input/sentinel12-image-pairs-segregated-by-terrain/v_2/urban/s1"
optical_dir = "/kaggle/input/sentinel12-image-pairs-segregated-by-terrain/v_2/urban/s2"

# Get a list of image files (assuming both SAR and optical images have matching filenames)
sar_images = sorted(os.listdir(sar_dir))
optical_images = sorted(os.listdir(optical_dir))

# Split data into training, validation, and test sets (80% train, 10% validation, 10% test)
train_sar, temp_sar, train_optical, temp_optical = train_test_split(sar_images, optical_images, test_size=0.2, random_state=42)
val_sar, test_sar, val_optical, test_optical = train_test_split(temp_sar, temp_optical, test_size=0.5, random_state=42)

# Create lists of full file paths for training, validation, and test sets
train_set = [(os.path.join(sar_dir, img), os.path.join(optical_dir, img)) for img in train_sar]
val_set = [(os.path.join(sar_dir, img), os.path.join(optical_dir, img)) for img in val_sar]
test_set = [(os.path.join(sar_dir, img), os.path.join(optical_dir, img)) for img in test_sar]

# Print the number of images in each set
print(f"Training set: {len(train_set)} images")
print(f"Validation set: {len(val_set)} images")
print(f"Test set: {len(test_set)} images")


In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class SAROpticalDataset(Dataset):
    def __init__(self, sar_images, optical_images, sar_dir, optical_dir, transform=None):
        self.sar_images = sar_images  # list of SAR image filenames
        self.optical_images = optical_images  # list of Optical image filenames
        self.sar_dir = sar_dir  # directory of SAR images
        self.optical_dir = optical_dir  # directory of Optical images
        self.transform = transform

    def __len__(self):
        return min(len(self.sar_images), len(self.optical_images))

    def __getitem__(self, idx):
        # Get image filenames
        sar_image_path = os.path.join(self.sar_dir, self.sar_images[idx])
        optical_image_path = os.path.join(self.optical_dir, self.optical_images[idx])
        
        # Load images
        sar_image = Image.open(sar_image_path).convert("RGB")
        optical_image = Image.open(optical_image_path).convert("RGB")

        if self.transform:
            sar_image = self.transform(sar_image)
            optical_image = self.transform(optical_image)

        return sar_image, optical_image

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])


In [None]:
# Create instances of the dataset
train_dataset = SAROpticalDataset(train_sar, train_optical, sar_dir, optical_dir, transform=transform)
val_dataset = SAROpticalDataset(val_sar, val_optical, sar_dir, optical_dir, transform=transform)
test_dataset = SAROpticalDataset(test_sar, test_optical, sar_dir, optical_dir, transform=transform)

# Create DataLoader instances
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
import torch
import torch.nn as nn

class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        
        # Encoder
        self.encoder1 = self.contracting_block(3, 64)    # Input: 256x256, Output: 128x128
        self.encoder2 = self.contracting_block(64, 128)   # Output: 64x64
        self.encoder3 = self.contracting_block(128, 256)  # Output: 32x32
        self.encoder4 = self.contracting_block(256, 512)  # Output: 16x16
        self.bottleneck = self.contracting_block(512, 1024)  # Output: 8x8

        # Decoder
        self.decoder4 = self.expansive_block(1024, 512) 
        self.decoder3 = self.expansive_block(1024, 256)   
        self.decoder2 = self.expansive_block(512, 128)   
        self.decoder1 = self.expansive_block(256, 64)     
        
        self.final_conv = nn.Conv2d(128, 3, kernel_size=1)  # Adjusted to take 128 input channels

    def contracting_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),  # Downsampling
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def expansive_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),  # Upsampling
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)  # 128x128
        enc2 = self.encoder2(enc1)  # 64x64
        enc3 = self.encoder3(enc2)  # 32x32
        enc4 = self.encoder4(enc3)  # 16x16
        bottleneck = self.bottleneck(enc4)  # 8x8

        # Decoder
        dec4 = self.decoder4(bottleneck)  # 16x16
        dec4 = torch.cat((dec4, enc4), dim=1)  # Skip connection
        dec3 = self.decoder3(dec4)  # 32x32
        dec3 = torch.cat((dec3, enc3), dim=1)  # Skip connection
        dec2 = self.decoder2(dec3)  # 64x64
        dec2 = torch.cat((dec2, enc2), dim=1)  # Skip connection
        dec1 = self.decoder1(dec2)  # 128x128
        dec1 = torch.cat((dec1, enc1), dim=1)  # Skip connection

        # Final output to get 256x256
        output = nn.functional.interpolate(dec1, size=(256, 256), mode='bilinear', align_corners=False)  # Ensure output is 256x256
        output = self.final_conv(output)  # Output: 256x256
        return output

# Example usage
model = UNetGenerator()
input_tensor = torch.randn(1, 3, 256, 256)  # Batch size of 1, 3 channels, 256x256 image
output_tensor = model(input_tensor)
print("Output tensor shape:", output_tensor.shape)  # Should print: Output tensor shape: torch.Size([1, 3, 256, 256])


In [None]:
# import torch
# import torchvision.transforms as transforms
# import matplotlib.pyplot as plt
# from PIL import Image
# import os

# # Load the trained generator model
# checkpoint = torch.load("/kaggle/input/checkpoint_epoch_80_.pth/pytorch/default/1/checkpoint_epoch_80 (1).pth")  # Load the entire checkpoint
# generator = UNetGenerator().to(device)  # Initialize the model

# # Load only the generator's state_dict
# generator.load_state_dict(checkpoint['generator_state_dict'])  # Load the state_dict for the generator
# generator.eval()

# # Function to perform inference and display SAR, original optical, and generated optical images
# def infer_sar_to_optical(test_loader, num_images=5):
#     generator.eval()  # Set the generator to evaluation mode
#     with torch.no_grad():  # Disable gradient computation for inference
#         for i, (sar_images, optical_images) in enumerate(test_loader):  # Iterate through the test DataLoader
#             sar_images = sar_images.to(device)  # Move SAR images to the appropriate device
            
#             # Generate optical images from SAR images
#             generated_images = generator(sar_images)  # Forward pass through the generator
            
#             # Post-process the output
#             generated_images = generated_images.squeeze().cpu()  # Remove batch dimension and move to CPU
#             generated_images = generated_images.clamp(0, 1)  # Clamp values to the valid range [0, 1]

#             # Display SAR, original optical, and generated optical images side by side
#             for sar_img, orig_optical_img, gen_img in zip(sar_images.cpu(), optical_images.cpu(), generated_images):
#                 sar_img_np = sar_img.squeeze().numpy().transpose(1, 2, 0)  # Convert SAR image to HWC format
#                 orig_optical_img_np = orig_optical_img.squeeze().numpy().transpose(1, 2, 0)  # Convert original optical image to HWC format
#                 gen_img_np = gen_img.numpy().transpose(1, 2, 0)  # Convert generated image to HWC format

#                 # Create a figure to display images
#                 plt.figure(figsize=(15, 5))  # Set the figure size

#                 plt.subplot(1, 3, 1)  # First subplot for SAR image
#                 plt.imshow(sar_img_np)  # Display SAR image
#                 plt.title('Original SAR Image')
#                 plt.axis('off')  # Hide axis

#                 plt.subplot(1, 3, 2)  # Second subplot for original optical image
#                 plt.imshow(orig_optical_img_np)  # Display original optical image
#                 plt.title('Original Optical Image')
#                 plt.axis('off')  # Hide axis

#                 plt.subplot(1, 3, 3)  # Third subplot for generated optical image
#                 plt.imshow(gen_img_np)  # Display generated optical image
#                 plt.title('Generated Optical Image')
#                 plt.axis('off')  # Hide axis

#                 plt.show()  # Show the figure

#             if i + 1 >= num_images:  # Stop after processing the desired number of images
#                 break

# # Example usage (assuming test_loader is defined)
# infer_sar_to_optical(test_loader, num_images=5)


In [None]:
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import os

# Load the trained generator model
checkpoint = torch.load("/kaggle/input/check_point110.pth/pytorch/default/1/checkpoint_epoch_110.pth")  # Load the entire checkpoint
generator = UNetGenerator().to(device)  # Initialize the model

# Load only the generator's state_dict
generator.load_state_dict(checkpoint['generator_state_dict'])  # Load the state_dict for the generator
generator.eval()

# Function to perform inference and display SAR, original optical, and generated optical images
def infer_sar_to_optical(test_loader, num_images=5):
    generator.eval()  # Set the generator to evaluation mode
    with torch.no_grad():  # Disable gradient computation for inference
        for i, (sar_images, optical_images) in enumerate(test_loader):  # Iterate through the test DataLoader
            sar_images = sar_images.to(device)  # Move SAR images to the appropriate device
            
            # Generate optical images from SAR images
            generated_images = generator(sar_images)  # Forward pass through the generator
            
            # Post-process the output
            generated_images = generated_images.squeeze().cpu()  # Remove batch dimension and move to CPU
            generated_images = generated_images.clamp(0, 1)  # Clamp values to the valid range [0, 1]

            # Display SAR, original optical, and generated optical images side by side
            for sar_img, orig_optical_img, gen_img in zip(sar_images.cpu(), optical_images.cpu(), generated_images):
                sar_img_np = sar_img.squeeze().numpy().transpose(1, 2, 0)  # Convert SAR image to HWC format
                orig_optical_img_np = orig_optical_img.squeeze().numpy().transpose(1, 2, 0)  # Convert original optical image to HWC format
                gen_img_np = gen_img.numpy().transpose(1, 2, 0)  # Convert generated image to HWC format

                # Create a figure to display images
                plt.figure(figsize=(15, 5))  # Set the figure size

                plt.subplot(1, 3, 1)  # First subplot for SAR image
                plt.imshow(sar_img_np)  # Display SAR image
                plt.title('Original SAR Image')
                plt.axis('off')  # Hide axis

                plt.subplot(1, 3, 2)  # Second subplot for original optical image
                plt.imshow(orig_optical_img_np)  # Display original optical image
                plt.title('Original Optical Image')
                plt.axis('off')  # Hide axis

                plt.subplot(1, 3, 3)  # Third subplot for generated optical image
                plt.imshow(gen_img_np)  # Display generated optical image
                plt.title('Generated Optical Image')
                plt.axis('off')  # Hide axis

                plt.show()  # Show the figure

            if i + 1 >= num_images:  # Stop after processing the desired number of images
                break

# Example usage (assuming test_loader is defined)
infer_sar_to_optical(test_loader, num_images=5)
