In [None]:
!git clone https://github.com/harichselvamc/caritry.git

In [None]:
# Install PyTorch (choose the appropriate command from https://pytorch.org/get-started/locally/)
!pip install torch torchvision

# Install other dependencies
!pip install numpy matplotlib pillow


In [1]:
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt


In [2]:
# Detect the number of available GPUs
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs available: {num_gpus}")

# Set the primary device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Number of GPUs available: 2
Using device: cuda


In [3]:
!nvidia-smi


Sat Oct 19 17:07:22 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   51C    P8              9W /   70W |       3MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla T4                       Off |   00

In [5]:
class CaricatureDataset(Dataset):
    def __init__(self, real_dir, caricature_dir, transform=None, split='train', split_ratio=0.8):
        """
        Args:
            real_dir (str): Directory containing real images.
            caricature_dir (str): Directory containing caricature images.
            transform (callable, optional): Transformations to apply to images.
            split (str): 'train' or 'val' to indicate dataset split.
            split_ratio (float): Fraction of data to use for training.
        """
        self.real_dir = real_dir
        self.caricature_dir = caricature_dir
        self.transform = transform
        self.split = split

        # Define allowed image extensions
        self.allowed_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff'}

        self.image_pairs = []

        print(f"Scanning Real images in {self.real_dir}...")
        for identity in os.listdir(self.real_dir):
            real_identity_dir = os.path.join(self.real_dir, identity)
            caricature_identity_dir = os.path.join(self.caricature_dir, identity)
            
            if not os.path.isdir(real_identity_dir):
                continue  # Skip if not a directory
            
            if not os.path.isdir(caricature_identity_dir):
                print(f"Warning: Caricature directory for '{identity}' does not exist. Skipping...")
                continue  # Skip if corresponding caricature directory does not exist
            
            for img_name in os.listdir(real_identity_dir):
                real_img_path = os.path.join(real_identity_dir, img_name)
                
                # Check if the file has an allowed image extension
                _, ext = os.path.splitext(img_name)
                if ext.lower() not in self.allowed_extensions:
                    print(f"Skipping non-image file: {real_img_path}")
                    continue  # Skip non-image files
                
                # Replace '_r_' with '_c_' to find the corresponding caricature image
                caricature_img_name = img_name.replace('_r_', '_c_')
                caricature_img_path = os.path.join(caricature_identity_dir, caricature_img_name)
                
                # Check if the caricature image exists and has an allowed extension
                if os.path.isfile(caricature_img_path):
                    _, caricature_ext = os.path.splitext(caricature_img_name)
                    if caricature_ext.lower() in self.allowed_extensions:
                        self.image_pairs.append((real_img_path, caricature_img_path))
                    else:
                        print(f"Skipping non-image caricature file: {caricature_img_path}")
                else:
                    print(f"Warning: Pair not found for {real_img_path} and {caricature_img_path}. Skipping...")

        print(f"Total image pairs found: {len(self.image_pairs)}")

        # Shuffle and split the data
        random.shuffle(self.image_pairs)
        split_idx = int(len(self.image_pairs) * split_ratio)
        if self.split == 'train':
            self.image_pairs = self.image_pairs[:split_idx]
        else:
            self.image_pairs = self.image_pairs[split_idx:]
        
        print(f"Number of samples for {self.split}: {len(self.image_pairs)}")

        # Print first 5 samples for verification
        for i in range(min(5, len(self.image_pairs))):
            print(f"\nSample {i+1}:")
            print(f"  Real Image: {self.image_pairs[i][0]}")
            print(f"  Caricature Image: {self.image_pairs[i][1]}")

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        real_img_path, caricature_img_path = self.image_pairs[idx]
        
        # Load images
        try:
            real_image = Image.open(real_img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading real image {real_img_path}: {e}")
            # Return a dummy tensor or handle as per your requirement
            real_image = Image.new('RGB', (256, 256))
        
        try:
            caricature_image = Image.open(caricature_img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading caricature image {caricature_img_path}: {e}")
            # Return a dummy tensor or handle as per your requirement
            caricature_image = Image.new('RGB', (256, 256))
        
        # Apply transforms if any
        if self.transform:
            real_image = self.transform(real_image)
            caricature_image = self.transform(caricature_image)
        
        return real_image, caricature_image


In [6]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])


In [7]:
# Define directories
real_dir = r'/kaggle/working/caritry/dataset/Real'                 # Directory with Real images
caricature_dir = r'/kaggle/working/caritry/dataset/Caricature'     # Directory with Caricature images

