In [43]:
import numpy as np
import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from FreeFormDeformation import DeformationLayer
from deepali.core import functional as U
from tqdm import tqdm
import random
from diffusion_unet import Unet
from torch import nn, optim
import os
import cv2
from torch.cuda.amp import GradScaler, autocast
from PIL import Image
import torchvision.transforms.functional as F

In [44]:
class CustomDataset(Dataset):
    def __init__(self, image_paths, transform=None, device = "cpu"):
        """
        Args:
            image_paths (list): List of all image Paths.
            shape: The shape of one image in the dataset.
            mean (float): The mean value for normalization.
            std (float): The standard deviation for normalization.
            transform (bool): Whether to apply the transformation.
        """
        self.image_paths = image_paths
        self.transform = transform
        self.device = device
    
    def __len__(self):
        return len(self.image_paths)
    
    def build_deformation_layer(self, shape, device):
        """
        Build and return a new deformation layer for each call to __getitem__.
        This method returns the created deformation layer.
        """
        deformation_layer = DeformationLayer(shape)
        deformation_layer.new_deformation(device=device)
        return deformation_layer

    def __getitem__(self, idx):
        # Fetch the original image
        image_path = self.image_paths[idx]
        img = cv2.imread(image_path)
        if img is None:
            raise FileNotFoundError(f"Image not found at path: {image_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img = Image.fromarray(img)
        transform = transforms.Resize((256,256))
        img = transform(img)
        img = F.pil_to_tensor(img).float().to(self.device)
        print("imgshape",img.shape)
        shape = img.squeeze(0).shape
        #shape = shape.to(self.device)
        print("squeeze shape", shape)
        #original_image = img.unsqueeze(0)  # Add batch dimension
        original_image= img.to(self.device)
        

        # Build a new deformation layer for the current image
        deformation_layer = self.build_deformation_layer(shape, self.device).to(self.device)

        # Apply deformation to get the deformed image
        deformed_image = deformation_layer.deform(original_image).to(self.device)
        # Fetch the current deformation field
        deformation_field = deformation_layer.get_deformation_field().squeeze(0).to(self.device)
        
        # transform the images
        if self.transform:
            original_image = self.transform(original_image)
            deformed_image = self.transform(deformed_image)

        # Stack the original and deformed images along the channel dimension
        stacked_image = torch.cat([original_image, deformed_image], dim=0).squeeze(0)

        return stacked_image, deformation_field

In [45]:
def train_model(model, train_loader, val_loader, criterion, optimizer, n_epochs, device):
    scaler = GradScaler()  # Initialize the GradScaler for mixed precision training
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        print(111)
        
        for i, (images, deformation_field) in enumerate(train_loader):
            print(1211)
            images = images.float().to(device)
            deformation_field = deformation_field.to(device)

            print(images.shape)
            print(deformation_field.shape)

            
            optimizer.zero_grad()

            # Use autocast for mixed precision
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, deformation_field)
                train_loss += loss.item()

            # Backward pass with scaled loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        avg_train_loss = train_loss / len(train_loader)
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for i, (images, deformation_field) in enumerate(val_loader):
                images = images.float().to(device)
                deformation_field = deformation_field.to(device)
                print(images.shape)
                print(deformation_field.shape)
                # Use autocast for mixed precision
                with autocast():
                    outputs = model(images)
                    batch_loss = criterion(outputs, deformation_field).item()
                    val_loss += batch_loss
            
            avg_val_loss = val_loss / len(val_loader)
            
            print(f'Training Loss (Epoch {epoch+1}/{n_epochs}): {avg_train_loss:.4f}')
            print(f'Validation Loss (Epoch {epoch+1}/{n_epochs}): {avg_val_loss:.4f}')   

In [46]:
def get_image_paths(root_dir): 
    
    image_paths = []
    for category in os.listdir(root_dir):
        category_dir = os.path.join(root_dir, category)
        if os.path.isdir(category_dir):
            for filename in os.listdir(category_dir):
                if filename.endswith(".jpg") or filename.endswith(".png"):
                    image_paths.append(os.path.join(category_dir, filename))   
    return image_paths

In [47]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_data_path = '/home/ubuntu/ADLM/data/high-resolution/afhq/train'
    val_data_path = '/home/ubuntu/ADLM/data/high-resolution/afhq/val'

    train_images_paths = get_image_paths(train_data_path)
    val_images_paths = get_image_paths(val_data_path)
    
    mean = 0.5
    std = 0.5
    
    train_dataset = CustomDataset(train_images_paths, transform=transforms.Compose([transforms.Normalize(mean=[mean], std=[std])]), device=device)
    val_dataset = CustomDataset(val_images_paths, transform=transforms.Compose([transforms.Normalize(mean=[mean], std=[std])]), device=device)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)

    model = Unet(
        dim=8,
        init_dim=None,
        out_dim=2,
        dim_mults=(1, 2, 4, 8),
        channels=2,
        resnet_block_groups=8,
        learned_variance=False,
        conditional_dimensions=0,
        patch_size=1,
        attention_layer=None
    )

    model.to(device)
    # Check if weights file exists
   

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

    n_epochs = 5
    train_model(model, train_loader, val_loader, criterion, optimizer, n_epochs, device)
    
    # save (update the number of epochs in name)
   
    

Attention lvl 0 8 8
Attention lvl 1 8 16
Attention lvl 2 16 32
Attention lvl 3 32 64
111
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shape torch.Size([256, 256])
imgshape torch.Size([1, 256, 256])
squeeze shap

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 

In [50]:
import os
import random
import shutil

def reduce_images_in_folder(folder_path, fraction=1/3):
    for subdir, _, files in os.walk(folder_path):
        if files:
            num_files_to_keep = max(1, int(len(files) * fraction))
            files_to_keep = random.sample(files, num_files_to_keep)
            files_to_remove = set(files) - set(files_to_keep)
            
            for file_name in files_to_remove:
                file_path = os.path.join(subdir, file_name)
                os.remove(file_path)
                print(f"Removed {file_path}")

def main():
    base_path = '/home/ubuntu/ADLM/data/high-resolution/afhq'
    train_path = os.path.join(base_path, 'train')
    val_path = os.path.join(base_path, 'val')
    
    # Reduce images in train and val folders
    reduce_images_in_folder(train_path, fraction=1/3)
    reduce_images_in_folder(val_path, fraction=1/3)
    
if __name__ == "__main__":
    main()

Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_002695.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_002502.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_001669.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_003321.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_001395.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_003146.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_002615.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_002369.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_001724.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_002920.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/pixabay_dog_002209.jpg
Removed /home/ubuntu/ADLM/data/high-resolution/afhq/train/dog/fli

In [51]:
import os

def count_images_in_folder(folder_path):
    # 获取文件夹中的所有文件
    files = os.listdir(folder_path)
    
    # 过滤掉非图片文件（假设所有图片文件都是常见的图片格式，如jpg, png, jpeg, bmp, gif等）
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif'}
    image_files = [file for file in files if os.path.splitext(file)[1].lower() in image_extensions]
    
    # 打印图片数量
    print(f"Total number of images in '{folder_path}': {len(image_files)}")

# 文件夹路径
folder_path = '/home/ubuntu/ADLM/data/high-resolution/afhq/train/cat'

# 调用函数并输出结果
count_images_in_folder(folder_path)


Total number of images in '/home/ubuntu/ADLM/data/high-resolution/afhq/train/cat': 572
