In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import random
import numpy as np
import time
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './')))

In [2]:
COCO_ROOT_DIR = r"D:\temp_dataset\coco\images\train2017.1\train2017"  # Thư mục gốc chứa COCO
TRAIN_IMAGES_SUBDIR = ""     # Thư mục con chứa ảnh train2017
TRAIN_ANNOTATIONS_FILENAME = "captions_train2017.json" # File chú thích cho train2017

PATCH_SIZE = (256, 256)  # Kích thước patch mong muốn (height, width)
BATCH_SIZE = 32          # Kích thước batch cho DataLoader
NUM_WORKERS = 4         # Số luồng để tải dữ liệu

# --- Xây dựng đường dẫn đầy đủ ---
train_images_path = os.path.join(COCO_ROOT_DIR, TRAIN_IMAGES_SUBDIR)
train_annotations_path = os.path.join(r"D:\temp_dataset\coco\images\train2017.1\annotations_trainval2017\annotations", TRAIN_ANNOTATIONS_FILENAME)

In [3]:

# ------------------------ Example usage ------------------------

from data_loader import get_coco_patches_loader


BATCH_SIZE = 32  # Adjust based on your GPU memory
NUM_WORKERS = 4  # 0 = no multiprocessing
MAX_IMAGES = 1  # Limit to first 3000 images

train_loader, train_dataset = get_coco_patches_loader(
    data_dir="D:\\ds_coco_patches",
    batch_size=BATCH_SIZE,
    pin_memory=True,
    num_workers=0 if torch.cuda.is_available() else 0,
    shuffle=True,
    cache_size=0,
    max_images=MAX_IMAGES
)

# Move a batch to CUDA and measure its size
for images in train_loader:
    images = images.cuda()  # Move to CUDA
    print(f"Batch size: {images.size(0)}")
    print(f"Batch shape: {images.shape}")  # Should be [B, 3, H, W]
    
    # Calculate memory usage
    batch_memory = images.element_size() * images.nelement() / (1024 ** 2)
    print(f"Estimated memory usage of the batch on CUDA: {batch_memory:.2f} MB")
    break


Found 1 images in D:\ds_coco_patches
Batch size: 1
Batch shape: torch.Size([1, 3, 256, 256])
Estimated memory usage of the batch on CUDA: 0.75 MB


In [4]:
sample_batch = next(iter(train_loader))
print(f"Batch shape: {sample_batch.shape}")
print(f"Estimated batch size in MB: {sample_batch.element_size() * sample_batch.nelement() / 1024**2:.2f}")


Batch shape: torch.Size([1, 3, 256, 256])
Estimated batch size in MB: 0.75


In [5]:
# from cdc_trainable import CDCTrainable
from cdc_trainable import CDCTrainable
from decompressor.diffusion_manager import DiffusionManager


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CDCTrainable(device=device)
model = model.to(device)

UNet Encoder Input-Output Pairs: [(3, 32), (32, 64), (64, 96), (96, 128), (128, 160), (160, 192), (192, 224), (224, 256)]
UNet Encoder Channels Pairs: [(6, 32), (64, 64), (128, 96), (192, 128), (256, 160), (320, 192), (192, 224), (224, 256)]
UNet Decoder Channels Pairs: [(512, 224), (448, 192), (384, 160), (320, 128), (256, 96), (192, 64), (128, 32), (64, 3)]
Context Channels: [3, 32, 64, 96, 128, 160]
UNet Encoder Input-Output Pairs: [(3, 32), (32, 64), (64, 96), (96, 128), (128, 160), (160, 192), (192, 224), (224, 256)]
UNet Encoder Channels Pairs: [(6, 32), (64, 64), (128, 96), (192, 128), (256, 160), (320, 192), (192, 224), (224, 256)]
UNet Decoder Channels Pairs: [(512, 224), (448, 192), (384, 160), (320, 128), (256, 96), (192, 64), (128, 32), (64, 3)]
Context Channels: [3, 32, 64, 96, 128, 160]


In [6]:
# from compressor.compressor import Compressor
# from decompressor.diffusion_manager import DiffusionManager
# from decompressor.unet_module import UnetModule


# # Initialize the compressor
# compressor = Compressor(
#     channel_multiplier=[1, 3, 3, 12, 52, 64],
#     hyperprior_channel_multiplier=[64,64,64]
# )

# # Generate a dummy tensor to simulate input data
# dummy_tensor = torch.randn(64, 3, 256, 256).cuda()

# # Pass the dummy tensor through the compressor
# output_dict = compressor(dummy_tensor)

# # Extract the shapes of the output tensors
# output_shapes = [output.shape[1] for output in output_dict['output']]

