In [1]:
import numpy as np
import cv2
import torch
from functions import show_tensor
from generator import define_G
from discriminator import define_D
import losses
from replay_pool import ReplayPool
from tqdm import tqdm
import os
from moving_average import moving_average

device = torch.device("cuda")

def make_generator():
    """Create a generator with the specified architecture"""
    return define_G(
        input_nc=3,
        output_nc=3,
        ngf=64,
        netG="global",
        norm="instance",
        n_downsample_global=3,
        n_blocks_global=9,
        n_local_enhancers=1,
        n_blocks_local=3
    ).to(device)

def save_checkpoint(model, checkpoint_dir, epoch):
    """Save model checkpoint"""
    import datetime
    out_file = f"epoch_{epoch}_{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')}.pt"
    out_file = os.path.join(checkpoint_dir, out_file)
    
    try:
        torch.save({
            "G1": model.G1_ema.state_dict(),
            "G2": model.G2_ema.state_dict(),
            "G3": model.G3_ema.state_dict(),
            "D1": model.D1.state_dict(),
            "D2": model.D2.state_dict(),
            "D3": model.D3.state_dict()
        }, out_file)
        print(f"Saved to {out_file}")
    except Exception as e:
        print(f"Error saving the file: {e}")

def process_loss(log, losses):
    """
    Process and aggregate loss values, updating the log dictionary
    
    Args:
        log (dict): Dictionary to store accumulated loss values
        losses (dict): Dictionary of current loss values to process
        
    Returns:
        torch.Tensor: Sum of all loss values
    """
    total_loss = 0
    
    # Add each loss component to the total and update the log
    for name, value in losses.items():
        if isinstance(value, torch.Tensor):
            total_loss += value
            if name not in log:
                log[name] = 0
            log[name] += value.item()
        else:
            total_loss += value
            if name not in log:
                log[name] = 0
            log[name] += value
            
    return total_loss



