In [None]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from PIL import Image
import timm
import torch
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import csv
import math
import os
import csv
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
import pickle

import hydra
import omegaconf
import pyrootutils

import matplotlib.pyplot as plt
from fgvc.data.plant_clef_data import PlantCLEFDataset, PlantSPECIESDataset

In [None]:
class_mapping = "/home/ubuntu/FGVC11/data/PlantClef/pretrained_models/class_mapping.txt"
species_mapping = "/home/ubuntu/FGVC11/data/PlantClef/pretrained_models/species_id_to_name.txt"
pretrained_path = "/home/ubuntu/FGVC11/data/PlantClef/pretrained_models/vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all/model_best.pth.tar"

In [None]:
def load_class_mapping(class_list_file):
    with open(class_list_file) as f:
        class_index_to_class_name = {i: line.strip() for i, line in enumerate(f)}
    return class_index_to_class_name


def load_species_mapping(species_map_file):
    df = pd.read_csv(species_map_file, sep=';', quoting=1, dtype={'species_id': str})
    df = df.set_index('species_id')
    return  df['species'].to_dict()

In [None]:
cid_to_spid = load_class_mapping(class_mapping)
spid_to_sp = load_species_mapping(species_mapping)
    
device = "cuda"

model = timm.create_model('vit_base_patch14_reg4_dinov2.lvd142m', pretrained=False, 
                          num_classes=len(cid_to_spid), checkpoint_path=pretrained_path)
model = model.to(device)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

### Feat Extraction on Train

In [None]:
df = pd.read_csv("/home/ubuntu/FGVC11/data/PlantClef/PlantCLEFTrainLQ.csv", delimiter=";", escapechar="/")

In [None]:
from torch.utils.data import Dataset, DataLoader
class PlantClefDataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row["path"]).convert("RGB")
        img = self.transforms(img)
        name = row["image_name"].replace(".jpg", "")
        return img, name

# Create the dataset
dataset = PlantClefDataset(df, transforms)

In [None]:
%%time
# for i, (crops, coords, plot_id) in tqdm(enumerate(dataset[10:]), total=len(dataset[10:])):
for i in range(10):
    crops, plot_id = dataset[i+10000]
    crops = crops.unsqueeze(0)
    attn_maps = []
    with torch.no_grad():
        out, attn = model.forward_features(crops.to(device), return_attn=True)
        class_attention = attn[:, :, 0, model.num_prefix_tokens:]
        class_attention = class_attention.mean(1)
        attention_map = class_attention.reshape((-1, 37, 37))
        attention_upsampled = torch.nn.functional.interpolate(
            attention_map.unsqueeze(1), 
            size=(518, 518), 
            mode='bilinear', 
            align_corners=False
        ).squeeze(1).cpu()

    plt.imshow(crops[0].permute(1, 2, 0).cpu().numpy(), alpha=0.9)
    plt.imshow(attention_upsampled.cpu().numpy()[0], cmap='hot', alpha=0.6)  # overlay attention
    plt.axis("off")
    plt.show()

In [None]:
model.global_pool

In [None]:
%%time
dirpath = "/home/ubuntu/FGVC11/data/PlantClef/lq_feats"
os.makedirs(dirpath, exist_ok=True)
for i, (data, name) in tqdm(enumerate(dataset), total=len(dataset)):
    save_path = f"{dirpath}/{name}.pt"
    if os.path.exists(save_path):
        continue
    with torch.no_grad():
        feats = model.forward_features(data.unsqueeze(0).to(device))[:, 0].detach().cpu()
    torch.save(feats.squeeze(0), save_path)
    # break

### Feat Extraction on test

In [None]:
submission_df = pd.DataFrame(columns=["path", "plot_id", "species_ids"])
submission_df["path"] = glob("/home/ubuntu/FGVC11/data/PlantClef/images/*.jpg")
submission_df["plot_id"] = submission_df["path"].apply(lambda x: x.split("/")[-1].split(".")[0])

In [None]:
# Image.open(submission_df["path"][0]).convert("RGB")

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def visualize_crops(crops, coords, crop_size, figsize=(10, 10)):
    """
    Visualize image crops at their respective coordinates.

    Args:
    - crops (list of PIL Images or Tensors): Image crops.
    - coords (list of tuples): Coordinates (x, y) for each crop.
    - figsize (tuple): Figure size for the plot.
    """
    # Create a large canvas
    max_x = max(x for x, y in coords) + crop_size
    max_y = max(y for x, y in coords) + crop_size
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot each crop in its correct location
    for crop, (x, y) in zip(crops, coords):
        if isinstance(crop, torch.Tensor):
            crop = crop.permute(1, 2, 0).numpy()  # Convert CHW tensor to HWC for visualization
        ax.imshow(crop, extent=(x, x + crop_size, y + crop_size, y), origin='upper')
        # Create a red rectangle around the crop
        rect = patches.Rectangle((x, y), crop_size, crop_size, linewidth=2, edgecolor='red', facecolor='none')
        ax.add_patch(rect)
    
    ax.set_xlim(0, max_x)
    ax.set_ylim(max_y, 0)
    ax.axis('off')  # Hide axes
    plt.show()

