In [33]:
# !pip install nbimporter


In [34]:
import os
import random
import numpy as np
from collections import deque
import copy
import cv2
import nbimporter
import SRGAN_body 
import Model_A_Binary_Classifier 

from matplotlib import pyplot as plt
import torch
from torch import nn
from torch.nn.utils import clip_grad_norm_, spectral_norm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import random_split, DataLoader, Dataset
from tqdm import tqdm


In [35]:
from SRGAN_body import weights_init, calculate_psnr, \
    Generator, Discriminator, FeatureExtractor

In [36]:
from Model_A_Binary_Classifier import set_seed, AnimalImage, get_transformation, get_test_transformation

In [37]:
class MergedDataset(Dataset):
    def __init__(self, dataset_32, dataset_128):
        self.dataset_32 = dataset_32
        self.dataset_128 = dataset_128
        assert len(dataset_32) == len(dataset_128), "Datasets must be the same size"

    def __len__(self):
        return len(self.dataset_32)

    def __getitem__(self, idx):
        data_32 = self.dataset_32[idx]
        data_128 = self.dataset_128[idx]
        return (data_128, data_32)

In [38]:
def save_images(hr_image, lr_image, fake_image, label, epoch, show=False):
    lr_image = np.transpose(lr_image.cpu().numpy(), (1, 2, 0))
    hr_image = np.transpose(hr_image.cpu().numpy(), (1, 2, 0))
    fake_image = np.transpose(fake_image.cpu().numpy(), (1, 2, 0))

    name = "Cat" if label == 1 else "Dog"
    cv2.imwrite(f"results/{name}_{epoch}_generated.png", fake_image * 255)
    cv2.imwrite(f"results/{name}_{epoch}_hr.png", hr_image * 255)
    cv2.imwrite(f"results/{name}_{epoch}_lr.png", lr_image * 255)

    if show:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        axs = axs.ravel()

        axs[0].imshow(lr_image, cmap='gray')
        axs[0].set_title('Low Resolution')

        axs[1].imshow(hr_image, cmap='gray')
        axs[1].set_title('High Resolution')

        axs[2].imshow(fake_image, cmap='gray')
        axs[2].set_title('Generated Image')

        for ax in axs:
            ax.set_xticks([])
            ax.set_yticks([])

        # fig.suptitle(definitions[label])
        plt.show()


In [39]:
def apply_spectral_norm(model):
    for module in model.modules():
        if isinstance(module, torch.nn.Conv2d):
            torch.nn.utils.spectral_norm(module)

