In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import random
import numpy as np
import time
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './')))

In [2]:
COCO_ROOT_DIR = r"D:\temp_dataset\coco\images\train2017.1\train2017"  # Thư mục gốc chứa COCO
TRAIN_IMAGES_SUBDIR = ""     # Thư mục con chứa ảnh train2017
TRAIN_ANNOTATIONS_FILENAME = "captions_train2017.json" # File chú thích cho train2017

PATCH_SIZE = (256, 256)  # Kích thước patch mong muốn (height, width)
BATCH_SIZE = 64          # Kích thước batch cho DataLoader
NUM_WORKERS = 4         # Số luồng để tải dữ liệu

# --- Xây dựng đường dẫn đầy đủ ---
train_images_path = os.path.join(COCO_ROOT_DIR, TRAIN_IMAGES_SUBDIR)
train_annotations_path = os.path.join(r"D:\temp_dataset\coco\images\train2017.1\annotations_trainval2017\annotations", TRAIN_ANNOTATIONS_FILENAME)

In [3]:

# ------------------------ Example usage ------------------------

from data_loader import get_coco_patches_loader


BATCH_SIZE = 1  # Adjust based on your GPU memory
NUM_WORKERS = 4  # 0 = no multiprocessing
MAX_IMAGES = 1  # Limit to first 3000 images

train_loader, train_dataset = get_coco_patches_loader(
    data_dir="D:\\ds_coco_patches",
    batch_size=BATCH_SIZE,
    pin_memory=True,
    num_workers=4 if torch.cuda.is_available() else 0,
    shuffle=True,
    cache_size=0,
    max_images=MAX_IMAGES
)

# Move a batch to CUDA and measure its size
for images in train_loader:
    images = images.cuda()  # Move to CUDA
    print(f"Batch size: {images.size(0)}")
    print(f"Batch shape: {images.shape}")  # Should be [B, 3, H, W]
    
    # Calculate memory usage
    batch_memory = images.element_size() * images.nelement() / (1024 ** 2)
    print(f"Estimated memory usage of the batch on CUDA: {batch_memory:.2f} MB")
    break


Found 1 images in D:\ds_coco_patches
Batch size: 1
Batch shape: torch.Size([1, 3, 256, 256])
Estimated memory usage of the batch on CUDA: 0.75 MB


In [4]:
sample_batch = next(iter(train_loader))
print(f"Batch shape: {sample_batch.shape}")
print(f"Estimated batch size in MB: {sample_batch.element_size() * sample_batch.nelement() / 1024**2:.2f}")


Batch shape: torch.Size([1, 3, 256, 256])
Estimated batch size in MB: 0.75


In [5]:
from compressor.compressor import Compressor
from decompressor.diffusion_manager import DiffusionManager
from decompressor.unet_module import UnetModule


# Initialize the compressor
compressor = Compressor(
    channel_multiplier=[1, 3, 3, 12, 52, 64],
    hyperprior_channel_multiplier=[64,64,64]
)

# Generate a dummy tensor to simulate input data
dummy_tensor = torch.randn(64, 3, 256, 256).cuda()

# Pass the dummy tensor through the compressor
output_dict = compressor(dummy_tensor)

# Extract the shapes of the output tensors
output_shapes = [output.shape[1] for output in output_dict['output']]

# Initialize the UNet module with the extracted channel dimensions
unet_module = UnetModule(
    base_channels=3,
    context_channels=output_shapes
)

del dummy_tensor

model = DiffusionManager(
    encoder=compressor,
    u_net=unet_module,
)

Initializing HyperPrior on device: cuda
DiffusionManager initialized on device: cuda




In [6]:
def get_model_size_in_mb(model):
    total_size = sum(param.element_size() * param.nelement() for param in model.parameters())
    total_size += sum(buffer.element_size() * buffer.nelement() for buffer in model.buffers())
    return total_size / (1024 ** 2)  # Convert bytes to MB

model_size_mb = get_model_size_in_mb(model)
print(f"Model size: {model_size_mb:.2f} MB")

