In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision.models.resnet import ResNet, Bottleneck
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, Resize, ToTensor, RandomRotation, RandomHorizontalFlip, ColorJitter, RandomGrayscale, GaussianBlur

In [None]:
class WFLWDataset(Dataset):
    def __init__(self, annotation_file, base_image_dir, img_size=(256, 256)):
        """
        Dataset for WFLW with resizing and augmentations.
        Args:
            annotation_file: Path to the file containing keypoints annotations.
            base_image_dir: Directory containing images.
            img_size: Target size for resizing images.
        """
        self.annotations = pd.read_csv(annotation_file, sep='\s+', header=None)
        self.base_image_dir = base_image_dir
        self.img_size = img_size

        # Define augmentations
        self.transforms = Compose([
            Resize(img_size),  # Resize image
            RandomRotation(10),  # Random rotation
            RandomHorizontalFlip(),  # Random horizontal flip
            ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color jitter
            RandomGrayscale(p=0.1),  # Convert to grayscale with 10% probability
            GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0)),  # Random blur
            ToTensor(),  # Convert to tensor
        ])

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        row = self.annotations.iloc[idx]

        # Construct image name as "wflw_train_with_box_{idx}.jpg"
        img_name = f"wflw_train_with_box_{idx + 1}.jpg"  # Assuming idx starts from 0, and image names start from 1
        img_path = os.path.join(self.base_image_dir, img_name)

        # Load the image
        image = Image.open(img_path).convert("RGB")

        # Apply transforms (resize and augmentations)
        image = self.transforms(image)

        # Extract keypoints (normalized x and y coordinates)
        keypoints_data = row.iloc[1:].values.astype('float').reshape(-1, 2)

        # Scale keypoints according to the image size
        keypoints_data[:, 0] = keypoints_data[:, 0] * self.img_size[0]  # Scale x
        keypoints_data[:, 1] = keypoints_data[:, 1] * self.img_size[1]  # Scale y

        return image, torch.tensor(keypoints_data, dtype=torch.float32)

In [4]:
batch_size = 16
train_dataset = WFLWDataset(
    annotation_file='WFLW/train.txt',
    base_image_dir='WFLW/train'
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [5]:
class PixelShuffle(nn.Module):
    def __init__(self):
        super(PixelShuffle, self).__init__()

    def forward(self, x):
        B, C, H, W = x.size()
        x = x.reshape(B, C // 4, 2, 2, H, W).permute(0, 1, 4, 2, 5, 3)
        return x.reshape(B, C // 4, H * 2, W * 2)

In [6]:
class ViT(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim):
        super(ViT, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim)
        )

    def forward(self, x):
        x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [7]:
class ChannelSplitViT(nn.Module):
    def __init__(self, input_channels=256, embed_dim=512, num_heads=8):
        super(ChannelSplitViT, self).__init__()
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=2, stride=2)  # Downsampling to (512, 16, 16)
        self.vit = ViT(256, num_heads, embed_dim * 4)  # Standard ViT with input (512, 256)
        self.pixel_shuffle = PixelShuffle()  # Upsampling to (128, 32, 32)

    def forward(self, x):
        x = self.conv(x)  # (256, 32, 32) → (512, 16, 16)
        B, C, H, W = x.shape
        x = x.view(B, C, H * W).permute(0, 1, 2)  # (512, 16, 16) → (512, 256)
        x = self.vit(x)  # (512, 256) → (512, 256)
        x = x.permute(0, 1, 2).view(B, C, H, W)  # (512, 256) → (512, 16, 16)
        x = self.pixel_shuffle(x)  # (512, 16, 16) → (128, 32, 32)
        return x

In [8]:
class SpatialSplitViT(nn.Module):
    def __init__(self, input_channels=256, embed_dim=512, num_heads=8):
        super(SpatialSplitViT, self).__init__()
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=2, stride=2)  # Downsampling to (512, 16, 16)
        self.vit = ViT(embed_dim, num_heads, embed_dim * 4)  # Standard ViT with input (256, 512)
        self.pixel_shuffle = PixelShuffle()  # Upsampling to (128, 32, 32)

    def forward(self, x):
        x = self.conv(x)  # (256, 32, 32) → (512, 16, 16)
        B, C, H, W = x.shape
        x = x.view(B, C, H * W).permute(0, 2, 1)  # (512, 16, 16) → (256, 512)
        x = self.vit(x)  # (256, 512) → (256, 512)
        x = x.permute(0, 2, 1).view(B, C, H, W)  # (256, 512) → (512, 16, 16)
        x = self.pixel_shuffle(x)  # (512, 16, 16) → (128, 32, 32)
        return x

