### This notebook provides code for embeddings visualization with tensorboard. Embeddings are computed from a Vision Transformers Masked Auto Encoder neural network, that was trained over Imagenet with image size 64 x 64. Embeddings can be visualized through PCA/T-SNE/UMAP projection from tensorboard framework, and we also provide code to visualize image content and labels with tensorboard. 

In this notebook we compute embeddings of weather image recognition dataset available <a href="https://www.kaggle.com/datasets/fceb22ab5e1d5288200c0f3016ccd626276983ca1fe8705ae2c32f7064d719de">here<a> and holding CC0 licence.




# Imports 

In [1]:
import csv
import cv2


import tensorflow as tf
from collections import defaultdict
from PIL import Image
import torch
import torchvision.datasets as datasets
from torchvision import transforms
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
import numpy as np
import pickle
from tqdm import tqdm
import torch.nn.functional as F
import random

from models_mae import MaskedAutoencoderViT

from functools import partial
from torch import nn

from utils.projector_utils import DATA, images_to_sprite,extractor
from utils.load_models import load_mae
from utils.shape_filtering import shape_filter
import os 


import datetime


inet_mean = [0.485, 0.456, 0.406]
inet_std = [0.229, 0.224, 0.225]
import numpy as np

from torch.utils.tensorboard import SummaryWriter
from functools import partial
import torch.nn as nn
import timm.optim.optim_factory as optim_factory
from models_mae import MaskedAutoencoderViT

# Creating logs folders for tensorboard logs and data


In [2]:

current = os.getcwd()
DATA_PATH = current + '\\dataset' #folder to store images
LOG_DIR = current + '\\logs'

if not os.path.exists(LOG_DIR):
    os.mkdir(LOG_DIR)

if not os.path.exists(DATA_PATH):
    from zipfile import ZipFile
    with ZipFile("archive.zip", 'r') as zObject: #the weather image recognition dataset provide this "archive.zip" file when it's downloaded from kaggle.
        zObject.extractall()

### Check that cuda is available to make things faster

In [3]:
import torch
device = 'cuda' if torch.cuda.is_available() else('cpu')
print(device)

cuda


# Load the model and the data

In [4]:
#Loading VitMAE model pretrained on imagenet 64*64
model = load_mae().to(device)

Model loaded.


In [5]:
#Loading list of image files

data_path_length = len(DATA_PATH)
import glob
classes_list = glob.glob(os.path.join(DATA_PATH, '*') )


classes = {}

k = 0
imgs_list = []
for k in range(len(classes_list)):
    classes[str(k)] = classes_list[k][data_path_length:]  
    class_imgs = glob.glob(os.path.join(classes_list[k], '*.jpg') )
    for j in class_imgs:
        imgs_list.append([j,k])

In [6]:
#Tensorboard won't be able to display more than 1500 images, so it is recommended to sample some of them for visualization. Here we sample among the 3 first classes.
imgs_list = random.sample(imgs_list[:2662],1500)

In [7]:
imgs_list = shape_filter(imgs_list) #filtering out image with wrong number of channels
weather = DATA(imgs_list, transform = None, mean = inet_mean,std = inet_std,img_size = 64)
e = extractor(model,weather) #extractor is used to extract images, embeddings or label

100%|██████████| 1500/1500 [00:09<00:00, 164.44it/s]

Number of bad images :  39  /  1500





# Generate embeddings and sprites, might take some time depending on your GPU

In [8]:
images_pil = []
images_embeddings = []
labels = []
for x in tqdm(range(weather.__len__())): 
    img_pil = e.get_img(x)
    img_embedding = e.get_embed(x)
    images_embeddings.append(img_embedding.cpu().detach().numpy())
    images_pil.append(np.array(img_pil))
    # Assuming your output data is directly the label
    label = e.get_label(x)
    labels.append(label)

100%|██████████| 1461/1461 [00:32<00:00, 45.15it/s]


In [9]:
#Saving sprite
sprite = images_to_sprite(np.array(images_pil))
cv2.imwrite(f'{LOG_DIR}\\sprite.jpg', sprite)

True

In [10]:
with open(f'{LOG_DIR}\\feature_vecs.tsv', 'w') as fw:
    csv_writer = csv.writer(fw, delimiter='\t')
    csv_writer.writerows(images_embeddings)

In [11]:
with open(f'{LOG_DIR}\\metadata.tsv', 'w') as file: 
    for label in labels:
        #file.write(f"{classes[str(label)]}\n")
        file.write(f"{label}\n")

In [12]:
#to get number of images per line in the sprite
int(np.ceil(np.sqrt(np.array(images_pil).shape[0]))) 

39

# Build proper config file for tensorboard and then visualize the embeddings


In [13]:
"""
This will be written in the config file.
embeddings {
  metadata_path: "metadata.tsv"
  tensor_path: "feature_vecs.tsv"
  sprite {
    image_path: "sprite.jpg"
    single_image_dim: 64
    single_image_dim: 64
  }
}
"""

text_file = open(LOG_DIR+'\\projector_config.pbtxt', "w")
 
text_file.write('embeddings {\n  metadata_path: "metadata.tsv"\n  tensor_path: "feature_vecs.tsv"\n  sprite {\n    image_path: "sprite.jpg"\n    single_image_dim: 64\n    single_image_dim: 64\n  }\n}')
 
text_file.close()

In [14]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
    pass

%load_ext tensorboard

In [15]:
#%reload_ext tensorboard #In case you made some experiments and whant to display another tensorboard.

In [16]:
%tensorboard --logdir ./logs #Go to projector


In [17]:
#On Windows, you might not be able to kill tensorboard process. In this case you should clear its data by removing this fodler : C:\Users\USER\AppData\Local\Temp\.tensorboard-info