In [1]:
import shutil, os, sys

# Only remove & clone if needed
if os.path.exists('./pix2pixHD'):
    shutil.rmtree('./pix2pixHD')
!git clone https://github.com/NVIDIA/pix2pixHD.git

sys.path.append('./pix2pixHD')
from models.networks import define_G, define_D

print("Repo cloned and imports successful!")

Cloning into 'pix2pixHD'...
remote: Enumerating objects: 343, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 343 (delta 0), reused 0 (delta 0), pack-reused 340 (from 1)[K
Receiving objects: 100% (343/343), 55.68 MiB | 48.90 MiB/s, done.
Resolving deltas: 100% (156/156), done.
Repo cloned and imports successful!


In [2]:
# Pix2PixHD + EMA + FP16 (amp/GradScaler)
import os, sys, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

!pip install --quiet piq
!pip install --quiet pytorch-fid

import csv
import os
import torch
import piq   # for PSNR/SSIM
from pytorch_fid import fid_score

def calculate_metrics(G, val_loader, device, epoch):
    G.eval()
    mse_loss = torch.nn.MSELoss()
    psnr_list, ssim_list, mse_list = [], [], []
    pred_dir = '/kaggle/working/fid_pred'
    gt_dir = '/kaggle/working/fid_gt'
    os.makedirs(pred_dir, exist_ok=True)
    os.makedirs(gt_dir, exist_ok=True)

    with torch.no_grad():
        for i, (cat_input, target_img, _) in enumerate(val_loader):
            cat_input = cat_input.to(device)
            target_img = target_img.to(device)
            fake_img = G(cat_input)
            fake_img = fake_img.clamp(0., 1.)
            target_img = target_img.clamp(0., 1.)
            # Save FID images (first 10 only, for speed)
            for j in range(min(fake_img.size(0), 3)):
                pred = fake_img[j].clamp(0,1).cpu().numpy().transpose(1,2,0) * 255
                pred = pred.astype('uint8')
                gt = target_img[j].clamp(0,1).cpu().numpy().transpose(1,2,0) * 255
                gt = gt.astype('uint8')
                from PIL import Image
                Image.fromarray(pred).save(os.path.join(pred_dir, f"{i}_{j}.png"))
                Image.fromarray(gt).save(os.path.join(gt_dir, f"{i}_{j}.png"))
            # PSNR, SSIM for first batch only
            # psnr_val = piq.psnr(fake_img, target_img).item()
            psnr_val = piq.psnr(fake_img, target_img, data_range=1.).item()
            # ssim_val = piq.ssim(fake_img, target_img).item()
            ssim_val = piq.ssim(fake_img, target_img, data_range=1.).item()
            psnr_list.append(psnr_val)
            ssim_list.append(ssim_val)
            mse = mse_loss(fake_img, target_img).item()
            mse_list.append(mse)
            break  # Only a single (or few) batches for speed—modify if needed.

    # FID score between pred/gt directories
    fid = fid_score.calculate_fid_given_paths([gt_dir, pred_dir],
                                              batch_size=32,
                                              device=device,
                                              dims=2048)
    psnr_avg = sum(psnr_list)/len(psnr_list)
    ssim_avg = sum(ssim_list)/len(ssim_list)
    mse_avg = sum(mse_list)/len(mse_list)
    return psnr_avg, ssim_avg, fid, mse_avg

def save_metrics_csv(epoch, g_loss, d_loss, psnr, ssim, fid, mse, step, file_path="/kaggle/working/metrics.csv"):
    file_exists = os.path.isfile(file_path)
    with open(file_path, "a", newline='') as csvfile:
        fieldnames = ["epoch", "step", "g_loss", "d_loss", "psnr", "ssim", "fid", "mse"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        if not file_exists:
            writer.writeheader()
        writer.writerow({
            "epoch": epoch,
            "step": step,
            "g_loss": g_loss,
            "d_loss": d_loss,
            "psnr": psnr,
            "ssim": ssim,
            "fid": fid,
            "mse": mse
        })


import torch
import os

def save_checkpoint(
    epoch, 
    G, D, ema_g, 
    optimizer_G, optimizer_D, 
    scaler_g, scaler_d, 
    checkpoint_dir='/kaggle/working/checkpoints', 
    extra_dict={}
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint = {
        'epoch': epoch + 1,  # next epoch to continue from
        # Main weights
        'generator_state_dict': G.state_dict(),
        'discriminator_state_dict': D.state_dict(),
        # Exponential moving average weights
        'ema_generator_state_dict': ema_g.ema_model.state_dict(),
        # Optimizers
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        # For AMP/mixed precision
        'scaler_g_state_dict': scaler_g.state_dict(),
        'scaler_d_state_dict': scaler_d.state_dict(),
    }
    # checkpoint.update(extra_dict)
    torch.save(checkpoint, os.path.join(checkpoint_dir, f'pix2pixhd_checkpoint_epoch_{epoch+1}.pth'))

In [3]:
# =================== Data =====================
class LightingDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.image_pairs = self.list_image_pairs()
    def list_image_pairs(self):
        image_pairs = []
        scene_folders = [ f for f in os.listdir(self.data_path) if os.path.isdir(os.path.join(self.data_path, f))][:100] 
        for scene_folder in scene_folders:
            probe_path = os.path.join(self.data_path, scene_folder)
            if not os.path.isdir(probe_path): continue
            source_img_path = os.path.join(probe_path, "dir_0_mip2.jpg")
            if not os.path.exists(source_img_path): continue
            for lighting_idx in range(1, 25):
                target_img_path = os.path.join(probe_path, f"dir_{lighting_idx}_mip2.jpg")
                if os.path.exists(target_img_path):
                    image_pairs.append((source_img_path, target_img_path, lighting_idx))
        return image_pairs
        
    def read_image(self, filepath):
        try:
            img = Image.open(filepath).convert("RGB")
            if self.transform: img = self.transform(img)
            return img
        except Exception as e:
            print(f"Failed to open {filepath}: {e}")
            raise
        
    def one_hot(self, idx, n=25):
        onehot = torch.zeros(n, 1, 1)
        onehot[idx] = 1
        onehot = onehot.expand(n, 256, 256)
        return onehot
    def __len__(self): return len(self.image_pairs)
    def __getitem__(self, idx):
        input_img_path, target_img_path, lighting_idx = self.image_pairs[idx]
        input_img = self.read_image(input_img_path)
        target_img = self.read_image(target_img_path)
        lighting_vec = self.one_hot(lighting_idx)
        cat_input = torch.cat([input_img, lighting_vec], dim=0)
        return cat_input, target_img, lighting_idx

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])

data_path = '/kaggle/input/multi-illumination-jpg/'
dataset = LightingDataset(data_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4, drop_last=True)

    
batch = next(iter(dataloader))
print("Batch shapes:", [x.shape if hasattr(x, 'shape') else type(x) for x in batch])

print("Total pairs found:", len(dataset))
if len(dataset) > 0:
    sample = dataset[0]
    print(type(sample), [t.shape if hasattr(t, 'shape') else None for t in sample])

# =================== Pix2PixHD code =====================
from models.networks import define_G, define_D

input_nc = 28
output_nc = 3







device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G = define_G(input_nc, output_nc, 64, 'global', 4, 9, 1, 3, 'instance', [])
print("G defined")
D = define_D(input_nc + output_nc, 64, 3, 'instance', False, 2, False, [])
print("D defined")


print("Both initializations complete")

G = nn.DataParallel(G).to(device)
D = nn.DataParallel(D).to(device)

criterionGAN = nn.MSELoss()
criterionL1 = nn.L1Loss()
lr, beta1 = 0.0002, 0.5
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))


