In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
import os
import shutil
import glob


In [4]:
# --- CONFIGURATION ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PHOTO_PATH = "../input/gan-getting-started/photo_jpg"
MODEL_PATH = "../input/experiment3/epoch_20.pth"

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, num_residual_blocks=9):
        super(GeneratorResNet, self).__init__()
        
        out_features = 64
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True)
        ]
        in_features = out_features

        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features

        # Residual Blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features

        # Output Layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, 7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# --- DATASET LOADER ---
class PhotoDataset(Dataset):
    def __init__(self, root):
        # Read all JPG files
        self.files = sorted(glob.glob(os.path.join(root, "*.jpg")))
        
        # Standard Resize & Normalize for Inference
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, index):
        img_path = self.files[index]
        img = Image.open(img_path).convert('RGB')
        return self.transform(img), os.path.basename(img_path)

    def __len__(self):
        return len(self.files)

# --- MAIN SUBMISSION FUNCTION ---
def make_submission():
    print(f"⏳ Setting up on {DEVICE}...")
    
    # 1. Initialize Generator
    G_AB = GeneratorResNet().to(DEVICE)
    
    if os.path.exists(MODEL_PATH):
        checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
        
        if 'G_AB' in checkpoint:
            G_AB.load_state_dict(checkpoint['G_AB'])
            print("✅ Loaded weights from checkpoint['G_AB']")
        else:
            G_AB.load_state_dict(checkpoint)
            print("✅ Loaded raw state dict")
    else:
        print(f"Error: Model file not found at {MODEL_PATH}")
        return

    G_AB.eval()

    dataset = PhotoDataset(PHOTO_PATH)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)
    os.makedirs("../images", exist_ok=True)

    print(f"Generating images for {len(dataset)} photos...")
    
    for i, (imgs, filenames) in enumerate(dataloader):
        imgs = imgs.to(DEVICE)
        
        with torch.no_grad():
            fake_monets = G_AB(imgs)
        
        fake_monets = fake_monets * 0.5 + 0.5
        
        for j in range(len(fake_monets)):
            save_path = os.path.join("../images", filenames[j])
            save_image(fake_monets[j], save_path)
            
        if i % 1000 == 0:
            print(f"Processed {i} images...")

    print("Zipping images...")
    shutil.make_archive("/kaggle/working/images", 'zip', "../images")
    shutil.rmtree("../images")
    print("Done! 'images.zip' is ready in Output section.")

if __name__ == "__main__":
    make_submission()

⏳ Setting up on cpu...
✅ Loaded weights from checkpoint['G_AB']
Generating images for 7038 photos...
Processed 0 images...


KeyboardInterrupt: 