In [1]:
! pip install ITTR_pytorch

Collecting ITTR_pytorch
  Downloading ITTR_pytorch-0.0.4-py3-none-any.whl.metadata (724 bytes)
Downloading ITTR_pytorch-0.0.4-py3-none-any.whl (4.2 kB)
Installing collected packages: ITTR_pytorch
Successfully installed ITTR_pytorch-0.0.4


CustomDataset.py


In [2]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, anime_dir, real_dir, transform=None):
        self.anime_images = os.listdir(anime_dir)
        self.real_images = os.listdir(real_dir)
        self.anime_dir = anime_dir
        self.real_dir = real_dir
        self.transform = transform

    def __len__(self):
        return min(len(self.anime_images), len(self.real_images))

    def __getitem__(self, idx):
        anime_image = Image.open(os.path.join(self.anime_dir, self.anime_images[idx])).convert("RGB")
        real_image = Image.open(os.path.join(self.real_dir, self.real_images[idx])).convert("RGB")

        if self.transform:
            anime_image = self.transform(anime_image)
            real_image = self.transform(real_image)

        return anime_image, real_image

# Usage example
# from torchvision import transforms

# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
# ])

# dataset = CustomDataset('/path/to/dataset/anime', '/path/to/dataset/real', transform)
# dataloader = DataLoader(dataset, batch_size=16, shuffle=True)


ITTR_model.py

In [3]:
# generator.py

import torch
import torch.nn as nn
from ITTR_pytorch import HPB, DPSA  # Assuming these are available from the library

class ITTRGenerator(nn.Module):
    def __init__(self, img_dim=512, num_blocks=9, heads=8, dim_head=32, top_k=16):
        super(ITTRGenerator, self).__init__()

        self.num_blocks = num_blocks
        # Add a convolutional layer to change input from 3 channels (RGB) to 512

        self.conv_layers = nn.Sequential(
            # 7x7 Convolution, Instance Norm, GELU
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),  # Output: (batch_size, 64, 128, 128)
            nn.InstanceNorm2d(64),
            nn.GELU(),

            # 3x3 Convolution, stride=2, Instance Norm, GELU
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=2, padding=1),  # Output: (batch_size, 256, 64, 64)
            nn.InstanceNorm2d(256),
            nn.GELU(),

            # 3x3 Convolution, stride=2, Instance Norm, GELU
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)  # Output: (batch_size, 512, 32, 32)
            )

        # Create the series of HPB blocks
        self.hpb_blocks = nn.ModuleList([
            HPB(
                dim=img_dim,              # input dimension of image embeddings
                dim_head=dim_head,        # dimension per attention head
                heads=heads,              # number of attention heads
                attn_height_top_k=top_k,  # number of top indices to select for attention along height
                attn_width_top_k=top_k,   # number of top indices to select for attention along width
                attn_dropout=0.,          # attention dropout
                ff_mult=4,                # feedforward expansion factor
                ff_dropout=0.             # feedforward dropout
            )
            for _ in range(num_blocks)
        ])

        # Final Dual Pruned Self-Attention (DPSA) layer
        self.dpsa = DPSA(
            dim=img_dim,          # input dimension of the image
            dim_head=dim_head,    # dimension per attention head
            heads=heads,          # number of attention heads
            height_top_k=top_k * 3,  # more top indices for final refinement
            width_top_k=top_k * 3,   # more top indices for final refinement
            dropout=0.            # dropout
        )

        # Final convolutional layer to return the image
        self.final_conv = nn.Conv2d(in_channels=img_dim, out_channels=3, kernel_size=3, padding=1)

        # DECODER
        self.deocder_layer1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),  # Output channels: 128
            nn.InstanceNorm2d(256),
            nn.GELU()
        )

        # Layer 2: Upsample
        self.deocder_upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 32 -> 64

        # Layer 3: 3x3 Conv -> IN -> GELU
        self.deocder_layer2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # Output channels: 128
            nn.InstanceNorm2d(256),
            nn.GELU()
        )

        # Layer 4: Upsample
        self.deocder_upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 64 -> 128

        # Layer 5: 3x3 Conv -> IN -> GELU
        self.deocder_layer3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),  # Output channels: 256
            nn.InstanceNorm2d(128),
            nn.GELU()
        )

        # Layer 6: Upsample
        self.deocder_upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 128 -> 256

        # Layer 7: 7x7 Conv
        self.deocder_layer4 = nn.Conv2d(128, 3, kernel_size=7, padding=3)  # Output channels: 3

        # Final activation layer: Tanh
        self.tanh = nn.Tanh()


    def forward(self, x):
        # Pass through the HPB blocks
        # print(f"shape of input at forward : {x.shape} -------------")
        x = self.conv_layers(x)
        # print(f"shape of output after conv2d layers : {x.shape} -------------")
        for block in self.hpb_blocks:
            x = block(x)

        # print(f"shape of output after hpb blocks : {x.shape} -----------")
        # Final pass through the DPSA block
        x = self.dpsa(x)

        # print(f"shape of output after dpsa blocks : {x.shape} -----------")
        # Final image generation
        x = self.deocder_layer1(x)                 # Shape: (B, 128, 32, 32)
        x = self.deocder_upsample1(x)              # Shape: (B, 128, 64, 64)
        x = self.deocder_layer2(x)                 # Shape: (B, 128, 64, 64)
        x = self.deocder_upsample2(x)              # Shape: (B, 128, 128, 128)
        x = self.deocder_layer3(x)                 # Shape: (B, 256, 128, 128)
        x = self.deocder_upsample3(x)              # Shape: (B, 256, 256, 256)
        x = self.deocder_layer4(x)                 # Shape: (B, 3, 256, 256)
        x = self.tanh(x)



        # x = self.final_conv(x)
        # print(f"shape of encoder output after final layer : {x.shape} -----------")

        return x



