Dataset Loader

In [9]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

# class ImageDataset(Dataset):
#     def __init__(self, rootA, rootB, transform=None):
#         self.transform = transform
#         self.files_A = sorted([os.path.join(rootA, x) for x in os.listdir(rootA) if x.endswith(('.png', '.jpg', '.jpeg'))])
#         self.files_B = sorted([os.path.join(rootB, x) for x in os.listdir(rootB) if x.endswith(('.png', '.jpg', '.jpeg'))])

#     def __len__(self):
#         return max(len(self.files_A), len(self.files_B))  # Ensuring both datasets are the same size

#     def __getitem__(self, index):
#         image_A = Image.open(self.files_A[index % len(self.files_A)]).convert('RGB')
#         image_B = Image.open(self.files_B[index % len(self.files_B)]).convert('RGB')

#         if self.transform:
#             image_A = self.transform(image_A)
#             image_B = self.transform(image_B)

#         return image_A, image_B


from torchvision import transforms

class ImageDataset(Dataset):
    def __init__(self, rootA, rootB, transform=None):
        self.transform = transform
        self.files_A = sorted([os.path.join(rootA, x) for x in os.listdir(rootA) if x.endswith(('.png', '.jpg', '.jpeg'))])
        self.files_B = sorted([os.path.join(rootB, x) for x in os.listdir(rootB) if x.endswith(('.png', '.jpg', '.jpeg'))])

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))  # Ensuring both datasets are the same size

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)]).convert('RGB')
        image_B = Image.open(self.files_B[index % len(self.files_B)]).convert('RGB')

        if self.transform:
            image_A = self.transform(image_A)
            image_B = self.transform(image_B)

        # Convert PIL images to tensors
        to_tensor = transforms.ToTensor()
        image_A = to_tensor(image_A)
        image_B = to_tensor(image_B)

        return image_A, image_B




In [10]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to a fixed size for all images
    transforms.ToTensor()
])


In [None]:
def get_data_loader( image_type,image_size=256, batch_size=16, num_workers=4):
    """Returns training and test data loaders for a given image type, either 'summer' or 'winter'. 
       These images will be resized to 128x128x3, by default, converted into Tensors, and normalized.
    """
    
    # resize and normalize the images
    transform = transforms.Compose([transforms.Resize((image_size,image_size)), # resize to 128x128
                                    transforms.ToTensor()])

    # get training and test directories
    image_path = ".\\cycle_gan\\dataset\\"
    train_path = image_path+"train{}".format(image_type)
    test_path = image_path+"test{}".format(image_type)

    # define datasets using ImageFolder
    train_dataset = datasets.ImageFolder(train_path, transform)
    test_dataset = datasets.ImageFolder(test_path, transform)

    # create and return DataLoaders
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

In [11]:
def create_data_loaders(trainA_path, trainB_path, testA_path, testB_path, batch_size=1, num_workers=0, transform=None):
    train_loader = DataLoader(
        ImageDataset(trainA_path, trainB_path, transform=transform),
        batch_size=batch_size, shuffle=True, num_workers=num_workers,pin_memory=True)

    test_loader = DataLoader(
        ImageDataset(testA_path, testB_path, transform=transform),
        batch_size=batch_size, shuffle=False, num_workers=num_workers,pin_memory=True)

    return train_loader, test_loader


In [14]:
train_dataset = ImageDataset('/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Train/Low',  '/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Train/Normal', transform=transform)
test_dataset = ImageDataset('/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Test/Low', '/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Test/Normal', transform=transform)
batch_size=16
num_workers=4
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
train_loader, test_loader = create_data_loaders(
    '/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Train/Low', '/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Train/Normal',
    '/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Test/Low', '/home/dgreen/sem2/dl/finalProject/LOL-v2/LOL-v2/Real_captured/Test/Normal',
    batch_size=16, num_workers=4
)


# Generator

In [15]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os

# -----------------------------------
# 1. Generator and Residual Block
# -----------------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, n_residual_blocks=9):
        super(Generator, self).__init__()
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_channels, 64, 7),
                 nn.InstanceNorm2d(64),
                 nn.ReLU(inplace=True)]
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model.extend([nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                          nn.InstanceNorm2d(out_features),
                          nn.ReLU(inplace=True)])
            in_features = out_features
            out_features *= 2
        for _ in range(n_residual_blocks):
            model.append(ResidualBlock(in_features))
        out_features = in_features // 2
        for _ in range(2):
            model.extend([nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                          nn.InstanceNorm2d(out_features),
                          nn.ReLU(inplace=True)])
            in_features = out_features
            out_features = in_features // 2
        model.extend([nn.ReflectionPad2d(3),
                      nn.Conv2d(64, output_channels, 7),
                      nn.Tanh()])
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# -----------------------------------
# 2. Discriminator
# -----------------------------------
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, x):
        return self.model(x)

