In [None]:
import gdown
import shutil
import os
from pathlib import Path
import pickle
import torch
from torch import nn, optim
import torchvision.models as models
import torch.nn.functional as F
import lightning as L
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score
import seaborn as sns

file_id = '1YHaxS8f6fQ5IMT3JJVeBH40eZSS-UC9h'
zip_file_path = '/content/Images.zip'
extract_folder = '/content/Images'

# Download and extract only if not already done
if not os.path.exists(extract_folder):
    gdown.download(f'https://drive.google.com/uc?export=download&id={file_id}', zip_file_path, quiet=False)
    shutil.unpack_archive(zip_file_path, extract_folder)
    print("Files extracted:", os.listdir(extract_folder))
else:
    print("Files already extracted.")

path = Path(extract_folder) / 'Images'

In [None]:
from fastai.vision.all import *
from PIL import Image
import os

# Paths

#path = Path("path/to/original/images")  # Replace with the path to your original images
save_path = Path("cropped/images")  # Replace with the path to save cropped images
save_path.mkdir(parents=True, exist_ok=True)

# Function to crop images into overlapping tiles
def crop_with_overlap(image_path, save_path, tile_size=(512, 512), stride=(256, 256)):
    img = Image.open(image_path)
    width, height = img.size
    label = image_path.parent.name  # Extract label from parent folder
    class_save_path = save_path / label
    class_save_path.mkdir(parents=True, exist_ok=True)

    for i in range(0, width - tile_size[0] + 1, stride[0]):
        for j in range(0, height - tile_size[1] + 1, stride[1]):
            left = i
            upper = j
            right = left + tile_size[0]
            lower = upper + tile_size[1]
            tile = img.crop((left, upper, right, lower))
            tile.save(class_save_path / f"{image_path.stem}_tile_{i}_{j}.jpg")

# Crop images into overlapping tiles
tile_size = (512, 512)
stride = (256, 256)

image_files = get_image_files(path)
for image_file in image_files:
    crop_with_overlap(image_file, save_path, tile_size=tile_size, stride=stride)

# Verify cropped images
sliced_image_files = get_image_files(save_path)
print(f"Total cropped images: {len(sliced_image_files)}")

In [None]:
from fastai.vision.all import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Step 1: Define the VAE model
# This class defines a Variational Autoencoder (VAE) for 512x512 images.
class VAE(nn.Module):
    def __init__(self, img_size=512, latent_dim=512):
        """
        Initialize the VAE model with:
        - img_size: Dimensions of the input image (e.g., 512x512).
        - latent_dim: The size of the latent space, where the image is encoded.
        """
        super(VAE, self).__init__()
        self.img_size = img_size
        self.latent_dim = latent_dim

        # Encoder: Compresses the image into a latent representation
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),   # B, 64, img_size/2, img_size/2
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), # B, 128, img_size/4, img_size/4
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1), # B, 256, img_size/8, img_size/8
            nn.ReLU(),
            nn.Conv2d(256, 512, 4, 2, 1), # B, 512, img_size/16, img_size/16
            nn.ReLU(),
            nn.Flatten()
        )
        # Two fully connected layers for the latent space:
        # fc_mu predicts the mean, and fc_logvar predicts the log of the variance.
        self.fc_mu = nn.Linear(512 * (img_size // 16) ** 2, latent_dim)
        self.fc_logvar = nn.Linear(512 * (img_size // 16) ** 2, latent_dim)

        # Decoder: Reconstructs the image from the latent representation
        self.fc_decode = nn.Linear(latent_dim, 512 * (img_size // 16) ** 2)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Sigmoid()  # Outputs values between 0 and 1
        )

    def encode(self, x):
        """
        Encodes the input image into the latent space by generating the mean (mu)
        and the log variance (logvar).
        """
        x = self.encoder(x)
        mu, logvar = self.fc_mu(x), self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from a normal distribution using
        the predicted mean and variance.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)  # Random noise
        return mu + eps * std

    def decode(self, z):
        """
        Decodes the latent vector z back into an image.
        """
        z = self.fc_decode(z).view(-1, 512, self.img_size // 16, self.img_size // 16)
        return self.decoder(z)

    def forward(self, x):
        """
        Full forward pass: encodes the image, samples from the latent space,
        and reconstructs the image.
        """
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Step 2: Define custom VAE loss function
# This combines reconstruction loss (how well the image is reconstructed)
# and KL divergence (how close the latent distribution is to a standard normal distribution).
def vae_loss_fn(preds, x, mu, logvar):
    recon_x, _, _ = preds
    # Reconstruction loss: measures pixel-wise difference
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    # KL divergence: encourages latent space to follow a normal distribution
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld_loss

# Step 3: Prepare the data using Fastai's DataBlock
# Define the structure of the dataset and transformations.
data_path = Path("cropped/images/")  # Path to your dataset
dblock = DataBlock(
    blocks=(ImageBlock, ImageBlock),  # Input and output both are images
    get_items=get_image_files,        # Function to retrieve all image files
    splitter=RandomSplitter(),        # Split into training and validation sets
    item_tfms=Resize(512),            # Resize all images to 512x512
    batch_tfms=Normalize.from_stats(*imagenet_stats)  # Normalize using ImageNet statistics
)

# Create DataLoaders for training and validation
dls = dblock.dataloaders(data_path, bs=8)  # Use smaller batch size for large images

# Step 4: Train the VAE model
# Initialize the VAE model
vae = VAE(img_size=512, latent_dim=512)

# Use Fastai's Learner to manage training
learn = Learner(dls, vae, loss_func=vae_loss_fn, metrics=[])

# Fit the model for 10 epochs
learn.fit(10, lr=1e-3)

# Step 5: Generate new images by sampling from the latent space
vae.eval()  # Switch to evaluation mode
with torch.no_grad():
    # Sample a random point in the latent space
    z = torch.randn(1, 512)  # Latent space size = 512
    # Decode the latent vector to generate an image
    generated_img = vae.decode(z).squeeze().permute(1, 2, 0)  # Reshape for visualization
    # Convert to a displayable format (0-255 pixel range)
    generated_img = (generated_img * 255).numpy().astype("uint8")

# Step 6: Visualize the generated image
# Display the generated image
plt.imshow(generated_img)
plt.axis("off")
plt.show()


In [None]:
# Save both model and optimizer state_dict
torch.save({
    'model_state_dict': vae.state_dict(),
    'optimizer_state_dict': learn.opt.state_dict()
}, "vae_model_and_optimizer.pth")

# Load the saved states
checkpoint = torch.load("vae_model_and_optimizer.pth")
loaded_vae.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Inference example
with torch.no_grad():
    z = torch.randn(1, 512)  # Sample from the latent space
    generated_img = loaded_vae.decode(z).squeeze().permute(1, 2, 0)
    generated_img = (generated_img * 255).numpy().astype("uint8")

# Visualize the generated image
plt.imshow(generated_img)
plt.axis("off")
plt.show()