In [1]:
!pip install rawpy



In [1]:
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_ROOT = '/content/drive/MyDrive/ML_ECOL_project'
os.chdir(PROJECT_ROOT)

DATA_PATH = os.path.join(PROJECT_ROOT, 'data')
DATASET_PATH = os.path.join(PROJECT_ROOT, 'dataset')
CHECKPOINTS_PATH = os.path.join(PROJECT_ROOT, 'checkpoints')
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)

import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder

import numpy as np
import matplotlib.pyplot as plt 
import cv2
import rawpy 
import imageio 
from PIL import Image
from tqdm import tqdm


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# === NEF - JPG ===

In [None]:
class NEFConverter():

    def find_nef_images(self, root_dir):
        self.root_dir = root_dir

        subdir_paths = []
        for subdir in os.listdir(self.root_dir):
            if os.path.isdir(os.path.join(self.root_dir, subdir)):
                subdir_paths.append(os.path.join(self.root_dir, subdir))

        self.img_paths = []
        for subdir_path in subdir_paths:
            for img_name in os.listdir(subdir_path):
                _, extension = os.path.splitext(img_name)
                if extension == '.nef':
                    self.img_paths.append(os.path.join(subdir_path, img_name))


    def convert(self):
        for i, img_path in enumerate(self.img_paths):

            print(f'Image {i}/{len(self.img_paths)}')
            path_wo_extension, extension = os.path.splitext(img_path)
            save_path_jpg = f'{path_wo_extension}.jpg'

            if os.path.isfile(save_path_jpg):
                print(f'Already exists: {save_path_jpg}')
                print('Skipping\n')
                continue

            with rawpy.imread(img_path) as raw:
                rgb = raw.postprocess()

                print(f'Saving to {save_path_jpg}')
                imageio.imsave(save_path_jpg, rgb)
                print('Done\n')






In [None]:
converter = NEFConverter()
converter.find_nef_images(DATA_PATH)

In [4]:
from time import time 

start = time()
converter.convert()
end = time()

print(end - start)


# === MODEL ===

In [2]:
class AEContractingBlock(nn.Module):
    
    def __init__(
        self, in_channels, out_channels=None, 
        use_bn = True, use_dropout = False
        ):
        super(AEContractingBlock, self).__init__()
        
        if out_channels is None:
            out_channels = 2 * in_channels
        self.use_bn = use_bn
        self.use_dropout = use_dropout

        self.conv = nn.Conv2d(
            in_channels, out_channels, 
            kernel_size = 4, stride = 2, padding = 1
            )
        
        if self.use_bn:
            self.bn = nn.BatchNorm2d(out_channels)
        if self.use_dropout:
            self.dropout = nn.Dropout()
        self.activation = nn.LeakyReLU(0.2)


    def forward(self, x):

        out = self.conv(x)
        if self.use_bn:
            out = self.bn(out)
        if self.use_dropout:
            out = self.dropout(out)
        out = self.activation(out)

        return out



class AEExpandingBlock(nn.Module):

    def __init__(
        self, in_channels, out_channels=None, 
        use_bn = True, use_dropout = False
        ):
        super(AEExpandingBlock, self).__init__()

        if out_channels is None:
            out_channels = in_channels // 2
        self.use_bn = use_bn
        self.use_dropout = use_dropout

        self.conv = nn.ConvTranspose2d(
            in_channels, out_channels, 
            kernel_size = 4, stride = 2, padding = 1
            )
        
        if self.use_bn:
            self.bn = nn.BatchNorm2d(out_channels)
        if self.use_dropout:
            self.dropout = nn.Dropout()
        self.activation = nn.ReLU() 
 

    def forward(self, x):
        
        out = self.conv(x)
        if self.use_bn:
            out = self.bn(out)
        if self.use_dropout:
            out = self.dropout(out)
        out = self.activation(out)

        return out