print("EMA starts")
# =================== EMA ======================
import copy
class EMA:
    def __init__(self, model, decay=0.999):
        self.model = model
        self.ema_model = copy.deepcopy(model)
        self.decay = decay
        for p in self.ema_model.parameters():
            p.requires_grad = False
        self.ema_model.eval()
    def update(self):
        with torch.no_grad():
            for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()):
                ema_param.data.mul_(self.decay).add_((1. - self.decay) * param.data)
    def to(self, device):
        self.ema_model.to(device)
        return self
    def module(self):
        return self.ema_model

ema_g = EMA(G, decay=0.999).to(device)

print("EMA complete")
# =================== FP16/AMP =====================
scaler_g = torch.amp.GradScaler('cuda')
scaler_d = torch.amp.GradScaler('cuda')

# =================== Training Loop =================

print("Scalers defined")

def denorm(x):
    arr = x.detach().cpu().numpy()
    if arr.dtype != np.float32 and arr.dtype != np.float64:
        arr = arr.astype(np.float32)                # Force to float32
    arr = np.clip(arr, 0, 1)
    arr = (arr * 255).astype(np.uint8)              # Now uint8 for imshow
    if arr.shape[0] == 1:                           # handle grayscale
        arr = arr[0]
    elif arr.shape[0] == 3:                         # CHW -> HWC
        arr = np.transpose(arr, (1,2,0))
    return arr


epochs, lambda_L1 = 30, 100


