In [None]:
import os
import csv
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image

# Load pre-trained BLIP model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

def generate_description(image_path):
    try:
        image = Image.open(image_path).convert("RGB")
        inputs = processor(image, return_tensors="pt")
        out = model.generate(**inputs)
        description = processor.decode(out[0], skip_special_tokens=True)
        return description
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return ""

def create_csv_from_images(image_folder, label_folder, output_csv_path):
    rows = [["image_path", "label_path", "text"]]
    
    for filename in sorted(os.listdir(image_folder)):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(image_folder, filename)
            image_id = os.path.splitext(filename)[0]
            
            # Assuming label file has same name as image file (can adjust extension if needed)
            label_path = os.path.join(label_folder, f"{image_id}.jpg")  # Change extension if label files differ

            caption = generate_description(image_path)
            rows.append([image_path, label_path, caption])  # Now includes full paths and caption

    with open(output_csv_path, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        writer.writerows(rows)
    print(f"CSV saved to {output_csv_path}")

# Example usage
image_folder = "/kaggle/input/final-dataset/photo/photo"
label_folder = "/kaggle/input/final-dataset/sketch/sketch"  # Adjust based on your label location
output_csv = "/kaggle/working/image_caption_with_paths.csv"
create_csv_from_images(image_folder, label_folder, output_csv)

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split

# Load your CSV file
df = pd.read_csv('/kaggle/input/newversion/image_caption.csv')  # Replace with your actual CSV file path

# Split the data into train (80%) and test (20%)
train_df, test_df = train_test_split(df, test_size=0.05, random_state=42, shuffle=True)


print(f"Train size: {len(train_df)}, Test size: {len(test_df)}")

train_df['image_path'] = train_df['image_path'].str.lower()
train_df['label_path'] = train_df['label_path'].str.lower()

In [None]:
import torch
import random

class AddGaussianNoise:
    def __init__(self, mean=0.0, std=0.05, p=0.5):
        self.mean = mean
        self.std = std
        self.p = p

    def __call__(self, tensor):
        if random.random() < self.p:
            noise = torch.randn_like(tensor) * self.std + self.mean
            return tensor + noise
        return tensor

import cv2
import numpy as np
from torchvision import transforms
from PIL import Image

class EnhanceCLAHEGamma:
    def __init__(self, gamma=1.0, clip_limit=2.0, tile_grid_size=(8, 8)):
        self.gamma = gamma
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size

    def __call__(self, img):
        img_cv = np.array(img)  # PIL to NumPy
        if len(img_cv.shape) == 2:  # Grayscale to BGR
            img_cv = cv2.cvtColor(img_cv, cv2.COLOR_GRAY2BGR)
        elif img_cv.shape[2] == 4:
            img_cv = cv2.cvtColor(img_cv, cv2.COLOR_RGBA2BGR)
        else:
            img_cv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2BGR)

        # CLAHE
        lab = cv2.cvtColor(img_cv, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
        cl = clahe.apply(l)
        limg = cv2.merge((cl, a, b))
        clahe_img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)

        # Gamma correction
        invGamma = 1.0 / self.gamma
        table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(256)]).astype("uint8")
        gamma_corrected = cv2.LUT(clahe_img, table)

        # Convert back to PIL
        final_img = cv2.cvtColor(gamma_corrected, cv2.COLOR_BGR2RGB)
        return Image.fromarray(final_img)

In [None]:
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class CustomImageDataset(Dataset):
    def __init__(self, dataframe, transform=None, label_transform=None):
        self.data = dataframe.reset_index(drop=True)
        self.transform = transform
        self.label_transform = label_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = Image.open(self.data.loc[idx, 'image_path']).convert('RGB')
        label = Image.open(self.data.loc[idx, 'label_path'])  # grayscale or RGB depending on your use case
     
        if self.transform:
            image = self.transform(image)
        if self.label_transform:
            label = self.label_transform(label)
        else:
            label = transforms.ToTensor()(label)  # default label tensor conversion

        caption = self.data.loc[idx, 'text']

        return image, label, caption


# Define transforms (same as earlier)
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),    # Ensures square aspect ratio
    EnhanceCLAHEGamma(gamma=0.8),   # You can tweak gamma value
    transforms.ToTensor(),
    AddGaussianNoise(std=0.03, p=0.5),  # After ToTensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5],  # GAN-friendly [-1,1] range
    std=[0.5, 0.5, 0.5])
   
])

label_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel
    transforms.ToTensor()
])

# Create datasets from DataFrames
train_dataset = CustomImageDataset(train_df, transform=image_transform, label_transform=label_transform)
test_dataset = CustomImageDataset(test_df, transform=image_transform, label_transform=label_transform)

    
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False,)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from transformers import CLIPModel, CLIPTokenizer

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        
        # Image encoder with partial fine-tuning
        self.image_encoder = models.mobilenet_v3_small(pretrained=True)
        self.image_encoder.classifier = nn.Identity()
        
        # Freeze early layers, allow later layers to fine-tune
        for i, param in enumerate(self.image_encoder.parameters()):
            if i < 100:  # Freeze first 100 layers
                param.requires_grad = False
        
        # CLIP model with gradient checkpointing
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        
        # Freeze CLIP entirely as it's already well-trained
        for param in self.clip_model.parameters():
            param.requires_grad = False
            
        # Projection layer with regularization
        self.image_projection = nn.Sequential(
            nn.Linear(576, 512),
            nn.Dropout(0.1),
            nn.LayerNorm(512)
        )

    def forward(self, image, text):
        # Image features with dropout during training
        image_features = self.image_encoder(image)
        image_features = self.image_projection(image_features)
        
        # Text features
        with torch.no_grad():  # No gradients for CLIP
            text_inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
            text_inputs = {k: v.to(image.device) for k, v in text_inputs.items()}
            text_features = self.clip_model.get_text_features(**text_inputs)
        
        return image_features, text_features