class CascadeGAN:
    def __init__(self, device, generators, discriminators, optimizers, replay_pools, losses):
        self.device = device
        
        # Generators
        self.G1 = generators['G1']
        self.G2 = generators['G2']
        self.G3 = generators['G3']
        self.G1_ema = generators['G1_ema']
        self.G2_ema = generators['G2_ema']
        self.G3_ema = generators['G3_ema']
        
        # Discriminators
        self.D1 = discriminators['D1']
        self.D2 = discriminators['D2']
        self.D3 = discriminators['D3']
        
        # Optimizers
        self.G1_optim = optimizers['G1']
        self.G2_optim = optimizers['G2']
        self.G3_optim = optimizers['G3']
        self.D1_optim = optimizers['D1']
        self.D2_optim = optimizers['D2']
        self.D3_optim = optimizers['D3']
        
        # Replay pools
        self.replay_pool1 = replay_pools['pool1']
        self.replay_pool2 = replay_pools['pool2']
        self.replay_pool3 = replay_pools['pool3']
        
        # Loss functions
        self.criterionGAN = losses['GAN']
        self.criterionFeat = losses['Feat']
        self.criterionVGG = losses['VGG']

    def calc_G1_losses(self, img_A, img_A_mask):
        """Calculate losses for first generator (A → A_mask)"""
        fake_A_mask = self.G1(img_A)
        loss_vgg = criterionVGG(fake_A_mask, img_A_mask)
        
        pred_fake = self.D1(torch.cat([img_A, fake_A_mask], dim=1))
        loss_adv = criterionGAN(pred_fake, 1)
        
        with torch.no_grad():
            pred_true = self.D1(torch.cat([img_A, img_A_mask], dim=1))
            
        loss_adv_feat = self.calc_feature_matching_loss(pred_fake, pred_true)
        
        return {
            "G1_vgg": loss_vgg,
            "G1_adv": loss_adv,
            "G1_adv_feat": 10 * loss_adv_feat
        }, fake_A_mask

    def calc_G2_losses(self, fake_A_mask, img_B_mask):
        """Calculate losses for second generator (A_mask → B_mask)"""
        fake_B_mask = self.G2(fake_A_mask.detach())
        loss_vgg = criterionVGG(fake_B_mask, img_B_mask)
        
        pred_fake = self.D2(torch.cat([fake_A_mask.detach(), fake_B_mask], dim=1))
        loss_adv = criterionGAN(pred_fake, 1)
        
        with torch.no_grad():
            pred_true = self.D2(torch.cat([fake_A_mask.detach(), img_B_mask], dim=1))
            
        loss_adv_feat = self.calc_feature_matching_loss(pred_fake, pred_true)
        
        return {
            "G2_vgg": loss_vgg,
            "G2_adv": loss_adv,
            "G2_adv_feat": 10 * loss_adv_feat
        }, fake_B_mask

    def calc_G3_losses(self, fake_B_mask, img_B):
        """Calculate losses for third generator (B_mask → B)"""
        fake_B = self.G3(fake_B_mask.detach())
        loss_vgg = criterionVGG(fake_B, img_B)
        
        pred_fake = self.D3(torch.cat([fake_B_mask.detach(), fake_B], dim=1))
        loss_adv = criterionGAN(pred_fake, 1)
        
        with torch.no_grad():
            pred_true = self.D3(torch.cat([fake_B_mask.detach(), img_B], dim=1))
            
        loss_adv_feat = self.calc_feature_matching_loss(pred_fake, pred_true)
        
        return {
            "G3_vgg": loss_vgg,
            "G3_adv": loss_adv,
            "G3_adv_feat": 10 * loss_adv_feat
        }, fake_B

    def calc_feature_matching_loss(self, pred_fake, pred_true):
        loss_adv_feat = 0
        adv_feats_count = 0
        for d_fake_out, d_true_out in zip(pred_fake, pred_true):
            for l_fake, l_true in zip(d_fake_out[:-1], d_true_out[:-1]):
                loss_adv_feat += criterionFeat(l_fake, l_true)
                adv_feats_count += 1
        return (4/adv_feats_count) * loss_adv_feat if adv_feats_count > 0 else 0

    def calc_D_losses(self, D, real_input, real_target, fake_input, fake_output, name):
        """Calculate discriminator losses"""
        pred_true = D(torch.cat([real_input, real_target], dim=1))
        loss_true = criterionGAN(pred_true, 1)
        
        pred_fake = D(torch.cat([fake_input, fake_output], dim=1))
        loss_false = criterionGAN(pred_fake, 0)
        
        return {
            f"D{name}_true": loss_true,
            f"D{name}_false": loss_false
        }


    def train_step(self, batch):
        img_A = batch['A'].to(self.device)
        img_A_mask = batch['A_mask'].to(self.device)
        img_B_mask = batch['B_mask'].to(self.device)
        img_B = batch['B'].to(self.device)
        
        log = {}
        
        # Stage 1: A → A_mask
        self.G1_optim.zero_grad()
        G1_losses, fake_A_mask = self.calc_G1_losses(img_A, img_A_mask)
        G1_loss = process_loss(log, G1_losses)
        G1_loss.backward()
        self.G1_optim.step()
        moving_average(self.G1, self.G1_ema)
        
        # Detach fake_A_mask to break the graph from stage 1
        fake_A_mask = fake_A_mask.detach()
        
        self.D1_optim.zero_grad()
        fake_data1 = self.replay_pool1.query({
            "input": img_A.detach(),
            "output": fake_A_mask
        })
        D1_losses = self.calc_D_losses(
            self.D1, img_A, img_A_mask,
            fake_data1["input"], fake_data1["output"], "1"
        )
        D1_loss = process_loss(log, D1_losses)
        D1_loss.backward()
        self.D1_optim.step()
        
        # Stage 2: A_mask → B_mask
        self.G2_optim.zero_grad()
        G2_losses, fake_B_mask = self.calc_G2_losses(fake_A_mask, img_B_mask)
        G2_loss = process_loss(log, G2_losses)
        G2_loss.backward()
        self.G2_optim.step()
        moving_average(self.G2, self.G2_ema)
        
        # Detach fake_B_mask to break the graph from stage 2
        fake_B_mask = fake_B_mask.detach()
        
        self.D2_optim.zero_grad()
        fake_data2 = self.replay_pool2.query({
            "input": fake_A_mask,
            "output": fake_B_mask
        })
        D2_losses = self.calc_D_losses(
            self.D2, fake_A_mask, img_B_mask,
            fake_data2["input"], fake_data2["output"], "2"
        )
        D2_loss = process_loss(log, D2_losses)
        D2_loss.backward()
        self.D2_optim.step()
        
        # Stage 3: B_mask → B
        self.G3_optim.zero_grad()
        G3_losses, fake_B = self.calc_G3_losses(fake_B_mask, img_B)
        G3_loss = process_loss(log, G3_losses)
        G3_loss.backward()
        self.G3_optim.step()
        moving_average(self.G3, self.G3_ema)
        
        self.D3_optim.zero_grad()
        fake_data3 = self.replay_pool3.query({
            "input": fake_B_mask,
            "output": fake_B.detach()
        })
        D3_losses = self.calc_D_losses(
            self.D3, fake_B_mask, img_B,
            fake_data3["input"], fake_data3["output"], "3"
        )
        D3_loss = process_loss(log, D3_losses)
        D3_loss.backward()
        self.D3_optim.step()
        
        return log, (fake_A_mask, fake_B_mask, fake_B)


    def test(self, test_loader, epoch, iteration, output_dir):
        """Generate test samples"""
        with torch.no_grad():
            batch = next(iter(test_loader))
            img_A = batch['A'].to(self.device)
            
            self.G1_ema.eval()
            fake_A_mask = self.G1_ema(img_A)
            
            self.G2_ema.eval()
            fake_B_mask = self.G2_ema(fake_A_mask)
            
            self.G3_ema.eval()
            fake_B = self.G3_ema(fake_B_mask)
            
            # Create visualization
            pairs = torch.cat([
                img_A,
                fake_A_mask,
                fake_B_mask,
                fake_B,
                batch['B'].to(self.device)
            ], -1)
            
            matrix = []
            for idx in range(img_A.shape[0]):
                img = 255 * (pairs[idx] + 1) / 2
                img = img.cpu().permute(1, 2, 0).clip(0, 255).numpy().astype(np.uint8)
                matrix.append(img)
            
            matrix = np.vstack(matrix)
            matrix = cv2.cvtColor(matrix, cv2.COLOR_RGB2BGR)
            out_file = os.path.join(output_dir, f"{epoch}_{iteration}.jpg")
            cv2.imwrite(out_file, matrix)
            
            self.G1_ema.train()
            self.G2_ema.train()
            self.G3_ema.train()