class Encoder(nn.Module):

    def __init__(self, in_channels, hidden_channels, depth):
        super(Encoder, self).__init__()

        self.depth = depth

        # Input -> hidden mapping
        self.set_hidden_channels = nn.Conv2d(
            in_channels, hidden_channels, kernel_size=1
            )

        # Contracting layers
        self.contracting_layers = nn.ModuleList()
        start_scale = int(hidden_channels / 2)
        contracting_scales = [2*i*start_scale for i in range(1, self.depth + 2)]
        for i in range(self.depth):
            curr_in_channels = contracting_scales[i]
            curr_out_channels = contracting_scales[i + 1]

            self.contracting_layers.append(
                AEContractingBlock(
                    curr_in_channels, curr_out_channels, use_dropout=True
                    )
                )
            

    def forward(self, x):
        out = self.set_hidden_channels(x)

        for i in range(self.depth):
            out = self.contracting_layers[i](out)

        return out



class Decoder(nn.Module):

    def __init__(self, hidden_channels, out_channels, depth):
        super(Decoder, self).__init__()

        self.depth = depth

        # Expanding layers
        self.expanding_layers = nn.ModuleList()
        start_scale = int(hidden_channels / 2)
        expanding_scales = [2*i*start_scale for i in range(1, self.depth + 2)]
        expanding_scales = expanding_scales[ : : -1]
        for i in range(self.depth):
            curr_in_channels = expanding_scales[i]
            curr_out_channels = expanding_scales[i + 1]
            
            self.expanding_layers.append(
                AEExpandingBlock(
                    curr_in_channels, curr_out_channels, use_dropout=True
                    )
                )

        # Hidden -> output mapping
        self.set_output_channels = nn.Conv2d(
            hidden_channels, out_channels, kernel_size=1
            )
    

    def forward(self, x):
        out = x
        for i in range(self.depth):
            out = self.expanding_layers[i](out)

        out = self.set_output_channels(out)

        return out



class AutoencoderCNN(nn.Module):

    def __init__(self, in_channels, out_channels, hidden_channels=64, depth=4):
        super(AutoencoderCNN, self).__init__()

        self.depth = depth
        self.checkpoint_dir = CHECKPOINTS_PATH

        # Encoder
        self.encoder = Encoder(in_channels, hidden_channels, depth)
        # Decoder 
        self.decoder = Decoder(hidden_channels, out_channels, depth)

 
    def forward(self, x):
        encoding = self.encoder(x)
        out = self.decoder(encoding)
        
        return {'out': out, 'encoding': encoding}

    
    def train_epoch(self, loader, optimizer, epoch_idx, device, max_iters=None):
        self.train()
        batch_size = loader.batch_size
        running_loss = 0.0

        train_len = len(loader) if max_iters is None else max_iters

        progress_bar = tqdm(
            loader, total=train_len, desc=f'Epoch {epoch_idx}'
            )

        for batch_idx, data in enumerate(progress_bar):
            if max_iters is not None:
                if batch_idx == max_iters:
                    break

            data = data.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            output = self(data)['out']
            loss = nn.functional.mse_loss(output, data)
            loss.backward()
            optimizer.step()

            # log statistics
            running_loss += loss.item()
            avg_loss = running_loss / (batch_idx + 1)
            progress_bar.set_postfix(loss=avg_loss, stage="train")
        
        if max_iters is None:
            return running_loss / len(loader)
        else:
            return running_loss / max_iters 


    def evaluate(self, loader, device, epoch_idx=None, should_print=True):        
        self.eval()
        test_loss, correct = 0, 0

        if epoch_idx is None:
          description = 'Evaluating'
        else:
          description = f'Validation epoch {epoch_idx}'

        progress_bar = tqdm(loader, total=len(loader), desc=description)
        
        with torch.no_grad():
            for batch_idx, data in enumerate(progress_bar):
                data = data.to(device)

                output = self(data)['out']

                loss = nn.functional.cross_entropy(output, data)
                test_loss += loss.item()
                avg_loss = test_loss / (batch_idx + 1)

                progress_bar.set_postfix(loss=avg_loss, stage="validation")

        test_loss /= len(loader.dataset)

        if should_print:
          print(f'\nAvg. loss: {test_loss:.4f}')

        return test_loss, accuracy


    def save_checkpoint(self, epoch=None, checkpoint_name=None):
        if checkpoint_name is None:
            checkpoint_name = 'latest_checkpoint.pt'
        save_path = os.path.join(self.checkpoint_dir, checkpoint_name)
        torch.save({'model': self, 'epoch': str(epoch)}, save_path)


    def plot_progress(
        self, 
        loader, 
        num_batches, 
        save_dir, 
        device, 
        show_plot=False, 
        normalized=True
        ):

        self.eval()

        for i, data in enumerate(loader):
            if i == num_batches:
                break 
            data = data.to(device)
            pred = self(data)['out']

            in_out_batch = torch.concat([data, pred], dim=0)
            grid_img = torchvision.utils.make_grid(
                in_out_batch, nrow=loader.batch_size
                )
            
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, f'batch_{i}.png')

            figsize = (4 * loader.batch_size, 10)
            show_image(
                grid_img.cpu(), 
                normalized=normalized, 
                figsize=figsize, 
                save_path=save_path, 
                show_plot=show_plot
                )


    def fit(self, loaders, optimizer, epochs, device, max_iters=None):
        train_loader, val_loader = loaders

        train_losses, val_losses = [], []
        for epoch in range(epochs):
            # Training epoch
            running_loss = self.train_epoch(
                loader=train_loader, 
                optimizer=optimizer, 
                epoch_idx=epoch,
                device=device,
                max_iters=max_iters
                )
            train_losses.append(running_loss)

            # Validation
            if val_loader is not None:
                val_loss, _ = self.evaluate(
                    val_loader, device, should_print=False, epoch_idx=epoch
                    )
                val_losses.append(val_loss)

            # Save checkpoint
            self.save_checkpoint(epoch=epoch)

            # Plot current output images
            progress_plot_dir = os.path.join(
                CHECKPOINTS_PATH, 'plots', f'progress_epoch_{epoch}'
                )
            self.plot_progress(train_loader, 5, progress_plot_dir, device)

        self.latest_train_losses = train_losses
        self.latest_val_losses = val_losses