Model size: 41.35 MB


In [7]:
torch.cuda.empty_cache()
import gc
gc.collect()

17

In [8]:
def draw_image(input, save_dir="visualizations"):
    os.makedirs(save_dir, exist_ok=True)
    image = input[0].cpu().detach().numpy()
    image = np.transpose(image, (1, 2, 0))
    image = (image * 255).astype(np.uint8)
    image = Image.fromarray(image)
    # Generate a unique filename using timestamp
    filename = f"image_{int(time.time() * 1000)}.png"
    save_path = os.path.join(save_dir, filename)
    image.save(save_path)
    print(f"Image saved to {save_path}")
    image.close()

In [9]:
# torch.autograd.set_detect_anomaly(True)
# from torch.utils.data import DataLoader
# import tqdm
# import datetime
# import os
# from PIL import Image

# import pynvml
# pynvml.nvmlInit()
# handle = pynvml.nvmlDeviceGetHandleByIndex(torch.cuda.current_device())

# def print_gpu_mem(tag=""):
#     info = pynvml.nvmlDeviceGetMemoryInfo(handle)
#     allocated = torch.cuda.memory_allocated() / 1024**2
#     reserved = torch.cuda.memory_reserved() / 1024**2
#     print(f"[{tag}] GPU Mem Used (NVML): {info.used / 1024**2:.2f} MB | Allocated (PyTorch): {allocated:.2f} MB | Reserved (PyTorch): {reserved:.2f} MB")

# NUM_WORKERS = 0 if torch.cuda.is_available() else 0

# # Training parameters
# num_epochs = 20000
# log_interval = 50  # Log every 50 batches
# save_interval = 100   # Save checkpoint every epoch

# # Create optimizer for both compressor and UNet - REMOVED mixed precision components
# optimizer = torch.optim.AdamW(
#     [
#         {'params': model.encoder.parameters()},
#         {'params': model.u_net.parameters()}
#     ],
#     lr=2e-3,  # Typical for diffusion models
#     weight_decay=1e-3,
#     betas=(0.9, 0.999)
# )

# # Cosine annealing scheduler is common for diffusion models
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     optimizer,
#     T_max=num_epochs,
#     eta_min=1e-5
# )

# # Create directory for saving checkpoints
# save_dir = os.path.join("checkpoints", 
#                         f"cdc_training_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
# os.makedirs(save_dir, exist_ok=True)

# # Training statistics
# train_losses = []
# prior_losses = []

# # Move models to CUDA
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = model.to(device)

# print(f"Starting training on {device}")
# print(f"Training for {num_epochs} epochs with batch size {BATCH_SIZE}")

# # Training loop
# for epoch in range(num_epochs):
#     train_dataset.clear_cache()  # Clear cache at the start of each epoch
#     epoch_losses = []
#     epoch_prior_losses = []
    
#     # Create tqdm progress bar
#     progress_bar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader),
#                             desc=f"Epoch {epoch+1}/{num_epochs}")
    
#     model.train()  # Set model to training mode
#     predict_x0_viz = None
#     noise = None

#     sav_images = None
    
#     for batch_idx, images in progress_bar:
#         images = images.to(device)  # Move images to GPU
#         sav_images = images
#         # Zero gradients
#         optimizer.zero_grad(set_to_none=True)
        
#         # Normal forward pass without autocast
#         total_loss, prior_loss, estimate_x0, loss_dict, noise_add = model(images)
#         # Save a batch for visualization at end of epoch
#         if batch_idx == 0:
#             predict_x0_viz = estimate_x0    
#             noise = noise_add
        
#         # Check for NaN in losses
#         if torch.isnan(total_loss).any() or torch.isnan(prior_loss).any():
#             print(f"NaN detected in losses at batch {batch_idx}. Skipping this batch.")
#             continue

#         # combined_loss = total_loss + prior_loss

#         # # Check for NaN in loss
#         # if torch.isnan(combined_loss).any():
#         #     print(f"NaN detected in combined_loss at batch {batch_idx}. Skipping this batch.")
#         #     continue