In [40]:
def train_SRGAN(gen, disc, feature_extractor, train_loader, val_loader,
                epochs, device="cpu"):
    gen.train()
    disc.train()
    best_psnr = 0
    epochs_since_improvement = 0

    # Define loss functions
    criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
    criterion_content = torch.nn.MSELoss().to(device)
    criterion_feature_matching = torch.nn.MSELoss().to(device)

    # Setup AdamW optimizers with different learning rates
    optimizer_G = torch.optim.AdamW(gen.parameters(), lr=1e-5, betas=(0.5, 0.999))
    optimizer_D = torch.optim.AdamW(disc.parameters(), lr=1e-6, betas=(0.5, 0.999))

    # Scheduler for the optimizers (Reduce learning rate when a metric has stopped improving)
    scheduler_G = ReduceLROnPlateau(optimizer_G, mode='min', factor=0.1, patience=10, verbose=True)
    scheduler_D = ReduceLROnPlateau(optimizer_D, mode='min', factor=0.1, patience=10, verbose=True)

    # Initialize historical discriminators
    historical_discriminators = []
    historical_disc = Discriminator().to(device)  # Create one instance to be used for all historical states

    # Training loop
    for epoch in range(epochs):
        genLossSum = 0
        disLossSum = 0

        for i, (high_data, low_data) in enumerate(tqdm(train_loader)):
            # Transfer data to device
            low_res = low_data["img"].to(device)
            high_res = high_data["img"].to(device)

            # Adversarial ground truths with label smoothing and noise
            valid = torch.FloatTensor(high_res.size(0), 1).uniform_(0.7, 1.2).to(device)
            fake = torch.FloatTensor(high_res.size(0), 1).uniform_(0.0, 0.3).to(device)
            valid = torch.clamp(valid, 0, 1)
            fake = torch.clamp(fake, 0, 1)

            # ------------------
            #  Train Generator
            # ------------------
            optimizer_G.zero_grad()

            # Generate a high resolution image from low resolution input
            gen_high_res = gen(low_res)

            # Adversarial loss
            validity = disc(gen_high_res)
            loss_GAN = criterion_GAN(validity, valid)

            # Content loss
            loss_content = criterion_content(gen_high_res, high_res)

            # Feature matching loss
            real_features = feature_extractor(high_res).detach()
            fake_features = feature_extractor(gen_high_res)
            loss_feature_matching = criterion_feature_matching(fake_features, real_features)

            # Total loss for the generator
            loss_G = loss_content + 0.001 * loss_GAN + 0.006 * loss_feature_matching
            loss_G.backward()
            clip_grad_norm_(gen.parameters(), 1)  # Clip gradients for generator
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = criterion_GAN(disc(high_res), valid)
            fake_loss = criterion_GAN(disc(gen_high_res.detach()), fake)
            loss_D = (real_loss + fake_loss) / 2

            # Historical averaging
            historical_losses = []
            for disc_state in historical_discriminators:
                historical_disc.load_state_dict(disc_state)
                historical_disc.eval()  # Set to evaluation mode for inference
                with torch.no_grad():  # No need to track gradients for historical discriminators
                    historical_losses.append(criterion_GAN(historical_disc(high_res), valid))

            if historical_losses:
                historical_loss = sum(historical_losses) / len(historical_losses)
                loss_D += 0.001 * historical_loss

            loss_D.backward()
            clip_grad_norm_(disc.parameters(), 1)  # Clip gradients for discriminator
            optimizer_D.step()

            # Save the current discriminator state
            if len(historical_discriminators) >= 10:
                historical_discriminators.pop(0)
            historical_discriminators.append(copy.deepcopy(disc.state_dict()))

            # Calculate the losses
            genLossSum += loss_G.item()
            disLossSum += loss_D.item()

        # Update learning rates
        scheduler_G.step(genLossSum)
        scheduler_D.step(disLossSum)
        
         # Clear GPU cache to avoid memory build-up
        torch.cuda.empty_cache()

        # Print the losses
        print(
            f"Epoch {epoch + 1}/{epochs} - Generator Loss: {genLossSum / (i + 1)} - Discriminator Loss: {disLossSum / (i + 1)}")

        psnr_values = []
        gen.eval()  # Set the generator to evaluation mode
        rnd = random.randint(0, batch_size - 1)
        with torch.no_grad():
            for i, (high_res, low_res) in enumerate(val_loader):
                high_res_real = high_res["img"].to(device)
                label = low_res["label"]
                low_res = low_res["img"].to(device)
                high_res_fake = gen(low_res)
                # Save some sample images
                if i == 0:  # Change this condition to save images as needed
                    save_images(high_res_real[rnd], low_res[rnd], high_res_fake[rnd], label[rnd], epoch)

                # Calculate PSNR
                psnr_value = calculate_psnr(high_res_real, high_res_fake, device)
                psnr_values.append(psnr_value)

        avg_psnr = sum(psnr_values) / len(psnr_values)
        print(f"Average PSNR on validation set for epoch {epoch + 1}: {avg_psnr} dB")

        # Check for improvement
        if avg_psnr > best_psnr:
            best_psnr = avg_psnr
            epochs_since_improvement = 0
            torch.save(gen.state_dict(), 'best_generator.pth')
            torch.save(disc.state_dict(), 'best_discriminator.pth')
        else:
            epochs_since_improvement += 1

        # Early stopping
        if epochs_since_improvement == 10:
            print("No improvement in PSNR for 10 consecutive epochs, stopping training.")
            break

        # Set the generator back to train mode
        gen.train()

In [41]:
if __name__ == "__main__":

    batch_size = 32
    seed = 10
    set_seed(seed)
    # Create dataset and split it into train, and validation
    cat_dog_dataset = AnimalImage(data_root="/home/mdnurualabsarsiddiky/Desktop/Absar/ECGR8119/Midterm_2024/dogs_vs_cats/train",
                                   transformation=get_test_transformation(size=128))
    train_data_128, val_data_128, _ = random_split(cat_dog_dataset, [0.65, 0.05, 0.3],
                                                   generator=torch.Generator().manual_seed(seed))

    # Create dataset and split it into train, and validation
    cat_dog_dataset = AnimalImage(data_root="/home/mdnurualabsarsiddiky/Desktop/Absar/ECGR8119/Midterm_2024/dogs_vs_cats/train",
                                   transformation=get_test_transformation(size=32))
    train_data_32, val_data_32, _ = random_split(cat_dog_dataset, [0.65, 0.05, 0.3],
                                                 generator=torch.Generator().manual_seed(seed))

    merged_train_data = MergedDataset(train_data_32, train_data_128)
    merged_val_data = MergedDataset(val_data_32, val_data_128)

    train_loader = DataLoader(merged_train_data, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=False,
                              num_workers=4)
    val_loader = DataLoader(merged_val_data, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=4,
                            drop_last=False)

    # Sets string definitions based on integer values
    definitions = {0: "Dog", 1: "Cat"}
    samples = 10
    epochs = 150
    lr_g = 1e-3  # Learning rate for generator
    lr_d = 1e-4  # Learning rate for discriminator
    device = "cuda:0"
    # Create the generator
    generator = Generator(num_residual_blocks=24).to(device)
    # Run a forward pass with a dummy 32 to initialize parameters
    dummy_input = torch.randn(1, 3, 32, 32).to(device)  # Adjust the size as necessary for your model
    generator(dummy_input)
    # Apply the weight initialization
    generator.apply(weights_init)
    # Create the discriminator
    discriminator = Discriminator().to(device)
    dummy_input = torch.randn(1, 3, 128, 128).to(device)  # The size should match the discriminator's expected input
    discriminator(dummy_input)
    discriminator.apply(weights_init)
    # Instantiate the feature extractor
    feature_extractor = FeatureExtractor().eval()  # Set to evaluation mode

    # Move to the device and make sure to not track gradients
    feature_extractor.to('cuda' if torch.cuda.is_available() else 'cpu')
    for parameter in feature_extractor.parameters():
        parameter.requires_grad = False

    train_SRGAN(gen=generator, disc=discriminator, feature_extractor=feature_extractor,
                train_loader=train_loader, val_loader=val_loader, epochs=epochs,
                device=device)