# === DATASET ===

In [3]:
class UnNormalize(object):
  def __init__(self, mean, std):
    self.mean = mean
    self.std = std

  def __call__(self, tensor):
    """
    Args:
      tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
    Returns:
      Tensor: Normalized image.
    """
    for t, m, s in zip(tensor, self.mean, self.std):
      t.mul_(s).add_(m)
      # The normalize code -> t.sub_(m).div_(s)
    return tensor


def show_image(img, figsize=(9, 6), normalized=False, save_path=None, show_plot=True):
    MEANS = (0.485, 0.456, 0.406)
    STDS = (0.229, 0.224, 0.225)
    unorm = UnNormalize(MEANS, STDS)

    img_to_show = img if not normalized else unorm(img)

    plt.figure(figsize=figsize)    
    plt.imshow(img_to_show.permute(1, 2, 0))
    if save_path is not None:
        plt.savefig(save_path)
    if show_plot:
        plt.show()
    plt.close()

In [4]:
class ECOLDataset(Dataset):

    def __init__(self, root_dir, transform=None, img_size=(512, 512)):
        super(ECOLDataset, self).__init__()

        self.root_dir = os.path.abspath(root_dir)
        self.MEANS = (0.485, 0.456, 0.406)
        self.STDS = (0.229, 0.224, 0.225)
        self.img_size = img_size

        # Iterate through the dataset iterator and count the elements
        dataset_iterator = os.scandir(root_dir)
        i = -1
        for i, _ in enumerate(dataset_iterator):
            pass

        self.len = i + 1

        # Torchvision transformation to be applied to each image
        if transform is not None:
            self.transform = transform 
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(self.img_size),
                transforms.Normalize(mean=self.MEANS, std=self.STDS),
                ])
            

    def __len__(self):
        # Has to exist in order to inherit the parent class properly!
        return self.len


    def __getitem__(self, idx):

        dataset_iterator = os.scandir(self.root_dir)
        img_path = None
        for i, curr_img_name in enumerate(dataset_iterator):
            if i == idx:
                img_path = os.path.join(self.root_dir, curr_img_name.name)
                break

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        transformed_image = self.transform(image)

        return transformed_image



