In [None]:
import torch
import torchvision.transforms as transforms
import numpy as np
import tifffile as tiff
import matplotlib.pyplot as plt
from PIL import Image
from safetensors.torch import safe_open
from collections import OrderedDict

model_path = "/content/drive/MyDrive/models/G_model2.safetensors"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = Generator(input_nc=3, output_nc=10).to(device)

# .safetensors format
with safe_open(model_path, framework="pt", device='cpu') as f:
    raw_state_dict = {key: f.get_tensor(key) for key in f.keys()}

def rename_key(k):
    if k.startswith("module."):
        k = k[len("module."):]

    if not k.startswith("model."):
        k = "model." + k

    return k

new_state_dict = OrderedDict()
for k, v in raw_state_dict.items():
    if isinstance(v, torch.Tensor):
        new_key = rename_key(k)
        new_state_dict[new_key] = v

# State dict
load_result = G.load_state_dict(new_state_dict, strict=False)
print("Load result:", load_result) 
G.to(device)
G.eval()

# Sample images
# w1172713800_CH_18_generated.jpg
# w1171277558_CH_18_generated.jpg
# w1171270719_CH_18_generated.jpg
# w1169720360_CH_21_generated.jpg
# w1169335568_DE__generated.jpg
# w1170430875_US_18_generated.jpg
rgb_path = "/content/drive/MyDrive/generated_images/w1169720360_CH_21_generated.jpg"
input_image = Image.open(rgb_path).convert("RGB").resize((512, 512))
arr = np.array(input_image, dtype=np.float32) / 255.0

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,)*3, (0.5,)*3)
])
input_tensor = transform(input_image).unsqueeze(0).to(device)  # (1,3,512,512)

# Sliding Window Inference
patch_size = 64
h, w = 512, 512
channels = 13
output_fake = torch.zeros((channels, h, w), device=device)

for y in range(0, h, patch_size):
    for x in range(0, w, patch_size):
        patch = input_tensor[:, :, y:y+patch_size, x:x+patch_size]
        with torch.no_grad():
            fake_extra = G(patch)  # (1,10,64,64)

        fake_B_patch = torch.empty(1, channels, patch_size, patch_size, device=device)
        fake_B_patch[:, 0, :, :] = fake_extra[:, 0, :, :]            # B4
        fake_B_patch[:, 1:4, :, :] = patch                          # B2,B3,B4 = RGB
        fake_B_patch[:, 4:, :, :] = fake_extra[:, 1:, :, :]         # B5-B13

        output_fake[:, y:y+patch_size, x:x+patch_size] = fake_B_patch.squeeze(0)

fake_B_np = output_fake.cpu().numpy()
fake_B_np = (fake_B_np + 1) / 2.0
fake_B_np = np.transpose(fake_B_np, (1, 2, 0)) 

output_path = "/content/generated_multispectral_512.tif"
tiff.imwrite(output_path, fake_B_np.astype(np.float32))
print(f"Multispectral TIFF kaydedildi: {output_path}")

spectral_composites = {
    "Natural Color (B4, B3, B2)":      {"bands": [1, 2, 3], "display": "rgb"},
    "Color Infrared (B8, B4, B3)":     {"bands": [7, 3, 2], "cmap": "inferno"},
    "Short-Wave Infrared (B12, B8A, B4)": {"bands": [12, 8, 3], "cmap": "YlOrBr"},
    "Agriculture (B11, B8, B2)":       {"bands": [10, 7, 1], "cmap": "Greens"},
    "Geology (B12, B11, B2)":          {"bands": [12, 10, 1], "cmap": "magma"},
    "Bathymetric (B4, B3, B1)":        {"bands": [3, 2, 0], "cmap": "YlGnBu"}
}

# NDVI ve Nem İndeksleri
NDVI    = (fake_B_np[:,:,7] - fake_B_np[:,:,3]) / (fake_B_np[:,:,7] + fake_B_np[:,:,3] + 1e-6)
Moisture= (fake_B_np[:,:,8] - fake_B_np[:,:,11]) / (fake_B_np[:,:,8] + fake_B_np[:,:,11] + 1e-6)

index_colormaps = {
    "Vegetation Index (NDVI)": 'RdYlGn',
    "Moisture Index": 'Blues'
}

fig, axs = plt.subplots(3, 3, figsize=(18, 18))
axs = axs.flatten()
i = 0

for title, info in spectral_composites.items():
    bands = info["bands"]
    if info.get("display") == "rgb":
        img = fake_B_np[:, :, bands]
        axs[i].imshow(img)
    else:
        img = np.mean(fake_B_np[:, :, bands], axis=2)
        axs[i].imshow(img, cmap=info.get("cmap"))
    axs[i].set_title(title)
    axs[i].axis("off")
    i += 1

axs[i].imshow(NDVI, cmap=index_colormaps["Vegetation Index (NDVI)"], vmin=-1, vmax=1)
axs[i].set_title("Vegetation Index (NDVI)")
axs[i].axis("off")
i += 1

Moisture_norm = 2 * ((Moisture - Moisture.min()) / (Moisture.max() - Moisture.min())) - 1
axs[i].imshow(Moisture_norm, cmap=index_colormaps["Moisture Index"], vmin=-1, vmax=1)
axs[i].set_title("Moisture Index")
axs[i].axis("off")
i += 1

for j in range(i, len(axs)):
    axs[j].axis("off")

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

band_names = ['1','2','3','4','5','6','7','8','8a','9','10','11','12']

fig, axes = plt.subplots(3, 5, figsize=(15, 9))
axes = axes.flatten()

for i, name in enumerate(band_names):
    ax = axes[i]
    band = fake_B_np[:, :, i]
    ax.imshow(band, cmap='gray')
    ax.set_title(f'Band {name}')
    ax.axis('off')

for ax in axes[len(band_names):]:
    fig.delaxes(ax)

plt.tight_layout()
plt.show()