anime2irl_trainer.py

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg16
from torchvision import transforms
# from ITTR_model import ITTRGenerator  # Import the generator class
# from CustomDataset import CustomDataset  # Assuming this is the custom dataset you already have
from torch.utils.data import DataLoader
from PIL import Image
import torchvision.utils as vutils
from tqdm.auto import tqdm
# from modelCheckpointing import save_input_output_images, show_tensor_images

# Training parameters
lr = 0.0002
betas = (0.5, 0.999)
batch_size = 12
num_epochs = 500
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint_dir = '/content/drive/MyDrive/Datasets/anime_2_real/saved_models'  # Directory to save checkpoints

# Loss functions
l1_loss = nn.L1Loss()

# VGG Perceptual loss function
vgg = vgg16(pretrained=True).features[:16].eval().to(device)
for param in vgg.parameters():
    param.requires_grad = False

# Function for perceptual loss
def perceptual_loss(vgg, gen_image, real_image):
    gen_features = vgg(gen_image)
    real_features = vgg(real_image)
    return l1_loss(gen_features, real_features)

# Function for cycle consistency loss
def cycle_consistency_loss(generator, fake_photo, real_anime):
    reconstructed_anime = generator(fake_photo)
    return l1_loss(reconstructed_anime, real_anime)

# Function to save model checkpoints
def save_checkpoint(epoch, model, optimizer, loss):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # old_file_path = os.path.join(checkpoint_dir, f"ITTR_checkpoint_epoch_{epoch - 2}.pth")
    # if os.path.exists(old_file_path):
    #     os.remove(old_file_path)


    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    file_name = f"ITTR_dpsa__checkpoint_epoch_{epoch}_v3.pth"
    torch.save(checkpoint, os.path.join(checkpoint_dir, file_name))
    print(f"Checkpoint saved at epoch {epoch}")

