In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
import os
from tqdm import tqdm

In [None]:
class GLOModel_CNN(nn.Module):
    def __init__(self, latent_dim, output_shape=(3, 400, 400)):
        super(GLOModel_CNN, self).__init__()
        self.init_size = output_shape[1] // 8  # This will scale down the output shape
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, output_shape[0], 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        img = nn.functional.interpolate(img, size=(400, 400), mode='bilinear', align_corners=False)
        return img
    
class GLOModel_MLP(nn.Module):
    def __init__(self, latent_dim, output_shape=(3, 400, 400)):
        super(GLOModel_MLP, self).__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, torch.prod(torch.tensor(output_shape))),
        )
        self.output_shape = output_shape

    def forward(self, x):
        x = self.fc_layers(x)
        x = x.view(-1, *self.output_shape)
        x = torch.sigmoid(x)
        return x

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.images = os.listdir(directory)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.directory, self.images[idx])
        image = Image.open(img_name).convert('RGB')
        #plt.imshow(image)
        #plt.show()
        if self.transform:
            image = self.transform(image)
        return idx, image

def compute_mean_std(loader):
    mean = 0.
    std = 0.
    total_images_count = 0
    for images in loader:
        images = images.view(images.size(0), images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images_count += images.size(0)
    mean /= total_images_count
    std /= total_images_count
    return mean, std

batch_size = 10
temp_loader = DataLoader(
    CustomImageDataset(directory='C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/helmet/_temp/train/imgs', 
                       transform=transforms.Compose([transforms.Resize((400, 400)), transforms.ToTensor()])),
    batch_size=batch_size,  # Adjust batch size according to your system's memory
    shuffle=False
)
'''mean, std = compute_mean_std(temp_loader)
print(f"Dataset Mean: {mean}")
print(f"Dataset Std: {std}")

transformations = transforms.Compose([
    transforms.Resize((400, 400)),  # Resize to 400x400 if not already
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=mean.tolist(), std=std.tolist())  # Custom normalization
])

dataset = CustomImageDataset(
    directory='C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/helmet/_temp/train/imgs', 
    transform=transformations
)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
mean, std = compute_mean_std(data_loader)
print(f"Dataset Mean: {mean}")
print(f"Dataset Std: {std}")'''
data_loader = temp_loader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 100
latent_dim = 200  # Dimension of the latent space
num_images = len(data_loader.dataset)

glo_model = GLOModel_CNN(latent_dim).to(device)

latent_vectors = nn.Parameter(torch.randn(num_images, latent_dim, device=device, requires_grad=True))
optimizer_model = optim.Adam(glo_model.parameters(), lr=1e-4)
optimizer_latent = optim.Adam([latent_vectors], lr=1e-4)
loss_fn_latent = nn.MSELoss()
loss_fn_model = nn.MSELoss()

for epoch in tqdm(range(num_epochs)):
    for idx_batch, (indices, gt_batch) in enumerate(data_loader):
        indices = indices.to(device)
        gt_batch = gt_batch.to(device)
        for idx, idx_global in enumerate(indices):
            optimizer_latent.zero_grad()
            latent_vector = latent_vectors[idx_global].unsqueeze(0)
            recon_img = glo_model(latent_vector).squeeze(0)
            loss_latent = loss_fn_latent(recon_img, gt_batch[idx])
            loss_latent.backward()
            optimizer_latent.step()
        
        optimizer_model.zero_grad()
        batch_latent_vectors = latent_vectors[indices]
        recon_batch = glo_model(batch_latent_vectors)
        loss_model = loss_fn_model(recon_batch, gt_batch)
        loss_model.backward()
        optimizer_model.step()

        '''
        #real_images = real_images.view(real_images.size(0), -1)
        real_images = real_images.view(real_images.size(0), 3, 400, 400)
        batch_latent_vectors = latent_vectors[i*data_loader.batch_size:(i+1)*data_loader.batch_size]        
        optimizer_model.zero_grad()
        optimizer_model.zero_grad()
        #generated_images = glo_model(latent_vectors[i])
        generated_images = glo_model(batch_latent_vectors)
        loss = criterion(generated_images, real_images)
        loss.backward()
        optimizer_model.step()
        optimizer_model.step()'''
    
    print(f'Epoch {epoch+1}, Loss: {loss_model.item()}')
    #plt.imshow(recon_batch[5].detach().cpu().numpy().transpose(1, 2, 0))
    #plt.show()
    #plt.imshow(gt_batch[0].detach().cpu().numpy().transpose(1, 2, 0))
    #plt.show()

# Save your model and latent vectors for future use
#torch.save(glo_model.state_dict(), 'glo_model.pth')
#torch.save(latent_vectors, 'latent_vectors.pth')

In [None]:
fig, axs = plt.subplots(25, 8, figsize=(25, 75))
axs = axs.flatten()
count = 0

for idx_batch, (indices, gt_batch) in enumerate(data_loader):
    if count >= 200:  # Stop if we have filled up our grid
        break
    
    indices = indices.to(device)
    gt_batch = gt_batch.to(device)
    
    for idx, idx_global in enumerate(indices):
        if count >= 200:  # Check again inside the loop
            break
        
        # Generate the reconstructed image
        latent_vector = latent_vectors[idx_global].unsqueeze(0)
        recon_img = glo_model(latent_vector).squeeze(0)
        
        # Plot the ground truth image
        axs[count].imshow(gt_batch[idx].detach().cpu().numpy().transpose(1, 2, 0))
        axs[count].axis('off')  # Remove axis for clarity
        count += 1  # Increment counter
        
        if count >= 200:  # Check again after plotting the real image
            break
        
        # Plot the reconstructed image
        axs[count].imshow(recon_img.detach().cpu().numpy().transpose(1, 2, 0))
        axs[count].axis('off')  # Remove axis for clarity
        count += 1  # Increment counter

plt.tight_layout()
plt.show()


'''for idx_batch, (indices, gt_batch) in enumerate(data_loader):
    indices = indices.to(device)
    gt_batch = gt_batch.to(device)
    for idx, idx_global in enumerate(indices):
        latent_vector = latent_vectors[idx_global].unsqueeze(0)
        recon_img = glo_model(latent_vector).squeeze(0)
        plt.imshow(gt_batch[idx].detach().cpu().numpy().transpose(1, 2, 0))
        plt.show()
        plt.imshow(recon_img.squeeze(0).detach().cpu().numpy().transpose(1, 2, 0))
        plt.show()'''