# Initialize datasets
print("\nInitializing datasets...")
train_dataset = CaricatureDataset(
    real_dir=real_dir,
    caricature_dir=caricature_dir,
    transform=transform,
    split='train',
    split_ratio=0.8
)
val_dataset = CaricatureDataset(
    real_dir=real_dir,
    caricature_dir=caricature_dir,
    transform=transform,
    split='val',
    split_ratio=0.8
)
print("Datasets initialized.")

# Create DataLoaders
print("\nCreating DataLoaders...")
batch_size = 16
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)
print("DataLoaders created.")



Initializing datasets...
Scanning Real images in /kaggle/working/caritry/dataset/Real...
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Will_Smith/Thumbs.db
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Vladimir_Putin/Will_Smith
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Vladimir_Putin/Carlos_Tevez
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Vladimir_Putin/Bruce_Willis
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Vladimir_Putin/Lady_Gaga
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Vladimir_Putin/Lucille_Ball
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Vladimir_Putin/brazilian_ronaldo
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Vladimir_Putin/shaquille_O_neal
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Vladimir_Putin/Marilyn_Monroe
Skipping non-image file: /kaggle/working/caritry/dataset/Real/Vladimir_Putin/Justin_Bieber
Skipping n

In [8]:
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super(UNetGenerator, self).__init__()
        self.down1 = self.contracting_block(in_channels, features, bn=False)      # 256 -> 128
        self.down2 = self.contracting_block(features, features*2)               # 128 -> 64
        self.down3 = self.contracting_block(features*2, features*4)             # 64 -> 32
        self.down4 = self.contracting_block(features*4, features*8)             # 32 -> 16
        self.down5 = self.contracting_block(features*8, features*8)             # 16 -> 8
        self.down6 = self.contracting_block(features*8, features*8)             # 8 -> 4
        self.down7 = self.contracting_block(features*8, features*8)             # 4 -> 2
        self.down8 = self.contracting_block(features*8, features*8, bn=False)   # 2 -> 1

        self.up1 = self.expansive_block(features*8, features*8, dropout=0.5)     # 1 -> 2
        self.up2 = self.expansive_block(features*16, features*8, dropout=0.5)    # 2 -> 4
        self.up3 = self.expansive_block(features*16, features*8, dropout=0.5)    # 4 -> 8
        self.up4 = self.expansive_block(features*16, features*8)                 # 8 -> 16
        self.up5 = self.expansive_block(features*16, features*4)                 # 16 -> 32
        self.up6 = self.expansive_block(features*8, features*2)                  # 32 -> 64
        self.up7 = self.expansive_block(features*4, features)                    # 64 -> 128
        self.up8 = nn.Sequential(
            nn.ConvTranspose2d(features*2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output values between -1 and 1
        )  # 128 -> 256

    def contracting_block(self, in_channels, out_channels, bn=True):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
        ]
        if bn:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def expansive_block(self, in_channels, out_channels, dropout=0):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Encoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        # Decoder with skip connections
        u1 = self.up1(d8)
        u1 = torch.cat([u1, d7], 1)
        u2 = self.up2(u1)
        u2 = torch.cat([u2, d6], 1)
        u3 = self.up3(u2)
        u3 = torch.cat([u3, d5], 1)
        u4 = self.up4(u3)
        u4 = torch.cat([u4, d4], 1)
        u5 = self.up5(u4)
        u5 = torch.cat([u5, d3], 1)
        u6 = self.up6(u5)
        u6 = torch.cat([u6, d2], 1)
        u7 = self.up7(u6)
        u7 = torch.cat([u7, d1], 1)
        u8 = self.up8(u7)

        return u8