# Function to load model checkpoints if available
def load_checkpoint(model, optimizer, current_epoch):
    file_name = f"ITTR_checkpoint_epoch_{current_epoch}.pth"
    checkpoint_path = os.path.join(checkpoint_dir, file_name)
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Checkpoint loaded: starting from epoch {start_epoch}, loss: {loss}")
        return start_epoch, loss
    else:
        print("No checkpoint found, starting training from scratch")
        return 0, float('inf')  # Start from scratch if no checkpoint exists

# Function to generate outputs for test images and save them
# def generate_and_save_test_outputs(generator, epoch,test_dir,checkpoint_output_dir):
#     if not os.path.exists(checkpoint_output_dir):
#         os.makedirs(checkpoint_output_dir)

#     transform = transforms.Compose([
#         transforms.Resize((256, 256)),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize for Tanh
#     ])

#     count = 0
#     for file in os.listdir(test_dir):
#         if count == 5:
#             break
#         count += 1
#         if file.endswith(('.png', '.jpg', '.jpeg')):
#             img_path = os.path.join(test_dir, file)
#             image = Image.open(img_path).convert('RGB')
#             input_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension

#             # Generate output
#             generator.eval()  # Set generator to evaluation mode
#             with torch.no_grad():
#                 generated_image = generator(input_tensor)

#             # Save output image
#             output_image_path = os.path.join(checkpoint_output_dir, f"{file.split('.')[0]}_epoch_{epoch}.png")
#             vutils.save_image(generated_image, output_image_path, normalize=True)
#             print(f"Generated output saved for {file} at epoch {epoch}")

def generate_and_save_test_outputs(generator, epoch, test_dir, checkpoint_output_dir):
    if not os.path.exists(checkpoint_output_dir):
        os.makedirs(checkpoint_output_dir)

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize for Tanh
    ])

    count = 0
    for file in os.listdir(test_dir):
        # if count == 5:
        #     break
        # count += 1
        if file.endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(test_dir, file)
            image = Image.open(img_path).convert('RGB')
            input_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension

            # Generate output
            generator.eval()  # Set generator to evaluation mode
            with torch.no_grad():
                generated_image = generator(input_tensor)

            # Denormalize and convert tensors to images for both input and generated images
            def denormalize(tensor):
                return tensor * 0.5 + 0.5  # Reverse the normalization (for Tanh normalization)

            # Process original test image
            test_image_np = denormalize(input_tensor.squeeze(0)).cpu()  # Remove batch dimension, move to CPU
            test_image_pil = transforms.ToPILImage()(test_image_np)  # Convert to PIL image

            # Process generated image
            generated_image_np = denormalize(generated_image.squeeze(0)).cpu()  # Remove batch dimension
            generated_image_pil = transforms.ToPILImage()(generated_image_np)  # Convert to PIL image

            # Stack the test and generated images horizontally
            combined_image = Image.new('RGB', (test_image_pil.width + generated_image_pil.width, test_image_pil.height))
            combined_image.paste(test_image_pil, (0, 0))
            combined_image.paste(generated_image_pil, (test_image_pil.width, 0))

            # Save the combined image
            output_image_path = os.path.join(checkpoint_output_dir, f"{file.split('.')[0]}_epoch_{epoch}.png")
            combined_image.save(output_image_path)

            print(f"Generated output saved for {file} at epoch {epoch}")

