In [1]:


import torch
import torch.nn as nn

# Random scenario setup
B = 1       # batch size
z_dim = 5      # latent size
K = 3          # number of classes
e = 2          # embedding dimension
img_dim = 6    # fake image flattened size (toy)

# Create random z and labels
z = torch.randn(B, z_dim)
print(z.shape)

y = torch.randint(0, K, (B,))

# Embedding for classes
embedding = nn.Embedding(K, e)
print(embedding)
y_emb = embedding(y)
print(y_emb.shape)

# Generator input
g_in = torch.cat([z, y_emb], dim=1)
print(g_in.shape)
# Fake generator (just linear for demo)
G = nn.Linear(z_dim + e, img_dim)
x_hat = G(g_in)

# Flatten real "images" (random)
x = torch.randn(B, img_dim)
x_flat = x.view(B, -1)

# Discriminator input
d_in = torch.cat([x_flat, y_emb], dim=1)
D = nn.Linear(img_dim + e, 1)
score = D(d_in)




torch.Size([1, 5])
Embedding(3, 2)
torch.Size([1, 2])
torch.Size([1, 7])


In [3]:
import torch
from src.models.shs_gan.shs_generator import Generator
from src.models.shs_gan.shs_discriminator import Critic3D

def test_shapes():
    gen = Generator()
    critic = Critic3D()
    
    # Test input
    x = torch.randn(2, 3, 224, 224)
    print(f"Generator input: {x.shape}")
    

    fake_hsi = gen(x)
    print(f"Generator output: {fake_hsi.shape}")  # Should be [2, 16, 224, 224]
    
  
    score = critic(fake_hsi)
    print(f"Critic output: {score.shape}")  # Should be [2, 1]
    
    return fake_hsi, score

fake_hsi, score = test_shapes()

Generator input: torch.Size([2, 3, 224, 224])
Generator output: torch.Size([2, 16, 224, 224])
Critic output: torch.Size([2, 1])


In [13]:
# Cell 1: Import dependencies
import torch
import yaml
import numpy as np
from pathlib import Path
from omegaconf import OmegaConf


In [16]:
# Test the dataset and configuration only

# 1. Load and analyze the config file
import yaml
from pathlib import Path

config_path = "/mnt/datahdd/kris_volume/dgm-2025.2/projects/hyperskin/configs/data/hsi_dermoscopy_synth.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("=== CONFIG FILE ANALYSIS ===")
print(f"Class path: {config['data']['class_path']}")
print(f"Image size: {config['data']['init_args']['image_size']}")
print(f"Batch size: {config['data']['init_args']['batch_size']}")
print(f"Allowed labels: {config['data']['init_args']['allowed_labels']}")
print(f"Data directory: {config['data']['init_args']['data_dir']}")

# 2. Fix the interpolation issues in transforms
image_size = config['data']['init_args']['image_size']

# Replace the interpolation variables with actual values
if 'transforms' in config['data']['init_args']:
    transforms = config['data']['init_args']['transforms']
    for stage in ['train', 'val', 'test']:
        if stage in transforms:
            for transform in transforms[stage]:
                if 'init_args' in transform:
                    for key, value in transform['init_args'].items():
                        if isinstance(value, str) and '${data.init_args.image_size}' in value:
                            transform['init_args'][key] = image_size

print("\n=== FIXED TRANSFORMS ===")
print("Replaced interpolation variables with actual values")

# 3. Import and setup the data module
from src.data_modules.hsi_dermoscopy import HSIDermoscopyDataModule

# Use the fixed config directly instead of OmegaConf
init_args = config['data']['init_args']
init_args['class_path'] = "src.data_modules.HSIDermoscopyDataModule"

datamodule = HSIDermoscopyDataModule(**init_args)

print("\n=== DATAMODULE CREATED ===")
print(f"Task: {datamodule.hparams.task}")
print(f"Image size: {datamodule.hparams.image_size}")
print(f"Batch size: {datamodule.hparams.batch_size}")

# 4. Prepare data (download if needed)
print("\n=== PREPARING DATA ===")
datamodule.prepare_data()

# 5. Setup for training
print("\n=== SETUP DATA SPLITS ===")
datamodule.setup(stage='fit')

# 6. Check dataset sizes
print("\n=== DATASET SIZES ===")
print(f"Training samples: {len(datamodule.data_train)}")
print(f"Validation samples: {len(datamodule.data_val)}")
if hasattr(datamodule, 'data_test'):
    print(f"Test samples: {len(datamodule.data_test)}")

# 7. Test one batch from training loader
print("\n=== TESTING ONE BATCH ===")
train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))

if isinstance(batch, (list, tuple)):
    images, labels = batch
    print(f"Batch images shape: {images.shape}")  # [B, C, H, W]
    print(f"Batch labels shape: {labels.shape}")
    print(f"Image dtype: {images.dtype}")
    print(f"Label dtype: {labels.dtype}")
    print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
    print(f"Labels: {labels}")
else:
    print(f"Unexpected batch format: {type(batch)}")

print("\n=== CONFIGURATION TEST COMPLETE ===")

=== CONFIG FILE ANALYSIS ===
Class path: data_modules.HSIDermoscopyDataModule
Image size: 256
Batch size: 4
Allowed labels: ['melanoma']
Data directory: data/hsi_dermoscopy