import torch
import torch.nn as nn
import torch.nn.functional as F

class AttributeAttention(nn.Module):
    def __init__(self, feature_dim=512):
        super(AttributeAttention, self).__init__()
        
        # Project image and text features separately
        self.image_proj = nn.Linear(feature_dim, feature_dim)
        self.text_proj = nn.Linear(feature_dim, feature_dim)
        
        # Layer Normalization for stability
        self.layer_norm = nn.LayerNorm(feature_dim)
        
        # Attention scoring network
        self.attention_layer = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(feature_dim // 2, 1)
        )

    def forward(self, image_features, text_features):
        # Project features first
        image_proj = self.image_proj(image_features)
        text_proj = self.text_proj(text_features)
        
        # Combine (add) and normalize
        combined = self.layer_norm(image_proj + text_proj)
        
        # Attention scores
        attention_scores = self.attention_layer(combined)  # (batch_size, 1)
        
        # Softmax over batch or feature dimension depending on input shape
        attention_weights = F.softmax(attention_scores.squeeze(-1), dim=-1)  # (batch_size,)
        
        # Reshape attention weights for broadcasting
        attention_weights = attention_weights.unsqueeze(-1)
        
        # Weighted sum of image features
        attended = image_features * attention_weights
        
        # (Optional) Sum across batch if you want a single vector
        # attended = attended.sum(dim=0)
        
        return attended.unsqueeze(-1).unsqueeze(-1)  # For CNN compatibility


class MultiModalFusion(nn.Module):
    def __init__(self):
        super(MultiModalFusion, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.attention = AttributeAttention()
        self.dropout = nn.Dropout2d(0.1)

    def forward(self, image, text):
        img_feat, txt_feat = self.feature_extractor(image, text)
        attended_feat = self.attention(img_feat, txt_feat)
        fused_feat = attended_feat.expand(-1, -1, 16, 16)
        return self.dropout(fused_feat)


class CoarseGenerator(nn.Module):
    def __init__(self):
        super(CoarseGenerator, self).__init__()
        
        # Encoder with spectral normalization
        self.enc1 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.1)
        )
        self.enc2 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2)
        )

        # Decoder with skip connections
        self.dec1 = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(64),
            nn.ReLU()
        )
        self.dec2 = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)),
            nn.Tanh()
        )

    def forward(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(0)

        e1 = self.enc1(x)
        e2 = self.enc2(e1)

        d1 = self.dec1(e2)
        d1_cat = torch.cat([d1, e1], dim=1)
        
        return self.dec2(d1_cat)

 

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(channels),
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.utils.spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)


import torch
import torch.nn as nn
import torch.nn.functional as F

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        scale = self.fc(self.global_avg_pool(x))
        return x * scale

class DilatedResidualBlock(nn.Module):
    def __init__(self, channels, dilation=2):
        super(DilatedResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
            ),
            nn.InstanceNorm2d(channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.utils.spectral_norm(
                nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
            ),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class RefinementGenerator(nn.Module):
    def __init__(self):
        super(RefinementGenerator, self).__init__()

        # Initial reduction
        self.initial = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(515, 256, kernel_size=1)),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.2)
        )

        # Encoder with improved residuals and SE attention
        self.encoder = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(256, 128, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            DilatedResidualBlock(128, dilation=2),
            SEBlock(128),

            nn.utils.spectral_norm(nn.Conv2d(128, 64, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            DilatedResidualBlock(64, dilation=2),
            SEBlock(64)
        )

        # Output layer
        self.output_layer = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(64, 3, kernel_size=3, padding=1)),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.encoder(x)
        return self.output_layer(x)


import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Full Image Discriminator (global discrimination)
        self.model = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.1),
            
            nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.1),
            
            # Final fully connected layer for the global real/fake prediction
            nn.utils.spectral_norm(nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0)),
            nn.Sigmoid()  # Output between 0 and 1 for full image real/fake
        )

    def forward(self, x):
        return self.model(x)


class DAASS(nn.Module):
    def __init__(self):
        super(DAASS, self).__init__()
        
        # Modules with initialization
        self.fusion = MultiModalFusion()
        self.coarse_gen = CoarseGenerator()
        self.refine_gen = RefinementGenerator()
        
        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, image, text):
        # Feature fusion
        fused_features = self.fusion(image, text)
        
        # Coarse generation
        coarse_sketch = self.coarse_gen(image)
        
        # Prepare refinement input
        fused_features_upsampled = F.interpolate(
            fused_features,
            size=coarse_sketch.shape[2:],
            mode='bilinear',
            align_corners=False
        )
        
        # Refinement
        combined_input = torch.cat([coarse_sketch, fused_features_upsampled], dim=1)
        refined_sketch = self.refine_gen(combined_input)
        
        return refined_sketch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from transformers import CLIPModel, CLIPTokenizer

# FeatureExtractor (unchanged)
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.image_encoder = models.mobilenet_v3_small(pretrained=True)
        self.image_encoder.classifier = nn.Identity()
        for i, param in enumerate(self.image_encoder.parameters()):
            if i < 100:
                param.requires_grad = False
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        for param in self.clip_model.parameters():
            param.requires_grad = False
        self.image_projection = nn.Sequential(
            nn.Linear(576, 512),
            nn.Dropout(0.1),
            nn.LayerNorm(512)
        )

    def forward(self, image, text):
        image_features = self.image_encoder(image)
        image_features = self.image_projection(image_features)
        with torch.no_grad():
            text_inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
            text_inputs = {k: v.to(image.device) for k, v in text_inputs.items()}
            text_features = self.clip_model.get_text_features(**text_inputs)
        return image_features, text_features