# Training loop
def train(generator, dataloader, test_dataloader, optimizer, num_epochs, current_epoch):
    generator.train()

    start_epoch, prev_loss = load_checkpoint(generator, optimizer, current_epoch)

    for epoch in range(start_epoch, num_epochs):
        total_train_loss = 0.0

        for batch in tqdm(dataloader):
            real_anime, real_photo = batch  # Assuming the Dataset returns a tuple of (anime_image, photo_image)

            real_anime, real_photo = real_anime.to(device), real_photo.to(device)

            # Ensure the input is in the expected format, i.e., 3 channels (RGB)
            assert real_anime.shape[1] == 3, "Input anime images must have 3 channels (RGB)"
            assert real_photo.shape[1] == 3, "Input photo images must have 3 channels (RGB)"


            # Generate fake photorealistic images from anime images
            fake_photo = generator(real_anime)

            # Calculate L1 pixel loss
            pixel_loss = l1_loss(fake_photo, real_photo)

            # Perceptual loss
            vgg_loss = perceptual_loss(vgg, fake_photo, real_photo)

            # Cycle consistency loss
            cycle_loss = cycle_consistency_loss(generator, fake_photo, real_anime)

            # Total loss
            total_train_loss = pixel_loss + 0.1 * vgg_loss + 10 * cycle_loss  # Adjust weights as necessary

            # Backpropagation
            optimizer.zero_grad()
            total_train_loss.backward()
            optimizer.step()
        # Train loss
        avg_train_loss = total_train_loss / len(dataloader)
        # Test loss :
        # Validation phase
        total_test_loss = 0.0
        generator.eval()  # Set generator to evaluation mode
        with torch.no_grad():
            for test_batch in tqdm(test_dataloader):
                test_anime, test_photo = test_batch
                test_anime, test_photo = test_anime.to(device), test_photo.to(device)

                # Generate fake photorealistic images from anime images
                test_fake_photo = generator(test_anime)

                # Calculate L1 pixel loss
                test_pixel_loss = l1_loss(test_fake_photo, test_photo)

                # Perceptual loss
                test_vgg_loss = perceptual_loss(vgg, test_fake_photo, test_photo)

                # Cycle consistency loss
                test_cycle_loss = cycle_consistency_loss(generator, test_fake_photo, test_anime)

                # Total validation loss
                test_loss = test_pixel_loss + 0.1 * test_vgg_loss + 10 * test_cycle_loss
                total_test_loss += test_loss.item()

        avg_test_loss = total_test_loss / len(test_dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}  Test Loss: {avg_test_loss:.4f}")

        # print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss.item()}")

        # Save checkpoint every 10 epochs
        test_dir = '/content/drive/MyDrive/Datasets/anime_2_real/TestA'
        save_img_dir = '/content/drive/MyDrive/Datasets/anime_2_real/generated_images'
        test_img_dir = '/content/drive/MyDrive/Datasets/anime_2_real/test_img'
        if (epoch + 1) % 10 == 0:
            save_checkpoint(epoch + 1, generator, optimizer, avg_train_loss)
            generate_and_save_test_outputs(generator, epoch + 1,test_img_dir,save_img_dir)

# Main function
if __name__ == "__main__":
    # Initialize the generator
    generator = ITTRGenerator().to(device)

    # Optimizer
    optimizer = optim.Adam(generator.parameters(), lr=lr, betas=betas)

    # Dataset and DataLoader
    custom_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize for Tanh
    ])
    dataset = CustomDataset('/content/drive/MyDrive/Datasets/anime_2_real/TrainA', '/content/drive/MyDrive/Datasets/anime_2_real/TrainB', custom_transforms)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # test dataloader
    test_dataset = CustomDataset('/content/drive/MyDrive/Datasets/anime_2_real/TestA', '/content/drive/MyDrive/Datasets/anime_2_real/TestB', custom_transforms)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)



    # Start training
    train(generator, dataloader,test_dataloader, optimizer, num_epochs, current_epoch=0)


No checkpoint found, starting training from scratch


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [1/500], Training Loss: 0.0561  Test Loss: 4.5960


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [2/500], Training Loss: 0.0560  Test Loss: 4.6939


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [3/500], Training Loss: 0.0483  Test Loss: 4.6511


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [4/500], Training Loss: 0.0432  Test Loss: 3.9046


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [5/500], Training Loss: 0.0328  Test Loss: 3.5955


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [6/500], Training Loss: 0.0320  Test Loss: 3.3383


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [7/500], Training Loss: 0.0350  Test Loss: 3.3645


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch [8/500], Training Loss: 0.0365  Test Loss: 2.9848


  0%|          | 0/78 [00:00<?, ?it/s]