#         # Standard backward pass
#         # combined_loss.backward()
#         total_loss.backward()
#         # prior_loss.backward()

#         # Clip gradients to prevent exploding gradients
#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

#         # Check for NaN in gradients
#         has_nan_grad = False
#         for name, param in model.named_parameters():
#             if param.grad is not None and torch.isnan(param.grad).any():
#                 print(f"NaN detected in gradients of parameter {name}. Skipping this batch.")
#                 has_nan_grad = True
#                 break
        
#         if has_nan_grad:
#             continue

#         # Standard optimizer step
#         optimizer.step()
        
#         loss_item = total_loss.item()
#         prior_loss_item = prior_loss.item()
#         epoch_losses.append(loss_item)
#         epoch_prior_losses.append(prior_loss_item)
        
#         # Update progress bar
#         progress_bar.set_postfix({
#             'loss': f"{epoch_losses[-1]:.4f}",
#             'prior_loss': f"{epoch_prior_losses[-1]:.4f}",
#             **{key: f"{value:.4f}" for key, value in loss_dict.items()}
#         })
        
#         # Log periodically
#         if batch_idx % log_interval == 0 and batch_idx > 0:
#             print(f"\nBatch {batch_idx}/{len(train_loader)}, "
#                 f"Loss: {sum(epoch_losses[-log_interval:]) / log_interval:.4f}, "
#                 f"Prior Loss: {sum(epoch_prior_losses[-log_interval:]) / log_interval:.4f}")

#         del total_loss, prior_loss, loss_item, prior_loss_item
    
#     # Calculate average epoch loss
#     avg_loss = sum(epoch_losses) / len(epoch_losses)
#     avg_prior_loss = sum(epoch_prior_losses) / len(epoch_prior_losses)
#     train_losses.append(avg_loss)
#     prior_losses.append(avg_prior_loss)
    
#     print(f"\nEpoch {epoch+1}/{num_epochs} completed, "
#         f"Avg Loss: {avg_loss:.4f}, "
#         f"Avg Prior Loss: {avg_prior_loss:.4f}")
    
#     # Save checkpoints
#     if (epoch + 1) % save_interval == 0:
#         checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt")
#         torch.save({
#             'epoch': epoch + 1,
#             'encoder_state_dict': model.encoder.state_dict(),
#             'unet_state_dict': model.u_net.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': avg_loss,
#         }, checkpoint_path)
#         print(f"Saved checkpoint to {checkpoint_path}")

#     if (epoch + 1) >= 3000 and (epoch + 1) % 1 == 0 and predict_x0_viz is not None:
#         # show image
#         draw_image(predict_x0_viz)
#         draw_image(noise)

#     if (epoch + 1) >= 4000 and (epoch + 1) % 100 == 0:
#         print ((epoch + 1) % 10)
#         model.eval()
        
#         image = sav_images

#         start_noise = image
#         start_noise = torch.rand_like(start_noise) 
#         draw_image(start_noise)
#         model.evaluate_ddim(image, start_noise, denoise_steps=30)
#         model.train()

#         torch.cuda.empty_cache()
#         gc.collect()


#     # Step the learning rate scheduler
#     scheduler.step()


        

In [11]:
save_dir = r"E:\DUT Courses\Academic year 4\semester 2\PBL\PBL7 CDC Compression\checkpoints\cdc_training_20250519_013027"

# Load checkpoint 
checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{4000}.pt")
checkpoint = torch.load(checkpoint_path)
model.encoder.load_state_dict(checkpoint['encoder_state_dict'])
model.u_net.load_state_dict(checkpoint['unet_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.train()  # Set model to training mode
print(f"Loaded checkpoint from {checkpoint_path}")


Loaded checkpoint from E:\DUT Courses\Academic year 4\semester 2\PBL\PBL7 CDC Compression\checkpoints\cdc_training_20250519_013027\checkpoint_epoch_4000.pt