Batch shapes: [torch.Size([8, 28, 256, 256]), torch.Size([8, 3, 256, 256]), torch.Size([8])]
Total pairs found: 2400
<class 'tuple'> [torch.Size([28, 256, 256]), torch.Size([3, 256, 256]), None]
GlobalGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(28, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (11): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False,

In [4]:
checkpoint = torch.load('/kaggle/input/pix2pix-hd-lighting/pytorch/epoch30/6/pix2pixhd_checkpoint_epoch_60.pth', map_location=device)
start_epoch = checkpoint['epoch']
G.load_state_dict(checkpoint['generator_state_dict'])
D.load_state_dict(checkpoint['discriminator_state_dict'])
ema_g.ema_model.load_state_dict(checkpoint['ema_generator_state_dict'])

optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])

scaler_g.load_state_dict(checkpoint['scaler_g_state_dict'])
scaler_d.load_state_dict(checkpoint['scaler_d_state_dict'])

  checkpoint = torch.load('/kaggle/input/pix2pix-hd-lighting/pytorch/epoch30/6/pix2pixhd_checkpoint_epoch_60.pth', map_location=device)


In [5]:
print("Epoch starts")

for epoch in range(epochs):
    G.train(); D.train()
    for i, (cat_input, target_img, _) in enumerate(dataloader):
        cat_input = cat_input.to(device)
        target_img = target_img.to(device)
        #========== Train G (autocast/FP16) ===========
        optimizer_G.zero_grad()
        with torch.amp.autocast('cuda'):
            fake_out = G(cat_input)
            fake_pair = torch.cat([cat_input, fake_out], 1)
            pred_fake = D(fake_pair)
            loss_G_GAN = 0
            for pred in pred_fake:
                if isinstance(pred, list):
                    pred = pred[-1]
                loss_G_GAN += criterionGAN(pred, torch.ones_like(pred))
            loss_G_GAN /= len(pred_fake)
            loss_G_L1 = criterionL1(fake_out, target_img) * lambda_L1
            loss_G = loss_G_GAN + loss_G_L1
        scaler_g.scale(loss_G).backward()
        scaler_g.step(optimizer_G)
        scaler_g.update()

        ema_g.update()
        #========== Train D (autocast/FP16) ===========
        optimizer_D.zero_grad()
        with torch.amp.autocast('cuda'):
            real_pair = torch.cat([cat_input, target_img], 1)
            pred_real = D(real_pair)
            
            loss_D_real = 0
            for pr in pred_real:
                if isinstance(pr, list):
                    pr = pr[-1]
                loss_D_real += criterionGAN(pr, torch.ones_like(pr))
            loss_D_real /= len(pred_real)
            pred_fake_detach = D(fake_pair.detach())
            
            loss_D_fake = 0
            for pf in pred_fake_detach:
                if isinstance(pf, list):
                    pf = pf[-1]
                loss_D_fake += criterionGAN(pf, torch.ones_like(pf))
            loss_D_fake /= len(pred_fake_detach)
            loss_D = (loss_D_real + loss_D_fake) * 0.5
        scaler_d.scale(loss_D).backward()
        scaler_d.step(optimizer_D)
        scaler_d.update()
        if i % 100 == 0:
            print(f"Epoch [{start_epoch+epoch+1}/{start_epoch+epochs}] Batch [{i}] | LossG: {loss_G.item():.4f} | LossD: {loss_D.item():.4f}")
            # Assume dataloader is your train loader; for FID, use a separate validation or subset if possible
            psnr, ssim, fid, mse = calculate_metrics(ema_g.ema_model, dataloader, device, start_epoch+epoch+1)
            save_metrics_csv(start_epoch+epoch+1, loss_G.item(), loss_D.item(), psnr, ssim, fid, mse, i % 100)

    #========= Visualization (use EMA generator) =========
    import os

    G.eval(); ema_g.ema_model.eval()
    os.makedirs('/kaggle/working/visualizations', exist_ok=True)
    with torch.no_grad(), torch.amp.autocast('cuda'):
        val_input, val_target, _ = next(iter(dataloader))
        val_input = val_input.to(device)
        val_target = val_target.to(device)
        fake_img = ema_g.ema_model(val_input)
    
        # Take first sample in batch
        inp = denorm(val_input[0, :3])
        tgt = denorm(val_target[0])
        pred = denorm(fake_img[0])
    
        import matplotlib.pyplot as plt
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs[0].imshow(denorm(val_input[0, :3]))
        axs[0].set_title("Input Image")
        axs[1].imshow(denorm(val_target[0]))
        axs[1].set_title("Target Lighting GT")
        axs[2].imshow(denorm(fake_img[0]))
        axs[2].set_title("Predicted Lighting (EMA G)")


        for a in axs: a.axis('off')
        output_path = f"/kaggle/working/visualizations/epoch_{start_epoch+epoch+1:02d}_viz.png"
        plt.savefig(output_path, bbox_inches='tight')
        plt.close(fig)
    
        # Optionally, also save the images individually if preferred:
        from PIL import Image
        Image.fromarray((inp * 255).astype(np.uint8)).save(f"/kaggle/working/visualizations/epoch_{start_epoch+epoch+1:02d}_input.png")
        Image.fromarray((tgt * 255).astype(np.uint8)).save(f"/kaggle/working/visualizations/epoch_{start_epoch+epoch+1:02d}_gt.png")
        Image.fromarray((pred * 255).astype(np.uint8)).save(f"/kaggle/working/visualizations/epoch_{start_epoch+epoch+1:02d}_pred.png")

