<a href="https://colab.research.google.com/github/debi201326/SAR-Image-Colorization/blob/main/SAR_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

In [4]:
import zipfile

zip_file_path = 'sar_color.zip'
extract_to_folder = '/content'

os.makedirs(extract_to_folder, exist_ok=True)

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to_folder)

In [3]:
# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [5]:
# Custom Dataset Class for SAR and Optical Images
class SAROpticalDataset(Dataset):
    def __init__(self, sar_dir, optical_dir, transform=None):
        self.sar_dir = sar_dir
        self.optical_dir = optical_dir
        self.sar_images = os.listdir(sar_dir)
        self.optical_images = os.listdir(optical_dir)
        self.transform = transform

    def __len__(self):
        return len(self.sar_images)

    def __getitem__(self, idx):
        # Load SAR image
        sar_img_path = os.path.join(self.sar_dir, self.sar_images[idx])
        sar_image = Image.open(sar_img_path).convert("L")

        # Load corresponding Optical image
        optical_img_path = os.path.join(self.optical_dir, self.optical_images[idx])
        optical_image = Image.open(optical_img_path).convert("RGB")

        # Apply transformations
        if self.transform:
            sar_image = self.transform(sar_image)
            optical_image = self.transform(optical_image)

        return sar_image, optical_image

# Image Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Dataset Directories
sar_dir = "sar_color/s1"
optical_dir = "sar_color/s2"

# Create Dataset and DataLoader
dataset = SAROpticalDataset(sar_dir=sar_dir, optical_dir=optical_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)


In [6]:
# Define the Generator Network
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(Generator, self).__init__()
        # Initial Convolution Block
        self.model = nn.Sequential(
            nn.Conv2d(input_nc, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            # Downsampling
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
            # Residual Blocks
            *[ResidualBlock(256) for _ in range(9)],
            # Upsampling
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            # Output Layer
            nn.Conv2d(64, output_nc, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)


In [7]:
# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(features)
        )

    def forward(self, x):
        return x + self.block(x)

In [8]:
# Define the Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [9]:
# Initialize Generators and Discriminators
G1 = Generator(input_nc=1, output_nc=3).to(device)
G2 = Generator(input_nc=3, output_nc=1).to(device)
D1 = Discriminator(input_nc=3).to(device)
D2 = Discriminator(input_nc=1).to(device)

In [10]:
# Define Loss Functions and Optimizers
criterion_GAN = nn.BCELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

optimizer_G1 = optim.Adam(G1.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_G2 = optim.Adam(G2.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D1 = optim.Adam(D1.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D2 = optim.Adam(D2.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [13]:
# Training Loop
num_epochs = 100
for epoch in range(num_epochs):
    for i, (sar_images, optical_images) in enumerate(dataloader):

        sar_images = sar_images.to(device)
        optical_images = optical_images.to(device)

        optimizer_G1.zero_grad()
        optimizer_G2.zero_grad()
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()


        fake_optical = G1(sar_images)
        reconstructed_sar = G2(fake_optical)

        fake_sar = G2(optical_images)
        reconstructed_optical = G1(fake_sar)

        # Calculate losses
        loss_GAN_G1 = criterion_GAN(D1(fake_optical), torch.ones_like(D1(fake_optical)))
        loss_GAN_G2 = criterion_GAN(D2(fake_sar), torch.ones_like(D2(fake_sar)))
        loss_cycle_SAR = criterion_cycle(reconstructed_sar, sar_images) * 10.0
        loss_cycle_Optical = criterion_cycle(reconstructed_optical, optical_images) * 10.0

        # Total generator loss
        loss_G1 = loss_GAN_G1 + loss_cycle_SAR
        loss_G2 = loss_GAN_G2 + loss_cycle_Optical

        # Backpropagation
        loss_G1.backward()
        loss_G2.backward()
        optimizer_G1.step()
        optimizer_G2.step()

        # Discriminator Losses
        loss_D1_real = criterion_GAN(D1(optical_images), torch.ones_like(D1(optical_images)))
        loss_D1_fake = criterion_GAN(D1(fake_optical.detach()), torch.zeros_like(D1(fake_optical)))
        loss_D2_real = criterion_GAN(D2(sar_images), torch.ones_like(D2(sar_images)))
        loss_D2_fake = criterion_GAN(D2(fake_sar.detach()), torch.zeros_like(D2(fake_sar)))

        loss_D1 = (loss_D1_real + loss_D1_fake) * 0.5
        loss_D2 = (loss_D2_real + loss_D2_fake) * 0.5

        loss_D1.backward()
        loss_D2.backward()
        optimizer_D1.step()
        optimizer_D2.step()


        print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "
              f"Loss_G1: {loss_G1.item():.4f}, Loss_G2: {loss_G2.item():.4f}, "
              f"Loss_D1: {loss_D1.item():.4f}, Loss_D2: {loss_D2.item():.4f}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch [51/100], Step [1/100], Loss_G1: 1.3874, Loss_G2: 1.5293, Loss_D1: 0.6177, Loss_D2: 0.6379
Epoch [51/100], Step [2/100], Loss_G1: 1.2608, Loss_G2: 2.4320, Loss_D1: 0.5716, Loss_D2: 0.5535
Epoch [51/100], Step [3/100], Loss_G1: 1.3691, Loss_G2: 1.4665, Loss_D1: 0.6035, Loss_D2: 0.8325
Epoch [51/100], Step [4/100], Loss_G1: 1.2906, Loss_G2: 1.6934, Loss_D1: 0.5787, Loss_D2: 0.4705
Epoch [51/100], Step [5/100], Loss_G1: 0.9976, Loss_G2: 1.5625, Loss_D1: 0.6795, Loss_D2: 0.6749
Epoch [51/100], Step [6/100], Loss_G1: 1.0387, Loss_G2: 1.1986, Loss_D1: 0.6615, Loss_D2: 0.7642
Epoch [51/100], Step [7/100], Loss_G1: 1.6822, Loss_G2: 1.2954, Loss_D1: 0.5788, Loss_D2: 0.5937
Epoch [51/100], Step [8/100], Loss_G1: 1.0478, Loss_G2: 1.6910, Loss_D1: 0.9674, Loss_D2: 0.5839
Epoch [51/100], Step [9/100], Loss_G1: 1.6253, Loss_G2: 1.2080, Loss_D1: 0.4937, Loss_D2: 0.6112
Epoch [51/100], Step [10/100], Loss_G1: 1.0750, Loss_G2: 1.233

In [14]:
torch.save(G1.state_dict(), "generator_G1.pth")
torch.save(G2.state_dict(), "generator_G2.pth")

In [None]:
# Load the models
G1 = Generator(input_nc=1, output_nc=3).to(device)
G2 = Generator(input_nc=3, output_nc=1).to(device)

G1.load_state_dict(torch.load("generator_G1.pth"))
G2.load_state_dict(torch.load("generator_G2.pth"))


G1.eval()
G2.eval()


In [16]:
from PIL import Image
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load and preprocess the grayscale image
test_image_path = '/content/ROIs1970_fall_s1_50_p305.png'
sar_image = Image.open(test_image_path).convert("L")
sar_image = transform(sar_image).unsqueeze(0).to(device)


In [17]:
with torch.no_grad():
    generated_optical_image = G1(sar_image)

In [18]:
from torchvision.utils import save_image

# Denormalize and save the image
generated_optical_image = generated_optical_image * 0.5 + 0.5
save_image(generated_optical_image, "generated_optical_image.png")