def train(model, train_loader, test_loader, num_epochs, checkpoint_dir):
    """Training loop"""
    for epoch in range(num_epochs):
        print(f"Training epoch {epoch}...")
        N = 0
        log = {}
        
        pbar = tqdm(train_loader)
        for batch in pbar:
            batch_log, _ = model.train_step(batch)
            
            for k, v in batch_log.items():
                if k not in log:
                    log[k] = 0
                log[k] += v
            
            N += 1
            if (N % 100 == 0) or (N + 1 >= len(train_loader)):
                for i in range(3):
                    model.test(test_loader, epoch, N + i, os.path.join(checkpoint_dir, "images"))
            
            # Update progress bar
            txt = " ".join([f"{k}: {v/N:.3e}" for k, v in log.items()])
            pbar.set_description(txt)
        
        # Save checkpoint
        if (epoch + 1) % 20 == 0:
            save_checkpoint(model, checkpoint_dir, epoch)


# Initialize all three generators and their EMA counterparts
generator1 = make_generator()  # A → A_mask
generator2 = make_generator()  # A_mask → B_mask
generator3 = make_generator()  # B_mask → B

generator1_ema = make_generator()
generator2_ema = make_generator()
generator3_ema = make_generator()

# Initialize EMA weights
with torch.no_grad():
    for g, g_ema in [
        (generator1, generator1_ema),
        (generator2, generator2_ema),
        (generator3, generator3_ema)
    ]:
        for gp, ep in zip(g.parameters(), g_ema.parameters()):
            ep.data = gp.data.detach()

# Initialize discriminators
discriminator1 = define_D(
    input_nc=3 + 3,  # A + A_mask
    ndf=64,
    n_layers_D=3,
    num_D=3,
    norm="instance",
    getIntermFeat=True
).to(device)

discriminator2 = define_D(
    input_nc=3 + 3,  # A_mask + B_mask
    ndf=64,
    n_layers_D=3,
    num_D=3,
    norm="instance",
    getIntermFeat=True
).to(device)

discriminator3 = define_D(
    input_nc=3 + 3,  # B_mask + B
    ndf=64,
    n_layers_D=3,
    num_D=3,
    norm="instance",
    getIntermFeat=True
).to(device)

# Initialize loss functions
criterionGAN = losses.GANLoss(use_lsgan=True).to(device)
criterionFeat = torch.nn.L1Loss().to(device)
criterionVGG = losses.VGGLoss().to(device)

# Initialize replay pools for each stage
replay_pool1 = ReplayPool(10)
replay_pool2 = ReplayPool(10)
replay_pool3 = ReplayPool(10)

# Initialize optimizers
G1_optim = torch.optim.AdamW(generator1.parameters(), lr=1e-4)
G2_optim = torch.optim.AdamW(generator2.parameters(), lr=1e-4)
G3_optim = torch.optim.AdamW(generator3.parameters(), lr=1e-4)