# save_checkpoint(start_epoch+epochs, G, D, ema_g, optimizer_G, optimizer_D, scaler_g, scaler_d)

Epoch starts
Epoch [61/90] Batch [0] | LossG: 10.5024 | LossD: 0.0006


Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 127MB/s] 




100%|██████████| 1/1 [00:00<00:00,  4.63it/s]




100%|██████████| 1/1 [00:00<00:00,  8.11it/s]


Epoch [61/90] Batch [100] | LossG: 10.7174 | LossD: 0.0018


100%|██████████| 1/1 [00:00<00:00,  9.06it/s]




100%|██████████| 1/1 [00:00<00:00,  8.87it/s]


Epoch [61/90] Batch [200] | LossG: 11.6048 | LossD: 0.0005


100%|██████████| 1/1 [00:00<00:00,  8.88it/s]




100%|██████████| 1/1 [00:00<00:00,  8.67it/s]


Epoch [62/90] Batch [0] | LossG: 8.5878 | LossD: 0.0008


100%|██████████| 1/1 [00:00<00:00,  9.04it/s]




100%|██████████| 1/1 [00:00<00:00,  8.99it/s]


Epoch [62/90] Batch [100] | LossG: 9.7755 | LossD: 0.0023


100%|██████████| 1/1 [00:00<00:00,  9.07it/s]




100%|██████████| 1/1 [00:00<00:00,  8.98it/s]


Epoch [62/90] Batch [200] | LossG: 8.1203 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  9.17it/s]




100%|██████████| 1/1 [00:00<00:00,  8.78it/s]


Epoch [63/90] Batch [0] | LossG: 10.7820 | LossD: 0.0008


100%|██████████| 1/1 [00:00<00:00,  8.93it/s]




100%|██████████| 1/1 [00:00<00:00,  8.87it/s]


Epoch [63/90] Batch [100] | LossG: 10.6848 | LossD: 0.0012


100%|██████████| 1/1 [00:00<00:00,  8.99it/s]




100%|██████████| 1/1 [00:00<00:00,  8.99it/s]


Epoch [63/90] Batch [200] | LossG: 10.6207 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  9.03it/s]




100%|██████████| 1/1 [00:00<00:00,  8.84it/s]


Epoch [64/90] Batch [0] | LossG: 14.7170 | LossD: 0.0032


100%|██████████| 1/1 [00:00<00:00,  7.53it/s]




100%|██████████| 1/1 [00:00<00:00,  7.52it/s]


Epoch [64/90] Batch [100] | LossG: 11.4855 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  9.03it/s]




100%|██████████| 1/1 [00:00<00:00,  8.93it/s]


Epoch [64/90] Batch [200] | LossG: 7.8446 | LossD: 0.0010


100%|██████████| 1/1 [00:00<00:00,  9.05it/s]




100%|██████████| 1/1 [00:00<00:00,  8.95it/s]


Epoch [65/90] Batch [0] | LossG: 12.6195 | LossD: 0.0019


100%|██████████| 1/1 [00:00<00:00,  8.65it/s]




100%|██████████| 1/1 [00:00<00:00,  9.01it/s]


Epoch [65/90] Batch [100] | LossG: 8.3376 | LossD: 0.0015


100%|██████████| 1/1 [00:00<00:00,  8.90it/s]




100%|██████████| 1/1 [00:00<00:00,  8.61it/s]


Epoch [65/90] Batch [200] | LossG: 15.3534 | LossD: 0.0015


100%|██████████| 1/1 [00:00<00:00,  9.01it/s]




100%|██████████| 1/1 [00:00<00:00,  8.69it/s]


Epoch [66/90] Batch [0] | LossG: 12.3447 | LossD: 0.0008


100%|██████████| 1/1 [00:00<00:00,  8.86it/s]




100%|██████████| 1/1 [00:00<00:00,  8.76it/s]


Epoch [66/90] Batch [100] | LossG: 10.9028 | LossD: 0.0020


100%|██████████| 1/1 [00:00<00:00,  8.99it/s]




