In [25]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
from transformers import FlaxCLIPModel
# from src.step_utils import CLIPProcessor
import jax
from flax.training import common_utils, train_state

# CSV creation (with image names)

In [26]:
import os
import csv
import pandas as pd
path = './putting-nerf-on-a-diet/data/phototourism/notre/images'

data = {
    'image_name': []
}
for root, dirs, files in os.walk(path):
    for filename in files:
        data['image_name'].append(str(filename))

df = pd.DataFrame(data)
df.to_csv('phototourism_notre.csv')

In [29]:
pd.read_csv('putting-nerf-on-a-diet/data/phototourism/notre/phototourism_notre.csv').head()

Unnamed: 0.1,Unnamed: 0,image_name
0,0,omaromar_11727516.jpg
1,1,jtriefen_370744178.rd.jpg
2,2,m500_2346149454.rd.jpg
3,3,algreen_2068599265.jpg
4,4,13284673@N00_416967273.jpg


In [30]:
class NeRF_Dataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.img_names = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_name = os.path.join(self.root_dir,
                                self.img_names.iloc[idx, 1])      # Image name is in the 1th column of the csv file, so grabbing image name from idx row
        image = Image.open(img_name).convert('RGB')
        
        sample = {'image': image}

        if self.transform:
            sample = self.transform(sample['image'])

        return sample

# Need to make all the images of same shape, so dataloader can efficiently load
img_size = [300, 300]
data_transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor()
        ])
dataset = NeRF_Dataset(csv_file='putting-nerf-on-a-diet/data/phototourism/notre/phototourism_notre.csv',
                                    root_dir='putting-nerf-on-a-diet/data/phototourism/notre/images',
                                    transform=data_transform)

In [31]:
dataloader = DataLoader(dataset, batch_size=16,
                        shuffle=True, num_workers=8)


In [33]:
CLIP_model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32", dtype = np.float16)

state = {
    'imgs': [],
    'embeded_imgs': []
}

def embed_images(state, images):

    # CLIP_model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32", dtype = np.float16)

    def CLIPProcessor(image):
        '''
            jax-based preprocessing for CLIP

            image  [B, 3, H, W]: batch image
            return [B, 3, 224, 224]: pre-processed image for CLIP
        '''
        B,D,H,W = image.shape
        image = jax.image.resize(image, (B,D,224,224), 'bicubic') # assume that images have rectangle shape. 
        mean = np.array([0.48145466, 0.4578275, 0.40821073]).reshape(1,3,1,1)
        std = np.array([0.26862954, 0.26130258, 0.27577711]).reshape(1,3,1,1)
        image = (image - mean.astype(image.dtype)) / std.astype(image.dtype) 
        return image


    for img in images:
        img = img.transpose(1, 2, 0)  # changing shape from CxHxW to HxWxC
        state['imgs'].append(img)
        H, W = img.shape[:2]
        i, j = np.meshgrid(np.arange(0, W, 4), np.arange(0, H, 4), indexing='xy')
        images = img[j, i]
        images /= 255.
        target_emb = CLIP_model.get_image_features(pixel_values=CLIPProcessor(np.expand_dims(images,0).transpose(0,3,1,2)))
        target_emb /= np.linalg.norm(target_emb, axis=-1, keepdims=True)
        state['embeded_imgs'].append(target_emb)
    
    return state

# p_embed_images = jax.pmap(embed_images)
# j_embed_images = jax.jit(embed_images)


    

In [None]:
%%time
for images in dataloader:
    embed_images(state, np.array(images))
    # state = p_embed_images(state, images, CLIP_model)