In [9]:
class DualViT(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(DualViT, self).__init__()
        self.spatial_vit = SpatialSplitViT()  # Handles spatial splitting
        self.channel_vit = ChannelSplitViT()  # Handles channel splitting
        # self.merge_conv = nn.Conv2d(embed_dim * 2, embed_dim, kernel_size=1)  # To merge the outputs

    def forward(self, x):
        # First DualViT
        spatial_out = self.spatial_vit(x)  # Output: (B, embed_dim, 32, 32)
        # Second DualViT
        channel_out = self.channel_vit(x)  # Output: (B, embed_dim, 32, 32)
        output = torch.cat([spatial_out, channel_out], dim=1)  # Concatenate along channel dimension
        # Further layers
        # output = self.merge_conv(output)  # Merge operation
        # print(f"After merge_conv: {output.shape}")

        return output


In [10]:
class PredictionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, num_landmarks):
        super(PredictionBlock, self).__init__()
        self.d_vit1 = DualViT(embed_dim, num_heads)  # First DualViT
        self.d_vit2 = DualViT(embed_dim, num_heads)  # Second DualViT
        self.final_conv = nn.Conv2d(embed_dim, num_landmarks, kernel_size=1)  # Heatmap prediction

    def forward(self, x):
        # Step 1: First DualViT
        x = self.d_vit1(x)  # Output: (B, embed_dim, 32, 32)

        # Step 2: Second DualViT
        x = self.d_vit2(x)  # Output: (B, embed_dim, 32, 32)

        # Step 3: Predict heatmaps
        heatmaps = self.final_conv(x)  # Output: (B, num_landmarks, 32, 32)

        return x, heatmaps

In [None]:
class ModifiedResNet50(nn.Module):
    def __init__(self):
        super(ModifiedResNet50, self).__init__()
        # Initialize the ResNet backbone up to layer4
        self.resnet = ResNet(block=Bottleneck, layers=[3, 4, 6, 3])

        # Remove the fully connected layer and adaptive pooling
        self.resnet.avgpool = nn.Identity()  # Disable AdaptiveAvgPool2d
        self.resnet.fc = nn.Identity()  # Disable the final fully connected layer

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)
        
        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        
        return x

In [13]:
def soft_argmax(heatmaps):
    """
    Compute soft-argmax for heatmaps.
    Input:
        heatmaps: (batch_size, num_landmarks, height, width)
    Output:
        coordinates: (batch_size, num_landmarks, 2) - [x, y] coordinates
    """
    b, n, h, w = heatmaps.shape
    heatmaps = F.softmax(heatmaps.view(b, n, -1), dim=-1)  # Flatten and apply softmax
    heatmaps = heatmaps.view(b, n, h, w)  # Reshape back to (b, n, h, w)
    
    # Create grid for coordinates
    x = torch.linspace(0, w - 1, w).to(heatmaps.device)
    y = torch.linspace(0, h - 1, h).to(heatmaps.device)
    x_grid, y_grid = torch.meshgrid(x, y, indexing="xy")

    # Compute weighted sums for x and y
    x_coords = torch.sum(heatmaps * x_grid[None, None, :, :], dim=(2, 3))
    y_coords = torch.sum(heatmaps * y_grid[None, None, :, :], dim=(2, 3))

    return torch.stack([x_coords, y_coords], dim=-1)  # (b, n, 2)