100%|██████████| 1/1 [00:00<00:00,  8.84it/s]


Epoch [66/90] Batch [200] | LossG: 8.2832 | LossD: 0.0015


100%|██████████| 1/1 [00:00<00:00,  8.97it/s]




100%|██████████| 1/1 [00:00<00:00,  8.80it/s]


Epoch [67/90] Batch [0] | LossG: 7.0323 | LossD: 0.0024


100%|██████████| 1/1 [00:00<00:00,  8.86it/s]




100%|██████████| 1/1 [00:00<00:00,  8.68it/s]


Epoch [67/90] Batch [100] | LossG: 14.3666 | LossD: 0.0011


100%|██████████| 1/1 [00:00<00:00,  8.82it/s]




100%|██████████| 1/1 [00:00<00:00,  8.70it/s]


Epoch [67/90] Batch [200] | LossG: 14.4106 | LossD: 0.0006


100%|██████████| 1/1 [00:00<00:00,  8.96it/s]




100%|██████████| 1/1 [00:00<00:00,  8.87it/s]


Epoch [68/90] Batch [0] | LossG: 13.2963 | LossD: 0.0042


100%|██████████| 1/1 [00:00<00:00,  9.05it/s]




100%|██████████| 1/1 [00:00<00:00,  8.74it/s]


Epoch [68/90] Batch [100] | LossG: 10.2468 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  8.94it/s]




100%|██████████| 1/1 [00:00<00:00,  8.77it/s]


Epoch [68/90] Batch [200] | LossG: 11.7740 | LossD: 0.0004


100%|██████████| 1/1 [00:00<00:00,  8.94it/s]




100%|██████████| 1/1 [00:00<00:00,  8.64it/s]


Epoch [69/90] Batch [0] | LossG: 12.8868 | LossD: 0.0017


100%|██████████| 1/1 [00:00<00:00,  8.85it/s]




100%|██████████| 1/1 [00:00<00:00,  8.75it/s]


Epoch [69/90] Batch [100] | LossG: 9.3391 | LossD: 0.0002


100%|██████████| 1/1 [00:00<00:00,  8.88it/s]




100%|██████████| 1/1 [00:00<00:00,  8.70it/s]


Epoch [69/90] Batch [200] | LossG: 13.9560 | LossD: 0.0001


100%|██████████| 1/1 [00:00<00:00,  8.94it/s]




100%|██████████| 1/1 [00:00<00:00,  8.69it/s]


Epoch [70/90] Batch [0] | LossG: 8.9516 | LossD: 0.0004


100%|██████████| 1/1 [00:00<00:00,  8.79it/s]




100%|██████████| 1/1 [00:00<00:00,  8.65it/s]


Epoch [70/90] Batch [100] | LossG: 9.7954 | LossD: 0.0013


100%|██████████| 1/1 [00:00<00:00,  8.98it/s]




100%|██████████| 1/1 [00:00<00:00,  8.93it/s]


Epoch [70/90] Batch [200] | LossG: 13.0166 | LossD: 0.0037


100%|██████████| 1/1 [00:00<00:00,  8.81it/s]




100%|██████████| 1/1 [00:00<00:00,  8.56it/s]


Epoch [71/90] Batch [0] | LossG: 10.8316 | LossD: 0.0005


100%|██████████| 1/1 [00:00<00:00,  8.76it/s]




100%|██████████| 1/1 [00:00<00:00,  8.55it/s]


Epoch [71/90] Batch [100] | LossG: 13.1223 | LossD: 0.0025


100%|██████████| 1/1 [00:00<00:00,  8.89it/s]




100%|██████████| 1/1 [00:00<00:00,  8.77it/s]


Epoch [71/90] Batch [200] | LossG: 10.1105 | LossD: 0.0001


100%|██████████| 1/1 [00:00<00:00,  8.90it/s]




100%|██████████| 1/1 [00:00<00:00,  8.63it/s]


Epoch [72/90] Batch [0] | LossG: 11.3466 | LossD: 0.0000


100%|██████████| 1/1 [00:00<00:00,  8.81it/s]




100%|██████████| 1/1 [00:00<00:00,  8.79it/s]


Epoch [72/90] Batch [100] | LossG: 11.1357 | LossD: 0.0001


100%|██████████| 1/1 [00:00<00:00,  8.63it/s]




100%|██████████| 1/1 [00:00<00:00,  8.36it/s]


Epoch [72/90] Batch [200] | LossG: 13.0202 | LossD: 0.0001


100%|██████████| 1/1 [00:00<00:00,  8.88it/s]




100%|██████████| 1/1 [00:00<00:00,  8.68it/s]


Epoch [73/90] Batch [0] | LossG: 5.9009 | LossD: 0.0008