=== FIXED TRANSFORMS ===
Replaced interpolation variables with actual values


TypeError: HSIDermoscopyDataModule.__init__() got an unexpected keyword argument 'class_path'

In [None]:
python src/main.py fit -c configs/data/hsi_dermoscopy_synth.yaml -c configs/model/shs_gan.yaml
d1f0e4f28a2d0dfd6eb977721654e79967e39dfa

In [17]:
%load_ext autoreload
%autoreload 2

import torch
from src.models.shs_gan.shs_generator import Generator
from src.models.shs_gan.shs_discriminator import Critic3D
from src.data_modules.hsi_dermoscopy import HSIDermoscopyDataModule
from omegaconf import OmegaConf


In [19]:
# Load YAML configs
cfg_data = OmegaConf.load("/mnt/datahdd/kris_volume/dgm-2025.2/projects/hyperskin/configs/data/hsi_dermoscopy_synth.yaml")
cfg_model = OmegaConf.load("/mnt/datahdd/kris_volume/dgm-2025.2/projects/hyperskin/configs/model/shs_gan.yaml")

# Print them to verify
print("Data config:")
print(OmegaConf.to_yaml(cfg_data))

print("\nModel config:")
print(OmegaConf.to_yaml(cfg_model))


Data config:
data:
  class_path: data_modules.HSIDermoscopyDataModule
  init_args:
    task: GENERATION
    train_val_test_split:
    - 0.7
    - 0.15
    - 0.15
    image_size: 256
    batch_size: 4
    num_workers: 8
    pin_memory: true
    allowed_labels:
    - melanoma
    data_dir: data/hsi_dermoscopy
    google_drive_id: 1WyIHxY1zh_f3uXwUVRvX9CzuFtfJchmx
    balanced_sampling: false
    transforms:
      train:
      - class_path: SmallestMaxSize
        init_args:
          max_size: ${data.init_args.image_size}
      - class_path: CenterCrop
        init_args:
          height: ${data.init_args.image_size}
          width: ${data.init_args.image_size}
      - class_path: Resize
        init_args:
          height: ${data.init_args.image_size}
          width: ${data.init_args.image_size}
      - class_path: ToTensorV2
        init_args: {}
      val:
      - class_path: SmallestMaxSize
        init_args:
          max_size: ${data.init_args.image_size}
      - class_path: Cent

In [20]:
# Instantiate datamodule from YAML
datamodule = HSIDermoscopyDataModule(**cfg_data)
datamodule.prepare_data()
datamodule.setup("fit")

train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))
real_hsi, _ = batch

print(f"Real HSI batch shape: {real_hsi.shape}")


TypeError: HSIDermoscopyDataModule.__init__() got an unexpected keyword argument 'data'

In [4]:
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

from src.modules.generative.gan.wgan import WGANModule
from src.data_modules import HSIDermoscopyDataModule

# ---- CONFIG ----
ckpt_path = None        # or path to a trained checkpoint if you want to load weights
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- instantiate data module ---
data_module = HSIDermoscopyDataModule(
    task="GENERATION",
    train_val_test_split=(0.7, 0.15, 0.15),
    batch_size=4,
    data_dir="data/hsi_dermoscopy_croppedv2_256",
    image_size=64,
)
data_module.setup("fit")

train_loader = data_module.train_dataloader()
real_imgs, _ = next(iter(train_loader))
real_imgs = real_imgs.to(device)

# --- instantiate model ---
model = WGANModule(
    img_channels=16,
    input_channels=1,
    img_size=64,
    constraint_method="gp",  # or "clip"
)
if ckpt_path:
    model = model.load_from_checkpoint(ckpt_path)
model = model.to(device)
model.eval()

# --- generate fake images ---
z = torch.randn_like(real_imgs[:, :model.hparams.input_channels, :, :])
with torch.no_grad():
    fake_imgs = model(z)

# --- check statistics ---
def print_stats(name, tensor):
    print(f"\n{name} stats:")
    print(f"  shape: {tuple(tensor.shape)}")
    print(f"  global min: {tensor.min().item():.4f}")
    print(f"  global max: {tensor.max().item():.4f}")
    print(f"  mean: {tensor.mean().item():.4f}")
    print(f"  std:  {tensor.std().item():.4f}")
    # per-channel stats (first few channels)
    if tensor.size(1) > 1:
        for c in range(min(3, tensor.size(1))):
            print(f"  channel {c}: min={tensor[:, c].min().item():.3f}, max={tensor[:, c].max().item():.3f}")

print_stats("REAL", real_imgs)
print_stats("FAKE", fake_imgs)

# --- visualize one sample ---
def show_tensor_grid(tensor, title):
    grid = make_grid(tensor[:4].detach().cpu(), nrow=4, normalize=True, value_range=(-1, 1))
    plt.figure(figsize=(6, 3))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis("off")
    plt.title(title)
    plt.show()

show_tensor_grid(real_imgs, "Real Images")
show_tensor_grid(fake_imgs, "Fake Images")


FileNotFoundError: One or more data directories do not exist in data/hsi_dermoscopy_croppedv2_256