In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.nn.functional import mse_loss
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from tqdm import tqdm

PATH_TRAINING = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/helmet/_temp/train/imgs'
PATH_TESTING = 'C:/_sw/eb_python/deep_learning/_dataset/NeRF/images/helmet/_temp/test/imgs'
BATCH_SIZE = 10
NUM_EPOCHS = 10000
LATENT_DIM = 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class GLOModel_CNN(nn.Module):
    def __init__(self, latent_dim, output_shape=(3, 400, 400)):
        super(GLOModel_CNN, self).__init__()
        self.init_size = output_shape[1] // 8  # This will scale down the output shape
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, output_shape[0], 3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        img = nn.functional.interpolate(img, size=(400, 400), mode='bilinear', align_corners=False)
        return img
    
class GLOModel_MLP(nn.Module):
    def __init__(self, latent_dim, output_shape=(3, 400, 400)):
        super(GLOModel_MLP, self).__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, torch.prod(torch.tensor(output_shape))),
        )
        self.output_shape = output_shape

    def forward(self, x):
        x = self.fc_layers(x)
        x = x.view(-1, *self.output_shape)
        x = torch.sigmoid(x)
        return x

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.images = os.listdir(directory)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.directory, self.images[idx])
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return idx, image

def calculate_psnr(gt_batch, recon_batch, max_pixel_value=1.0):
    lst_psnr = []
    for (gt, recon) in zip(gt_batch, recon_batch):
        mse = mse_loss(recon, gt, reduction='mean')
        psnr = 20 * torch.log10(max_pixel_value / torch.sqrt(mse))
        lst_psnr.append(psnr.item())
        
    return lst_psnr

In [None]:
data_loader = DataLoader(
    CustomImageDataset(directory=PATH_TRAINING,
                       transform=transforms.Compose([transforms.Resize((400, 400)), transforms.ToTensor()])),
    batch_size=BATCH_SIZE,
    shuffle=False
)
num_images = len(data_loader.dataset)

glo_model = GLOModel_CNN(LATENT_DIM).to(DEVICE)
latent_vectors = nn.Parameter(torch.randn(num_images, LATENT_DIM, device=DEVICE, requires_grad=True))

optimizer_model = optim.Adam(glo_model.parameters(), lr=1e-4)
optimizer_latent = optim.Adam([latent_vectors], lr=1e-4)
loss_fn_latent = nn.MSELoss()
loss_fn_model = nn.MSELoss()

all_epochs_psnr = []
all_epochs_loss = []
for epoch in tqdm(range(NUM_EPOCHS)):
    epoch_psnr = [] = []
    for idx_batch, (indices, gt_batch) in enumerate(data_loader):
        indices = indices.to(DEVICE)
        gt_batch = gt_batch.to(DEVICE)
        for idx, idx_global in enumerate(indices):
            optimizer_latent.zero_grad()
            latent_vector = latent_vectors[idx_global].unsqueeze(0)
            recon_img = glo_model(latent_vector).squeeze(0)
            loss_latent = loss_fn_latent(recon_img, gt_batch[idx])
            loss_latent.backward()
            optimizer_latent.step()
        
        optimizer_model.zero_grad()
        batch_latent_vectors = latent_vectors[indices]
        recon_batch = glo_model(batch_latent_vectors)
        loss_model = loss_fn_model(recon_batch, gt_batch)
        loss_model.backward()
        optimizer_model.step()
        epoch_psnr.extend(calculate_psnr(gt_batch, recon_batch))
    
    all_epochs_psnr.append(epoch_psnr)
    all_epochs_loss.append(loss_model.item())
    #print(f'Epoch {epoch+1}, Loss: {loss_model.item():.4f}')
    #print(f'PSNR, min: {np.min(all_epochs_psnr[-1]):.1f}, max: {np.max(all_epochs_psnr[-1]):.1f}, mean: {np.mean(all_epochs_psnr[-1]):.1f}, std: {np.std((all_epochs_psnr[-1])):.1f}\r\n')
    
# Save your model and latent vectors for future use
#torch.save(glo_model.state_dict(), 'glo_model.pth')
#torch.save(latent_vectors, 'latent_vectors.pth')

In [None]:
min, max, mean, std = [], [], [], []
for epoch_psnr in all_epochs_psnr:
    min.append(np.min(epoch_psnr))
    max.append(np.max(epoch_psnr))
    mean.append(np.mean(epoch_psnr))
    std.append(np.std(epoch_psnr))

fig, ax1 = plt.subplots()
ax1.plot(min, 'g', label='Min psnr', lw = 0.75)
ax1.plot(max, 'b', label='Max psnr', lw = 0.75)
ax1.plot(mean, 'r', label='Mean psnr', lw = 0.75)
ax1.plot(std, 'y', label='Std psnr', lw = 0.75)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('PSNR', color='black')

ax2 = ax1.twinx()
ax2.plot(all_epochs_loss, 'c', label='Loss', lw = 0.75)
ax2.set_ylabel('Loss', color='black')

ax1.grid(True)

handles1, labels1 = ax1.get_legend_handles_labels()
handles2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(handles1 + handles2, labels1 + labels2, loc='best')

plt.title('Loss and min/max/mean of PSNR per epoch')
plt.show()

In [None]:
# Save your model and latent vectors for future use
torch.save(glo_model.state_dict(), 'glo_model_epoch=1000.pth')
torch.save(latent_vectors, 'latent_vectors_epoch=1000.pth')

In [None]:
fig, axs = plt.subplots(25, 8, figsize=(25, 75))
axs = axs.flatten()
count = 0

for idx_batch, (indices, gt_batch) in enumerate(data_loader):
    indices = indices.to(DEVICE)
    gt_batch = gt_batch.to(DEVICE)
    for idx, idx_global in enumerate(indices):
        axs[count].imshow(gt_batch[idx].detach().cpu().numpy().transpose(1, 2, 0))
        axs[count].axis('off')
        count += 1
        if count >= 200:
            break
        
        latent_vector = latent_vectors[idx_global].unsqueeze(0)
        recon_img = glo_model(latent_vector).squeeze(0)
        axs[count].imshow(recon_img.detach().cpu().numpy().transpose(1, 2, 0))
        axs[count].axis('off')
        count += 1
        if count >= 200:
            break

plt.tight_layout()
plt.show()