In [None]:
class CascadedDViT(nn.Module):
    def __init__(self, num_blocks, embed_dim, num_heads, num_landmarks):
        super(CascadedDViT, self).__init__()
        self.backbone = ModifiedResNet50()
        
        # Upsample ResNet output to match the spatial dimensions of DualViT input
        self.upsample_resnet = nn.ConvTranspose2d(2048, embed_dim, kernel_size=4, stride=4, padding=0)

        # Create Prediction Blocks
        self.prediction_blocks = nn.ModuleList(
            [PredictionBlock(embed_dim, num_heads, num_landmarks) for _ in range(num_blocks)]
        )

        # Merge concatenated outputs for subsequent blocks
        self.merge_conv = nn.Conv2d(embed_dim * 2, embed_dim, kernel_size=1)

    def forward(self, x):
        # Step 1: Extract features using ResNet
        resnet_output = self.backbone(x)  # ResNet output: (B, 2048, 8, 8)
        resnet_output = self.upsample_resnet(resnet_output)  # Upsampled output: (B, embed_dim, 32, 32)
    
        predictions = []
        prev_features = None

        # Step 2: Process through PredictionBlocks
        for i, block in enumerate(self.prediction_blocks):
            if i == 0:
                # First block gets only ResNet output
                combined_features = resnet_output
            else:
                # Subsequent blocks get concatenated features
                combined_features = torch.cat([resnet_output, prev_features], dim=1)  # (B, embed_dim * 2, 32, 32)
                combined_features = self.merge_conv(combined_features)  # (B, embed_dim, 32, 32)

            # Process the combined features through the PredictionBlock
            prev_features, heatmaps = block(combined_features)
            coords_pred = soft_argmax(heatmaps)
            # Save the heatmaps
            predictions.append((heatmaps, coords_pred))

        return predictions

In [None]:
class AWingLoss(nn.Module):
    def __init__(self, omega=14, epsilon=1, theta=0.5, alpha=2.1):
        """
        Adaptive Wing Loss
        Args:
            omega: Scaling factor (default: 14)
            epsilon: Smoothing parameter (default: 1)
            theta: Threshold for switching between linear and non-linear (default: 0.5)
            alpha: Exponential decay factor for the nonlinear component (default: 2.1)
        """
        super(AWingLoss, self).__init__()
        self.omega = omega
        self.epsilon = epsilon
        self.theta = theta
        self.alpha = alpha

    def forward(self, pred, target):
        """
        Compute AWing loss between predicted heatmaps and ground truth.
        Args:
            pred: Predicted heatmaps (b, n, h, w)
            target: Ground truth heatmaps (b, n, h, w)
        Returns:
            loss: Adaptive Wing Loss
        """
        diff = target - pred  # Difference between ground truth and predictions
        abs_diff = diff.abs()

        # Non-linear component for abs_diff < theta
        A = self.omega * (1 / (1 + (self.theta / self.epsilon) ** (self.alpha - target))) * (self.alpha - target) * ((self.theta / self.alpha) ** (self.alpha - target - 1)) * (1 / self.epsilon)
        nonlinear_loss = self.omega * torch.log(1 + ((abs_diff / self.epsilon) ** (self.alpha - target)))

        # Linear component for abs_diff >= theta
        C = self.theta * A - self.omega * torch.log(1 + ((self.theta / self.epsilon) ** (self.alpha - target)))
        linear_loss = A * abs_diff - C

        # Combine losses
        loss = torch.where(abs_diff < self.theta, nonlinear_loss, linear_loss)
        return loss.mean()

In [16]:
# Smooth L1 loss for coordinate regression
smooth_l1_loss = nn.SmoothL1Loss()

# Training loss class
class TrainingLoss(nn.Module):
    def __init__(self, num_blocks, w, beta):
        super(TrainingLoss, self).__init__()
        self.num_blocks = num_blocks
        self.w = w
        self.beta = beta
        self.heatmap_loss = AWingLoss()

    def forward(self, predictions, heatmaps_gt, coords_gt):
        """
        Input:
            predictions: List of predictions from all blocks [(coords, heatmaps), ...]
            heatmaps_gt: Ground truth heatmaps (b, n, h, w)
            coords_gt: Ground truth coordinates (b, n, 2)
        Output:
            total_loss: Combined loss
        """
        total_loss = 0
        for j, (heatmaps_pred, coords_pred) in enumerate(predictions):
            # Compute intermediate losses
            coord_loss = smooth_l1_loss(coords_pred, coords_gt)
            heatmap_loss = self.heatmap_loss(heatmaps_pred, heatmaps_gt)

            # Combine losses with beta
            intermediate_loss = coord_loss + self.beta * heatmap_loss

            # Weight by w^(j-B)
            weight = self.w ** (j - self.num_blocks)
            total_loss += weight * intermediate_loss

        return total_loss

