In [5]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger # Comment out if not using wandb
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks import StochasticWeightAveraging
import torch
# torch.multiprocessing.set_start_method('spawn', force=True)
from q2l_labeller.data.coco_data_module import COCODataModule
from q2l_labeller.pl_modules.query2label_train_module import Query2LabelTrainModule
from q2l_labeller.data.dataset import SeaThruAugmentation
pl.seed_everything(42)

param_dict = {
    "backbone_desc":"resnest101e",
    "conv_out_dim":2048,
    "hidden_dim":256,
    "num_encoders":1,
    "num_decoders":2,
    "num_heads":4,
    "batch_size":64,
    "image_dim":384,
    "learning_rate":1e-4,
    "momentum":0.9,
    "weight_decay":0.01,
    "n_classes":290,
    "thresh":0.5,
    "use_cutmix":True,
    "use_pos_encoding":True,
    "loss":"mll",
    "use_seathru": True  # Add this line to enable or disable SeaThru augmentation
}

train_classes = [160, 51, 119, 37, 52, 10, 88, 146, 125, 1, 260, 133, 9, 214, 70, 120, 111, 142, 274, 105, 69, 174, 203, 103, 228, 259, 205, 104, 116, 242, 16, 219, 81, 61, 100, 11, 224, 202, 82, 108, 255, 3, 54, 162, 85, 256, 8, 67, 71, 75, 173, 201, 93, 243, 218, 131, 99, 43, 36, 283]

image_folder = '/home/mundus/mrahman528/thesis/query2label/train/'
depth_image_folder = '/home/mundus/mrahman528/thesis/Depth-Anything-V2/depth_vis_train'
depth_npy_folder = '/home/mundus/mrahman528/thesis/Depth-Anything-V2/depth_vis_train'
seathru_parameters_path = '/home/mundus/mrahman528/thesis/sucre/src/output/parameters_train.json'
seathru_transform = SeaThruAugmentation(image_folder, depth_image_folder, depth_npy_folder, seathru_parameters_path)

coco = COCODataModule(
    data_dir="/home/mundus/mrahman528/thesis/query2label/",
    img_size=384,
    batch_size=128,
    num_workers=8,
    use_cutmix=True,
    cutmix_alpha=1.0,
    train_classes=None,
    use_seathru=True,
    seathru_transform=seathru_transform
)

param_dict["data"] = coco

pl_model = Query2LabelTrainModule(**param_dict)

[rank: 0] Seed set to 42


In [6]:
# Comment out if not using wandb
# wandb_logger = WandbLogger(
#     project="fathomnet_osd", 
#     save_dir="training/logs/fathomnet_with_all_seathru",
#     log_model=True)
# wandb_logger.watch(pl_model, log="all")

trainer = pl.Trainer(
    max_epochs=24,
    precision=16,
    accelerator='gpu', 
    devices=1,  # Use all available GPUs
    # strategy='ddp',  # Use Distributed Data Parallel strategy
    # logger=wandb_logger, # Comment out if not using wandb
    default_root_dir="training/checkpoints/depth_jitter",
    callbacks=[TQDMProgressBar(refresh_rate=10)]
)
trainer.fit(pl_model, param_dict["data"])

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


loading annotations into memory...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type        | Params
-----------------------------------------------
0 | model          | Query2Label | 72.8 M
1 | base_criterion | TwoWayLoss  | 0     
-----------------------------------------------
72.8 M    Trainable params
0         Non-trainable params
72.8 M    Total params
291.313   Total estimated model params size (MB)


Done (t=0.41s)
creating index...
index created!
Seathru Transform Initialized: True
loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
Seathru Transform Initialized: True


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

In [2]:
import os
import numpy as np
import scipy.ndimage
import matplotlib.pyplot as plt
import torch