# AttributeAttention (unchanged)
class AttributeAttention(nn.Module):
    def __init__(self, feature_dim=512):
        super(AttributeAttention, self).__init__()
        self.image_proj = nn.Linear(feature_dim, feature_dim)
        self.text_proj = nn.Linear(feature_dim, feature_dim)
        self.layer_norm = nn.LayerNorm(feature_dim)
        self.attention_layer = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(feature_dim // 2, 1)
        )

    def forward(self, image_features, text_features):
        image_proj = self.image_proj(image_features)
        text_proj = self.text_proj(text_features)
        combined = self.layer_norm(image_proj + text_proj)
        attention_scores = self.attention_layer(combined)
        attention_weights = F.softmax(attention_scores.squeeze(-1), dim=-1)
        attention_weights = attention_weights.unsqueeze(-1)
        attended = image_features * attention_weights
        return attended.unsqueeze(-1).unsqueeze(-1)

# MultiModalFusion (unchanged)
class MultiModalFusion(nn.Module):
    def __init__(self):
        super(MultiModalFusion, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.attention = AttributeAttention()
        self.dropout = nn.Dropout2d(0.2)

    def forward(self, image, text):
        img_feat, txt_feat = self.feature_extractor(image, text)
        attended_feat = self.attention(img_feat, txt_feat)
        fused_feat = F.interpolate(attended_feat, size=image.shape[2:], mode='bilinear', align_corners=False)
        return self.dropout(fused_feat)

# ResidualBlock (unchanged)
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(channels),
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.utils.spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

# SEBlock (unchanged)
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        scale = self.fc(self.global_avg_pool(x))
        return x * scale

# DilatedResidualBlock (unchanged)
class DilatedResidualBlock(nn.Module):
    def __init__(self, channels, dilation=2):
        super(DilatedResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
            ),
            nn.InstanceNorm2d(channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.utils.spectral_norm(
                nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
            ),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv_block(x)

# RefinementGenerator (unchanged)
class RefinementGenerator(nn.Module):
    def __init__(self):
        super(RefinementGenerator, self).__init__()
        self.initial = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(515, 256, kernel_size=1)),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.2)
        )
        self.encoder = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(256, 128, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            DilatedResidualBlock(128, dilation=2),
            SEBlock(128),
            nn.utils.spectral_norm(nn.Conv2d(128, 64, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            DilatedResidualBlock(64, dilation=2),
            SEBlock(64)
        )
        self.output_layer = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(64, 3, kernel_size=3, padding=1)),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.encoder(x)
        return self.output_layer(x)

# Discriminator (unchanged)
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.1),
            nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Dropout2d(0.2),
            nn.utils.spectral_norm(nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0)),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# RSUBlock (unchanged)
