### Extract Embeddings

Based on https://github.com/rom1504/clip-retrieval

First, pip install clip-retrieval

`pip install git+https://github.com/openai/CLIP.git`

TODO:
- labels fed to Kairos should be species/insect type, not "Noisy" vs "clean"
- Need to save list of indexes of noisy images in iNat

In [1]:
import os
import torch
import clip
from PIL import Image
from tqdm import tqdm
import numpy as np
from datasets import load_dataset

import pandas as pd
from utils.label_mappings import *

OUT_DIR = 'data/embs'
os.makedirs(OUT_DIR, exist_ok=True)

In [2]:
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

#### iNaturalist Embeddings
Need to get embeddings for all images in iNat dataset (3.3GB) for Kairos to curate the insects from the rest (noisy). 36355 rows/images

In [3]:
# load iNat data
iNat36 = load_dataset("sxj1215/inaturalist", split='train') #36k rows #3.3 GB
ids = list(range(len(iNat36)))
iNat36 = iNat36.add_column("id", ids) #not idempotent
iNat36

Dataset({
    features: ['messages', 'images', 'id'],
    num_rows: 36355
})

In [4]:
iNat36_label_df = pd.DataFrame({'messages': iNat36['messages'], 'id': iNat36['id']})

def get_iNat_label(messages):
    return messages[1]['content']
    
iNat36_label_df['species'] = iNat36_label_df['messages'].apply(get_iNat_label)

def map_inat_to_clean_label(label):
    if label in iNat_to_clean_map:
        return iNat_to_clean_map[label]
    else:
        return 'noise'
        #add index to noisy index list here?
        
iNat36_label_df['clean_label'] = iNat36_label_df['species'].apply(map_inat_to_clean_label)

In [5]:
noisy_idxs = iNat36_label_df[iNat36_label_df['clean_label'] == 'noise'].index
np.save(os.path.join(OUT_DIR, f"inat_noisy_indexes.npy"), noisy_idxs)

In [6]:
def generate_inat_embs(inat_split_ds, file_prefix, out_dir):
    '''
    inat_split_ds is a split (train, test) of the iNat dataset
    
    file_prefix is a string to clearly label files and distinguish different groups of embeddings
    such as 'train_inat' or 'test_inat' 

    out_dir is the directory the data will be saved to
    '''
    embs = []
    labels = []
    row_ids = []

    for i in tqdm(range(len(inat_split_ds))):
        try:
            row = inat_split_ds[i]
            row_id = row['id']
            img = row["images"][0]
            img = preprocess(img).unsqueeze(0).to(device)
            with torch.no_grad():
                feats = model.encode_image(img)
                feats /= feats.norm(dim=-1, keepdim=True)
            embs.append(feats.cpu().numpy())
            labels.append(iNat36_label_df.iloc[idx]["clean_label"]) # can use species as label for now, but need clean_label later for resnet
            row_ids.append(row_id)
        except Exception as e:
            embs.append(np.zeros(512).reshape(1, 512)) #need to add placeholder so entries line up later
            labels.append('noise')
            row_ids.append(row_id)
            continue
    
    print(f"Successfully processed {len(embs)} examples")

    if embs:
        emb_matrix = np.vstack(embs)
        
        np.save(os.path.join(OUT_DIR, f"{file_prefix}_embeddings.npy"), emb_matrix)
        np.save(os.path.join(OUT_DIR, f"{file_prefix}_labels.npy"), labels)
        
        with open(os.path.join(OUT_DIR, f"{file_prefix}_row_ids.txt"), "w") as f:
            f.write(str(row_ids))
            
        print(f"Success! Saved {emb_matrix.shape} matrix to {OUT_DIR}")


    return embs, labels, row_ids

In [7]:
inat_embeddings, inat_labels, inat_row_ids = generate_inat_embs(iNat36, 'inat', OUT_DIR)

  2%|‚ñè         | 865/36355 [01:38<1:07:19,  8.79it/s]


KeyboardInterrupt: 

#### Kaggle clean embeddings

In [None]:
from utils.sample_clean_data import kairos_clean_data, test_clean_data # stratefied random sampled data, rest of data
label_pos_in_path = 2

def generate_clean_embs(images_var, file_prefix, out_dir):
    '''
    images var is a list of filepaths and is defined in sample_clean_data.py
    
    file_prefix is a string to clearly label files and distinguish different groups of embeddings
    such as 'kairos_clean' or 'test_clean' 

    out_dir is the directory the data will be saved to
    '''
    embs = []
    labels = []
    filepaths = []

    for path in tqdm(images_var): # path = data/clean_insect_images/Ant/Ant_283.jpg
        try:
            image = preprocess(Image.open(path)).unsqueeze(0).to(device)
            
            with torch.no_grad():
                features = model.encode_image(image)
                features /= features.norm(dim=-1, keepdim=True)    # normalize for cosine similarity
                
            embs.append(features.cpu().numpy())
            labels.append(path.split('/')[label_pos_in_path])
            filepaths.append(path)
            
        except Exception as e:
            print(f"Skipping corrupt image {path}: {e}")
    
    print(f"Successfully processed {len(embs)} examples")
    
    if embs:
        emb_matrix = np.vstack(embs)
        
        np.save(os.path.join(OUT_DIR, f"{file_prefix}_embeddings.npy"), emb_matrix)
        np.save(os.path.join(OUT_DIR, f"{file_prefix}_labels.npy"), labels)
        
        with open(os.path.join(OUT_DIR, f"{file_prefix}_filepaths.txt"), "w") as f:
            f.write(str(filepaths))
            
        print(f"Success! Saved {emb_matrix.shape} matrix to {OUT_DIR}")


    return embs, labels, filepaths

In [None]:
kairos_clean_embeddings, kairos_clean_labels, kairos_clean_file_paths = generate_clean_embs(kairos_clean_data, 'kairos_clean', OUT_DIR)
test_clean_embeddings, test_clean_labels, test_clean_file_paths = generate_clean_embs(test_clean_data, 'test_clean', OUT_DIR)