Seed set to 10
100%|██████████| 507/507 [03:28<00:00,  2.43it/s]

Epoch 1/150 - Generator Loss: 8.445867410072914 - Discriminator Loss: 0.6236695330994134





Average PSNR on validation set for epoch 1: 46.24361801147461 dB


100%|██████████| 507/507 [03:26<00:00,  2.46it/s]

Epoch 2/150 - Generator Loss: 3.8256716229741623 - Discriminator Loss: 0.4947392616634068





Average PSNR on validation set for epoch 2: 47.51630783081055 dB


100%|██████████| 507/507 [03:26<00:00,  2.46it/s]

Epoch 3/150 - Generator Loss: 3.2516847246967595 - Discriminator Loss: 0.4258946911587988





Average PSNR on validation set for epoch 3: 48.31959915161133 dB


100%|██████████| 507/507 [03:26<00:00,  2.45it/s]

Epoch 4/150 - Generator Loss: 2.9258799005072027 - Discriminator Loss: 0.40109030293995107





Average PSNR on validation set for epoch 4: 49.028167724609375 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 5/150 - Generator Loss: 2.724883165350092 - Discriminator Loss: 0.3887469950159626





Average PSNR on validation set for epoch 5: 49.442256927490234 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 6/150 - Generator Loss: 2.598806256373253 - Discriminator Loss: 0.38329430815030835





Average PSNR on validation set for epoch 6: 49.89835739135742 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 7/150 - Generator Loss: 2.4969367268522817 - Discriminator Loss: 0.37972468014299515





Average PSNR on validation set for epoch 7: 50.20587921142578 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 8/150 - Generator Loss: 2.390003817321281 - Discriminator Loss: 0.37545060388435275





Average PSNR on validation set for epoch 8: 50.49558639526367 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 9/150 - Generator Loss: 2.302128168487925 - Discriminator Loss: 0.37250262587028143





Average PSNR on validation set for epoch 9: 50.67613983154297 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 10/150 - Generator Loss: 2.2385949169155173 - Discriminator Loss: 0.3730031282473833





Average PSNR on validation set for epoch 10: 50.888179779052734 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 11/150 - Generator Loss: 2.173594906720419 - Discriminator Loss: 0.3718345900611765





Average PSNR on validation set for epoch 11: 51.176326751708984 dB


100%|██████████| 507/507 [03:26<00:00,  2.46it/s]

Epoch 12/150 - Generator Loss: 2.1023623146248993 - Discriminator Loss: 0.3723484237753663





Average PSNR on validation set for epoch 12: 51.33570098876953 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 13/150 - Generator Loss: 2.0498493742895785 - Discriminator Loss: 0.3679527302935749





Average PSNR on validation set for epoch 13: 51.461181640625 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 14/150 - Generator Loss: 1.9983774872691438 - Discriminator Loss: 0.36858196847537567





Average PSNR on validation set for epoch 14: 51.60194778442383 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 15/150 - Generator Loss: 1.9632683115362886 - Discriminator Loss: 0.3691092768952221





Average PSNR on validation set for epoch 15: 51.79150390625 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 16/150 - Generator Loss: 1.908023936743802 - Discriminator Loss: 0.37095085915023757





Average PSNR on validation set for epoch 16: 51.87042999267578 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 17/150 - Generator Loss: 1.8862027595264201 - Discriminator Loss: 0.36816320804919483





Average PSNR on validation set for epoch 17: 52.124263763427734 dB


100%|██████████| 507/507 [03:23<00:00,  2.49it/s]

Epoch 18/150 - Generator Loss: 1.8512784577685701 - Discriminator Loss: 0.3660199474065732





Average PSNR on validation set for epoch 18: 52.147674560546875 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 19/150 - Generator Loss: 1.8122864126923992 - Discriminator Loss: 0.36723650148399245





Average PSNR on validation set for epoch 19: 52.10749053955078 dB


100%|██████████| 507/507 [03:25<00:00,  2.46it/s]

Epoch 20/150 - Generator Loss: 1.7775229314375205 - Discriminator Loss: 0.3657993891182736





Average PSNR on validation set for epoch 20: 52.303009033203125 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 21/150 - Generator Loss: 1.7407547677527282 - Discriminator Loss: 0.3666233279883744