In [9]:
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6, features=64):
        super(PatchDiscriminator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.layer1 = self.discriminator_block(features, features*2)
        self.layer2 = self.discriminator_block(features*2, features*4)
        self.layer3 = self.discriminator_block(features*4, features*8, stride=1)
        self.final = nn.Conv2d(features*8, 1, kernel_size=4, stride=1, padding=1)

    def discriminator_block(self, in_channels, out_channels, stride=2):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        return block

    def forward(self, x):
        x = self.initial(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.final(x)
        return x


In [10]:
# Initialize models
print("\nInitializing models...")
generator = UNetGenerator().to(device)
discriminator = PatchDiscriminator().to(device)
print("Models initialized.")

# Check if multiple GPUs are available
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for training.")
    generator = nn.DataParallel(generator)
    discriminator = nn.DataParallel(discriminator)
else:
    print("Using a single GPU or CPU for training.")



Initializing models...
Models initialized.
Using 2 GPUs for training.


In [11]:
# Define loss functions
criterion_GAN = nn.MSELoss()
criterion_L1 = nn.L1Loss()

# Define optimizers
learning_rate = 2e-4
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
print("Loss functions and optimizers set.")


Loss functions and optimizers set.


In [12]:
# Create output directories if they don't exist
output_dir = r'/kaggle/working/caritry/outputs'           # Directory to save output images
checkpoint_dir = r'/kaggle/working/caritry/checkpoints'   # Directory to save model checkpoints
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"\nOutput directory: {output_dir}")
print(f"Checkpoint directory: {checkpoint_dir}")



Output directory: /kaggle/working/caritry/outputs
Checkpoint directory: /kaggle/working/caritry/checkpoints


In [13]:
def denormalize(tensors):
    return (tensors * 0.5) + 0.5


In [None]:
# Set training parameters
num_epochs = 200
l1_lambda = 100

for epoch in range(1, num_epochs + 1):
    generator.train()
    discriminator.train()
    epoch_loss_G = 0
    epoch_loss_D = 0

    print(f"\n=== Epoch {epoch}/{num_epochs} ===")
    for i, (real_imgs, caricature_imgs) in enumerate(train_loader, 1):
        real_imgs = real_imgs.to(device)
        caricature_imgs = caricature_imgs.to(device)

        # ============================
        # Train Discriminator
        # ============================
        optimizer_D.zero_grad()

        # Real pairs
        real_pair = torch.cat((real_imgs, caricature_imgs), 1)
        pred_real = discriminator(real_pair)
        target_real = torch.ones_like(pred_real).to(device)
        loss_real = criterion_GAN(pred_real, target_real)

        # Fake pairs
        fake_imgs = generator(real_imgs)
        fake_pair = torch.cat((real_imgs, fake_imgs.detach()), 1)
        pred_fake = discriminator(fake_pair)
        target_fake = torch.zeros_like(pred_fake).to(device)
        loss_fake = criterion_GAN(pred_fake, target_fake)

        # Total Discriminator Loss
        loss_D = (loss_real + loss_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()

        # ============================
        # Train Generator
        # ============================
        optimizer_G.zero_grad()

        # Adversarial loss
        pred_fake = discriminator(torch.cat((real_imgs, fake_imgs), 1))
        loss_G_GAN = criterion_GAN(pred_fake, target_real)

        # L1 loss
        loss_G_L1 = criterion_L1(fake_imgs, caricature_imgs) * l1_lambda

        # Total Generator Loss
        loss_G = loss_G_GAN + loss_G_L1
        loss_G.backward()
        optimizer_G.step()

        epoch_loss_G += loss_G.item()
        epoch_loss_D += loss_D.item()

        # Print losses every 100 batches
        if i % 100 == 0:
            print(f"  Batch {i}/{len(train_loader)} | Loss D: {loss_D.item():.4f} | Loss G: {loss_G.item():.4f}")

    avg_loss_G = epoch_loss_G / len(train_loader)
    avg_loss_D = epoch_loss_D / len(train_loader)
    print(f"Epoch {epoch} | Average Loss D: {avg_loss_D:.4f} | Average Loss G: {avg_loss_G:.4f}")

    # ============================
    # Validation and Checkpoints
    # ============================
    generator.eval()
    discriminator.eval()
    with torch.no_grad():
        for val_real, val_caricature in val_loader:
            val_real = val_real.to(device)
            fake_val = generator(val_real)
            break  # Only process the first batch for validation

    # Denormalize images
    val_real_denorm = denormalize(val_real.cpu())
    fake_val_denorm = denormalize(fake_val.cpu())
    val_caricature_denorm = denormalize(val_caricature.cpu())

    # Create image grids
    grid_real = torchvision.utils.make_grid(val_real_denorm, nrow=4, normalize=True)
    grid_fake = torchvision.utils.make_grid(fake_val_denorm, nrow=4, normalize=True)
    grid_caricature = torchvision.utils.make_grid(val_caricature_denorm, nrow=4, normalize=True)

    # Save images
    torchvision.utils.save_image(grid_real, os.path.join(output_dir, f'epoch_{epoch}_real.png'))
    torchvision.utils.save_image(grid_fake, os.path.join(output_dir, f'epoch_{epoch}_fake.png'))
    torchvision.utils.save_image(grid_caricature, os.path.join(output_dir, f'epoch_{epoch}_caricature.png'))

    print(f"  Saved images for Epoch {epoch}.")

    # Save model checkpoints
    torch.save(generator.state_dict(), os.path.join(checkpoint_dir, f'generator_epoch_{epoch}.pth'))
    torch.save(discriminator.state_dict(), os.path.join(checkpoint_dir, f'discriminator_epoch_{epoch}.pth'))

    print(f"  Saved model checkpoints for Epoch {epoch}.")



=== Epoch 1/200 ===


  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):


  Batch 100/191 | Loss D: 0.0558 | Loss G: 52.5385
Epoch 1 | Average Loss D: 0.1604 | Average Loss G: 54.9374
  Saved images for Epoch 1.
  Saved model checkpoints for Epoch 1.

=== Epoch 2/200 ===
  Batch 100/191 | Loss D: 0.0206 | Loss G: 55.6323
Epoch 2 | Average Loss D: 0.0233 | Average Loss G: 53.4370
  Saved images for Epoch 2.
  Saved model checkpoints for Epoch 2.

=== Epoch 3/200 ===
  Batch 100/191 | Loss D: 0.0259 | Loss G: 50.8431
Epoch 3 | Average Loss D: 0.0131 | Average Loss G: 52.9054
  Saved images for Epoch 3.
  Saved model checkpoints for Epoch 3.

=== Epoch 4/200 ===
  Batch 100/191 | Loss D: 3.3158 | Loss G: 57.9757
Epoch 4 | Average Loss D: 0.1654 | Average Loss G: 52.3659
  Saved images for Epoch 4.
  Saved model checkpoints for Epoch 4.

=== Epoch 5/200 ===
  Batch 100/191 | Loss D: 0.0160 | Loss G: 52.1642
Epoch 5 | Average Loss D: 0.0258 | Average Loss G: 52.2982
  Saved images for Epoch 5.
  Saved model checkpoints for Epoch 5.

=== Epoch 6/200 ===
  Batch 10

In [None]:
# Saving models remains the same
torch.save(generator.state_dict(), os.path.join(checkpoint_dir, f'generator_epoch_{epoch}.pth'))
torch.save(discriminator.state_dict(), os.path.join(checkpoint_dir, f'discriminator_epoch_{epoch}.pth'))

# Loading models without DataParallel
state_dict = torch.load(generator_checkpoint, map_location=device)
# Remove the 'module.' prefix
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] if k.startswith('module.') else k  # remove 'module.' prefix
    new_state_dict[name] = v
