In [None]:
import os
from glob import glob
from PIL import Image
from torch.utils.data import Dataset, DataLoader


base_path = "data"
folders = {
    300: os.path.join(base_path, "300"),
    600: os.path.join(base_path, "600"),
    1200: os.path.join(base_path, "1200"),
    2400: os.path.join(base_path, "2400"),
}


all_files = {size: glob(f"{folders[size]}/*.png") for size in folders}
prefixes = set(os.path.basename(f).split("_")[0] for f in all_files[300])

paired_images = []
for prefix in prefixes:
    matched_files = {size: os.path.join(folders[size], f"{prefix}_GOES16-ABI-ne-GEOCOLOR-{size}x{size}.png") for size in folders}
    
    # Ensure all resolutions exist
    if all(os.path.exists(matched_files[size]) for size in folders):
        paired_images.append(matched_files)

# Step 3: Define PyTorch Dataset
class SuperResDataset(Dataset):
    def __init__(self, paired_images, low_res_size, high_res_size, transform=None):
        self.paired_images = paired_images
        self.low_res_size = low_res_size
        self.high_res_size = high_res_size
        self.transform = transform

    def __len__(self):
        return len(self.paired_images)

    def __getitem__(self, idx):
        img_paths = self.paired_images[idx]
        
        low_res_path = img_paths[self.low_res_size]
        high_res_path = img_paths[self.high_res_size]
        
        low_res = Image.open(low_res_path).convert("RGB")
        high_res = Image.open(high_res_path).convert("RGB")

        if self.transform:
            low_res = self.transform(low_res)
            high_res = self.transform(high_res)

        return low_res, high_res
train_dataset = SuperResDataset(paired_images, low_res_size=300, high_res_size=600)

# Example: Load into DataLoader
train_loader = DataLoader(train_dataset, batch_size=8)


In [None]:
import torch
from diffusers import AutoencoderKL
from torchvision import transforms
from PIL import Image

# Load the VAE
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to("mps")


dummy_image = Image.new("RGB", (300, 300), color="white")
transform = transforms.Compose([
    transforms.ToTensor(),

])
image_tensor = transform(dummy_image).unsqueeze(0).to("mps")

# Encode and check the shape
with torch.no_grad():
    latent = vae.encode(image_tensor).latent_dist.sample()

print("Latent shape:", latent.shape)


  from .autonotebook import tqdm as notebook_tqdm


Latent shape: torch.Size([1, 4, 37, 37])


In [None]:
from diffusers import UNet2DConditionModel ,DiffusionPipeline

# Autoencoder
vae = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

#Diffusion Model
model = UNet2DConditionModel(
    sample_size=32,  # Latent space size
    in_channels=4,  # Latent space channels (not RGB)
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(128, 256, 512),
    down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
)



# When running in collab
#model.to("cuda") 
#vae.to("cuda") 



  from .autonotebook import tqdm as notebook_tqdm
Fetching 16 files:  19%|█▉        | 3/16 [01:05<04:43, 21.82s/it]