Average PSNR on validation set for epoch 21: 52.27093505859375 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 22/150 - Generator Loss: 1.7100287073932927 - Discriminator Loss: 0.3690716104041895





Average PSNR on validation set for epoch 22: 52.5561408996582 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 23/150 - Generator Loss: 1.6906639527992384 - Discriminator Loss: 0.3666894174891816





Average PSNR on validation set for epoch 23: 52.57078170776367 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 24/150 - Generator Loss: 1.6661596471035973 - Discriminator Loss: 0.3684050367311143





Average PSNR on validation set for epoch 24: 52.60528564453125 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 25/150 - Generator Loss: 1.625820533881291 - Discriminator Loss: 0.3681782171923733





Average PSNR on validation set for epoch 25: 52.74153518676758 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 26/150 - Generator Loss: 1.6091833566067486 - Discriminator Loss: 0.3665512176542828





Average PSNR on validation set for epoch 26: 52.77332305908203 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 27/150 - Generator Loss: 1.5812549437056396 - Discriminator Loss: 0.36623633043065346





Average PSNR on validation set for epoch 27: 52.758270263671875 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 28/150 - Generator Loss: 1.5642098537563571 - Discriminator Loss: 0.36569602890362635





Average PSNR on validation set for epoch 28: 52.95424270629883 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 29/150 - Generator Loss: 1.5333328312905818 - Discriminator Loss: 0.36819210736709235





Average PSNR on validation set for epoch 29: 52.96441650390625 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 30/150 - Generator Loss: 1.520469663881456 - Discriminator Loss: 0.36432930562623156





Average PSNR on validation set for epoch 30: 53.06980514526367 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 31/150 - Generator Loss: 1.491246321732711 - Discriminator Loss: 0.36590526401408796





Average PSNR on validation set for epoch 31: 53.15208053588867 dB


100%|██████████| 507/507 [03:24<00:00,  2.47it/s]

Epoch 32/150 - Generator Loss: 1.4690235868478432 - Discriminator Loss: 0.3669504560309754





Average PSNR on validation set for epoch 32: 53.14925003051758 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 33/150 - Generator Loss: 1.4617862399278072 - Discriminator Loss: 0.3662076373189629





Average PSNR on validation set for epoch 33: 53.19648361206055 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 34/150 - Generator Loss: 1.4316276416505818 - Discriminator Loss: 0.36526397846389336





Average PSNR on validation set for epoch 34: 53.20427322387695 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 35/150 - Generator Loss: 1.4177043268666465 - Discriminator Loss: 0.3635050057189234





Average PSNR on validation set for epoch 35: 53.27700424194336 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 36/150 - Generator Loss: 1.3909474051680555 - Discriminator Loss: 0.364446583878124





Average PSNR on validation set for epoch 36: 53.43949508666992 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 37/150 - Generator Loss: 1.3789389077022935 - Discriminator Loss: 0.3633651624064474





Average PSNR on validation set for epoch 37: 53.4709587097168 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 38/150 - Generator Loss: 1.3588339460672005 - Discriminator Loss: 0.3653279156727198





Average PSNR on validation set for epoch 38: 53.56309127807617 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 39/150 - Generator Loss: 1.345283312557717 - Discriminator Loss: 0.36413237793440884





Average PSNR on validation set for epoch 39: 53.54164505004883 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 40/150 - Generator Loss: 1.3248656999900261 - Discriminator Loss: 0.36485608318853663





Average PSNR on validation set for epoch 40: 53.5218620300293 dB


100%|██████████| 507/507 [03:24<00:00,  2.47it/s]

Epoch 41/150 - Generator Loss: 1.3122417106195785 - Discriminator Loss: 0.3646486398618837





Average PSNR on validation set for epoch 41: 53.60142135620117 dB


100%|██████████| 507/507 [03:26<00:00,  2.46it/s]

Epoch 42/150 - Generator Loss: 1.2949837416821919 - Discriminator Loss: 0.36377774017332104





Average PSNR on validation set for epoch 42: 53.69816970825195 dB


100%|██████████| 507/507 [03:25<00:00,  2.46it/s]

Epoch 43/150 - Generator Loss: 1.2836038842944233 - Discriminator Loss: 0.3648703449223874





Average PSNR on validation set for epoch 43: 53.68769073486328 dB


100%|██████████| 507/507 [03:26<00:00,  2.46it/s]

Epoch 44/150 - Generator Loss: 1.2693722343303748 - Discriminator Loss: 0.3625728947641346





Average PSNR on validation set for epoch 44: 53.876033782958984 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 45/150 - Generator Loss: 1.2573473393564394 - Discriminator Loss: 0.3654031948107469





Average PSNR on validation set for epoch 45: 53.72756576538086 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 46/150 - Generator Loss: 1.241044681335577 - Discriminator Loss: 0.3640593101051902





Average PSNR on validation set for epoch 46: 53.90177536010742 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 47/150 - Generator Loss: 1.232711358653488 - Discriminator Loss: 0.36524821274144176





Average PSNR on validation set for epoch 47: 53.88472366333008 dB