100%|██████████| 1/1 [00:00<00:00,  8.88it/s]




100%|██████████| 1/1 [00:00<00:00,  8.65it/s]


Epoch [73/90] Batch [100] | LossG: 13.2584 | LossD: 0.0003


100%|██████████| 1/1 [00:00<00:00,  9.06it/s]




100%|██████████| 1/1 [00:00<00:00,  8.81it/s]


Epoch [73/90] Batch [200] | LossG: 7.3984 | LossD: 0.0005


100%|██████████| 1/1 [00:00<00:00,  8.86it/s]




100%|██████████| 1/1 [00:00<00:00,  7.66it/s]


Epoch [74/90] Batch [0] | LossG: 8.4028 | LossD: 0.0001


100%|██████████| 1/1 [00:00<00:00,  8.79it/s]




100%|██████████| 1/1 [00:00<00:00,  8.47it/s]


Epoch [74/90] Batch [100] | LossG: 9.2887 | LossD: 0.0004


100%|██████████| 1/1 [00:00<00:00,  8.87it/s]




100%|██████████| 1/1 [00:00<00:00,  8.82it/s]


Epoch [74/90] Batch [200] | LossG: 12.7967 | LossD: 0.0010


100%|██████████| 1/1 [00:00<00:00,  8.80it/s]




100%|██████████| 1/1 [00:00<00:00,  8.87it/s]


Epoch [75/90] Batch [0] | LossG: 10.5360 | LossD: 0.0006


100%|██████████| 1/1 [00:00<00:00,  8.64it/s]




100%|██████████| 1/1 [00:00<00:00,  8.53it/s]


Epoch [75/90] Batch [100] | LossG: 8.8216 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  8.62it/s]




100%|██████████| 1/1 [00:00<00:00,  8.70it/s]


Epoch [75/90] Batch [200] | LossG: 14.8369 | LossD: 0.0005


100%|██████████| 1/1 [00:00<00:00,  8.81it/s]




100%|██████████| 1/1 [00:00<00:00,  8.72it/s]


Epoch [76/90] Batch [0] | LossG: 12.0696 | LossD: 0.0005


100%|██████████| 1/1 [00:00<00:00,  8.84it/s]




100%|██████████| 1/1 [00:00<00:00,  8.77it/s]


Epoch [76/90] Batch [100] | LossG: 7.7638 | LossD: 0.0017


100%|██████████| 1/1 [00:00<00:00,  8.86it/s]




100%|██████████| 1/1 [00:00<00:00,  8.78it/s]


Epoch [76/90] Batch [200] | LossG: 10.9216 | LossD: 0.0003


100%|██████████| 1/1 [00:00<00:00,  8.76it/s]




100%|██████████| 1/1 [00:00<00:00,  8.84it/s]


Epoch [77/90] Batch [0] | LossG: 13.4808 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  8.82it/s]




100%|██████████| 1/1 [00:00<00:00,  8.66it/s]


Epoch [77/90] Batch [100] | LossG: 9.9744 | LossD: 0.0014


100%|██████████| 1/1 [00:00<00:00,  8.90it/s]




100%|██████████| 1/1 [00:00<00:00,  8.48it/s]


Epoch [77/90] Batch [200] | LossG: 10.3152 | LossD: 0.0003


100%|██████████| 1/1 [00:00<00:00,  8.97it/s]




100%|██████████| 1/1 [00:00<00:00,  8.74it/s]


Epoch [78/90] Batch [0] | LossG: 8.9631 | LossD: 0.0001


100%|██████████| 1/1 [00:00<00:00,  8.85it/s]




100%|██████████| 1/1 [00:00<00:00,  8.78it/s]


Epoch [78/90] Batch [100] | LossG: 10.5753 | LossD: 0.0181


100%|██████████| 1/1 [00:00<00:00,  8.55it/s]




100%|██████████| 1/1 [00:00<00:00,  8.56it/s]


Epoch [78/90] Batch [200] | LossG: 10.2856 | LossD: 0.0003


100%|██████████| 1/1 [00:00<00:00,  8.32it/s]




100%|██████████| 1/1 [00:00<00:00,  8.06it/s]


Epoch [79/90] Batch [0] | LossG: 11.8805 | LossD: 0.0005


100%|██████████| 1/1 [00:00<00:00,  8.79it/s]




100%|██████████| 1/1 [00:00<00:00,  8.77it/s]


Epoch [79/90] Batch [100] | LossG: 14.6743 | LossD: 0.0012


100%|██████████| 1/1 [00:00<00:00,  9.01it/s]




100%|██████████| 1/1 [00:00<00:00,  8.56it/s]