def load_dataset_ECOL(
    root_dir, 
    img_size, 
    batch_size, 
    num_workers=0, 
    shuffle=True,
    transform=None,
    ):

    dataset = ECOLDataset(root_dir, transform, img_size) 

    # Create DataLoader object for the dataset  
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
        ) 

    return loader

In [8]:
IN_CHANNELS = 3
OUT_CHANNELS = 3
HIDDEN_CHANNELS = 2
DEPTH = 4
EPOCHS = 10
LEARNING_RATE = 0.001
MAX_ITERS = 1000

AE_PARAMS = (IN_CHANNELS, OUT_CHANNELS, HIDDEN_CHANNELS, DEPTH)

LOADER_PARAMS = {
    'root_dir': DATASET_PATH, 'img_size': (256, 256), 'batch_size': 4
    }

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cuda


In [9]:
# Loading the Data Loader
dataloader = load_dataset_ECOL(**LOADER_PARAMS)

# Loading the AE model 
model = AutoencoderCNN(*AE_PARAMS)
model = model.to(DEVICE)

# Loading the Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), LEARNING_RATE)

In [None]:
model.fit([dataloader, None], optimizer, EPOCHS, DEVICE, MAX_ITERS)

Epoch 0: 100%|██████████| 1000/1000 [35:14<00:00,  2.11s/it, loss=0.219, stage=train]
Epoch 1: 100%|██████████| 1000/1000 [31:58<00:00,  1.92s/it, loss=0.171, stage=train]
Epoch 2: 100%|██████████| 1000/1000 [28:18<00:00,  1.70s/it, loss=0.164, stage=train]
Epoch 3:  76%|███████▌  | 760/1000 [20:17<06:51,  1.71s/it, loss=0.155, stage=train]

# === UNUSED STUFF === 

In [91]:
loaded = torch.load(os.path.join(CHECKPOINTS_PATH, 'latest_checkpoint.pt'))
loaded_model = loaded['model'].to(DEVICE)


In [110]:
pred = loaded_model(next(iter(dataloader)).to(DEVICE))

In [111]:
pred['encoding'].shape

torch.Size([4, 10, 16, 16])

In [None]:
class ECOLDatasetRaw(Dataset):

    def __init__(
        self, 
        root_dir, 
        transform=None, 
        extensions=['.jpg', '.nef'],
        img_size=(512, 512)
        ):
        super(ECOLDatasetRaw, self).__init__()

        self.root_dir = os.path.abspath(root_dir)
        self.extensions = extensions
        self.MEANS = (0.485, 0.456, 0.406)
        self.STDS = (0.229, 0.224, 0.225)
        self.img_size = img_size

        subdir_paths = []
        for subdir in os.listdir(self.root_dir):
            if os.path.isdir(os.path.join(self.root_dir, subdir)):
                subdir_paths.append(os.path.join(self.root_dir, subdir))

        self.img_paths = []
        for subdir_path in subdir_paths:
            for img_name in os.listdir(subdir_path):
                _, extension = os.path.splitext(img_name)
                if extension in self.extensions:
                    self.img_paths.append(os.path.join(subdir_path, img_name))

        self.len = len(self.img_paths)

        if transform is not None:
            self.transform = transform 
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(self.img_size),
                transforms.Normalize(mean=self.MEANS, std=self.STDS),
                ])
            

    def __len__(self):
        # Has to exist in order to inherit the parent class properly!
        return self.len


    def __getitem__(self, idx):
        
        img_path = self.img_paths[idx]

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        transformed_image = self.transform(image)

        return transformed_image



def load_dataset_ECOL_raw(
    root_dir, 
    img_size, 
    batch_size, 
    num_workers=0, 
    shuffle=True,
    transform=None,
    extensions=['.jpg']
    ):

    dataset = ECOLDatasetRaw(
        root_dir=root_dir, 
        img_size=img_size, 
        transform=transform, 
        extensions=extensions
        ) 

    # Create DataLoader object for the dataset  
    train_loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
        ) 

    return train_loader