class RSUBlock(nn.Module):
    """Residual U-block (RSU) for U^2-Net."""
    def __init__(self, in_channels, mid_channels, out_channels, num_layers=4):
        super(RSUBlock, self).__init__()
        self.in_channels = in_channels
        self.mid_channels = mid_channels
        self.out_channels = out_channels
        self.num_layers = num_layers

        # Initial convolution
        self.conv1 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Encoder path
        self.encoder = nn.ModuleList()
        current_channels = out_channels
        encoder_channels = [mid_channels] + [mid_channels * 2] * (num_layers - 1)
        for i in range(num_layers):
            self.encoder.append(
                nn.Sequential(
                    nn.utils.spectral_norm(nn.Conv2d(current_channels, encoder_channels[i], kernel_size=3, padding=1)),
                    nn.InstanceNorm2d(encoder_channels[i]),
                    nn.LeakyReLU(0.2, inplace=True)
                )
            )
            if i < num_layers - 1:
                self.encoder.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True))
            current_channels = encoder_channels[i]

        # Dilated convolution at bottleneck
        self.bottleneck = nn.Sequential(
            nn.utils.spectral_norm(
                nn.Conv2d(current_channels, current_channels, kernel_size=3, padding=2, dilation=2)
            ),
            nn.InstanceNorm2d(current_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Decoder path (only conv layers, upsampling via F.interpolate in forward)
        self.decoder = nn.ModuleList()
        decoder_in_channels = [current_channels] + [mid_channels * 2] * (num_layers - 2)
        decoder_out_channels = [mid_channels * 2] * (num_layers - 2) + [mid_channels]
        encoder_skip_channels = encoder_channels[:-1][::-1]
        for i in range(num_layers - 1):
            self.decoder.append(
                nn.Sequential(
                    nn.utils.spectral_norm(
                        nn.Conv2d(decoder_in_channels[i] + encoder_skip_channels[i], decoder_out_channels[i], kernel_size=3, padding=1)
                    ),
                    nn.InstanceNorm2d(decoder_out_channels[i]),
                    nn.LeakyReLU(0.2, inplace=True)
                )
            )

        # Final convolution
        self.conv_out = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(mid_channels + out_channels, out_channels, kernel_size=3, padding=1)),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        hx = self.conv1(x)
        hxs = [hx]
        for i in range(self.num_layers):
            hx = self.encoder[i * 2](hx)
            if i < self.num_layers - 1:
                hx = self.encoder[i * 2 + 1](hx)
            hxs.append(hx)
        hx = self.bottleneck(hxs[-1])
        for i in range(self.num_layers - 1):
            # Upsample to match the skip connection size
            hx = F.interpolate(hx, size=hxs[-(i + 2)].shape[2:], mode='bilinear', align_corners=False)
            hx = self.decoder[i](torch.cat([hx, hxs[-(i + 2)]], dim=1))
        # Final upsampling to match hxs[0] size
        hx = F.interpolate(hx, size=hxs[0].shape[2:], mode='bilinear', align_corners=False)
        output = self.conv_out(torch.cat([hx, hxs[0]], dim=1))
        return output + hxs[0]

# Updated U2NetCoarseGenerator
class U2NetCoarseGenerator(nn.Module):
    """U^2-Net-based coarse generator."""
    def __init__(self, model_type='u2net'):
        super(U2NetCoarseGenerator, self).__init__()
        self.model_type = model_type

        # Channel configurations
        if model_type == 'u2net':
            cfg = [64, 128, 256, 512, 512, 512]
            mid_cfg = [32, 32, 64, 128, 256, 256]
        else:  # u2netp
            cfg = [16, 32, 64, 128, 256, 256]
            mid_cfg = [16, 16, 32, 64, 128, 128]

        # Encoder stages
        self.stage1 = RSUBlock(3, mid_cfg[0], cfg[0], num_layers=4)
        self.pool12 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.stage2 = RSUBlock(cfg[0], mid_cfg[1], cfg[1], num_layers=4)
        self.pool23 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.stage3 = RSUBlock(cfg[1], mid_cfg[2], cfg[2], num_layers=4)
        self.pool34 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.stage4 = RSUBlock(cfg[2], mid_cfg[3], cfg[3], num_layers=3)
        self.pool45 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.stage5 = RSUBlock(cfg[3], mid_cfg[4], cfg[4], num_layers=3)
        self.pool56 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        self.stage6 = RSUBlock(cfg[4], mid_cfg[5], cfg[5], num_layers=3)

        # Decoder stages
        self.stage5d = RSUBlock(cfg[4], mid_cfg[4], cfg[4], num_layers=3)
        self.stage4d = RSUBlock(cfg[3], mid_cfg[3], cfg[3], num_layers=3)
        self.stage3d = RSUBlock(cfg[2], mid_cfg[2], cfg[2], num_layers=4)
        self.stage2d = RSUBlock(cfg[1], mid_cfg[1], cfg[1], num_layers=4)
        self.stage1d = RSUBlock(cfg[0], mid_cfg[0], cfg[0], num_layers=4)

        # Channel reduction convolutions for decoder skip connections
        self.conv5d = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(cfg[5], cfg[4], kernel_size=1)),
            nn.InstanceNorm2d(cfg[4]),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4d = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(cfg[4], cfg[3], kernel_size=1)),
            nn.InstanceNorm2d(cfg[3]),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3d = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(cfg[3], cfg[2], kernel_size=1)),
            nn.InstanceNorm2d(cfg[2]),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2d = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(cfg[2], cfg[1], kernel_size=1)),
            nn.InstanceNorm2d(cfg[1]),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv1d = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(cfg[1], cfg[0], kernel_size=1)),
            nn.InstanceNorm2d(cfg[0]),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Output layer
        self.output_layer = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(cfg[0], 3, kernel_size=3, padding=1)),
            nn.Tanh()
        )

    def forward(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(0)

        # Encoder path
        hx1 = self.stage1(x)
        hx = self.pool12(hx1)
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)
        hx6 = self.stage6(hx)

        # Decoder path
        hx6_up = F.interpolate(hx6, size=hx5.shape[2:], mode='bilinear', align_corners=False)
        hx5d = self.stage5d(self.conv5d(hx6_up) + hx5)
        hx5d_up = F.interpolate(hx5d, size=hx4.shape[2:], mode='bilinear', align_corners=False)
        hx4d = self.stage4d(self.conv4d(hx5d_up) + hx4)
        hx4d_up = F.interpolate(hx4d, size=hx3.shape[2:], mode='bilinear', align_corners=False)
        hx3d = self.stage3d(self.conv3d(hx4d_up) + hx3)
        hx3d_up = F.interpolate(hx3d, size=hx2.shape[2:], mode='bilinear', align_corners=False)
        hx2d = self.stage2d(self.conv2d(hx3d_up) + hx2)
        hx2d_up = F.interpolate(hx2d, size=hx1.shape[2:], mode='bilinear', align_corners=False)
        hx1d = self.stage1d(self.conv1d(hx2d_up) + hx1)

        # Output
        output = self.output_layer(hx1d)
        return output

# DAASS (unchanged)
class DAASS(nn.Module):
    def __init__(self, coarse_model_type='u2net'):
        super(DAASS, self).__init__()
        self.fusion = MultiModalFusion()
        self.coarse_gen = U2NetCoarseGenerator(model_type=coarse_model_type)
        self.refine_gen = RefinementGenerator()
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, image, text):
        fused_features = self.fusion(image, text)
        coarse_sketch = self.coarse_gen(image)
        fused_features_upsampled = F.interpolate(
            fused_features,
            size=coarse_sketch.shape[2:],
            mode='bilinear',
            align_corners=False
        )
        combined_input = torch.cat([coarse_sketch, fused_features_upsampled], dim=1)
        refined_sketch = self.refine_gen(combined_input)
        return refined_sketch

New

In [None]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor

class VGGPerceptualLoss(nn.Module):
    def __init__(self, layers=['relu3_3']):
        super(VGGPerceptualLoss, self).__init__()
        
        # Load pre-trained VGG19 model
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.eval()
        for param in vgg.parameters():
            param.requires_grad = False

        # Map layer index to names
        layer_name_mapping = {
            '0': "conv1_1", '1': "relu1_1",
            '2': "conv1_2", '3': "relu1_2",
            '4': "pool1",
            '5': "conv2_1", '6': "relu2_1",
            '7': "conv2_2", '8': "relu2_2",
            '9': "pool2",
            '10': "conv3_1", '11': "relu3_1",
            '12': "conv3_2", '13': "relu3_2",
            '14': "conv3_3", '15': "relu3_3",
            '16': "pool3",
            '17': "conv4_1", '18': "relu4_1",
            '19': "conv4_2", '20': "relu4_2",
            '21': "conv4_3", '22': "relu4_3",
            '23': "pool4",
            '24': "conv5_1", '25': "relu5_1",
            '26': "conv5_2", '27': "relu5_2",
            '28': "conv5_3", '29': "relu5_3",
            '30': "pool5",
        }

        # Set which layers to extract
        return_nodes = {k: v for k, v in layer_name_mapping.items() if v in layers}

        # Create efficient feature extractor
        self.feature_extractor = create_feature_extractor(vgg, return_nodes=return_nodes)

        # Store selected layers for loss computation
        self.selected_layers = layers

        # ImageNet normalization
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def normalize(self, x):
        return (x - self.mean) / self.std

    def forward(self, x, y):
        x = self.normalize(x)
        y = self.normalize(y)

        x_feats = self.feature_extractor(x)
        y_feats = self.feature_extractor(y)

        loss = 0.0
        for layer in self.selected_layers:
            loss += nn.functional.l1_loss(x_feats[layer], y_feats[layer])
        return loss

In [None]:
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import numpy as np
def calculate_foreground_accuracy(fake_imgs, real_imgs, threshold=0.3):
    """
    Calculates accuracy only on the foreground (i.e., non-background) pixels.
    """
    with torch.no_grad():
        # Normalize from [-1, 1] to [0, 1]
        fake_imgs = (fake_imgs + 1) / 2
        real_imgs = (real_imgs + 1) / 2

        # Binarize
        fake_bin = (fake_imgs >= threshold).float()
        real_bin = (real_imgs >= threshold).float()

        # Focus only on pixels where ground truth is foreground (i.e., real_bin == 1)
        foreground_pixels = (real_bin == 1).float()

        if foreground_pixels.sum() == 0:
            return torch.tensor(1.0)  # If no foreground in mask, consider it perfect (or handle differently)

        correct_foreground = ((fake_bin == real_bin) * foreground_pixels).sum()
        total_foreground = foreground_pixels.sum()

        acc = correct_foreground / total_foreground
    return acc


def calculate_discriminator_accuracy_biased(real_preds, fake_preds, bias=0.05):
    device = real_preds.device

    with torch.no_grad():
        threshold_real = 0.5 - bias
        threshold_fake = 0.5 + bias

        pred_real = (real_preds >= threshold_real).float()
        pred_fake = (fake_preds < threshold_fake).float()

        correct = pred_real.sum() + pred_fake.sum()
        total = real_preds.numel() + fake_preds.numel()

        acc = correct / total
        acc = acc.to(device)

    return acc
# def calculate_metrics(fake_imgs, real_imgs):
#     """Calculates PSNR and SSIM between fake and real images"""
#     psnr_total = 0.0
#     ssim_total = 0.0
#     batch_size = fake_imgs.shape[0]

#     for i in range(batch_size):
#         # Convert torch tensors to numpy, rescale to [0, 1]
#         fake_np = fake_imgs[i].cpu().detach().numpy().transpose(1, 2, 0)
#         real_np = real_imgs[i].cpu().detach().numpy().transpose(1, 2, 0)

#         fake_np = ((fake_np + 1) / 2).clip(0, 1)  # Tanh to [0, 1]
#         real_np = ((real_np + 1) / 2).clip(0, 1)

#         psnr = peak_signal_noise_ratio(real_np, fake_np, data_range=1)
#         ssim = structural_similarity(real_np, fake_np, data_range=1, win_size=5, channel_axis=-1)


#         psnr_total += psnr
#         ssim_total += ssim

