In [None]:
# Install PyTorch (choose the appropriate command from https://pytorch.org/get-started/locally/)
!pip install torch torchvision

# Install other dependencies
!pip install numpy matplotlib pillow


In [12]:
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class CaricatureDataset(Dataset):
    def __init__(self, data_file, root_dir, transform=None, split='train', split_ratio=0.8):
        """
        Args:
            data_file (str): Path to the data.txt file.
            root_dir (str): Root directory containing the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            split (str): 'train' or 'val' to indicate dataset split.
            split_ratio (float): Ratio to split the dataset into training and validation.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.split = split

        print(f"Loading data from {data_file}...")
        # Read the data file
        with open(data_file, 'r') as f:
            lines = f.readlines()
        print(f"Total data entries found: {len(lines)}")

        # Shuffle the data
        random.shuffle(lines)
        print("Data shuffled.")

        # Split the data
        split_idx = int(len(lines) * split_ratio)
        if split == 'train':
            self.lines = lines[:split_idx]
            print(f"Selected {len(self.lines)} samples for training.")
        else:
            self.lines = lines[split_idx:]
            print(f"Selected {len(self.lines)} samples for validation.")

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        # Parse the line
        line = self.lines[idx].strip().split()
        # Assuming the format:
        # [Identity] [Caricature Image Path] [Real Image Path] [Label]
        # Adjust indices if necessary
        caricature_path = line[1].replace('\\', '/')
        real_path = line[2].replace('\\', '/')

        # Full paths
        caricature_full_path = os.path.join(self.root_dir, caricature_path)
        real_full_path = os.path.join(self.root_dir, real_path)

        # Debug: Print the paths being loaded
        if idx < 5:  # Print first 5 samples
            print(f"Loading sample {idx}:")
            print(f"  Caricature Image Path: {caricature_full_path}")
            print(f"  Real Image Path: {real_full_path}")

        # Open images
        try:
            real_image = Image.open(real_full_path).convert('RGB')
            caricature_image = Image.open(caricature_full_path).convert('RGB')
        except Exception as e:
            print(f"Error loading images at index {idx}: {e}")
            # You can choose to skip or handle the error as needed
            raise e

        if self.transform:
            real_image = self.transform(real_image)
            caricature_image = self.transform(caricature_image)

        return real_image, caricature_image


In [13]:
import torch.nn as nn

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super(UNetGenerator, self).__init__()
        self.down1 = self.contracting_block(in_channels, features, bn=False)      # 256 -> 128
        self.down2 = self.contracting_block(features, features*2)               # 128 -> 64
        self.down3 = self.contracting_block(features*2, features*4)             # 64 -> 32
        self.down4 = self.contracting_block(features*4, features*8)             # 32 -> 16
        self.down5 = self.contracting_block(features*8, features*8)             # 16 -> 8
        self.down6 = self.contracting_block(features*8, features*8)             # 8 -> 4
        self.down7 = self.contracting_block(features*8, features*8)             # 4 -> 2
        self.down8 = self.contracting_block(features*8, features*8, bn=False)   # 2 -> 1

        self.up1 = self.expansive_block(features*8, features*8, dropout=0.5)     # 1 -> 2
        self.up2 = self.expansive_block(features*16, features*8, dropout=0.5)    # 2 -> 4
        self.up3 = self.expansive_block(features*16, features*8, dropout=0.5)    # 4 -> 8
        self.up4 = self.expansive_block(features*16, features*8)                 # 8 -> 16
        self.up5 = self.expansive_block(features*16, features*4)                 # 16 -> 32
        self.up6 = self.expansive_block(features*8, features*2)                  # 32 -> 64
        self.up7 = self.expansive_block(features*4, features)                    # 64 -> 128
        self.up8 = nn.Sequential(
            nn.ConvTranspose2d(features*2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output values between -1 and 1
        )  # 128 -> 256

    def contracting_block(self, in_channels, out_channels, bn=True):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)
        ]
        if bn:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def expansive_block(self, in_channels, out_channels, dropout=0):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Encoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        # Decoder with skip connections
        u1 = self.up1(d8)
        u1 = torch.cat([u1, d7], 1)
        u2 = self.up2(u1)
        u2 = torch.cat([u2, d6], 1)
        u3 = self.up3(u2)
        u3 = torch.cat([u3, d5], 1)
        u4 = self.up4(u3)
        u4 = torch.cat([u4, d4], 1)
        u5 = self.up5(u4)
        u5 = torch.cat([u5, d3], 1)
        u6 = self.up6(u5)
        u6 = torch.cat([u6, d2], 1)
        u7 = self.up7(u6)
        u7 = torch.cat([u7, d1], 1)
        u8 = self.up8(u7)

        return u8


In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt

def main():
    # ============================
    # Hyperparameters and Paths
    # ============================
    NUM_EPOCHS = 200
    BATCH_SIZE = 16
    LEARNING_RATE = 2e-4
    L1_LAMBDA = 100
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {DEVICE}")

    # Paths (Update these paths as per your system)
    DATA_FILE = r'C:\Users\haric\Downloads\cari\dataset\data.txt'
    ROOT_DIR = r'C:\Users\haric\Downloads\cari\dataset'
    OUTPUT_DIR = 'outputs'
    CHECKPOINT_DIR = 'checkpoints'

    # ============================
    # Transforms
    # ============================
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])

    # ============================
    # Create Datasets and DataLoaders
    # ============================
    print("Initializing datasets...")
    train_dataset = CaricatureDataset(data_file=DATA_FILE, root_dir=ROOT_DIR, transform=transform, split='train', split_ratio=0.8)
    val_dataset = CaricatureDataset(data_file=DATA_FILE, root_dir=ROOT_DIR, transform=transform, split='val', split_ratio=0.8)
    print("Datasets initialized.")

    print("Creating DataLoaders...")
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    print("DataLoaders created.")

    # ============================
    # Initialize Models
    # ============================
    print("Initializing models...")
    generator = UNetGenerator().to(DEVICE)
    discriminator = PatchDiscriminator().to(DEVICE)
    print("Models initialized.")

    # ============================
    # Loss Functions and Optimizers
    # ============================
    criterion_GAN = nn.MSELoss()
    criterion_L1 = nn.L1Loss()

    optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    print("Loss functions and optimizers set.")

    # ============================
    # Create Output Directories
    # ============================
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    print(f"Output directory: {OUTPUT_DIR}")
    print(f"Checkpoint directory: {CHECKPOINT_DIR}")

    # ============================
    # Training Loop
    # ============================
    def denormalize(tensors):
        return (tensors * 0.5) + 0.5

    for epoch in range(NUM_EPOCHS):
        generator.train()
        discriminator.train()
        epoch_loss_G = 0
        epoch_loss_D = 0

        print(f"\nStarting Epoch {epoch+1}/{NUM_EPOCHS}")
        for i, (real_img, caricature_img) in enumerate(train_loader):
            real_img = real_img.to(DEVICE)
            caricature_img = caricature_img.to(DEVICE)

            # ============================
            # Train Discriminator
            # ============================
            optimizer_D.zero_grad()

            # Real pairs
            real_pair = torch.cat((real_img, caricature_img), 1)
            pred_real = discriminator(real_pair)
            real_loss = criterion_GAN(pred_real, torch.ones_like(pred_real).to(DEVICE))
            # Debug
            if i == 0:
                print(f"Batch {i+1}: Real loss: {real_loss.item():.4f}")

            # Fake pairs
            fake_img = generator(real_img)
            fake_pair = torch.cat((real_img, fake_img.detach()), 1)
            pred_fake = discriminator(fake_pair)
            fake_loss = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(DEVICE))
            # Debug
            if i == 0:
                print(f"Batch {i+1}: Fake loss: {fake_loss.item():.4f}")

            # Total Discriminator Loss
            loss_D = (real_loss + fake_loss) * 0.5
            loss_D.backward()
            optimizer_D.step()

            # ============================
            # Train Generator
            # ============================
            optimizer_G.zero_grad()

            # Adversarial loss
            pred_fake = discriminator(torch.cat((real_img, fake_img), 1))
            loss_G_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(DEVICE))

            # L1 loss
            loss_G_L1 = criterion_L1(fake_img, caricature_img) * L1_LAMBDA

            # Total Generator Loss
            loss_G = loss_G_GAN + loss_G_L1
            loss_G.backward()
            optimizer_G.step()

            epoch_loss_G += loss_G.item()
            epoch_loss_D += loss_D.item()

            if (i+1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Batch [{i+1}/{len(train_loader)}], "
                      f"Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

        avg_loss_G = epoch_loss_G / len(train_loader)
        avg_loss_D = epoch_loss_D / len(train_loader)
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] completed. Average Loss D: {avg_loss_D:.4f}, Average Loss G: {avg_loss_G:.4f}")

        # ============================
        # Validation and Checkpoints
        # ============================
        generator.eval()
        discriminator.eval()
        with torch.no_grad():
            for val_input, val_target in val_loader:
                val_input = val_input.to(DEVICE)
                fake_val = generator(val_input)
                break  # Take first batch

        # Denormalize and save images
        val_input_denorm = denormalize(val_input.cpu())
        fake_val_denorm = denormalize(fake_val.cpu())
        val_target_denorm = denormalize(val_target.cpu())

        grid_input = torchvision.utils.make_grid(val_input_denorm, nrow=4)
        grid_fake = torchvision.utils.make_grid(fake_val_denorm, nrow=4)
        grid_target = torchvision.utils.make_grid(val_target_denorm, nrow=4)

        torchvision.utils.save_image(grid_input, os.path.join(OUTPUT_DIR, f'epoch_{epoch+1}_input.png'))
        torchvision.utils.save_image(grid_fake, os.path.join(OUTPUT_DIR, f'epoch_{epoch+1}_fake.png'))
        torchvision.utils.save_image(grid_target, os.path.join(OUTPUT_DIR, f'epoch_{epoch+1}_target.png'))

        print(f"Saved generated images for Epoch {epoch+1}.")

        # Save model checkpoints
        torch.save(generator.state_dict(), os.path.join(CHECKPOINT_DIR, f'generator_epoch_{epoch+1}.pth'))
        torch.save(discriminator.state_dict(), os.path.join(CHECKPOINT_DIR, f'discriminator_epoch_{epoch+1}.pth'))

        print(f"Saved model checkpoints for Epoch {epoch+1}.")

    print("\nTraining completed!")

    # ============================
    # Inference Function
    # ============================
    def generate_caricature(input_image_path, generator, transform, device, output_path):
        """
        Transforms a real image into a caricature using the trained generator.

        Args:
            input_image_path (str): Path to the input real image.
            generator (nn.Module): Trained generator model.
            transform (callable): Transformations to apply to the input image.
            device (torch.device): Device to perform computations on.
            output_path (str): Path to save the generated caricature.
        """
        generator.eval()
        with torch.no_grad():
            try:
                # Load and preprocess the image
                print(f"Loading test image from {input_image_path}...")
                image = Image.open(input_image_path).convert('RGB')
                input_tensor = transform(image).unsqueeze(0).to(device)
                print("Image loaded and transformed.")

                # Generate caricature
                print("Generating caricature...")
                fake_tensor = generator(input_tensor)
                print("Caricature generated.")

                # Denormalize
                fake_tensor = denormalize(fake_tensor.cpu())

                # Convert to PIL Image
                fake_image = fake_tensor.squeeze(0).permute(1, 2, 0).numpy()
                fake_image = (fake_image * 255).astype('uint8')
                fake_image = Image.fromarray(fake_image)

                # Save the image
                fake_image.save(output_path)
                print(f"Caricature saved to {output_path}")
            except Exception as e:
                print(f"Error during inference: {e}")

    # ============================
    # Example Inference
    # ============================
    # Uncomment the lines below to perform inference after training


if __name__ == '__main__':
    main()


Using device: cpu
Initializing datasets...
Loading data from C:\Users\haric\Downloads\cari\dataset\data.txt...
Total data entries found: 21064
Data shuffled.
Selected 16851 samples for training.
Loading data from C:\Users\haric\Downloads\cari\dataset\data.txt...
Total data entries found: 21064
Data shuffled.
Selected 4213 samples for validation.
Datasets initialized.
Creating DataLoaders...
DataLoaders created.
Initializing models...
Models initialized.
Loss functions and optimizers set.
Output directory: outputs
Checkpoint directory: checkpoints

Starting Epoch 1/200
