In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import random
import numpy as np
import time
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './')))

In [2]:
COCO_ROOT_DIR = r"D:\temp_dataset\coco\images\train2017.1\train2017"  # Thư mục gốc chứa COCO
TRAIN_IMAGES_SUBDIR = ""     # Thư mục con chứa ảnh train2017
TRAIN_ANNOTATIONS_FILENAME = "captions_train2017.json" # File chú thích cho train2017

PATCH_SIZE = (256, 256)  # Kích thước patch mong muốn (height, width)
BATCH_SIZE = 64          # Kích thước batch cho DataLoader
NUM_WORKERS = 0         # Số luồng để tải dữ liệu

# --- Xây dựng đường dẫn đầy đủ ---
train_images_path = os.path.join(COCO_ROOT_DIR, TRAIN_IMAGES_SUBDIR)
train_annotations_path = os.path.join(r"D:\temp_dataset\coco\images\train2017.1\annotations_trainval2017\annotations", TRAIN_ANNOTATIONS_FILENAME)

In [3]:
import gc
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class ImageOnlyDataset(Dataset):
    """
    A dataset that loads only images from a directory (no labels).
    All images are assumed to be the same size.
    """
    def __init__(self, image_dir, transform=None, cache_size=0, max_images=None):
        """
        Args:
            image_dir (str): Path to directory containing images
            transform (callable, optional): Transform to apply to images
            cache_size (int): Number of images to cache in memory (0 for no caching)
            max_images (int): Limit number of images loaded from the folder
        """
        self.image_dir = image_dir
        self.transform = transform or transforms.Compose([
            transforms.ToTensor(),  # Converts images to tensors [0-1]
        ])
        self.cache_size = cache_size
        self.cache = {}
        
        # Get all image paths but don't load them yet
        self.image_paths = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            self.image_paths.extend(glob.glob(os.path.join(image_dir, ext)))
            
        # Sort for reproducibility
        self.image_paths.sort()
        
        # Only keep a limited number of images
        if max_images is not None:
            self.image_paths = self.image_paths[:max_images]
        
        print(f"Found {len(self.image_paths)} images in {image_dir}")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Check if image is in cache
        if idx in self.cache:
            return self.cache[idx]
        
        # Load image only when needed
        image_path = self.image_paths[idx]
        try:
            with Image.open(image_path) as img:
                image = img.convert('RGB')
                if self.transform:
                    image = self.transform(image)
                if self.cache_size > 0 and len(self.cache) < self.cache_size:
                    self.cache[idx] = image
                return image
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return torch.zeros(3, 256, 256)  # Placeholder on failure

    def clear_cache(self):
        self.cache.clear()
        gc.collect()


# Setup the dataset and dataloader
def get_coco_patches_loader(
    data_dir="D:\\ds_coco_patches", 
    batch_size=64,
    num_workers=0,
    pin_memory=False,
    shuffle=True,
    cache_size=0,
    max_images=None
):
    """
    Creates a DataLoader for loading images from the specified directory.
    """
    dataset = ImageOnlyDataset(data_dir, cache_size=cache_size, max_images=max_images)
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    
    return dataloader, dataset


# ------------------------ Example usage ------------------------

BATCH_SIZE = 32  # Adjust based on your GPU memory
NUM_WORKERS = 0  # 0 = no multiprocessing
MAX_IMAGES = 3000  # Limit to first 3000 images

train_loader, train_dataset = get_coco_patches_loader(
    data_dir="D:\\ds_coco_patches",
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    cache_size=0,
    max_images=MAX_IMAGES
)

# Move a batch to CUDA and measure its size
for images in train_loader:
    images = images.cuda()  # Move to CUDA
    print(f"Batch size: {images.size(0)}")
    print(f"Batch shape: {images.shape}")  # Should be [B, 3, H, W]
    
    # Calculate memory usage
    batch_memory = images.element_size() * images.nelement() / (1024 ** 2)
    print(f"Estimated memory usage of the batch on CUDA: {batch_memory:.2f} MB")
    break


Found 3000 images in D:\ds_coco_patches
Batch size: 32
Batch shape: torch.Size([32, 3, 256, 256])
Estimated memory usage of the batch on CUDA: 24.00 MB


In [4]:
sample_batch = next(iter(train_loader))
print(f"Batch shape: {sample_batch.shape}")
print(f"Estimated batch size in MB: {sample_batch.element_size() * sample_batch.nelement() / 1024**2:.2f}")


Batch shape: torch.Size([32, 3, 256, 256])
Estimated batch size in MB: 24.00


In [5]:
from compressor.compressor import Compressor
from decompressor.diffusion_manager import DiffusionManager
from decompressor.unet_module import UnetModule


# Initialize the compressor
compressor = Compressor()

# Generate a dummy tensor to simulate input data
dummy_tensor = torch.randn(64, 3, 256, 256).cuda()

# Pass the dummy tensor through the compressor
output_dict = compressor(dummy_tensor)

# Extract the shapes of the output tensors
output_shapes = [output.shape[1] for output in output_dict['output']]

# Initialize the UNet module with the extracted channel dimensions
unet_module = UnetModule(context_channels=output_shapes)

del dummy_tensor

model = DiffusionManager(
    encoder=compressor,
    u_net=unet_module,
)