generator.load_state_dict(new_state_dict)


In [None]:
def perform_inference(test_image_path, output_path, generator_checkpoint, device='cuda'):
    """
    Generates a caricature for a given real image using the trained generator.

    Args:
        test_image_path (str): Path to the input real image.
        output_path (str): Path to save the generated caricature.
        generator_checkpoint (str): Path to the trained generator model checkpoint.
        device (str): Device to perform computations on ('cuda' or 'cpu').
    """
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])

    # Initialize generator
    generator = UNetGenerator().to(device)

    # Handle DataParallel checkpoints
    checkpoint = torch.load(generator_checkpoint, map_location=device)
    if 'module.' in list(checkpoint.keys())[0]:
        # If the model was wrapped with DataParallel, remove 'module.' prefix
        from collections import OrderedDict
        new_checkpoint = OrderedDict()
        for k, v in checkpoint.items():
            new_key = k.replace('module.', '')
            new_checkpoint[new_key] = v
        generator.load_state_dict(new_checkpoint)
        print(f"Loaded generator weights from {generator_checkpoint} (DataParallel format).")
    else:
        generator.load_state_dict(checkpoint)
        print(f"Loaded generator weights from {generator_checkpoint}.")

    # Set generator to evaluation mode
    generator.eval()

    # Generate caricature
    with torch.no_grad():
        try:
            # Load and preprocess the image
            print(f"\nLoading test image from {test_image_path}...")
            image = Image.open(test_image_path).convert('RGB')
            input_tensor = transform(image).unsqueeze(0).to(device)
            print("Image loaded and transformed.")

            # Generate fake caricature
            fake_tensor = generator(input_tensor)
            print("Generated caricature.")

            # Denormalize
            fake_tensor = (fake_tensor * 0.5) + 0.5

            # Convert to PIL Image
            fake_image = fake_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
            fake_image = (fake_image * 255).astype('uint8')
            fake_image = Image.fromarray(fake_image)

            # Save the image
            fake_image.save(output_path)
            print(f"Caricature saved to {output_path}")
        except Exception as e:
            print(f"Error during inference: {e}")


In [None]:
# Define paths
test_image_path = r'/kaggle/working/caritry/dataset/Real/Aamir_Khan/Aamir_Khan_r_0.jpg'      # Path to a real image
output_image_path = r'/kaggle/working/caritry/outputs/test_output_caricature.jpg'            # Path to save the caricature
generator_checkpoint = os.path.join(checkpoint_dir, f'generator_epoch_{num_epochs}.pth')  # Path to the trained generator

# Perform inference
perform_inference(
    test_image_path=test_image_path,
    output_path=output_image_path,
    generator_checkpoint=generator_checkpoint,
    device=device
)