100%|██████████| 507/507 [03:24<00:00,  2.47it/s]

Epoch 48/150 - Generator Loss: 1.2147323096527385 - Discriminator Loss: 0.36442446920293325





Average PSNR on validation set for epoch 48: 53.76026153564453 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 49/150 - Generator Loss: 1.20551215861676 - Discriminator Loss: 0.36383677016818783





Average PSNR on validation set for epoch 49: 53.95154571533203 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 50/150 - Generator Loss: 1.1892767527399684 - Discriminator Loss: 0.3649316333689868





Average PSNR on validation set for epoch 50: 54.03174591064453 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 51/150 - Generator Loss: 1.1860073421598656 - Discriminator Loss: 0.3640191194103549





Average PSNR on validation set for epoch 51: 53.93659591674805 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 52/150 - Generator Loss: 1.172135128424718 - Discriminator Loss: 0.36711821036461073





Average PSNR on validation set for epoch 52: 53.9938850402832 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 53/150 - Generator Loss: 1.1632920210883462 - Discriminator Loss: 0.36608686055657425





Average PSNR on validation set for epoch 53: 53.97730255126953 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 54/150 - Generator Loss: 1.1437210091472378 - Discriminator Loss: 0.3651806195459422





Average PSNR on validation set for epoch 54: 54.10622024536133 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 55/150 - Generator Loss: 1.126633738624979 - Discriminator Loss: 0.36516621667722743





Average PSNR on validation set for epoch 55: 54.0724983215332 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 56/150 - Generator Loss: 1.1274806681939602 - Discriminator Loss: 0.36473357418820235





Average PSNR on validation set for epoch 56: 54.25751876831055 dB


100%|██████████| 507/507 [03:27<00:00,  2.45it/s]

Epoch 57/150 - Generator Loss: 1.1169888289017085 - Discriminator Loss: 0.36404432794281244





Average PSNR on validation set for epoch 57: 54.1523551940918 dB


100%|██████████| 507/507 [03:26<00:00,  2.46it/s]

Epoch 58/150 - Generator Loss: 1.108026935504033 - Discriminator Loss: 0.3629579731581009





Average PSNR on validation set for epoch 58: 54.21673583984375 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 59/150 - Generator Loss: 1.0945898085185997 - Discriminator Loss: 0.365994780080088





Average PSNR on validation set for epoch 59: 54.12691116333008 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 60/150 - Generator Loss: 1.083158016204834 - Discriminator Loss: 0.36402477861861504





Average PSNR on validation set for epoch 60: 54.18302536010742 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 61/150 - Generator Loss: 1.0778387429916412 - Discriminator Loss: 0.36376640529792453





Average PSNR on validation set for epoch 61: 54.336212158203125 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 62/150 - Generator Loss: 1.0650695373085595 - Discriminator Loss: 0.36462413712130964





Average PSNR on validation set for epoch 62: 54.35433578491211 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 63/150 - Generator Loss: 1.0541888819878857 - Discriminator Loss: 0.36364339718216737





Average PSNR on validation set for epoch 63: 54.38190460205078 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 64/150 - Generator Loss: 1.0524908875572612 - Discriminator Loss: 0.36471590399742126





Average PSNR on validation set for epoch 64: 54.32004928588867 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 65/150 - Generator Loss: 1.0378124122553793 - Discriminator Loss: 0.3648638983332429





Average PSNR on validation set for epoch 65: 54.434730529785156 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 66/150 - Generator Loss: 1.036761811616623 - Discriminator Loss: 0.3643178595946385





Average PSNR on validation set for epoch 66: 54.3920783996582 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 67/150 - Generator Loss: 1.0194703077188376 - Discriminator Loss: 0.36487122182780235





Average PSNR on validation set for epoch 67: 54.50261306762695 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 68/150 - Generator Loss: 1.011619448896931 - Discriminator Loss: 0.3661404624140474





Average PSNR on validation set for epoch 68: 54.48934555053711 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 69/150 - Generator Loss: 1.0063020022663138 - Discriminator Loss: 0.3649842829041227





Average PSNR on validation set for epoch 69: 54.395259857177734 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 70/150 - Generator Loss: 0.996083654007733 - Discriminator Loss: 0.36678352334795616





Average PSNR on validation set for epoch 70: 54.475154876708984 dB


100%|██████████| 507/507 [03:25<00:00,  2.46it/s]

Epoch 71/150 - Generator Loss: 0.9901779359849482 - Discriminator Loss: 0.3654566084492136





Average PSNR on validation set for epoch 71: 54.497650146484375 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 72/150 - Generator Loss: 0.9853447576951698 - Discriminator Loss: 0.36605897833844847





Average PSNR on validation set for epoch 72: 54.48891830444336 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 73/150 - Generator Loss: 0.9720454659217443 - Discriminator Loss: 0.3665255806620069





Average PSNR on validation set for epoch 73: 54.52702713012695 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 74/150 - Generator Loss: 0.9679385755189072 - Discriminator Loss: 0.3642153989163611





