In [None]:
import os
import torch
import rasterio
from torch.utils.data import Dataset, DataLoader
from skimage.transform import resize
import matplotlib.pyplot as plt
from transformers import Swin2SRImageProcessor, Swin2SRForImageSuperResolution
from torch.optim import Adam
from tqdm import tqdm
from PIL import Image
import numpy as np

hr_size = (120,120) 
lr_size = (60,60)

class SRDataset(Dataset):
    def __init__(self, data_paths, lr_size=(64, 64), hr_size=(128, 128)):
        self.data_paths = data_paths
        self.lr_size = lr_size
        self.hr_size = hr_size

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        image_path = self.data_paths[idx]
        
        img = Image.open(image_path).convert('RGB')  
        
        img = np.array(img) 
        
        lr_img = resize(img, self.lr_size, anti_aliasing=True, preserve_range=True)  
        hr_img = resize(img, self.hr_size, anti_aliasing=True, preserve_range=True) 

        lr_img = lr_img / 255.0
        hr_img = hr_img / 255.0


        lr_tensor = torch.tensor(lr_img, dtype=torch.float32).permute(2, 0, 1)  # [3, 64, 64]
        hr_tensor = torch.tensor(hr_img, dtype=torch.float32).permute(2, 0, 1)  # [3, 128, 128]

        return lr_tensor, hr_tensor, image_path

def filter_120x120_images(datasetPath):
    imagePaths = []
    for root, dirs, files in os.walk(datasetPath):
        for file in files:
            image_path = os.path.join(root, file)
            with rasterio.open(image_path) as src:
                if src.width == 120 and src.height == 120:
                    imagePaths.append(image_path)
                else:
                    print(f"Skipping image {file} with size {src.width}x{src.height}")

    return imagePaths

datasetPath = r"SampleDataset"
imagePaths = filter_120x120_images(datasetPath)

dataset = SRDataset(imagePaths, lr_size, hr_size)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


In [None]:
import matplotlib.pyplot as plt


lr_img, hr_img, pths = next(iter(dataloader))

lr_img = lr_img.permute(0, 2, 3, 1).numpy() 
hr_img = hr_img.permute(0, 2, 3, 1).numpy() 

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(lr_img[0])  
plt.title("Low-res Image")
plt.axis('off') 

plt.subplot(1, 2, 2)
plt.imshow(hr_img[0])  
plt.title("High-res Image")
plt.axis('off') 

plt.show()


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=60, patch_size=6, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, embed_dim, num_patches_per_row, num_patches_per_col]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        return x

class SwinTransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, drop=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=drop)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class Swin2SR(nn.Module):
    def __init__(self, img_size=60, patch_size=6, in_channels=3, embed_dim=768, depth=12, num_heads=12, scale_factor=2):
        super().__init__()
        self.scale_factor = scale_factor
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2, embed_dim))
        self.pos_drop = nn.Dropout(p=0.1)
        
        self.blocks = nn.ModuleList([SwinTransformerBlock(embed_dim, num_heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        
        self.upscaled_patch_size = patch_size * scale_factor
        self.upsample = nn.Sequential(
            nn.Linear(embed_dim, (patch_size * scale_factor) ** 2 * in_channels),
            nn.GELU()
        )
        
        self.refine = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, 1, 1),
            nn.GELU(),
            nn.Conv2d(64, in_channels, 3, 1, 1)
        )

    def forward(self, x):
        B = x.shape[0]

        x = self.patch_embed(x)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        x = self.upsample(x)
        
        num_patches = (self.img_size // self.patch_size) ** 2
        x = x.view(B, num_patches, self.in_channels, self.upscaled_patch_size, self.upscaled_patch_size)

        patches_per_row = self.img_size // self.patch_size
        rows = []
        for i in range(patches_per_row):
            row_patches = []
            for j in range(patches_per_row):
                patch_idx = i * patches_per_row + j
                row_patches.append(x[:, patch_idx, :, :, :])
            row = torch.cat(row_patches, dim=3)
            rows.append(row)
        x = torch.cat(rows, dim=2)
        x = self.refine(x)

        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
swin2sr_model = Swin2SR(
    img_size=60,
    patch_size=6,
    in_channels=3,
    embed_dim=768,
    depth=12,
    num_heads=12,
    scale_factor=2 
).to(device)

x = torch.randn(1, 3, 60, 60).to(device)
output = swin2sr_model(x)
print(output.shape) 


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
from tqdm import tqdm

def psnr_loss(output, target, max_pixel_value=1.0):
    mse = torch.mean((output - target) ** 2)
    psnr = 20 * torch.log10(max_pixel_value / torch.sqrt(mse))
    return -psnr  # Return negative PSNR to minimize
from torchmetrics.functional import structural_similarity_index_measure as ssim

class PSNR_SSIM_Loss(nn.Module):
    def __init__(self, alpha=0.5):
        super(PSNR_SSIM_Loss, self).__init__()
        self.alpha = alpha 
    
    def forward(self, output, target):
        psnr_value = psnr_loss(output, target)
        ssim_value = 1 - ssim(output, target)
        return self.alpha * psnr_value + (1 - self.alpha) * ssim_value
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit_model = Swin2SR(img_size=60, patch_size=6, in_channels=3, embed_dim=768, depth=12, num_heads=12)
vit_model.to(device)
criterion = PSNR_SSIM_Loss(alpha=0.7) 
optimizer = optim.Adam(vit_model.parameters(), lr=1e-4)

num_epochs = 5

for epoch in range(num_epochs):
    vit_model.train()
    epoch_loss = 0.0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (lr_tensor, hr_tensor, pths) in enumerate(progress_bar):
        lr_tensor, hr_tensor = lr_tensor.to(device), hr_tensor.to(device)

        optimizer.zero_grad()
        outputs = vit_model(lr_tensor)

        if outputs.shape != hr_tensor.shape:
            print(f"Output shape: {outputs.shape}, HR shape: {hr_tensor.shape}")
            raise ValueError("Output shape does not match HR shape")

        loss = criterion(outputs, hr_tensor)
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()
        progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(dataloader):.4f}")


In [10]:
torch.save(vit_model.state_dict(), "vit_model2.pth")

In [None]:
# load
vit_model = Swin2SR(img_size=60, patch_size=6, in_channels=3, embed_dim=768, depth=12, num_heads=12)

vit_model.load_state_dict(torch.load("vit_model2.pth"))

In [None]:
next(iter(dataloader))[2][0]

In [None]:
# rejoin list
"_".join(next(iter(dataloader))[2][0].split("_")[:-2]).split("\\")[-1]


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import numpy as np

# Ensure the model is in evaluation mode
vit_model.eval()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit_model.to(device)

# MSE Loss function
mse_criterion = torch.nn.MSELoss()

def calculate_psnr(sr_img, hr_img):
    # PSNR calculation (using skimage for precision)
    sr_img_np = sr_img.cpu().detach().numpy()
    hr_img_np = hr_img.cpu().detach().numpy()
    psnr_value = peak_signal_noise_ratio(hr_img_np, sr_img_np, data_range=1.0)
    return psnr_value

def calculate_ssim(sr_img, hr_img):
    # SSIM calculation (using skimage) with a smaller window size (3x3)
    sr_img_np = sr_img.permute(1, 2, 0).cpu().detach().numpy()
    hr_img_np = hr_img.permute(1, 2, 0).cpu().detach().numpy()
    ssim_value = structural_similarity(hr_img_np, sr_img_np, multichannel=True, data_range=1.0, win_size=3) # need to explicitly state 3 because of the small image size
    return ssim_value

def visualize_super_resolve_dataloader(dataloader):
    psnr = 0.0
    ssim = 0.0
    mse = 0.0

    for batch_idx, (lr_tensor, hr_tensor, pths) in enumerate(tqdm(dataloader)):
        lr_tensor, hr_tensor = lr_tensor.to(device), hr_tensor.to(device)

        with torch.no_grad():
            sr_image = vit_model(lr_tensor) 
        for i in range(lr_tensor.shape[0]): 
            #plt.figure(figsize=(12, 6))
            #plt.subplot(1, 3, 1)
            lr_img = lr_tensor[i].permute(1, 2, 0).cpu().numpy() 
            #plt.imshow(lr_img)
            #plt.title(f"Low-Resolution: {lr_img.shape[:2]}")

            sr_img = sr_image[i]
            #plt.subplot(1, 3, 2)
            sr_img_np = sr_img.permute(1, 2, 0).cpu().detach().numpy()
            #plt.imshow(sr_img_np)
            #plt.title(f"Super-Resolved: {sr_img_np.shape[:2]}")

            hr_img = hr_tensor[i]
            #plt.subplot(1, 3, 3)
            hr_img_np = hr_img.permute(1, 2, 0).cpu().numpy()
            #plt.imshow(hr_img_np)
            #plt.title(f"High-Resolution (Ground Truth): {hr_img_np.shape[:2]}")

            psnr_value = calculate_psnr(sr_img, hr_img)
            ssim_value = calculate_ssim(sr_img, hr_img)  # Updated with win_size=3
            mse_value = mse_criterion(sr_img, hr_img).item()

            #print(f"PSNR: {psnr_value:.2f}, SSIM: {ssim_value:.4f}, MSE: {mse_value:.4f}")

            psnr += psnr_value
            ssim += ssim_value
            mse += mse_value

            plt.show()

    num_samples = len(dataloader) * dataloader.batch_size
    print(f"Average PSNR: {psnr/num_samples:.2f}")
    print(f"Average SSIM: {ssim/num_samples:.4f}")
    print(f"Average MSE: {mse/num_samples:.4f}")


# Call the function with your DataLoader
visualize_super_resolve_dataloader(dataloader)


In [None]:
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import numpy as np
import os

import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import numpy as np
import os

def super_resolve_images(dataloader, output_dir="SR_Basic"):
    os.makedirs(output_dir, exist_ok=True)
    
    vit_model.eval()
    with torch.no_grad():
        for batch_idx, (lr_tensor, hr_tensor, paths) in enumerate(tqdm(dataloader)):
            if lr_tensor.shape[1] == 1:
                lr_tensor = lr_tensor.repeat(1, 3, 1, 1)

            lr_tensor = lr_tensor.to(device)
            sr_images = vit_model(lr_tensor)

            for i in range(lr_tensor.shape[0]):
                # Convert and process images
                lr_img = lr_tensor[i].permute(1, 2, 0).cpu().numpy()
                sr_img = sr_images[i].permute(1, 2, 0).cpu().detach().numpy()

                lr_img = np.clip(lr_img, 0, 1)
                sr_img = np.clip(sr_img, 0, 1)

                # Display images
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.imshow(lr_img)
                plt.title(f"Low-Resolution: {lr_img.shape[:2]}")
                plt.subplot(1, 2, 2)
                plt.imshow(sr_img)
                plt.title(f"Super-Resolved: {sr_img.shape[:2]}")

                # Get the current path
                current_path = paths[i] if isinstance(paths[i], str) else paths[i][0]
                
                # Extract base filename while preserving numbers
                base_name = os.path.splitext(os.path.basename(current_path))[0]
                # base_name will be like 'S2A_MSIL2A_20170613T101031_N9999_R022_T33UUP_40_65'
                
                # Save with original numbers intact
                lr_save_path = os.path.join(output_dir, f"{base_name}_LR.png")
                sr_save_path = os.path.join(output_dir, f"{base_name}_SR.png")

                # Save images
                Image.fromarray((lr_img * 255).astype(np.uint8)).save(lr_save_path)
                Image.fromarray((sr_img * 255).astype(np.uint8)).save(sr_save_path)

                print(f"Saved images:\n  LR: {lr_save_path}\n  SR: {sr_save_path}")

    return lr_img, sr_img
# Usage
lr_img, sr_img = super_resolve_images(dataloader)

In [None]:
print(paths[0])

In [None]:
i = 0
for (lr_tensor, hr_tensor, paths) in dataloader:
    print(paths)
    if (i >= 10):
        i += 1
        break