def visualize_attn(crops, attn_maps, coords, crop_size, figsize=(10, 10)):
    """
    Visualize image crops at their respective coordinates.

    Args:
    - crops (list of PIL Images or Tensors): Image crops.
    - coords (list of tuples): Coordinates (x, y) for each crop.
    - figsize (tuple): Figure size for the plot.
    """
    # Create a large canvas
    max_x = max(x for x, y in coords) + crop_size
    max_y = max(y for x, y in coords) + crop_size
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot each crop in its correct location
    for crop, attn, (x, y) in zip(crops, attn_maps, coords):
        if isinstance(crop, torch.Tensor):
            crop = crop.permute(1, 2, 0).numpy()  # Convert CHW tensor to HWC for visualization
        ax.imshow(crop, extent=(x, x + crop_size, y + crop_size, y), origin='upper',  alpha=1)
        ax.imshow(attn, extent=(x, x + crop_size, y + crop_size, y), cmap='hot', alpha=0.7)
    
    ax.set_xlim(0, max_x)
    ax.set_ylim(max_y, 0)
    ax.axis('off')  # Hide axes
    plt.show()

def visualize_global_attn(crops, attn_maps, coords, crop_size, figsize=(15, 15)):
    """
    Visualize a global attention map created by merging individual crop attention maps
    onto a composite image reconstructed from all crops.

    Args:
    - crops (list of PIL Images or Tensors): Image crops.
    - attn_maps (list of Tensors): Attention maps for each crop, matching the crop size.
    - coords (list of tuples): Coordinates (x, y) for each crop.
    - figsize (tuple): Figure size for the plot.
    """
    # Determine the dimensions of the full image
    max_x = max(x + crop_size for x, y in coords)
    max_y = max(y + crop_size for x, y in coords)

    # Create empty canvas for the full image and attention
    full_image = np.zeros((max_y, max_x, 3), dtype=np.float32)
    full_attention = np.zeros((max_y, max_x), dtype=np.float32)

    # Assemble the full image and the corresponding attention map
    for crop, attn, (x, y) in zip(crops, attn_maps, coords):
        if isinstance(crop, torch.Tensor):
            crop = crop.permute(1, 2, 0).numpy()  # Convert CHW tensor to HWC for visualization

        # Place crop in the full image
        full_image[y:y+crop_size, x:x+crop_size, :] = crop
        
        # Place attention map in the full attention map
        full_attention[y:y+crop_size, x:x+crop_size] = attn.numpy()

    # Visualize the results
    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(full_image, extent=(0, max_x, max_y, 0))
    attention_image = ax.imshow(full_attention, extent=(0, max_x, max_y, 0), cmap='hot', alpha=0.7)
    
    # Adding a color bar
    cbar = plt.colorbar(attention_image, ax=ax, orientation='vertical')
    cbar.set_label('Attention Intensity')
    
    ax.set_xlim(0, max_x)
    ax.set_ylim(max_y, 0)
    ax.axis('off')  # Hide axes
    plt.show()

In [None]:
# from PIL import Image
# import torch
# from torch.utils.data import Dataset
# import torchvision.transforms as trfs

# class PlantClefDataset(Dataset):
#     def __init__(self, df, transforms, crop_size=800):
#         self.df = df
#         self.transforms = transforms
#         self.crop_size = crop_size

#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):
#         row = self.df.iloc[idx]
#         img = Image.open(row["path"]).convert("RGB")

#         # Initialize list to hold the crops and their coordinates
#         crops = []
#         coords = []

#         # Determine the starting point for the last crop in each dimension
#         last_crop_start_x = max(0, img.width - self.crop_size)
#         last_crop_start_y = max(0, img.height - self.crop_size)

#         # Number of crops horizontally and vertically
#         num_crops_x = (img.width + self.crop_size - 1) // self.crop_size
#         num_crops_y = (img.height + self.crop_size - 1) // self.crop_size

#         for i in range(num_crops_y):
#             for j in range(num_crops_x):
#                 # Calculate the starting points for the crops
#                 start_x = j * self.crop_size if j != num_crops_x - 1 else last_crop_start_x
#                 start_y = i * self.crop_size if i != num_crops_y - 1 else last_crop_start_y