# # Initialize the UNet module with the extracted channel dimensions
# unet_module = UnetModule(
#     base_channels=3,
#     context_channels=output_shapes
# )

# del dummy_tensor

# model = DiffusionManager(
#     encoder=compressor,
#     u_net=unet_module,
# )

In [7]:
def get_model_size_in_mb(model):
    total_size = sum(param.element_size() * param.nelement() for param in model.parameters())
    total_size += sum(buffer.element_size() * buffer.nelement() for buffer in model.buffers())
    return total_size / (1024 ** 2)  # Convert bytes to MB

model_size_mb = get_model_size_in_mb(model)
print(f"Model size: {model_size_mb:.2f} MB")

Model size: 187.36 MB


In [8]:
torch.cuda.empty_cache()
import gc
gc.collect()

0

In [9]:
def draw_image(input, save_dir="visualizations"):
    os.makedirs(save_dir, exist_ok=True)
    image = input[0].cpu().detach().numpy()
    image = np.transpose(image, (1, 2, 0))
    image = (image * 255).astype(np.uint8)
    image = Image.fromarray(image)
    # Generate a unique filename using timestamp
    filename = f"image_{int(time.time() * 1000)}.png"
    save_path = os.path.join(save_dir, filename)
    image.save(save_path)
    # print(f"Image saved to {save_path}")
    image.close()

In [10]:
torch.autograd.set_detect_anomaly(True)
from torch.amp import autocast, GradScaler
import tqdm 
import datetime
import os
from PIL import Image

NUM_WORKERS = 0 if not torch.cuda.is_available() else 0 

num_epochs = 100000
log_interval = 100
save_interval = 1000

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=5e-4,  # Typical for diffusion models
    weight_decay=1e-4,
    betas=(0.9, 0.999)
)

# Cosine annealing scheduler is common for diffusion models
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs,
    eta_min=1e-7
)

save_dir = os.path.join("checkpoints", 
                        f"cdc_training_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
os.makedirs(save_dir, exist_ok=True)

train_losses = []

In [11]:
# save_dir = r"E:\DUT Courses\Academic year 4\semester 2\PBL\PBL7 CDC Compression\checkpoints\cdc_training_20250605_133848"

# # Load checkpoint 
# checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{20000}.pt")
# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint['model_state_dict'])
# epochs = checkpoint['epoch']
# # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# model.train()  # Set model to training mode
# print(f"Loaded checkpoint from {checkpoint_path}")


In [None]:
for epoch in range(num_epochs):
    train_dataset.clear_cache()
    epoch_losses = []
    model.train()

    progress_bar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    image_for_viz = None
    predicted_x0 = None
    test_image = None

    for batch_idx, images in progress_bar:
        images = images.to(device)
        test_image = images[0].unsqueeze(0)  # For visualization purposes

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type='cuda', enabled=torch.cuda.is_available()):
            loss, x0_img, image_noise,  mse_loss, reconstruction_loss, bpp_loss  = model(images)

        if batch_idx == 0:
            predicted_x0 = x0_img
            image_for_viz = image_noise[0].unsqueeze(0)

        if torch.isnan(loss).any():
            print(f"NaN in loss at batch {batch_idx}, skipping...")
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        if any(torch.isnan(param.grad).any() for param in model.parameters() if param.grad is not None):
            print(f"NaN in gradients at batch {batch_idx}, skipping...")
            continue

        optimizer.step()

        epoch_losses.append(loss.item())

        progress_bar.set_postfix({'loss': f"{epoch_losses[-1]:.4f}", 'mse_loss': f"{mse_loss.item():.4f}", 'reconstruction_loss': f"{reconstruction_loss.item():.4f}", 'bpp_loss': f"{bpp_loss.item():.4f}"})

        if batch_idx % log_interval == 0 and batch_idx > 0:
            avg_recent_loss = sum(epoch_losses[-log_interval:]) / log_interval
            print(f"Batch {batch_idx}/{len(train_loader)} - Avg Loss: {avg_recent_loss:.4f}")

        del loss

    avg_loss = sum(epoch_losses) / len(epoch_losses)
    train_losses.append(avg_loss)

    print(f"\nEpoch {epoch+1} completed - Avg Loss: {avg_loss:.4f}")

    if (epoch + 1) % save_interval == 0:
        checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

    # if (epoch) >= 10000 and predicted_x0 is not None and (epoch + 1) % 500 == 0:
    draw_image(predicted_x0)
    draw_image(image_for_viz)

    if (epoch) >= 10000 and (epoch + 1) % 500 == 0:
        model.eval()
        # draw_image(test_image, save_dir="visualizations_test")
        model.ddim_forward(test_image)
        model.train()

        torch.cuda.empty_cache()
        gc.collect()
        break

    scheduler.step()