#     return psnr_total / batch_size, ssim_total / batch_size
def calculate_metrics(fake_imgs, real_imgs):
    """Calculates PSNR and SSIM between fake and real images"""
    psnr_total = 0.0
    ssim_total = 0.0
    batch_size = fake_imgs.shape[0]

    for i in range(batch_size):
        # Convert torch tensors to numpy arrays and rescale to [0, 1]
        fake_np = fake_imgs[i].cpu().detach().numpy().transpose(1, 2, 0)
        real_np = real_imgs[i].cpu().detach().numpy().transpose(1, 2, 0)

        # Ensure images are in the range [0, 1] for PSNR and SSIM calculation
        fake_np = np.clip((fake_np + 1) / 2, 0, 1)  # Tanh to [0, 1]
        real_np = np.clip((real_np + 1) / 2, 0, 1)

        # Calculate PSNR and SSIM
        psnr = peak_signal_noise_ratio(real_np, fake_np, data_range=1)
        ssim = structural_similarity(real_np, fake_np, data_range=1, win_size=5, channel_axis=-1)

        psnr_total += psnr
        ssim_total += ssim

    # Average PSNR and SSIM over the batch
    avg_psnr = psnr_total / batch_size
    avg_ssim = ssim_total / batch_size
    return avg_psnr, avg_ssim

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def train_daass(model, discriminator, dataloader, optimizer_G, optimizer_D, criterion, device,  scheduler_G=None, scheduler_D=None,num_epochs=100):
    model.to(device)
    discriminator.to(device)
    perceptual_loss = VGGPerceptualLoss().to(device)
        
    # Check parameters
    print(f"Generator parameters: {count_parameters(generator_model)}")
    print(f"Discriminator parameters: {count_parameters(discriminator_model)}")


    for epoch in range(num_epochs):
        model.train()
        discriminator.train()

        total_d_loss = 0.0
        total_g_loss = 0.0
        total_psnr = 0.0
        total_ssim = 0.0
        total_acc = 0.0
        total_acc2 = 0.0
        total_batches = 0
        
        print(f"\n--- Epoch [{epoch+1}/{num_epochs}] ---")

        for batch_idx, (images, labels, texts) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass through generator
            fake_sketches = model(images, texts)

            # Discriminator predictions
            real_preds = discriminator(labels)
            fake_preds = discriminator(fake_sketches.detach())

            # Real and fake labels
            real_labels = torch.ones_like(real_preds).to(device)
            fake_labels = torch.zeros_like(fake_preds).to(device)

            # Discriminator loss
            d_loss_real = criterion(real_preds, real_labels)
            d_loss_fake = criterion(fake_preds, fake_labels)
            d_loss = (d_loss_real + d_loss_fake) / 2

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

            # Generator loss
            fake_preds_for_g = discriminator(fake_sketches)
            fake_norm = (fake_sketches + 1) / 2
            real_norm = (labels + 1) / 2

            # Make 3 channels if needed
            if fake_norm.shape[1] == 1:
                fake_norm = fake_norm.repeat(1, 3, 1, 1)
                real_norm = real_norm.repeat(1, 3, 1, 1)
            
            g_adv_loss = criterion(fake_preds_for_g, real_labels)
            p_loss = perceptual_loss(fake_norm, real_norm)
            g_loss = g_adv_loss + 0.1 * p_loss  # 0.1 is the weight for perceptual loss
                        #g_loss = criterion(fake_preds_for_g, real_labels)
            #print(f"Perceptual Loss: {p_loss.item():.4f}")


            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

            # Metrics
            psnr_score, ssim_score = calculate_metrics(fake_sketches, labels)

            # Discriminator accuracy
            #pred_real = (real_preds > 0.5).float()
            #pred_fake = (fake_preds < 0.5).float()
            #acc = (pred_real.sum() + pred_fake.sum()) / (2 * images.size(0) * real_preds.shape[-2] * real_preds.shape[-1])
            # Bias real predictions by 5% toward being classified as real
            acc = calculate_discriminator_accuracy_biased(real_preds, fake_preds, bias=0.05)
            acc2 = calculate_foreground_accuracy(fake_sketches, labels, threshold=0.5)




            # Accumulate
            total_d_loss += d_loss.item()
            total_g_loss += g_loss.item()
            total_psnr += psnr_score
            total_ssim += ssim_score
            total_acc += acc.item()
            total_acc2 += acc2.item()
            total_batches += 1

            # # Print batch-level info
            # if (batch_idx + 1) % 10 == 0 or (batch_idx == len(dataloader) - 1):
            #     print(f"[Epoch {epoch+1}/{num_epochs}] Batch {batch_idx+1}/{len(dataloader)} | "
            #           f"D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f} | "
            #           f"PSNR: {psnr_score:.2f} | SSIM: {ssim_score:.3f} | Acc: {acc.item():.3f}")

        # Print epoch summary
        avg_d_loss = total_d_loss / total_batches
        avg_g_loss = total_g_loss / total_batches
        avg_psnr = total_psnr / total_batches
        avg_ssim = total_ssim / total_batches
        avg_acc = total_acc / total_batches
        avg_acc2 = total_acc2 / total_batches

        print(f"\nEpoch {epoch+1} Summary → D Loss: {avg_d_loss:.4f} | G Loss: {avg_g_loss:.4f} | "
              f"PSNR: {avg_psnr:.2f} | SSIM: {avg_ssim:.3f} | Acc: {avg_acc:.3f} | Acc2: {avg_acc2:.3f}")
         # Step the learning rate scheduler (after epoch)
        if scheduler_G is not None:
            scheduler_G.step()
        if scheduler_D is not None:
            scheduler_D.step()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate models
generator_model = DAASS().to(device)
discriminator_model = Discriminator().to(device)

# Define loss criterion
criterion = nn.BCELoss()  # Binary Cross Entropy Loss for adversarial training



#optimizer_G = optim.Adam(generator_model.parameters(), lr=0.0002, betas=(0.5, 0.999),weight_decay=1e-5)
#optimizer_D = optim.Adam(discriminator_model.parameters(), lr=0.0002, betas=(0.5, 0.999),weight_decay=1e-5)
optimizer_G = optim.AdamW(generator_model.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4)
optimizer_D = optim.AdamW(discriminator_model.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-4)
# Define optimizers
scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_G, T_max=50)
scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_D, T_max=50)

# Assuming you already have your DataLoader defined as `train_loader`
# It should return: images (tensor), labels (target sketch images), texts (list of strings)

# Call training loop
train_daass(
    model=generator_model,
    discriminator=discriminator_model,
    dataloader=train_loader,
    optimizer_G=optimizer_G,
    optimizer_D=optimizer_D,
    criterion=criterion,
    device=device,
    scheduler_G=scheduler_G,  # ✅ new
    scheduler_D=scheduler_D,  # ✅ new
    num_epochs=20
)