def visualize_attention_maps(attn_maps, image_tensor, output_dir, image_id, threshold=0.3, sigma=1):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if attn_maps is None or len(attn_maps) == 0:
        print(f"No attention maps to visualize for {image_id}")
        return

    # Ensure attn_maps is a tensor
    if not torch.is_tensor(attn_maps):
        raise ValueError("Attention maps should be a torch tensor")

    # Combine attention maps from multiple heads
    attn_maps_tensor = attn_maps[0]  # Extract the first batch element's attention maps
    combined_attn_map = torch.mean(attn_maps_tensor, dim=0)  # Average over the attention heads

    # Move to CPU and convert to numpy array
    combined_attn_map_np = combined_attn_map.cpu().numpy()

    img_size = image_tensor.size()[-2:]  # Use size() to get the dimensions

    # Ensure attention map is 2D
    if len(combined_attn_map_np.shape) != 2:
        num_tokens = combined_attn_map_np.shape[0]
        h = w = int(np.sqrt(num_tokens))
        if h * w != num_tokens:
            raise ValueError(f"Cannot reshape attention map of shape {combined_attn_map_np.shape} to 2D grid")
        combined_attn_map_np = combined_attn_map_np.reshape(h, w)

    # Apply Gaussian blur for smoothing
    smoothed_attn_map = scipy.ndimage.gaussian_filter(combined_attn_map_np, sigma=sigma)

    # Resize attention map to match image size using bicubic interpolation
    zoom_factors = (img_size[0] / smoothed_attn_map.shape[0], img_size[1] / smoothed_attn_map.shape[1])
    resized_attn_map = scipy.ndimage.zoom(smoothed_attn_map, zoom_factors, order=3)  # order=3 for bicubic

    # Normalize the attention map
    resized_attn_map = np.clip(resized_attn_map / np.max(resized_attn_map), 0, 1)

    # Apply a soft mask based on the attention values
    alpha_mask = np.clip(resized_attn_map, threshold, 1)  # Apply threshold

    # Convert the image tensor to a format suitable for plotting
    image_np = image_tensor.squeeze().cpu().permute(1, 2, 0).numpy()
    image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
    image_np = np.clip(image_np, 0, 1)  # Clip to valid range

    # Create a black background image
    black_background = np.zeros_like(image_np)

    # Highlight the attention areas on the black background
    highlighted_image = black_background * (1 - alpha_mask[:, :, np.newaxis]) + image_np * alpha_mask[:, :, np.newaxis]

    # Plot the raw image, attention map, and highlighted image in the same plot
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Plot raw image
    axes[0].imshow(image_np)
    axes[0].set_title('Raw Image')
    axes[0].axis('off')

    # Plot attention map
    axes[1].imshow(resized_attn_map, cmap='viridis')
    axes[1].set_title('Attention Map')
    axes[1].axis('off')

    # Plot highlighted image
    axes[2].imshow(highlighted_image)
    axes[2].set_title('Highlighted Image')
    axes[2].axis('off')

    # Save the combined visualization
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'attention_map_{image_id}.png'))
    # plt.show()
    # plt.close()



In [None]:
import os
import torch
from torchvision import transforms
from PIL import Image
import pandas as pd

# Define a function to preprocess the image
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((640, 640)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)  # Add batch dimension

# Load the trained model
model = Query2LabelTrainModule.load_from_checkpoint(**param_dict,checkpoint_path='/home/mundus/mrahman528/thesis/query2label/training/checkpoints/depth_jitter/lightning_logs/version_54324/checkpoints/epoch=23-step=912.ckpt')
model.eval().to('cuda')

# Define the folder containing the images and the output directory for attention maps
folder_path = '/home/mundus/mrahman528/thesis/fgvc-comp-2023/eval_images'
output_dir = './attention_maps'
results = []

# Iterate over each image in the folder
for filename in os.listdir(folder_path):
    if filename.endswith(('.jpg', '.jpeg', '.png')):
        image_id = os.path.splitext(filename)[0]  # Get image id without extension
        image_path = os.path.join(folder_path, filename)
        input_tensor = preprocess_image(image_path).to('cuda')
        with torch.no_grad():
            # output,attn_maps = model(input_tensor,return_attention=True)
            output = model(input_tensor)
        probabilities = torch.sigmoid(output)
        predicted_classes = (probabilities >= 0.5).nonzero().cpu().numpy()[:, 1]  # Get predicted classes
        
        # Combine multiple predicted classes into a list
        predicted_classes_list = list(predicted_classes) if len(predicted_classes) > 0 else []
        predicted_classes_list = [cls-1 for cls in predicted_classes_list]
        # Calculate osd
        osd = 1.0 if len(predicted_classes_list) == 0 else 1 - (torch.max(probabilities).item())
        # Calculate confidence scores
        print(f'{image_id}, categories: {predicted_classes_list}, osd: {osd}') 
        # Add results to the list
        results.append({'id': image_id, 'categories': predicted_classes_list, 'osd': osd})

        # Visualize attention maps
        # visualize_attention_maps(attn_maps[-1], input_tensor, output_dir, image_id)


# Create a DataFrame from the results list
df = pd.DataFrame(results)

# Save the DataFrame to a CSV file
df.to_csv('predictions.csv', index=False)


00b4467c-3a2d-4b9c-9b75-6a44d7361c3c, categories: [160], osd: 0.01696789264678955
00068d34-25b1-4eb9-9898-3e984cfb1e50, categories: [160], osd: 0.38564181327819824
00ba9786-8b7a-424b-9faa-18cc7d05223d, categories: [160], osd: 0.4023429751396179
000caf3b-c3db-4ccb-af06-fd1f13b56836, categories: [51, 160], osd: 0.47071200609207153
00bbce9f-0205-436a-adfb-80b686b269a2, categories: [], osd: 1.0
0011aadc-7852-4914-a288-471728e55cb5, categories: [160], osd: 0.44535964727401733
00c09207-0c99-4a11-aa3c-b7294050c19e, categories: [160], osd: 0.40240931510925293
002ee02b-d561-44a4-8a2d-fc093cdad3ea, categories: [51, 160], osd: 0.39995884895324707
00c4da4c-5cba-4f4f-9c7a-46b73ae203ce, categories: [119, 160], osd: 0.09793853759765625
00388f81-cf38-4ddf-bf7c-3892455fcd97, categories: [160], osd: 0.007276356220245361
00cacc9c-f537-4f44-84bc-d2053f48c120, categories: [160], osd: 0.036675333976745605
003ad516-aacb-4b9b-bc86-0e17b2d83079, categories: [160], osd: 0.47597140073776245
00da28ca-085e-4a69-a2