Epoch 1/100000: 100%|██████████| 1/1 [00:02<00:00,  2.71s/it, loss=2.3447, mse_loss=1.5789, reconstruction_loss=4.1313, bpp_loss=0.0001]



Epoch 1 completed - Avg Loss: 2.3447


Epoch 2/100000: 100%|██████████| 1/1 [00:01<00:00,  1.54s/it, loss=1.0428, mse_loss=1.4124, reconstruction_loss=0.1769, bpp_loss=0.0010]



Epoch 2 completed - Avg Loss: 1.0428


Epoch 3/100000: 100%|██████████| 1/1 [00:01<00:00,  1.49s/it, loss=0.9642, mse_loss=1.3212, reconstruction_loss=0.1309, bpp_loss=0.0001]



Epoch 3 completed - Avg Loss: 0.9642


Epoch 4/100000: 100%|██████████| 1/1 [00:01<00:00,  1.57s/it, loss=0.8970, mse_loss=1.2587, reconstruction_loss=0.0523, bpp_loss=0.0002]



Epoch 4 completed - Avg Loss: 0.8970


Epoch 5/100000: 100%|██████████| 1/1 [00:01<00:00,  1.62s/it, loss=1.7253, mse_loss=1.1906, reconstruction_loss=2.9723, bpp_loss=0.0002]



Epoch 5 completed - Avg Loss: 1.7253
DDIM Step: 50/50
DDIM Step: 49/50
DDIM Step: 48/50
DDIM Step: 47/50
DDIM Step: 46/50
DDIM Step: 45/50
DDIM Step: 44/50
DDIM Step: 43/50
DDIM Step: 42/50
DDIM Step: 41/50
DDIM Step: 40/50
DDIM Step: 39/50
DDIM Step: 38/50
DDIM Step: 37/50
DDIM Step: 36/50
DDIM Step: 35/50
DDIM Step: 34/50
DDIM Step: 33/50
DDIM Step: 32/50
DDIM Step: 31/50
DDIM Step: 30/50
DDIM Step: 29/50
DDIM Step: 28/50
DDIM Step: 27/50
DDIM Step: 26/50
DDIM Step: 25/50
DDIM Step: 24/50
DDIM Step: 23/50
DDIM Step: 22/50
DDIM Step: 21/50
DDIM Step: 20/50
DDIM Step: 19/50
DDIM Step: 18/50
DDIM Step: 17/50
DDIM Step: 16/50
DDIM Step: 15/50
DDIM Step: 14/50
DDIM Step: 13/50
DDIM Step: 12/50
DDIM Step: 11/50
DDIM Step: 10/50
DDIM Step: 9/50
DDIM Step: 8/50
DDIM Step: 7/50
DDIM Step: 6/50
DDIM Step: 5/50
DDIM Step: 4/50
DDIM Step: 3/50
DDIM Step: 2/50
DDIM Step: 1/50


In [13]:
# model.eval()

# images = next(iter(train_loader))

# for epoch in range(0, 50):
#     train_dataset.clear_cache()
#     epoch_losses = []

#     progress_bar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
#     image_for_viz = None
#     predicted_x0 = None
#     test_image = None


#     images = images.to(device)
#     test_image = images[0].unsqueeze(0)  # For visualization purposes

#     optimizer.zero_grad(set_to_none=True)

#     with autocast(device_type='cuda', enabled=torch.cuda.is_available()):
#         model.forward_fake(images, denoise_steps=50)
#         break

#     # if True:
#     #     predicted_x0 = x0_img
#     #     image_for_viz = image_noise[0].unsqueeze(0)

#     # # if (epoch) >= 10000 and predicted_x0 is not None and (epoch + 1) % 500 == 0:
#     # draw_image(predicted_x0)
#     # draw_image(image_for_viz)

#     # images = predicted_x0

#     # if (epoch) >= 1000 and (epoch + 1) % 100 == 0:
#     #     model.eval()
#     #     # draw_image(test_image, save_dir="visualizations_test")
#     #     model.ddim_forward(test_image, denoise_steps=50)
#     #     model.train()

#     #     torch.cuda.empty_cache()
#     #     gc.collect()
#     #     # break

#     scheduler.step()


In [14]:
# Save model as final (for evaluation/inference)
model.eval()
final_checkpoint_path = os.path.join(save_dir, "final_model.pt")
torch.save({
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
}, final_checkpoint_path)
print(f"Final model saved for evaluation at {final_checkpoint_path}")

Final model saved for evaluation at checkpoints\cdc_training_20250607_111541\final_model.pt