In [None]:
import torch
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def calculate_accuracy(fake, real, threshold=0.3):
    """Pixel-wise accuracy for binary images (or grayscale if thresholded)."""
    fake_bin = (fake > threshold).float()
    real_bin = (real >= threshold).float()
    correct = (fake_bin == real_bin).float()
    accuracy = correct.mean().item()
    return accuracy

def test_daass(model, discriminator, dataloader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    discriminator.eval()

    total_psnr = 0.0
    total_ssim = 0.0
    total_accuracy = 0.0
    total_batches = 0

    with torch.no_grad():
        print("\n--- Testing ---")

        for batch_idx, (images, labels, texts) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass through the generator
            fake_sketches = model(images, texts)

            # Post-processing: scale Tanh output from [-1, 1] to [0, 1]
            fake_sketches = (fake_sketches + 1) / 2
            fake_sketches = fake_sketches.clamp(0, 1)

            # Metrics
            psnr_score, ssim_score = calculate_metrics(fake_sketches, labels)
            accuracy_score = calculate_accuracy(fake_sketches, labels)

            total_psnr += psnr_score
            total_ssim += ssim_score
            total_accuracy += accuracy_score
            total_batches += 1

            # Visualize all images in the first batch only
            if batch_idx == 0:
                num_images = images.size(0)
                fig, axes = plt.subplots(num_images, 3, figsize=(15, 5 * num_images))

                if num_images == 0:
                    axes = [axes]  # Handle case when batch size is 1

                for idx in range(num_images):
                    input_img = images[idx].cpu().permute(1, 2, 0).numpy().clip(0, 1)
                    real_img = labels[idx].cpu().permute(1, 2, 0).numpy().clip(0, 1)
                    gen_img = fake_sketches[idx].cpu().permute(1, 2, 0).numpy().clip(0, 1)

                    axes[idx][0].imshow(input_img)
                    axes[idx][0].set_title("Input Image")
                    axes[idx][0].axis("off")

                    axes[idx][1].imshow(real_img)
                    axes[idx][1].set_title("Real Sketch")
                    axes[idx][1].axis("off")

                    axes[idx][2].imshow(gen_img)
                    axes[idx][2].set_title("Generated Sketch")
                    axes[idx][2].axis("off")

                plt.tight_layout()
                plt.show()

        # Print average metrics
        avg_psnr = total_psnr / total_batches
        avg_ssim = total_ssim / total_batches
        avg_accuracy = total_accuracy / total_batches
        
        print(f"\nTest Summary → PSNR: {avg_psnr:.2f} | SSIM: {avg_ssim:.3f} | Accuracy: {avg_accuracy:.3f}")

In [None]:

# Assuming you have a test DataLoader called `test_loader`:

# Load model weights (if not already loaded)
# generator_model.load_state_dict(torch.load('generator_epoch_final.pth'))
# discriminator_model.load_state_dict(torch.load('discriminator_epoch_final.pth'))

# Test the model
test_daass(
    model=generator_model,
    discriminator=discriminator_model,
    dataloader=test_loader,
    criterion=criterion,
    device=device
)

In [None]:

# Assuming you have a test DataLoader called `test_loader`:

# Load model weights (if not already loaded)
# generator_model.load_state_dict(torch.load('generator_epoch_final.pth'))
# discriminator_model.load_state_dict(torch.load('discriminator_epoch_final.pth'))

# Test the model
test_daass(
    model=generator_model,
    discriminator=discriminator_model,
    dataloader=test_loader,
    criterion=criterion,
    device=device
)

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage
from skimage.feature import local_binary_pattern
from skimage.filters import sobel

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def calculate_vif(fake, real, sigma_n_sq=2.0):
    """Calculate Visual Information Fidelity (VIF) score using a simplified approach."""
    # Convert tensors to numpy arrays, shape (N, H, W, C)
    fake = fake.cpu().permute(0, 2, 3, 1).numpy()  # (N, H, W, C)
    real = real.cpu().permute(0, 2, 3, 1).numpy()  # (N, H, W, C)
    
    # Convert to grayscale if multi-channel (RGB)
    if fake.shape[-1] == 3:  # RGB
        fake = 0.2989 * fake[..., 0] + 0.5870 * fake[..., 1] + 0.1140 * fake[..., 2]  # (N, H, W)
        real = 0.2989 * real[..., 0] + 0.5870 * real[..., 1] + 0.1140 * real[..., 2]  # (N, H, W)
    
    vif_scores = []
    for i in range(fake.shape[0]):
        # Apply Gaussian blur to model human visual system
        fake_blur = ndimage.gaussian_filter(fake[i], sigma=2)
        real_blur = ndimage.gaussian_filter(real[i], sigma=2)
        
        # Compute variance and covariance
        var_fake = np.var(fake_blur)
        var_real = np.var(real_blur)
        cov_fake_real = np.cov(fake_blur.flatten(), real_blur.flatten())[0, 1]
        
        # Avoid division by zero
        if var_real < 1e-10:
            vif = 0.0
        else:
            # Simplified VIF: mutual information approximation
            vif = np.log2(1 + (var_fake + cov_fake_real) / (var_real + sigma_n_sq))
        vif_scores.append(vif)
    
    return np.mean(vif_scores)


def calculate_fsim(fake, real):
    """Calculate Feature Similarity Index (FSIM) score using gradient and texture features."""
    # Convert tensors to numpy arrays, shape (N, H, W, C)
    fake = fake.cpu().permute(0, 2, 3, 1).numpy()  # (N, H, W, C)
    real = real.cpu().permute(0, 2, 3, 1).numpy()  # (N, H, W, C)
    
    # Convert to grayscale if multi-channel (RGB)
    if fake.shape[-1] == 3:  # RGB
        fake = 0.2989 * fake[..., 0] + 0.5870 * fake[..., 1] + 0.1140 * fake[..., 2]  # (N, H, W)
        real = 0.2989 * real[..., 0] + 0.5870 * real[..., 1] + 0.1140 * real[..., 2]  # (N, H, W)
    
    fsim_scores = []
    for i in range(fake.shape[0]):
        fake_img = fake[i]
        real_img = real[i]
        
        # Compute gradient magnitude using Sobel
        fake_grad = sobel(fake_img)
        real_grad = sobel(real_img)
        
        # Compute texture using Local Binary Pattern (LBP)
        lbp_fake = local_binary_pattern(fake_img, P=8, R=1, method='uniform')
        lbp_real = local_binary_pattern(real_img, P=8, R=1, method='uniform')
        
        # Gradient similarity
        grad_sim = (2 * fake_grad * real_grad + 1e-6) / (fake_grad**2 + real_grad**2 + 1e-6)
        grad_sim = np.mean(grad_sim)
        
        # Texture similarity
        texture_sim = (2 * lbp_fake * lbp_real + 1e-6) / (lbp_fake**2 + lbp_real**2 + 1e-6)
        texture_sim = np.mean(texture_sim)
        
        # FSIM score (weighted combination)
        fsim = 0.8 * grad_sim + 0.2 * texture_sim
        fsim_scores.append(fsim)
    
    return np.mean(fsim_scores)


def calculate_gsm(fake, real):
    """Calculate Gradient Similarity Metric (GSM) based on gradient magnitude."""
    # Convert tensors to numpy arrays, shape (N, H, W, C)
    fake = fake.cpu().permute(0, 2, 3, 1).numpy()  # (N, H, W, C)
    real = real.cpu().permute(0, 2, 3, 1).numpy()  # (N, H, W, C)
    
    # Convert to grayscale if multi-channel (RGB)
    if fake.shape[-1] == 3:  # RGB
        fake = 0.2989 * fake[..., 0] + 0.5870 * fake[..., 1] + 0.1140 * fake[..., 2]  # (N, H, W)
        real = 0.2989 * real[..., 0] + 0.5870 * real[..., 1] + 0.1140 * real[..., 2]  # (N, H, W)
    
    gsm_scores = []
    for i in range(fake.shape[0]):
        fake_img = fake[i]
        real_img = real[i]
        
        # Compute gradient magnitude using Sobel
        fake_grad = sobel(fake_img)
        real_grad = sobel(real_img)
        
        # Gradient similarity
        grad_sim = (2 * fake_grad * real_grad + 1e-6) / (fake_grad**2 + real_grad**2 + 1e-6)
        gsm = np.mean(grad_sim)
        gsm_scores.append(gsm)
    
    return np.mean(gsm_scores)


def test_daass2(model, discriminator, dataloader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    discriminator.eval()

    total_vif = 0.0
    total_fsim = 0.0
    total_gsm = 0.0
    total_batches = 0

    with torch.no_grad():
        print("\n--- Testing ---")

        for batch_idx, (images, labels, texts) in enumerate(dataloader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass through the generator
            fake_sketches = model(images, texts)

            # Post-processing: scale Tanh output from [-1, 1] to [0, 1]
            fake_sketches = (fake_sketches + 1) / 2
            fake_sketches = fake_sketches.clamp(0, 1)

            # Metrics
            vif_score = calculate_vif(fake_sketches, labels)
            fsim_score = calculate_fsim(fake_sketches, labels)
            gsm_score = calculate_gsm(fake_sketches, labels)

            total_vif += vif_score
            total_fsim += fsim_score
            total_gsm += gsm_score
            total_batches += 1

            # Visualize all images in the first batch only
            if batch_idx == 0:
                num_images = images.size(0)
                fig, axes = plt.subplots(num_images, 3, figsize=(15, 5 * num_images))

                if num_images == 1:  # Handle case when batch size is 1
                    axes = [axes]

                for idx in range(num_images):
                    input_img = images[idx].cpu().permute(1, 2, 0).numpy().clip(0, 1)
                    real_img = labels[idx].cpu().permute(1, 2, 0).numpy().clip(0, 1)
                    gen_img = fake_sketches[idx].cpu().permute(1, 2, 0).numpy().clip(0, 1)

                    axes[idx][0].imshow(input_img)
                    axes[idx][0].set_title("Input Image")
                    axes[idx][0].axis("off")

                    axes[idx][1].imshow(real_img)
                    axes[idx][1].set_title("Real Sketch")
                    axes[idx][1].axis("off")

                    axes[idx][2].imshow(gen_img)
                    axes[idx][2].set_title("Generated Sketch")
                    axes[idx][2].axis("off")

                plt.tight_layout()
                plt.show()

        # Print average metrics
        avg_vif = total_vif / total_batches
        avg_fsim = total_fsim / total_batches
        avg_gsm = total_gsm / total_batches
        
        print(f"\nTest Summary → VIF: {avg_vif:.3f} | FSIM: {avg_fsim:.3f} | GSM: {avg_gsm:.3f}")

In [None]:

# Assuming you have a test DataLoader called `test_loader`:

# Load model weights (if not already loaded)
# generator_model.load_state_dict(torch.load('generator_epoch_final.pth'))
# discriminator_model.load_state_dict(torch.load('discriminator_epoch_final.pth'))

# Test the model
test_daass2(
    model=generator_model,
    discriminator=discriminator_model,
    dataloader=test_loader,
    criterion=criterion,
    device=device
)