Epoch [79/90] Batch [200] | LossG: 10.6452 | LossD: 0.0008


100%|██████████| 1/1 [00:00<00:00,  8.95it/s]




100%|██████████| 1/1 [00:00<00:00,  8.66it/s]


Epoch [80/90] Batch [0] | LossG: 10.9828 | LossD: 0.0004


100%|██████████| 1/1 [00:00<00:00,  8.84it/s]




100%|██████████| 1/1 [00:00<00:00,  8.68it/s]


Epoch [80/90] Batch [100] | LossG: 10.6492 | LossD: 0.0009


100%|██████████| 1/1 [00:00<00:00,  8.55it/s]




100%|██████████| 1/1 [00:00<00:00,  8.81it/s]


Epoch [80/90] Batch [200] | LossG: 6.9537 | LossD: 0.0011


100%|██████████| 1/1 [00:00<00:00,  8.76it/s]




100%|██████████| 1/1 [00:00<00:00,  8.70it/s]


Epoch [81/90] Batch [0] | LossG: 11.3383 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  8.41it/s]




100%|██████████| 1/1 [00:00<00:00,  8.53it/s]


Epoch [81/90] Batch [100] | LossG: 7.7830 | LossD: 0.0008


100%|██████████| 1/1 [00:00<00:00,  8.71it/s]




100%|██████████| 1/1 [00:00<00:00,  8.51it/s]


Epoch [81/90] Batch [200] | LossG: 11.7373 | LossD: 0.0010


100%|██████████| 1/1 [00:00<00:00,  8.88it/s]




100%|██████████| 1/1 [00:00<00:00,  8.81it/s]


Epoch [82/90] Batch [0] | LossG: 8.6503 | LossD: 0.0004


100%|██████████| 1/1 [00:00<00:00,  8.71it/s]




100%|██████████| 1/1 [00:00<00:00,  8.50it/s]


Epoch [82/90] Batch [100] | LossG: 13.0216 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  8.75it/s]




100%|██████████| 1/1 [00:00<00:00,  8.70it/s]


Epoch [82/90] Batch [200] | LossG: 11.7111 | LossD: 0.0010


100%|██████████| 1/1 [00:00<00:00,  8.68it/s]




100%|██████████| 1/1 [00:00<00:00,  8.56it/s]


Epoch [83/90] Batch [0] | LossG: 11.0620 | LossD: 0.0016


100%|██████████| 1/1 [00:00<00:00,  8.55it/s]




100%|██████████| 1/1 [00:00<00:00,  8.45it/s]


Epoch [83/90] Batch [100] | LossG: 13.3080 | LossD: 0.0019


100%|██████████| 1/1 [00:00<00:00,  8.88it/s]




100%|██████████| 1/1 [00:00<00:00,  8.74it/s]


Epoch [83/90] Batch [200] | LossG: 7.6828 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  8.67it/s]




100%|██████████| 1/1 [00:00<00:00,  8.72it/s]


Epoch [84/90] Batch [0] | LossG: 11.2751 | LossD: 0.0004


100%|██████████| 1/1 [00:00<00:00,  8.79it/s]




100%|██████████| 1/1 [00:00<00:00,  8.48it/s]


Epoch [84/90] Batch [100] | LossG: 8.2437 | LossD: 0.0010


100%|██████████| 1/1 [00:00<00:00,  8.52it/s]




100%|██████████| 1/1 [00:00<00:00,  8.73it/s]


Epoch [84/90] Batch [200] | LossG: 10.3661 | LossD: 0.0031


100%|██████████| 1/1 [00:00<00:00,  8.86it/s]




100%|██████████| 1/1 [00:00<00:00,  8.65it/s]


Epoch [85/90] Batch [0] | LossG: 12.7440 | LossD: 0.0005


100%|██████████| 1/1 [00:00<00:00,  8.42it/s]




100%|██████████| 1/1 [00:00<00:00,  8.75it/s]


Epoch [85/90] Batch [100] | LossG: 13.3737 | LossD: 0.0049


100%|██████████| 1/1 [00:00<00:00,  8.64it/s]




100%|██████████| 1/1 [00:00<00:00,  8.75it/s]


Epoch [85/90] Batch [200] | LossG: 10.3173 | LossD: 0.0001


100%|██████████| 1/1 [00:00<00:00,  8.82it/s]




100%|██████████| 1/1 [00:00<00:00,  8.70it/s]


Epoch [86/90] Batch [0] | LossG: 9.6743 | LossD: 0.0002


100%|██████████| 1/1 [00:00<00:00,  8.44it/s]




100%|██████████| 1/1 [00:00<00:00,  8.69it/s]


Epoch [86/90] Batch [100] | LossG: 11.1696 | LossD: 0.0001