In [17]:
def generate_heatmaps(keypoints, img_size, heatmap_size, sigma=2):
    """
    Generate heatmaps for keypoints using the coordinate encoding method.
    
    Args:
        keypoints (torch.Tensor): Keypoints of shape (N, num_keypoints, 2) - [(u, v)].
        img_size (tuple): Original image size (height, width).
        heatmap_size (tuple): Heatmap size (height, width).
        sigma (float): Standard deviation for the Gaussian kernel.

    Returns:
        torch.Tensor: Heatmaps of shape (N, num_keypoints, heatmap_height, heatmap_width).
    """
    N, num_keypoints, _ = keypoints.shape
    device = keypoints.device
    heatmaps = torch.zeros((N, num_keypoints, heatmap_size[0], heatmap_size[1]), device=device, dtype=torch.float32)

    # Downsampling ratio
    lambda_x = img_size[1] / heatmap_size[1]  # width ratio
    lambda_y = img_size[0] / heatmap_size[0]  # height ratio

    for i in range(N):  # Iterate over batch
        for j in range(num_keypoints):  # Iterate over each keypoint
            u, v = keypoints[i, j]  # Original coordinates (u, v)

            # Downsample coordinates
            u_prime = u / lambda_x
            v_prime = v / lambda_y
            
            # Quantize (can use floor, ceil, or round)
            u_quant = torch.round(u_prime)
            v_quant = torch.round(v_prime)

            # Create Gaussian kernel centered at quantized location
            y = torch.arange(0, heatmap_size[0], device=device, dtype=torch.float32)
            x = torch.arange(0, heatmap_size[1], device=device, dtype=torch.float32)
            y_grid, x_grid = torch.meshgrid(y, x, indexing="ij")

            heatmaps[i, j] = torch.exp(
                -((x_grid - u_quant)**2 + (y_grid - v_quant)**2) / (2 * sigma**2)
            )
            heatmaps[i, j] /= (2 * np.pi * sigma**2)  # Normalize the Gaussian

    return heatmaps

In [None]:
# Define the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:1")
# Initialize the model and move it to the correct device
model = CascadedDViT(num_blocks=8, embed_dim=256, num_heads=8, num_landmarks=98).to(device)

# Initialize optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = TrainingLoss(num_blocks=8, w=1.2, beta=0.5)

# Original image size
img_size = (256, 256)  # Assuming your input images are cropped and resized to 256x256.

# Heatmap size
heatmap_size = (32, 32)  # Matches the predicted heatmap size.

num_epochs = 10
# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for images, coords_gt in train_loader:  # Only images and keypoints are provided
        # Ensure images and ground truth are moved to the correct device
        images = images.to(device)  # Shape: (batch_size, 3, img_H, img_W)
        coords_gt = coords_gt.to(device)  # Shape: (batch_size, num_landmarks, 2)
        
        # Generate ground truth heatmaps
        heatmaps_gt = generate_heatmaps(coords_gt, img_size, heatmap_size)  # (batch_size, num_landmarks, heatmap_H, heatmap_W)
        heatmaps_gt = heatmaps_gt.to(device)  # Move heatmaps to the correct device

        # Forward pass
        predictions = model(images)  # Output: [(coords_pred, heatmaps_pred), ...]
        
        # Compute loss
        loss = criterion(predictions, heatmaps_gt, coords_gt)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Print average loss for the epoch
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

# Save the trained model
torch.save(model.state_dict(), "dvit_keypoint_model.pth")
print("Training complete and model saved.")

Epoch 1/10, Loss: 407.9499
Epoch 2/10, Loss: 405.7901
Epoch 3/10, Loss: 405.7766
Epoch 4/10, Loss: 405.7796
Epoch 5/10, Loss: 405.7797
Epoch 6/10, Loss: 405.7864
Epoch 7/10, Loss: 405.7935
Epoch 8/10, Loss: 405.7826
Epoch 9/10, Loss: 405.7921
Epoch 10/10, Loss: 405.7826
Training complete and model saved.
