In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        pass
print("hello")

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, utils
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Encoder class
class Encoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=256):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 4, 2, 1)
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.conv4 = nn.Conv2d(256, 512, 4, 2, 1)
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(512 * 16 * 32, latent_dim)
        self.fc_logvar = nn.Linear(512 * 16 * 32, latent_dim)

    def forward(self, x):
        skips = []
        x = F.relu(self.conv1(x)); skips.append(x)
        x = F.relu(self.conv2(x)); skips.append(x)
        x = F.relu(self.conv3(x)); skips.append(x)
        x = F.relu(self.conv4(x)); skips.append(x)
        x = self.flatten(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar, skips

# Decoder class
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 512 * 16 * 32)
        self.deconv1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
        self.deconv2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.deconv3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.output_layer = nn.ConvTranspose2d(64, 3, 4, 2, 1)

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 512, 16, 32)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))
        x = torch.sigmoid(self.output_layer(x))
        return x

# VAE
class VAE(nn.Module):
    def __init__(self, latent_dim=256):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim=latent_dim)
        self.decoder = Decoder(latent_dim=latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar, skips = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar

    def generate(self, batch_size=1):
        z = torch.randn(batch_size, 256).to(next(self.parameters()).device)
        recon = self.decoder(z)
        return recon

# Loss Function (BCE + KL)
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

# Dataset Path and Transform
dataset_path = '/kaggle/input/mangoleaf-dataset/dataset'
transform = transforms.Compose([
    transforms.Resize((256, 512)),
    transforms.ToTensor()
])

# Load base dataset for class filtering
full_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

# Helper: Train VAE per class
def train_vae_per_class(class_name, latent_dim=256, num_epochs=200, batch_size=4, patience=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    class_index = full_dataset.class_to_idx[class_name]
    class_indices = [i for i, (_, label) in enumerate(full_dataset.samples) if label == class_index]
    dataset = Subset(full_dataset, class_indices)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    vae = VAE(latent_dim=latent_dim).to(device)
    optimizer = optim.Adam(vae.parameters(), lr=1e-4)

    best_loss = float("inf")
    early_stopping_counter = 0
    os.makedirs("vae_weights", exist_ok=True)

    print(f"\nTraining VAE for class: {class_name}")
    for epoch in range(num_epochs):
        vae.train()
        total_loss = 0
        for images, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images = images.to(device)
            optimizer.zero_grad()
            recon, mu, logvar = vae(images)
            loss = vae_loss(recon, images, mu, logvar)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader.dataset)
        print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            early_stopping_counter = 0
            torch.save(vae.state_dict(), f"vae_weights/vae_{class_name}.pth")
            print(f"Saved best model for {class_name}")
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

        # Visualize reconstruction
        vae.eval()
        with torch.no_grad():
            sample_img = images[0].unsqueeze(0)
            recon_img, _, _ = vae(sample_img)
            comparison = torch.cat([sample_img.cpu(), recon_img.cpu()])
            grid = utils.make_grid(comparison, nrow=2)
            plt.figure(figsize=(10, 5))
            plt.axis('off')
            plt.title(f"{class_name} - Epoch {epoch+1}")
            plt.imshow(grid.permute(1, 2, 0).clamp(0, 1))
            plt.show()

    return vae

# Generate synthetic images
def generate_images(vae, batch_size=5):
    vae.eval()
    with torch.no_grad():
        generated_images = vae.generate(batch_size=batch_size)
        grid = utils.make_grid(generated_images, nrow=batch_size)
        plt.figure(figsize=(batch_size * 2, 5))
        plt.imshow(grid.permute(1, 2, 0).clamp(0, 1))
        plt.axis('off')
        plt.title("Generated Samples")
        plt.show()

# Train for a class and generate
vae = train_vae_per_class(class_name='DIEBACK', latent_dim=256, num_epochs=300, batch_size=4)
generate_images(vae, batch_size=5)


In [None]:
# import os
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
# from torch.utils.data import DataLoader
# from torchvision import datasets, transforms, utils
# from tqdm import tqdm
# import matplotlib.pyplot as plt

# # Encoder class
# class Encoder(nn.Module):
#     def __init__(self, in_channels=3, latent_dim=256):
#         super(Encoder, self).__init__()
#         self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1)  # 128x256
#         self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)          # 64x128
#         self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)         # 32x64
#         self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)         # 16x32
#         self.flatten = nn.Flatten()
#         self.fc_mu = nn.Linear(512 * 16 * 32, latent_dim)
#         self.fc_logvar = nn.Linear(512 * 16 * 32, latent_dim)

#     def forward(self, x):
#         skips = []
#         x = F.relu(self.conv1(x)); skips.append(x)
#         x = F.relu(self.conv2(x)); skips.append(x)
#         x = F.relu(self.conv3(x)); skips.append(x)
#         x = F.relu(self.conv4(x)); skips.append(x)
#         x = self.flatten(x)
#         mu = self.fc_mu(x)
#         logvar = self.fc_logvar(x)
#         return mu, logvar, skips

# # Decoder class (Modified to ignore skip connections for generation)
# class Decoder(nn.Module):
#     def __init__(self, latent_dim):
#         super(Decoder, self).__init__()
#         self.fc = nn.Linear(latent_dim, 512 * 16 * 32)

#         self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)  # 32x64
#         self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)  # 64x128
#         self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)    # 128x256
#         self.output_layer = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)      # 256x512

#     def forward(self, z, skips=None):
#         x = self.fc(z)
#         x = x.view(-1, 512, 16, 32)

#         # Skip connections are ignored during generation
#         x = F.relu(self.deconv1(x))
#         x = F.relu(self.deconv2(x))
#         x = F.relu(self.deconv3(x))

#         # Generate final output
#         x = torch.sigmoid(self.output_layer(x))
#         return x

# # VAE class
# class VAE(nn.Module):
#     def __init__(self, latent_dim=256):
#         super(VAE, self).__init__()
#         self.encoder = Encoder(latent_dim=latent_dim)
#         self.decoder = Decoder(latent_dim=latent_dim)

#     def reparameterize(self, mu, logvar):
#         std = torch.exp(0.5 * logvar)
#         eps = torch.randn_like(std)
#         return mu + eps * std

#     def forward(self, x):
#         mu, logvar, skips = self.encoder(x)  # For training, we still use the encoder
#         z = self.reparameterize(mu, logvar)
#         recon = self.decoder(z)  # During training, decoder uses the latent vector from encoder
#         return recon, mu, logvar

#     def generate(self, batch_size=1):
#         # For generation, we sample random noise (z) from a normal distribution
#         z = torch.randn(batch_size, 256).to(next(self.parameters()).device)
#         recon = self.decoder(z)  # Generate images from random noise
#         return recon

# # VAE Loss Function
# def vae_loss(recon_x, x, mu, logvar):
#     recon_loss = F.mse_loss(recon_x, x, reduction='sum')
#     kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
#     return recon_loss + kl_div

# # Dataset Loading
# image_size = (512, 1024)
# transform = transforms.Compose([
#     transforms.Resize(image_size),
#     transforms.ToTensor(),
# ])

# dataset_path = '/kaggle/input/mangoleaf-dataset/dataset'
# dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
# class_names = dataset.classes
# print("Loaded classes:", class_names)

# # Training function
# def train_vae_per_class(class_name, dataset_path, image_size=(256, 512), latent_dim=256, num_epochs=200, batch_size=2, patience=20):
#     transform = transforms.Compose([
#         transforms.Resize(image_size),
#         transforms.ToTensor()
#     ])

#     class_dir = os.path.join(dataset_path, class_name)
#     dataset = datasets.ImageFolder(root=os.path.join(dataset_path), transform=transform)
#     filtered_dataset = [d for d in dataset.samples if class_name in d[0]]
#     dataset.samples = filtered_dataset
#     dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     vae = VAE(latent_dim=latent_dim).to(device)
#     optimizer = torch.optim.Adam(vae.parameters(), lr=1e-4)

#     best_loss = float("inf")
#     early_stopping_counter = 0
#     os.makedirs("vae_weights", exist_ok=True)

#     print(f"\nTraining VAE for class: {class_name}")
#     for epoch in range(num_epochs):
#         vae.train()
#         total_loss = 0
#         for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
#             images, _ = batch
#             images = images.to(device)
#             optimizer.zero_grad()
#             recon_images, mu, logvar = vae(images)
#             loss = vae_loss(recon_images, images, mu, logvar)
#             loss.backward()
#             optimizer.step()
#             total_loss += loss.item()

#         avg_loss = total_loss / len(dataloader)
#         print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.2f}")

#         # Early Stopping
#         if avg_loss < best_loss:
#             best_loss = avg_loss
#             early_stopping_counter = 0
#             torch.save(vae.state_dict(), f"vae_weights/vae_{class_name}.pth")
#             print(f"Saved best model for {class_name} with loss {best_loss:.2f}")
#         else:
#             early_stopping_counter += 1
#             print(f"No improvement. Early stopping counter: {early_stopping_counter}/{patience}")
#             if early_stopping_counter >= patience:
#                 print(f"Early stopping at epoch {epoch+1}")
#                 break

#         # Generate and show image
#         vae.eval()
#         with torch.no_grad():
#             sample_img = images[0].unsqueeze(0)
#             recon_img, _, _ = vae(sample_img)
#             comparison = torch.cat([sample_img.cpu(), recon_img.cpu()])
#             grid = utils.make_grid(comparison, nrow=2)
#             plt.figure(figsize=(10, 5))
#             plt.axis('off')
#             plt.title(f"{class_name} - Epoch {epoch+1}")
#             plt.imshow(grid.permute(1, 2, 0))
#             plt.show()

#     return vae

# # Usage: Generate random images after training
# def generate_images(vae, batch_size=5):
#     vae.eval()
#     with torch.no_grad():
#         generated_images = vae.generate(batch_size=batch_size)
#         grid = utils.make_grid(generated_images, nrow=batch_size)
#         plt.figure(figsize=(10, 5))
#         plt.imshow(grid.permute(1, 2, 0))
#         plt.axis('off')
#         plt.show()

# # Example: Train the VAE for a specific class
# vae = train_vae_per_class(class_name='DIEBACK', dataset_path=dataset_path, image_size=(256, 512), latent_dim=256, num_epochs=300, batch_size=2)

# # After training, you can generate synthetic images
# generate_images(vae, batch_size=5)  # Generate and show 5 images from random noise


In [None]:
def generate_images(vae, batch_size=5):
    vae.eval()
    with torch.no_grad():
        generated_images = vae.generate(batch_size=batch_size)
        grid = utils.make_grid(generated_images, nrow=batch_size)
        plt.figure(figsize=(10, 5))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())  # move to CPU before displaying
        plt.axis('off')
        plt.show()


In [None]:
generate_images(vae, batch_size=1)