Average PSNR on validation set for epoch 74: 54.478363037109375 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 75/150 - Generator Loss: 0.9566735083299746 - Discriminator Loss: 0.3643917681197443





Average PSNR on validation set for epoch 75: 54.633888244628906 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 76/150 - Generator Loss: 0.9529841168392339 - Discriminator Loss: 0.36347359283671105





Average PSNR on validation set for epoch 76: 54.594295501708984 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 77/150 - Generator Loss: 0.9461367160846026 - Discriminator Loss: 0.3665943471402576





Average PSNR on validation set for epoch 77: 54.692237854003906 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 78/150 - Generator Loss: 0.9388041122425237 - Discriminator Loss: 0.3649735523633938





Average PSNR on validation set for epoch 78: 54.601619720458984 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 79/150 - Generator Loss: 0.9294101953271343 - Discriminator Loss: 0.36514774399866484





Average PSNR on validation set for epoch 79: 54.68299102783203 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 80/150 - Generator Loss: 0.9278866268944223 - Discriminator Loss: 0.36588517754270716





Average PSNR on validation set for epoch 80: 54.78017807006836 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 81/150 - Generator Loss: 0.9226874572755788 - Discriminator Loss: 0.3665394071173621





Average PSNR on validation set for epoch 81: 54.75513458251953 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 82/150 - Generator Loss: 0.9104920640969888 - Discriminator Loss: 0.36430150711324794





Average PSNR on validation set for epoch 82: 54.770843505859375 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 83/150 - Generator Loss: 0.9020964389251769 - Discriminator Loss: 0.36218199780471694





Average PSNR on validation set for epoch 83: 54.6517219543457 dB


100%|██████████| 507/507 [03:27<00:00,  2.45it/s]

Epoch 84/150 - Generator Loss: 0.8972653274000044 - Discriminator Loss: 0.3662517130139782





Average PSNR on validation set for epoch 84: 54.86133575439453 dB


100%|██████████| 507/507 [03:26<00:00,  2.46it/s]

Epoch 85/150 - Generator Loss: 0.8907801448946169 - Discriminator Loss: 0.36438920517880063





Average PSNR on validation set for epoch 85: 54.78298568725586 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 86/150 - Generator Loss: 0.882398362343128 - Discriminator Loss: 0.36440629736911617





Average PSNR on validation set for epoch 86: 54.7578010559082 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 87/150 - Generator Loss: 0.8754564052267658 - Discriminator Loss: 0.3656778532489987





Average PSNR on validation set for epoch 87: 54.81816482543945 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 88/150 - Generator Loss: 0.8719692384232666 - Discriminator Loss: 0.3660483106941396





Average PSNR on validation set for epoch 88: 54.8297233581543 dB


100%|██████████| 507/507 [03:25<00:00,  2.46it/s]

Epoch 89/150 - Generator Loss: 0.8633832822184591 - Discriminator Loss: 0.36442936689425737





Average PSNR on validation set for epoch 89: 54.86068344116211 dB


100%|██████████| 507/507 [03:27<00:00,  2.44it/s]

Epoch 90/150 - Generator Loss: 0.8601954388312805 - Discriminator Loss: 0.36457732043557856





Average PSNR on validation set for epoch 90: 54.92749786376953 dB


100%|██████████| 507/507 [03:27<00:00,  2.45it/s]

Epoch 91/150 - Generator Loss: 0.8529285141937362 - Discriminator Loss: 0.36492619717850017





Average PSNR on validation set for epoch 91: 54.81856155395508 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 92/150 - Generator Loss: 0.844328097102675 - Discriminator Loss: 0.36383330710305734





Average PSNR on validation set for epoch 92: 54.92055130004883 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 93/150 - Generator Loss: 0.843089644372816 - Discriminator Loss: 0.36595274881263223





Average PSNR on validation set for epoch 93: 54.969356536865234 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 94/150 - Generator Loss: 0.8397542867199674 - Discriminator Loss: 0.36616695567936114





Average PSNR on validation set for epoch 94: 54.97880935668945 dB


100%|██████████| 507/507 [03:24<00:00,  2.47it/s]

Epoch 95/150 - Generator Loss: 0.827791907848456 - Discriminator Loss: 0.36381285424533444





Average PSNR on validation set for epoch 95: 55.04548263549805 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 96/150 - Generator Loss: 0.8213386767244433 - Discriminator Loss: 0.3661485956853193





Average PSNR on validation set for epoch 96: 54.9566650390625 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 97/150 - Generator Loss: 0.8160472512245178 - Discriminator Loss: 0.36454173634508424





Average PSNR on validation set for epoch 97: 55.00181198120117 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 98/150 - Generator Loss: 0.8109593775850781 - Discriminator Loss: 0.3641351203946672





Average PSNR on validation set for epoch 98: 55.06757736206055 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 99/150 - Generator Loss: 0.8044818053113874 - Discriminator Loss: 0.3652168862683298





Average PSNR on validation set for epoch 99: 55.09757614135742 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 100/150 - Generator Loss: 0.7975642736376625 - Discriminator Loss: 0.36520826563797526