# -----------------------------------
# 3. Initialize Models
# -----------------------------------
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G_AB = Generator(3, 3).to(device)
G_BA = Generator(3, 3).to(device)
D_A = Discriminator(3).to(device)
D_B = Discriminator(3).to(device)
print('Models moved to GPU.')

# -----------------------------------
# 4. Losses and Optimizers
# -----------------------------------
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

optimizer_G = optim.Adam(list(G_AB.parameters()) + list(G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))


Models moved to GPU.


In [16]:
def train_cycle_gan(loader_A, loader_B, epochs=100, save_interval=10):
    G_losses, D_losses = [], []

    for epoch in range(epochs):
        iteration=0
        for real_A, real_B in zip(loader_A, loader_B):

            print(real_A[0][0])
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            valid = torch.ones(real_A.size(0), 1, device=device, requires_grad=False)
            fake = torch.zeros(real_A.size(0), 1, device=device, requires_grad=False)

            ######################
            #  Train Generators  #
            ######################
            optimizer_G.zero_grad()

            # Identity loss
            loss_identity_A = criterion_identity(G_BA(real_A), real_A)
            loss_identity_B = criterion_identity(G_AB(real_B), real_B)

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            # Cycle loss
            recovered_A = G_BA(fake_B)
            recovered_B = G_AB(fake_A)
            loss_cycle_ABA = criterion_cycle(recovered_A, real_A)
            loss_cycle_BAB = criterion_cycle(recovered_B, real_B)

            # Total loss for generators
            total_loss_G = (loss_identity_A + loss_identity_B) * 5.0 + (loss_GAN_AB + loss_GAN_BA) + (loss_cycle_ABA + loss_cycle_BAB) * 10.0
            total_loss_G.backward()
            optimizer_G.step()

            #############################
            #  Train Discriminator D_A  #
            #############################
            optimizer_D_A.zero_grad()

            # Real loss
            loss_D_real_A = criterion_GAN(D_A(real_A), valid)

            # Fake loss (on detached fake_A to avoid training G on these labels)
            fake_A = G_BA(real_B).detach()
            loss_D_fake_A = criterion_GAN(D_A(fake_A), fake)

            # Total loss for discriminator A
            total_loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
            total_loss_D_A.backward()
            optimizer_D_A.step()

            #############################
            #  Train Discriminator D_B  #
            #############################
            optimizer_D_B.zero_grad()

            # Real loss
            loss_D_real_B = criterion_GAN(D_B(real_B), valid)

            # Fake loss (on detached fake_B to avoid training G on these labels)
            fake_B = G_AB(real_A).detach()
            loss_D_fake_B = criterion_GAN(D_B(fake_B), fake)

            # Total loss for discriminator B
            total_loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
            total_loss_D_B.backward()
            optimizer_D_B.step()

            # Collect losses for logging
            G_losses.append(total_loss_G.item())
            D_losses.append((total_loss_D_A.item() + total_loss_D_B.item()) / 2)
            print(f"{epoch + 1}:{iteration} , G Loss: {total_loss_G.item():.4f}, D Loss: {(total_loss_D_A.item() + total_loss_D_B.item()) / 2:.4f}")
            iteration+=1

        # Logging and saving checkpoints
        print(f"Epoch [{epoch + 1}/{epochs}], G Loss: {total_loss_G.item():.4f}, D Loss: {(total_loss_D_A.item() + total_loss_D_B.item()) / 2:.4f}")

        if (epoch + 1) % save_interval == 0:
            torch.save(G_AB.state_dict(), f'G_AB_epoch_{epoch + 1}.pth')
            torch.save(G_BA.state_dict(), f'G_BA_epoch_{epoch + 1}.pth')
            torch.save(D_A.state_dict(), f'D_A_epoch_{epoch + 1}.pth')
            torch.save(D_B.state_dict(), f'D_B_epoch_{epoch + 1}.pth')

    return G_losses, D_losses


In [17]:
type(train_loader)

torch.utils.data.dataloader.DataLoader

In [18]:
train_cycle_gan(train_loader,test_loader)

TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>