In [1]:
from torch.utils.data import Dataset
import numpy as np
import os
from glob import glob
import torch

class HyperspectralDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = glob(os.path.join(root_dir, "*", "*.npy"))
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path = self.samples[idx]
        image = np.load(file_path).astype(np.float32)  # shape (H, W, C)
        label_str = os.path.basename(os.path.dirname(file_path))
        label = int(label_str)
        image = np.transpose(image, (2, 0, 1))  # (C, H, W)

        image = torch.from_numpy(image)
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)


## CVAE model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CVAE(nn.Module):
    def __init__(self, img_channels=250, condition_dim=10, latent_dim=128, hidden_dims=None):
        super(CVAE, self).__init__()
        self.img_channels = img_channels
        self.latent_dim = latent_dim
        self.condition_dim = condition_dim

        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256]

        self.condition_embed = nn.Linear(condition_dim, 64)

        # Encoder
        encoder_layers = []
        in_channels = img_channels + 1  # 1 for condition broadcast
        for h_dim in hidden_dims:
            encoder_layers.append(nn.Conv2d(in_channels, h_dim, kernel_size=3, stride=2, padding=1))
            encoder_layers.append(nn.ReLU())
            in_channels = h_dim
        self.encoder = nn.Sequential(*encoder_layers)

        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(hidden_dims[-1]*8*8, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dims[-1]*8*8, latent_dim)


        # Decoder
        self.decoder_input = nn.Linear(latent_dim + 64, hidden_dims[-1] * 4 * 4)

        hidden_dims.reverse()
        decoder_layers = []
        
        for i in range(len(hidden_dims) - 1):
            decoder_layers.append(nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1],
                                                     kernel_size=4, stride=2, padding=1))
            decoder_layers.append(nn.ReLU())
        
        # 32 → 16 → 125
        decoder_layers.append(nn.ConvTranspose2d(hidden_dims[-1], 64, kernel_size=4, stride=2, padding=1))
        decoder_layers.append(nn.ReLU())
        
        decoder_layers.append(nn.ConvTranspose2d(64, img_channels, kernel_size=4, stride=2, padding=1))
        decoder_layers.append(nn.Sigmoid())
        self.decoder = nn.Sequential(*decoder_layers)

#
    def encode(self, x, c):
        B, _, H, W = x.shape
        c_broadcast = c.argmax(dim=1).view(B, 1, 1, 1).float().expand(-1, 1, H, W)
        x_cond = torch.cat([x, c_broadcast], dim=1)
        x_enc = self.encoder(x_cond)
        x_flat = self.flatten(x_enc)
        mu = self.fc_mu(x_flat)
        logvar = self.fc_logvar(x_flat)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        c_embed = self.condition_embed(c)
        zc = torch.cat([z, c_embed], dim=1)
        x = self.decoder_input(zc)
        x = x.view(x.size(0), -1, 4, 4)
        x = self.decoder(x)
        return x

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, c)
        return recon, mu, logvar


## Loss Function

In [3]:
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='mean')
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + kld_loss


## Training Loop

In [4]:
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm

def train_cvae(model, dataloader, device, num_epochs=20, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for x, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            c = F.one_hot(labels, num_classes=10).float().to(device)

            recon, mu, logvar = model(x, c)
            loss = vae_loss(recon, x, mu, logvar)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.4f}")
     

## Usage

In [5]:
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = HyperspectralDataset(root_dir="/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2025p2/Train")
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)

model = CVAE(img_channels=125)
train_cvae(model, dataloader, device, num_epochs=10)


Epoch 1/10: 100%|██████████| 131/131 [00:51<00:00,  2.56it/s]


Epoch [1/10] - Loss: nan


Epoch 2/10: 100%|██████████| 131/131 [00:24<00:00,  5.28it/s]


Epoch [2/10] - Loss: nan


Epoch 3/10: 100%|██████████| 131/131 [00:25<00:00,  5.18it/s]


Epoch [3/10] - Loss: nan


Epoch 4/10: 100%|██████████| 131/131 [00:25<00:00,  5.09it/s]


Epoch [4/10] - Loss: nan


Epoch 5/10: 100%|██████████| 131/131 [00:25<00:00,  5.22it/s]


Epoch [5/10] - Loss: nan


Epoch 6/10: 100%|██████████| 131/131 [00:25<00:00,  5.20it/s]


Epoch [6/10] - Loss: nan


Epoch 7/10: 100%|██████████| 131/131 [00:24<00:00,  5.25it/s]


Epoch [7/10] - Loss: nan


Epoch 8/10: 100%|██████████| 131/131 [00:25<00:00,  5.18it/s]


Epoch [8/10] - Loss: nan


Epoch 9/10: 100%|██████████| 131/131 [00:25<00:00,  5.22it/s]


Epoch [9/10] - Loss: nan


Epoch 10/10: 100%|██████████| 131/131 [00:24<00:00,  5.25it/s]

Epoch [10/10] - Loss: nan





## Generation

In [6]:
def generate_samples(model, disease_level, num_samples=50, latent_dim=128, device='cpu'):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim).to(device)
        labels = torch.full((num_samples,), disease_level, dtype=torch.long).to(device)
        c = F.one_hot(labels, num_classes=10).float()
        samples = model.decode(z, c)
        return samples.cpu().numpy()  # (B, C, H, W)