Initializing HyperPrior on device: cuda
DiffusionManager initialized on device: cuda




In [6]:
def get_model_size_in_mb(model):
    total_size = sum(param.element_size() * param.nelement() for param in model.parameters())
    total_size += sum(buffer.element_size() * buffer.nelement() for buffer in model.buffers())
    return total_size / (1024 ** 2)  # Convert bytes to MB

model_size_mb = get_model_size_in_mb(model)
print(f"Model size: {model_size_mb:.2f} MB")

Model size: 15.23 MB


In [7]:
torch.cuda.empty_cache()
import gc
gc.collect()

17

In [None]:
torch.autograd.set_detect_anomaly(True)
from torch.utils.data import DataLoader
import tqdm
import datetime
import os

import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(torch.cuda.current_device())

def print_gpu_mem(tag=""):
    info = pynvml.nvmlDeviceGetMemoryInfo(handle)
    allocated = torch.cuda.memory_allocated() / 1024**2
    reserved = torch.cuda.memory_reserved() / 1024**2
    print(f"[{tag}] GPU Mem Used (NVML): {info.used / 1024**2:.2f} MB | Allocated (PyTorch): {allocated:.2f} MB | Reserved (PyTorch): {reserved:.2f} MB")

NUM_WORKERS = 0 if torch.cuda.is_available() else 0

# Create optimizer for both compressor and UNet
optimizer = torch.optim.Adam([
    {'params': model.encoder.parameters()},
    {'params': model.u_net.parameters()}
], lr=1e-4)

# Training parameters
num_epochs = 5
log_interval = 50  # Log every 50 batches
save_interval = 1   # Save checkpoint every epoch

# Create directory for saving checkpoints
save_dir = os.path.join("checkpoints", 
                        f"cdc_training_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}")
os.makedirs(save_dir, exist_ok=True)

# Training statistics
train_losses = []
prior_losses = []

# Move models to CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"Starting training on {device}")
print(f"Training for {num_epochs} epochs with batch size {BATCH_SIZE}")

# Training loop
for epoch in range(num_epochs):
    train_dataset.clear_cache()  # Clear cache at the start of each epoch
    epoch_losses = []
    epoch_prior_losses = []

    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"len train_loader: {len(train_loader)}")
    # print_gpu_mem("Epoch Start")
    
    # Create tqdm progress bar
    progress_bar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader),
                             desc=f"Epoch {epoch+1}/{num_epochs}")
    
    model.train()  # Set model to training mode
    
    for batch_idx, images in progress_bar:
        # print_gpu_mem("Batch Start")
        images = images.to(device)  # Move images to GPU
        # Zero gradients
        optimizer.zero_grad(set_to_none=True) # set_to_none=True có thể tiết kiệm chút mem
        # print_gpu_mem("After optimizer.zero_grad")
        
        total_loss, prior_loss = model(images)
        # print_gpu_mem("After model(images)")
        
        total_loss.backward()
        prior_loss.backward()
        # print_gpu_mem("After total_loss.backward()")

        # --- Optimizer step ---
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        # print_gpu_mem("After clip_grad_norm_")
        optimizer.step()
        # print_gpu_mem("After optimizer.step()")
        
        loss_item = total_loss.item() # Lấy item() NGAY LẬP TỨC
        prior_loss_item = prior_loss.item()
        epoch_losses.append(loss_item)
        epoch_prior_losses.append(prior_loss_item)
        # print_gpu_mem("After loss .item()")
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f"{epoch_losses[-1]:.4f}",
            'prior_loss': f"{epoch_prior_losses[-1]:.4f}"
        })
        
        # Log periodically
        if batch_idx % log_interval == 0 and batch_idx > 0:
            print(f"\nBatch {batch_idx}/{len(train_loader)}, "
                  f"Loss: {sum(epoch_losses[-log_interval:]) / log_interval:.4f}, "
                  f"Prior Loss: {sum(epoch_prior_losses[-log_interval:]) / log_interval:.4f}")
            
        time.sleep(5)

        del images, total_loss, prior_loss, loss_item, prior_loss_item
        # print_gpu_mem("Batch End (after explicit del)")
    
    # Calculate average epoch loss
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    avg_prior_loss = sum(epoch_prior_losses) / len(epoch_prior_losses)
    train_losses.append(avg_loss)
    prior_losses.append(avg_prior_loss)
    
    print(f"\nEpoch {epoch+1}/{num_epochs} completed, "
          f"Avg Loss: {avg_loss:.4f}, "
          f"Avg Prior Loss: {avg_prior_loss:.4f}")
    
    # Save checkpoints
    if (epoch + 1) % save_interval == 0:
        checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'encoder_state_dict': model.encoder.state_dict(),
            'unet_state_dict': model.u_net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, checkpoint_path)
        print(f"Saved checkpoint to {checkpoint_path}")

# Plot training curve
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Total Loss')
plt.plot(range(1, num_epochs + 1), prior_losses, label='Prior Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.grid(True)
plt.show()

print("Training completed!")

Starting training on cuda
Training for 5 epochs with batch size 32

Epoch 1/5
len train_loader: 94


Epoch 1/5:  19%|█▉        | 18/94 [04:00<16:19, 12.88s/it, loss=4.4239, prior_loss=3.2104]     