Average PSNR on validation set for epoch 100: 55.07421875 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 101/150 - Generator Loss: 0.7945449527433873 - Discriminator Loss: 0.3641912158894586





Average PSNR on validation set for epoch 101: 55.23804473876953 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 102/150 - Generator Loss: 0.7889350485754671 - Discriminator Loss: 0.36525002098412673





Average PSNR on validation set for epoch 102: 55.05287551879883 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 103/150 - Generator Loss: 0.7818902488173347 - Discriminator Loss: 0.36498548064006153





Average PSNR on validation set for epoch 103: 55.05230712890625 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 104/150 - Generator Loss: 0.779811381941011 - Discriminator Loss: 0.36498126211251025





Average PSNR on validation set for epoch 104: 55.080657958984375 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 105/150 - Generator Loss: 0.7710764617374428 - Discriminator Loss: 0.36510228640464165





Average PSNR on validation set for epoch 105: 55.2003059387207 dB


100%|██████████| 507/507 [03:26<00:00,  2.45it/s]

Epoch 106/150 - Generator Loss: 0.7716576086698904 - Discriminator Loss: 0.3651625909043487





Average PSNR on validation set for epoch 106: 55.1777229309082 dB


100%|██████████| 507/507 [03:26<00:00,  2.46it/s]

Epoch 107/150 - Generator Loss: 0.7626209430087952 - Discriminator Loss: 0.3669130742785023





Average PSNR on validation set for epoch 107: 55.1212272644043 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 108/150 - Generator Loss: 0.7605829513989962 - Discriminator Loss: 0.36366084200390697





Average PSNR on validation set for epoch 108: 55.14411544799805 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 109/150 - Generator Loss: 0.7539018968623535 - Discriminator Loss: 0.36462199676201423





Average PSNR on validation set for epoch 109: 55.1330451965332 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 110/150 - Generator Loss: 0.7463894044388916 - Discriminator Loss: 0.36641249466224535





Average PSNR on validation set for epoch 110: 55.218135833740234 dB


100%|██████████| 507/507 [03:24<00:00,  2.47it/s]

Epoch 111/150 - Generator Loss: 0.745053122617319 - Discriminator Loss: 0.36393643818663424





Average PSNR on validation set for epoch 111: 55.333919525146484 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 112/150 - Generator Loss: 0.7373065426504823 - Discriminator Loss: 0.3656440946830095





Average PSNR on validation set for epoch 112: 55.25189971923828 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 113/150 - Generator Loss: 0.732897055102053 - Discriminator Loss: 0.36467049325241374





Average PSNR on validation set for epoch 113: 55.198272705078125 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 114/150 - Generator Loss: 0.7279311510232779 - Discriminator Loss: 0.36519053208052055





Average PSNR on validation set for epoch 114: 55.26275634765625 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 115/150 - Generator Loss: 0.7247989948919774 - Discriminator Loss: 0.36640802137480216





Average PSNR on validation set for epoch 115: 55.39803695678711 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 116/150 - Generator Loss: 0.7205634485099913 - Discriminator Loss: 0.36431317833753735





Average PSNR on validation set for epoch 116: 55.23508834838867 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 117/150 - Generator Loss: 0.7177322841372481 - Discriminator Loss: 0.36370464783213313





Average PSNR on validation set for epoch 117: 55.27956771850586 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 118/150 - Generator Loss: 0.71176485972997 - Discriminator Loss: 0.36629651431031013





Average PSNR on validation set for epoch 118: 55.29615020751953 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 119/150 - Generator Loss: 0.7086840356595417 - Discriminator Loss: 0.36582176584228726





Average PSNR on validation set for epoch 119: 55.26626205444336 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 120/150 - Generator Loss: 0.7018570550800075 - Discriminator Loss: 0.36298964532639616





Average PSNR on validation set for epoch 120: 55.277523040771484 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 121/150 - Generator Loss: 0.6971062435318497 - Discriminator Loss: 0.3650661272527668





Average PSNR on validation set for epoch 121: 55.344451904296875 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 122/150 - Generator Loss: 0.6930490410069272 - Discriminator Loss: 0.36583533783165895





Average PSNR on validation set for epoch 122: 55.31608963012695 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 123/150 - Generator Loss: 0.6890897776130623 - Discriminator Loss: 0.3631160025883473





Average PSNR on validation set for epoch 123: 55.34975814819336 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 124/150 - Generator Loss: 0.6840475179857521 - Discriminator Loss: 0.36502604736143784





Average PSNR on validation set for epoch 124: 55.483577728271484 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 125/150 - Generator Loss: 0.6808296159291879 - Discriminator Loss: 0.36567979445589127





Average PSNR on validation set for epoch 125: 55.38534164428711 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 126/150 - Generator Loss: 0.6779609829833992 - Discriminator Loss: 0.3654625799058692





Average PSNR on validation set for epoch 126: 55.419979095458984 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 127/150 - Generator Loss: 0.6755002091034394 - Discriminator Loss: 0.3644945060125234