100%|██████████| 1/1 [00:00<00:00,  8.60it/s]




100%|██████████| 1/1 [00:00<00:00,  7.94it/s]


Epoch [86/90] Batch [200] | LossG: 15.3175 | LossD: 0.0000


100%|██████████| 1/1 [00:00<00:00,  8.51it/s]




100%|██████████| 1/1 [00:00<00:00,  8.51it/s]


Epoch [87/90] Batch [0] | LossG: 8.8004 | LossD: 0.0003


100%|██████████| 1/1 [00:00<00:00,  8.64it/s]




100%|██████████| 1/1 [00:00<00:00,  8.26it/s]


Epoch [87/90] Batch [100] | LossG: 9.8021 | LossD: 0.0000


100%|██████████| 1/1 [00:00<00:00,  8.55it/s]




100%|██████████| 1/1 [00:00<00:00,  8.47it/s]


Epoch [87/90] Batch [200] | LossG: 10.7459 | LossD: 0.0000


100%|██████████| 1/1 [00:00<00:00,  8.44it/s]




100%|██████████| 1/1 [00:00<00:00,  8.39it/s]


Epoch [88/90] Batch [0] | LossG: 12.0105 | LossD: 0.0018


100%|██████████| 1/1 [00:00<00:00,  8.66it/s]




100%|██████████| 1/1 [00:00<00:00,  8.71it/s]


Epoch [88/90] Batch [100] | LossG: 14.6024 | LossD: 0.0000


100%|██████████| 1/1 [00:00<00:00,  8.86it/s]




100%|██████████| 1/1 [00:00<00:00,  8.77it/s]


Epoch [88/90] Batch [200] | LossG: 10.4517 | LossD: 0.0004


100%|██████████| 1/1 [00:00<00:00,  8.52it/s]




100%|██████████| 1/1 [00:00<00:00,  8.37it/s]


Epoch [89/90] Batch [0] | LossG: 12.1056 | LossD: 0.0060


100%|██████████| 1/1 [00:00<00:00,  8.58it/s]




100%|██████████| 1/1 [00:00<00:00,  8.19it/s]


Epoch [89/90] Batch [100] | LossG: 10.5975 | LossD: 0.0000


100%|██████████| 1/1 [00:00<00:00,  8.35it/s]




100%|██████████| 1/1 [00:00<00:00,  8.43it/s]


Epoch [89/90] Batch [200] | LossG: 13.7431 | LossD: 0.0000


100%|██████████| 1/1 [00:00<00:00,  6.99it/s]




100%|██████████| 1/1 [00:00<00:00,  8.26it/s]


Epoch [90/90] Batch [0] | LossG: 9.7675 | LossD: 0.0000


100%|██████████| 1/1 [00:00<00:00,  8.57it/s]




100%|██████████| 1/1 [00:00<00:00,  8.25it/s]


Epoch [90/90] Batch [100] | LossG: 12.6716 | LossD: 0.0007


100%|██████████| 1/1 [00:00<00:00,  8.21it/s]




100%|██████████| 1/1 [00:00<00:00,  8.21it/s]


Epoch [90/90] Batch [200] | LossG: 9.7791 | LossD: 0.0004


100%|██████████| 1/1 [00:00<00:00,  8.26it/s]




100%|██████████| 1/1 [00:00<00:00,  8.32it/s]


In [6]:

# save_checkpoint(start_epoch+epochs, G, D, ema_g, optimizer_G, optimizer_D, scaler_g, scaler_d)

In [7]:
# os.makedirs('/kaggle/working/checkpoints/', exist_ok=True)
checkpoint = {
    'epoch': start_epoch+epoch + 1,
    'generator_state_dict': G.state_dict(),
    'discriminator_state_dict': D.state_dict(),
    'ema_generator_state_dict': ema_g.ema_model.state_dict(),
    'optimizer_G_state_dict': optimizer_G.state_dict(),
    'optimizer_D_state_dict': optimizer_D.state_dict(),
    'scaler_g_state_dict': scaler_g.state_dict(),
    'scaler_d_state_dict': scaler_d.state_dict(),
}
torch.save(checkpoint, os.path.join('/kaggle/working/', f'pix2pixhd_checkpoint_epoch_{start_epoch+epoch+1}.pth'))

In [8]:
!rm file.zip

rm: cannot remove 'file.zip': No such file or directory


In [9]:
# import shutil
# shutil.make_archive('all', 'zip', '/kaggle/working/')

!zip -r file.zip /kaggle/working

  adding: kaggle/working/ (stored 0%)
  adding: kaggle/working/pix2pixhd_checkpoint_epoch_60.pth^C



zip error: Interrupted (aborting)