D1_optim = torch.optim.AdamW(discriminator1.parameters(), lr=1e-4)
D2_optim = torch.optim.AdamW(discriminator2.parameters(), lr=1e-4)
D3_optim = torch.optim.AdamW(discriminator3.parameters(), lr=1e-4)


from dataloader import Dataset, print_dataset, plot_batch_samples
# DataLoader setup
batch_size = 8
num_workers = 4

train_dataset = Dataset(images_dir="/media/irfan/New Volume/Dentalverse/code/pix2pix/simplified_pix2pixHD/data/ortho/train", mode="train")
test_dataset = Dataset(images_dir="/media/irfan/New Volume/Dentalverse/code/pix2pix/simplified_pix2pixHD/data/ortho/test", mode="test")

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)
# Print information for both datasets
print_dataset(train_loader, mode="train")
print_dataset(test_loader, mode="test")

# Get and display a batch
batch = next(iter(train_loader))
plot_batch_samples(batch)

# Usage
checkpoint_dir = "/media/irfan/New Volume/Dentalverse/code/pix2pix/simplified_pix2pixHD/checkpoints/ortho"
images_output_dir = os.path.join(checkpoint_dir, "images")
os.makedirs(images_output_dir, exist_ok=True)
device = torch.device("cuda")

# Create the CascadeGAN model
model = CascadeGAN(
    device=device,
    generators={
        'G1': generator1,
        'G2': generator2,
        'G3': generator3,
        'G1_ema': generator1_ema,
        'G2_ema': generator2_ema,
        'G3_ema': generator3_ema
    },
    discriminators={
        'D1': discriminator1,
        'D2': discriminator2,
        'D3': discriminator3
    },
    optimizers={
        'G1': G1_optim,
        'G2': G2_optim,
        'G3': G3_optim,
        'D1': D1_optim,
        'D2': D2_optim,
        'D3': D3_optim
    },
    replay_pools={
        'pool1': replay_pool1,
        'pool2': replay_pool2,
        'pool3': replay_pool3
    },
    losses={
        'GAN': criterionGAN,
        'Feat': criterionFeat,
        'VGG': criterionVGG
    }
)

train(model, train_loader, test_loader, num_epochs=200, checkpoint_dir=checkpoint_dir)


GlobalGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (11): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (12): ReLU(inplace=True)
    (13): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(512, 512, kernel_size



Training epoch 0...


G1_vgg: 7.452e-01 G1_adv: 1.071e+00 G1_adv_feat: 4.430e+00 D1_true: 9.385e-01 D1_false: 9.075e-01 G2_vgg: 9.728e-01 G2_adv: 1.259e+00 G2_adv_feat: 7.831e+00 D2_true: 7.732e-01 D2_false: 6.996e-01 G3_vgg: 8.126e-01 G3_adv: 1.249e+00 G3_adv_feat: 7.644e+00 D3_true: 7.218e-01 D3_false: 6.696e-01: 100%|██████████| 533/533 [01:19<00:00,  6.68it/s]


Training epoch 1...


G1_vgg: 6.063e-01 G1_adv: 8.386e-01 G1_adv_feat: 2.676e+00 D1_true: 7.147e-01 D1_false: 7.267e-01 G2_vgg: 9.948e-01 G2_adv: 1.286e+00 G2_adv_feat: 7.671e+00 D2_true: 5.328e-01 D2_false: 4.914e-01 G3_vgg: 8.183e-01 G3_adv: 1.146e+00 G3_adv_feat: 8.110e+00 D3_true: 5.521e-01 D3_false: 5.188e-01: 100%|██████████| 533/533 [01:19<00:00,  6.67it/s]


Training epoch 2...


G1_vgg: 5.090e-01 G1_adv: 8.098e-01 G1_adv_feat: 2.307e+00 D1_true: 7.335e-01 D1_false: 7.409e-01 G2_vgg: 9.916e-01 G2_adv: 1.326e+00 G2_adv_feat: 7.771e+00 D2_true: 5.128e-01 D2_false: 4.621e-01 G3_vgg: 7.641e-01 G3_adv: 1.074e+00 G3_adv_feat: 8.222e+00 D3_true: 5.950e-01 D3_false: 5.384e-01: 100%|██████████| 533/533 [01:20<00:00,  6.66it/s]


Training epoch 3...


G1_vgg: 4.546e-01 G1_adv: 7.908e-01 G1_adv_feat: 2.174e+00 D1_true: 7.407e-01 D1_false: 7.443e-01 G2_vgg: 9.555e-01 G2_adv: 1.207e+00 G2_adv_feat: 7.787e+00 D2_true: 5.622e-01 D2_false: 5.211e-01 G3_vgg: 7.453e-01 G3_adv: 1.056e+00 G3_adv_feat: 8.350e+00 D3_true: 6.123e-01 D3_false: 5.401e-01: 100%|██████████| 533/533 [01:19<00:00,  6.68it/s]


Training epoch 4...


G1_vgg: 4.400e-01 G1_adv: 7.947e-01 G1_adv_feat: 2.246e+00 D1_true: 7.380e-01 D1_false: 7.416e-01 G2_vgg: 9.454e-01 G2_adv: 1.221e+00 G2_adv_feat: 7.972e+00 D2_true: 5.343e-01 D2_false: 4.973e-01 G3_vgg: 7.427e-01 G3_adv: 1.080e+00 G3_adv_feat: 8.621e+00 D3_true: 6.063e-01 D3_false: 5.120e-01: 100%|██████████| 533/533 [01:19<00:00,  6.74it/s]


Training epoch 5...


G1_vgg: 4.210e-01 G1_adv: 7.977e-01 G1_adv_feat: 2.221e+00 D1_true: 7.459e-01 D1_false: 7.464e-01 G2_vgg: 9.071e-01 G2_adv: 1.161e+00 G2_adv_feat: 8.157e+00 D2_true: 5.971e-01 D2_false: 5.432e-01 G3_vgg: 7.385e-01 G3_adv: 1.095e+00 G3_adv_feat: 8.606e+00 D3_true: 6.209e-01 D3_false: 5.222e-01: 100%|██████████| 533/533 [01:19<00:00,  6.72it/s]


Training epoch 6...


G1_vgg: 4.021e-01 G1_adv: 7.882e-01 G1_adv_feat: 2.133e+00 D1_true: 7.485e-01 D1_false: 7.497e-01 G2_vgg: 8.953e-01 G2_adv: 1.112e+00 G2_adv_feat: 8.201e+00 D2_true: 6.224e-01 D2_false: 5.602e-01 G3_vgg: 7.372e-01 G3_adv: 1.129e+00 G3_adv_feat: 9.011e+00 D3_true: 6.144e-01 D3_false: 4.939e-01: 100%|██████████| 533/533 [01:19<00:00,  6.71it/s]


Training epoch 7...


G1_vgg: 3.842e-01 G1_adv: 7.887e-01 G1_adv_feat: 2.044e+00 D1_true: 7.494e-01 D1_false: 7.483e-01 G2_vgg: 8.864e-01 G2_adv: 1.138e+00 G2_adv_feat: 8.338e+00 D2_true: 6.072e-01 D2_false: 5.497e-01 G3_vgg: 7.399e-01 G3_adv: 1.075e+00 G3_adv_feat: 9.289e+00 D3_true: 6.036e-01 D3_false: 4.987e-01: 100%|██████████| 533/533 [01:19<00:00,  6.69it/s]


Training epoch 8...


G1_vgg: 3.811e-01 G1_adv: 7.904e-01 G1_adv_feat: 2.057e+00 D1_true: 7.478e-01 D1_false: 7.475e-01 G2_vgg: 8.800e-01 G2_adv: 1.167e+00 G2_adv_feat: 8.408e+00 D2_true: 6.122e-01 D2_false: 5.393e-01 G3_vgg: 7.459e-01 G3_adv: 1.121e+00 G3_adv_feat: 9.109e+00 D3_true: 6.043e-01 D3_false: 4.933e-01: 100%|██████████| 533/533 [01:18<00:00,  6.78it/s]


Training epoch 9...


G1_vgg: 3.815e-01 G1_adv: 7.876e-01 G1_adv_feat: 2.071e+00 D1_true: 7.478e-01 D1_false: 7.478e-01 G2_vgg: 8.774e-01 G2_adv: 1.183e+00 G2_adv_feat: 8.554e+00 D2_true: 5.977e-01 D2_false: 5.388e-01 G3_vgg: 7.407e-01 G3_adv: 1.119e+00 G3_adv_feat: 9.478e+00 D3_true: 5.854e-01 D3_false: 4.852e-01: 100%|██████████| 533/533 [01:16<00:00,  6.98it/s]


Training epoch 10...


G1_vgg: 3.630e-01 G1_adv: 7.810e-01 G1_adv_feat: 1.978e+00 D1_true: 7.481e-01 D1_false: 7.487e-01 G2_vgg: 8.754e-01 G2_adv: 1.154e+00 G2_adv_feat: 8.641e+00 D2_true: 6.132e-01 D2_false: 5.481e-01 G3_vgg: 7.432e-01 G3_adv: 1.121e+00 G3_adv_feat: 9.882e+00 D3_true: 6.071e-01 D3_false: 4.850e-01: 100%|██████████| 533/533 [01:15<00:00,  7.07it/s]


Training epoch 11...


G1_vgg: 3.510e-01 G1_adv: 7.871e-01 G1_adv_feat: 1.940e+00 D1_true: 7.479e-01 D1_false: 7.462e-01 G2_vgg: 8.717e-01 G2_adv: 1.170e+00 G2_adv_feat: 8.653e+00 D2_true: 6.043e-01 D2_false: 5.379e-01 G3_vgg: 7.377e-01 G3_adv: 1.156e+00 G3_adv_feat: 9.839e+00 D3_true: 5.835e-01 D3_false: 4.514e-01: 100%|██████████| 533/533 [01:15<00:00,  7.07it/s]


Training epoch 12...


G1_vgg: 3.560e-01 G1_adv: 7.855e-01 G1_adv_feat: 2.025e+00 D1_true: 7.459e-01 D1_false: 7.459e-01 G2_vgg: 8.718e-01 G2_adv: 1.166e+00 G2_adv_feat: 8.746e+00 D2_true: 6.157e-01 D2_false: 5.406e-01 G3_vgg: 7.403e-01 G3_adv: 1.153e+00 G3_adv_feat: 1.007e+01 D3_true: 5.956e-01 D3_false: 4.594e-01: 100%|██████████| 533/533 [01:16<00:00,  7.00it/s]


Training epoch 13...


G1_vgg: 3.437e-01 G1_adv: 7.850e-01 G1_adv_feat: 1.982e+00 D1_true: 7.477e-01 D1_false: 7.471e-01 G2_vgg: 8.686e-01 G2_adv: 1.177e+00 G2_adv_feat: 8.775e+00 D2_true: 6.140e-01 D2_false: 5.403e-01 G3_vgg: 7.376e-01 G3_adv: 1.107e+00 G3_adv_feat: 1.006e+01 D3_true: 6.017e-01 D3_false: 4.785e-01: 100%|██████████| 533/533 [01:18<00:00,  6.78it/s]


Training epoch 14...


G1_vgg: 3.333e-01 G1_adv: 7.839e-01 G1_adv_feat: 1.976e+00 D1_true: 7.436e-01 D1_false: 7.456e-01 G2_vgg: 8.710e-01 G2_adv: 1.148e+00 G2_adv_feat: 8.802e+00 D2_true: 6.185e-01 D2_false: 5.517e-01 G3_vgg: 7.421e-01 G3_adv: 1.135e+00 G3_adv_feat: 1.013e+01 D3_true: 5.847e-01 D3_false: 4.618e-01: 100%|██████████| 533/533 [01:19<00:00,  6.69it/s]


Training epoch 15...


G1_vgg: 3.406e-01 G1_adv: 7.807e-01 G1_adv_feat: 2.065e+00 D1_true: 7.441e-01 D1_false: 7.434e-01 G2_vgg: 8.694e-01 G2_adv: 1.148e+00 G2_adv_feat: 8.824e+00 D2_true: 6.128e-01 D2_false: 5.452e-01 G3_vgg: 7.383e-01 G3_adv: 1.153e+00 G3_adv_feat: 1.052e+01 D3_true: 6.103e-01 D3_false: 4.619e-01: 100%|██████████| 533/533 [01:19<00:00,  6.69it/s]


Training epoch 16...


G1_vgg: 3.196e-01 G1_adv: 7.864e-01 G1_adv_feat: 1.948e+00 D1_true: 7.464e-01 D1_false: 7.466e-01 G2_vgg: 8.679e-01 G2_adv: 1.175e+00 G2_adv_feat: 8.919e+00 D2_true: 6.085e-01 D2_false: 5.377e-01 G3_vgg: 7.435e-01 G3_adv: 1.223e+00 G3_adv_feat: 1.088e+01 D3_true: 5.701e-01 D3_false: 4.224e-01: 100%|██████████| 533/533 [01:19<00:00,  6.69it/s]


Training epoch 17...


G1_vgg: 3.091e-01 G1_adv: 7.852e-01 G1_adv_feat: 1.904e+00 D1_true: 7.462e-01 D1_false: 7.458e-01 G2_vgg: 8.697e-01 G2_adv: 1.158e+00 G2_adv_feat: 8.912e+00 D2_true: 6.007e-01 D2_false: 5.319e-01 G3_vgg: 7.424e-01 G3_adv: 1.198e+00 G3_adv_feat: 1.081e+01 D3_true: 5.774e-01 D3_false: 4.361e-01: 100%|██████████| 533/533 [01:19<00:00,  6.75it/s]


Training epoch 18...


G1_vgg: 3.017e-01 G1_adv: 7.855e-01 G1_adv_feat: 1.884e+00 D1_true: 7.469e-01 D1_false: 7.459e-01 G2_vgg: 8.677e-01 G2_adv: 1.195e+00 G2_adv_feat: 8.968e+00 D2_true: 6.096e-01 D2_false: 5.217e-01 G3_vgg: 7.384e-01 G3_adv: 1.202e+00 G3_adv_feat: 1.081e+01 D3_true: 5.828e-01 D3_false: 4.261e-01: 100%|██████████| 533/533 [01:19<00:00,  6.74it/s]


Training epoch 19...


G1_vgg: 2.948e-01 G1_adv: 7.856e-01 G1_adv_feat: 1.876e+00 D1_true: 7.452e-01 D1_false: 7.453e-01 G2_vgg: 8.685e-01 G2_adv: 1.186e+00 G2_adv_feat: 9.030e+00 D2_true: 5.849e-01 D2_false: 5.231e-01 G3_vgg: 7.399e-01 G3_adv: 1.202e+00 G3_adv_feat: 1.105e+01 D3_true: 5.853e-01 D3_false: 4.366e-01: 100%|██████████| 533/533 [01:19<00:00,  6.71it/s]


Saved to /media/irfan/New Volume/Dentalverse/code/pix2pix/simplified_pix2pixHD/checkpoints/ortho/epoch_19_2025_01_08_14_49.pt
Training epoch 20...


G1_vgg: 2.873e-01 G1_adv: 7.834e-01 G1_adv_feat: 1.837e+00 D1_true: 7.482e-01 D1_false: 7.476e-01 G2_vgg: 8.676e-01 G2_adv: 1.210e+00 G2_adv_feat: 9.079e+00 D2_true: 5.850e-01 D2_false: 5.165e-01 G3_vgg: 7.373e-01 G3_adv: 1.214e+00 G3_adv_feat: 1.100e+01 D3_true: 5.588e-01 D3_false: 4.235e-01: 100%|██████████| 533/533 [01:19<00:00,  6.70it/s]


Training epoch 21...


G1_vgg: 2.788e-01 G1_adv: 7.854e-01 G1_adv_feat: 1.811e+00 D1_true: 7.469e-01 D1_false: 7.461e-01 G2_vgg: 8.662e-01 G2_adv: 1.227e+00 G2_adv_feat: 9.096e+00 D2_true: 5.862e-01 D2_false: 5.161e-01 G3_vgg: 7.407e-01 G3_adv: 1.239e+00 G3_adv_feat: 1.123e+01 D3_true: 5.605e-01 D3_false: 4.118e-01: 100%|██████████| 533/533 [01:20<00:00,  6.65it/s]


Training epoch 22...


G1_vgg: 2.716e-01 G1_adv: 7.834e-01 G1_adv_feat: 1.771e+00 D1_true: 7.468e-01 D1_false: 7.458e-01 G2_vgg: 8.647e-01 G2_adv: 1.189e+00 G2_adv_feat: 9.117e+00 D2_true: 6.026e-01 D2_false: 5.322e-01 G3_vgg: 7.379e-01 G3_adv: 1.198e+00 G3_adv_feat: 1.090e+01 D3_true: 5.421e-01 D3_false: 4.170e-01: 100%|██████████| 533/533 [01:19<00:00,  6.69it/s]


Training epoch 23...


G1_vgg: 2.631e-01 G1_adv: 7.791e-01 G1_adv_feat: 1.704e+00 D1_true: 7.488e-01 D1_false: 7.484e-01 G2_vgg: 8.644e-01 G2_adv: 1.215e+00 G2_adv_feat: 9.111e+00 D2_true: 5.855e-01 D2_false: 5.165e-01 G3_vgg: 7.439e-01 G3_adv: 1.256e+00 G3_adv_feat: 1.123e+01 D3_true: 5.430e-01 D3_false: 4.043e-01: 100%|██████████| 533/533 [01:20<00:00,  6.65it/s]


Training epoch 24...


G1_vgg: 2.550e-01 G1_adv: 7.839e-01 G1_adv_feat: 1.674e+00 D1_true: 7.476e-01 D1_false: 7.470e-01 G2_vgg: 8.672e-01 G2_adv: 1.230e+00 G2_adv_feat: 9.189e+00 D2_true: 5.836e-01 D2_false: 5.090e-01 G3_vgg: 7.388e-01 G3_adv: 1.210e+00 G3_adv_feat: 1.097e+01 D3_true: 5.504e-01 D3_false: 4.107e-01: 100%|██████████| 533/533 [01:19<00:00,  6.69it/s]


Training epoch 25...


G1_vgg: 2.474e-01 G1_adv: 7.803e-01 G1_adv_feat: 1.637e+00 D1_true: 7.467e-01 D1_false: 7.470e-01 G2_vgg: 8.660e-01 G2_adv: 1.258e+00 G2_adv_feat: 9.228e+00 D2_true: 5.698e-01 D2_false: 4.979e-01 G3_vgg: 7.399e-01 G3_adv: 1.261e+00 G3_adv_feat: 1.132e+01 D3_true: 5.565e-01 D3_false: 4.017e-01: 100%|██████████| 533/533 [01:19<00:00,  6.70it/s]


Training epoch 26...


G1_vgg: 2.399e-01 G1_adv: 7.854e-01 G1_adv_feat: 1.588e+00 D1_true: 7.476e-01 D1_false: 7.477e-01 G2_vgg: 8.699e-01 G2_adv: 1.237e+00 G2_adv_feat: 9.329e+00 D2_true: 5.745e-01 D2_false: 5.131e-01 G3_vgg: 7.449e-01 G3_adv: 1.220e+00 G3_adv_feat: 1.122e+01 D3_true: 5.487e-01 D3_false: 4.089e-01: 100%|██████████| 533/533 [01:19<00:00,  6.69it/s]


Training epoch 27...


G1_vgg: 2.329e-01 G1_adv: 7.839e-01 G1_adv_feat: 1.559e+00 D1_true: 7.465e-01 D1_false: 7.459e-01 G2_vgg: 8.690e-01 G2_adv: 1.286e+00 G2_adv_feat: 9.357e+00 D2_true: 5.628e-01 D2_false: 4.936e-01 G3_vgg: 7.409e-01 G3_adv: 1.262e+00 G3_adv_feat: 1.138e+01 D3_true: 5.407e-01 D3_false: 3.960e-01: 100%|██████████| 533/533 [01:19<00:00,  6.67it/s]


Training epoch 28...


G1_vgg: 2.257e-01 G1_adv: 7.839e-01 G1_adv_feat: 1.515e+00 D1_true: 7.479e-01 D1_false: 7.471e-01 G2_vgg: 8.632e-01 G2_adv: 1.261e+00 G2_adv_feat: 9.330e+00 D2_true: 5.732e-01 D2_false: 5.134e-01 G3_vgg: 7.454e-01 G3_adv: 1.256e+00 G3_adv_feat: 1.149e+01 D3_true: 5.581e-01 D3_false: 4.099e-01: 100%|██████████| 533/533 [01:21<00:00,  6.56it/s]


Training epoch 29...


G1_vgg: 2.186e-01 G1_adv: 7.831e-01 G1_adv_feat: 1.474e+00 D1_true: 7.480e-01 D1_false: 7.482e-01 G2_vgg: 8.649e-01 G2_adv: 1.311e+00 G2_adv_feat: 9.317e+00 D2_true: 5.392e-01 D2_false: 4.886e-01 G3_vgg: 7.413e-01 G3_adv: 1.246e+00 G3_adv_feat: 1.154e+01 D3_true: 5.630e-01 D3_false: 3.888e-01: 100%|██████████| 533/533 [01:18<00:00,  6.76it/s]


Training epoch 30...


G1_vgg: 2.147e-01 G1_adv: 7.867e-01 G1_adv_feat: 1.457e+00 D1_true: 7.480e-01 D1_false: 7.450e-01 G2_vgg: 8.990e-01 G2_adv: 1.348e+00 G2_adv_feat: 9.479e+00 D2_true: 5.364e-01 D2_false: 4.807e-01 G3_vgg: 7.398e-01 G3_adv: 1.223e+00 G3_adv_feat: 1.120e+01 D3_true: 5.361e-01 D3_false: 4.141e-01:  21%|██        | 112/533 [00:16<01:00,  6.92it/s]


KeyboardInterrupt: 