Average PSNR on validation set for epoch 127: 55.58315658569336 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 128/150 - Generator Loss: 0.668398513949129 - Discriminator Loss: 0.36581494610690507





Average PSNR on validation set for epoch 128: 55.411476135253906 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 129/150 - Generator Loss: 0.6647427254879968 - Discriminator Loss: 0.3646735441990388





Average PSNR on validation set for epoch 129: 55.50555419921875 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 130/150 - Generator Loss: 0.6599185760674863 - Discriminator Loss: 0.3645207573911377





Average PSNR on validation set for epoch 130: 55.48220443725586 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 131/150 - Generator Loss: 0.6582747934132638 - Discriminator Loss: 0.36622753715844314





Average PSNR on validation set for epoch 131: 55.531463623046875 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 132/150 - Generator Loss: 0.6542617928464503 - Discriminator Loss: 0.36540899521265274





Average PSNR on validation set for epoch 132: 55.41168212890625 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 133/150 - Generator Loss: 0.6507042234348357 - Discriminator Loss: 0.3658305618184558





Average PSNR on validation set for epoch 133: 55.639156341552734 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 134/150 - Generator Loss: 0.6459016458405076 - Discriminator Loss: 0.3659945293761334





Average PSNR on validation set for epoch 134: 55.491878509521484 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 135/150 - Generator Loss: 0.6421870702117153 - Discriminator Loss: 0.3652859322535686





Average PSNR on validation set for epoch 135: 55.55810546875 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 136/150 - Generator Loss: 0.6376330796078112 - Discriminator Loss: 0.36484551441504876





Average PSNR on validation set for epoch 136: 55.578426361083984 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 137/150 - Generator Loss: 0.6350324823423721 - Discriminator Loss: 0.3664535599112276





Average PSNR on validation set for epoch 137: 55.514495849609375 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 138/150 - Generator Loss: 0.6326039978973494 - Discriminator Loss: 0.3661435938328211





Average PSNR on validation set for epoch 138: 55.537818908691406 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 139/150 - Generator Loss: 0.6263467825258507 - Discriminator Loss: 0.36457499426732637





Average PSNR on validation set for epoch 139: 55.68385696411133 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 140/150 - Generator Loss: 0.6225134219526068 - Discriminator Loss: 0.36461316775052977





Average PSNR on validation set for epoch 140: 55.5973014831543 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 141/150 - Generator Loss: 0.6210246156658646 - Discriminator Loss: 0.3638759790910536





Average PSNR on validation set for epoch 141: 55.6185188293457 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 142/150 - Generator Loss: 0.6176115715879895 - Discriminator Loss: 0.36330656428073754





Average PSNR on validation set for epoch 142: 55.717811584472656 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 143/150 - Generator Loss: 0.6144720892581714 - Discriminator Loss: 0.3647657416511099





Average PSNR on validation set for epoch 143: 55.6130256652832 dB


100%|██████████| 507/507 [03:25<00:00,  2.46it/s]

Epoch 144/150 - Generator Loss: 0.6091331863074143 - Discriminator Loss: 0.3641623859339682





Average PSNR on validation set for epoch 144: 55.629940032958984 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 145/150 - Generator Loss: 0.6094561912251647 - Discriminator Loss: 0.3638029422397914





Average PSNR on validation set for epoch 145: 55.7081298828125 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 146/150 - Generator Loss: 0.6047518258616769 - Discriminator Loss: 0.36441113139985815





Average PSNR on validation set for epoch 146: 55.685333251953125 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 147/150 - Generator Loss: 0.6011139969148579 - Discriminator Loss: 0.36526539604339375





Average PSNR on validation set for epoch 147: 55.67371368408203 dB


100%|██████████| 507/507 [03:24<00:00,  2.48it/s]

Epoch 148/150 - Generator Loss: 0.593902774228617 - Discriminator Loss: 0.36296016389331404





Average PSNR on validation set for epoch 148: 55.66559982299805 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 149/150 - Generator Loss: 0.5936466058095297 - Discriminator Loss: 0.36690531591691916





Average PSNR on validation set for epoch 149: 55.738800048828125 dB


100%|██████████| 507/507 [03:25<00:00,  2.47it/s]

Epoch 150/150 - Generator Loss: 0.5916595290280893 - Discriminator Loss: 0.3650570534154978





Average PSNR on validation set for epoch 150: 55.7865104675293 dB


In [31]:
import os
print(os.cpu_count())  # Number of CPU cores

64


In [42]:
import torch

# Print GPU memory usage
gpu_memory = torch.cuda.memory_allocated() / (1024 ** 2)
print(f"GPU Memory Usage: {gpu_memory:.2f} MB")

# Track total GPU memory used, including cached memory
total_gpu_memory = torch.cuda.memory_reserved() / (1024 ** 2)
print(f"Total GPU Memory Reserved: {total_gpu_memory:.2f} MB")


GPU Memory Usage: 596.15 MB
Total GPU Memory Reserved: 1496.00 MB