#                 # Crop the image
#                 crop = img.crop((start_x, start_y, start_x + self.crop_size, start_y + self.crop_size))
#                 crop = self.transforms(crop)
#                 crops.append(crop)
#                 coords.append((start_x, start_y))

#         # Stack all crops into a tensor
#         crops_tensor = torch.stack(crops)
#         return crops_tensor, coords, row["plot_id"]

# # Example usage
# # Assuming 'submission_df' and 'transforms' are defined elsewhere in your code
# dataset = PlantClefDataset(submission_df, transforms)


from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as trfs

class PlantClefDataset(Dataset):
    def __init__(self, df, transforms, min_crop_size=500, max_crop_size=600):
        self.df = df
        self.transforms = transforms
        self.min_crop_size = min_crop_size
        self.max_crop_size = max_crop_size

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row["path"]).convert("RGB")
        
        # Determine the best crop size within the range that covers the entire image
        best_crop_size = self.find_best_crop_size(img.width, img.height)

        # Initialize list to hold the crops and their coordinates
        crops = []
        coords = []

        # Number of crops horizontally and vertically
        num_crops_x = img.width // best_crop_size
        num_crops_y = img.height // best_crop_size

        for i in range(num_crops_y):
            for j in range(num_crops_x):
                start_x = j * best_crop_size
                start_y = i * best_crop_size
                crop = img.crop((start_x, start_y, start_x + best_crop_size, start_y + best_crop_size))
                crop = self.transforms(crop)
                crops.append(crop)
                coords.append((start_x, start_y))

        # Stack all crops into a tensor
        crops_tensor = torch.stack(crops)
        return crops_tensor, coords, row["plot_id"], best_crop_size

    def find_best_crop_size(self, width, height):
        # Evaluate which crop size in the range has the minimum leftover area
        best_crop_size = self.min_crop_size
        min_leftover = width % self.min_crop_size + height % self.min_crop_size

        for crop_size in range(self.min_crop_size + 1, self.max_crop_size + 1):
            leftover = width % crop_size + height % crop_size
            if leftover < min_leftover:
                min_leftover = leftover
                best_crop_size = crop_size

        return best_crop_size

# Example usage
# Assuming 'submission_df' and 'transforms' are defined elsewhere in your code
dataset = PlantClefDataset(submission_df, transforms)


In [None]:
# Example usage
crops, coords, plot_id, crop_size = dataset[110]
visualize_crops(crops, coords, 518)

In [None]:
crops.shape

In [None]:
%%time
# for i, (crops, coords, plot_id) in tqdm(enumerate(dataset[10:]), total=len(dataset[10:])):
crops, coords, plot_id, crop_size = dataset[100]
attn_maps = []
with torch.no_grad():
    out, attn = model.forward_features(crops.to(device), return_attn=True)
    class_attention = attn[:, :, 0, model.num_prefix_tokens:]
    class_attention = class_attention.mean(1)
    attention_map = class_attention.reshape((-1, 37, 37))
    attention_upsampled = torch.nn.functional.interpolate(
        attention_map.unsqueeze(1), 
        size=(518, 518), 
        mode='nearest-exact', 
        # align_corners=False
    ).squeeze(1).cpu()

visualize_global_attn(crops, attention_upsampled, coords, 518)
# 

In [None]:
%%time
test_labels = []
dirpath = "/home/ubuntu/FGVC11/data/PlantClef/lq_feats"
os.makedirs(dirpath, exist_ok=True)
for i, (crops, coords, name, cs) in tqdm(enumerate(dataset), total=len(dataset)):
    # save_path = f"{dirpath}/{name}.pt"
    # if os.path.exists(save_path):
    #     continue
    with torch.no_grad():
        out = model.forward(crops.to(device))
        max_prob = torch.argmax(out, dim=1).cpu().numpy()
        unique_idx = np.unique(max_prob)
        test_labels.append(str([int(cid_to_spid[j]) for j in unique_idx]))
        # feats = model.forward_features(data.unsqueeze(0).to(device))[:, 0].detach().cpu()
    # torch.save(feats.squeeze(0), save_path)
    # break

In [None]:
test_labels

In [None]:
submission_df["species_ids"] = test_labels
submission_df[["plot_id", "species_ids"]].to_csv("my_run_4.csv", sep=';', index=False, quoting=csv.QUOTE_NONE)

In [None]:
unique_idx = np.unique(max_prob)

In [None]:
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
import requests
from PIL import Image
from io import BytesIO
import re
from llava.utils import disable_torch_init