In [1]:
import os

In [2]:

%pip install pycocotools






In [3]:
%pip install pycocoevalcap

Note: you may need to restart the kernel to use updated packages.




In [4]:
pip install fvcore






In [5]:

import sys
import pprint
import random
import time
import tqdm
import logging
import argparse
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist

import losses
import models
import datasets
import lib.utils as utils
from lib.utils import AverageMeter
from optimizer.optimizer import Optimizer
from evaluation.evaler import Evaler
from scorer.scorer import Scorer
from lib.config import cfg, cfg_from_file


from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
# from timm.data.transforms import _pil_interp
import cv2
from PIL import Image
from fvcore.nn import FlopCountAnalysis

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [43]:
class Tester(object):
    def __init__(self, args):
        super(Tester, self).__init__()
        self.args = args
        self.device = torch.device("cuda")
        self.vocab = utils.load_vocab(args.vocab)
        
        self.transform = transforms.Compose([
            transforms.Resize((384, 384), interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]
        )

        self.setup_network()

    def setup_network(self):
        model = models.create(cfg.MODEL.TYPE)
        print(model)
        self.model = torch.nn.DataParallel(model).cuda()
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Total Parameters: {total_params / 1e6:.2f}M")   
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Trainable Parameters: {trainable_params / 1e6:.2f}M")
        # dummy_input = torch.randn(1, 3, 384, 384).cuda()
        # flops = FlopCountAnalysis(model, dummy_input)
        # print(f"GFLOPs: {flops.total() / 1e9:.2f}")  
        if self.args.resume > 0:
            self.model.load_state_dict(
                torch.load(self.snapshot_path("caption_model", self.args.resume),
                    map_location=lambda storage, loc: storage)
            )
            
    def make_kwargs(self, indices, ids, gv_feat, att_feats, att_mask):
        kwargs = {}
        kwargs[cfg.PARAM.INDICES] = indices
        kwargs[cfg.PARAM.GLOBAL_FEAT] = gv_feat
        kwargs[cfg.PARAM.ATT_FEATS] = att_feats
        kwargs[cfg.PARAM.ATT_FEATS_MASK] = att_mask
        kwargs['BEAM_SIZE'] = 5
        kwargs['GREEDY_DECODE'] = True
        return kwargs
    
    def read_img(self, image):
        img = cv2.imread(image)
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        att_feats = self.transform(img)[None].cuda()
        return att_feats
    
    def inference_img(self, image):
        with torch.no_grad():
            indices = 0
            ids = image
            gv_feat = None
            att_feats = self.read_img(image)
            att_mask = torch.ones(1, 12*12).cuda()
            
            kwargs = self.make_kwargs(indices, ids, gv_feat, att_feats, att_mask)
            if kwargs['BEAM_SIZE'] > 1:
                seq, _ = self.model.module.decode_beam(**kwargs)
            else:
                seq, _ = self.model.module.decode(**kwargs)
                
            sents = utils.decode_sequence(self.vocab, seq.data)
            # print(ids, ''.join(sents[0].split(' ')))
            return ' '.join(sents[0].split(' '))
        
    def eval(self, epoch, images):
        self.model.eval()
        
        caps = []
        for image in images:
            cap = self.inference_img(image)
            caps.append(cap)
        return caps
            

    def snapshot_path(self, name, epoch):
        snapshot_folder = os.path.join(cfg.ROOT_DIR, 'snapshot')
        return os.path.join(snapshot_folder, name + "_" + str(epoch) + ".pth")

def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description='Image Captioning')
    parser.add_argument('--folder', dest='folder', default=None, type=str)
    parser.add_argument("--resume", type=int, default=-1)
    parser.add_argument("--vocab", type=str, 
                        default=r'C:\Users\dhair\Documents\VS-Code-Practice-Files-main\CWNU Intern Work\PureT\coco_vocabulary.txt')
    # parser.add_argument("--images", type=str, nargs='+', default='')

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args(args=['--folder', './experiments_PureT/PureT_SCST/', '--resume', '27'])
    return args

In [44]:
args = parse_args()
    
print('Called with args:')
print(args)

if args.folder is not None:
    cfg_from_file(os.path.join(args.folder, 'config.yml'))
cfg.ROOT_DIR = args.folder

tester = Tester(args)

Called with args:
Namespace(folder='./experiments_PureT/PureT_SCST/', resume=27, vocab='C:\\Users\\dhair\\Documents\\VS-Code-Practice-Files-main\\CWNU Intern Work\\PureT\\coco_vocabulary.txt')
load pretrained weights!
PureT(
  (backbone): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0): BasicLayer(
        dim=192, input_resolution=(96, 96), depth=2
        (blocks): ModuleList(
          (0): SwinTransformerBlock(
            dim=192, input_resolution=(96, 96), num_heads=6, window_size=12, shift_size=0, mlp_ratio=4.0
            (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              dim=192, window_size=(12, 12), num_heads=6
              (qkv): Linear(in_features=192, out_features=576, bias=True)
             

  torch.load(self.snapshot_path("caption_model", self.args.resume),


In [11]:
%matplotlib inline
import matplotlib.pyplot as plt
import os

def vis_img_cap(img_files, caps):
    assert len(img_files) == len(caps), 'error'
    for i in range(len(img_files)):
        img_file = img_files[i]
        cap = caps[i]
        img = cv2.imread(img_file)
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        plt.imshow(img)
        plt.show()
        print(img_file, cap)

# Inference Code

In [30]:
import os
import json
import time
from tqdm import tqdm
import torch
from PIL import Image
import cv2
from torchvision import transforms

# Paths
# Paths (assuming RSICD dataset is in the current directory)
ann_dir = r'C:\Users\dhair\Documents\VS-Code-Practice-Files-main\CWNU Intern Work\RSICD\annotations'
img_dir = r'C:\Users\dhair\Documents\VS-Code-Practice-Files-main\CWNU Intern Work\RSICD\images'
json_files = ['RSICD_test.json']

# Check if directories exist
if not os.path.exists(ann_dir):
    raise FileNotFoundError(f"Annotations directory not found at {ann_dir}")
if not os.path.exists(img_dir):
    raise FileNotFoundError(f"Images directory not found at {img_dir}")

# Step 1: Build image-to-captions mapping
gt_dict = {}
for file in json_files:
    with open(os.path.join(ann_dir, file)) as f:
        data = json.load(f)
        for entry in data:
            fname = entry['image']
            caption = entry['caption']
            gt_dict.setdefault(fname, []).append(caption)   

# Step 2: Build image path list
# img_files = [os.path.join(img_dir, fname) for fname in gt_dict]
img_files = [os.path.join(img_dir, fname) for fname in list(gt_dict)[:]] 

In [31]:
len(img_files)

1094

In [None]:
# Step 3: Inference

caps = tester.eval(args.resume, img_files)  # preds
refs = [gt_dict[os.path.basename(p)] for p in img_files]

In [None]:
from PIL import Image
import os

def load_image_and_captions(index, img_files, caps, gt_dict):
    """
    Loads image, 5 reference captions, and predicted caption for a given index.
    
    Args:
        index (int): Index of the image in img_files.
        img_files (List[str]): List of image file paths.
        caps (List[str]): List of predicted captions (aligned with img_files).
        gt_dict (Dict[str, List[str]]): Mapping from image filename to reference captions.

    Returns:
        image (PIL.Image): The loaded image.
        references (List[str]): List of 5 ground truth captions.
        prediction (str): Predicted caption.
    """
    img_path = img_files[index]
    image = Image.open(img_path).convert("RGB")

    fname = os.path.basename(img_path)
    references = gt_dict.get(fname, [])
    prediction = caps[index]

    return image, references, prediction


# Example: View image and captions at index 7
index = 7
image, references, prediction = load_image_and_captions(index, img_files, caps, gt_dict)

print(f"Prediction: {prediction}")
print("References:")
for i, ref in enumerate(references):
    print(f"  {i+1}. {ref}")

# If using Jupyter or IPython
image.show()


Prediction: a view of a city from an airport
References:
  1. ['there are a lot of buildings at the airport .', 'there are several square at the airport .', 'a plane is near a building and a runway in an airport .', 'a plane is near a building and a runway in an airport .', 'there are a lot of buildings at the airport .']


In [42]:
references_dict = {str(i): refs[i] for i in range(len(refs))}
predictions_dict = {str(i): [caps[i]] for i in range(len(caps))} 

In [43]:
import json

with open("preds_dict.json", "w") as f:
    json.dump(predictions_dict, f, indent=2)

with open("refs_dict.json", "w") as f:
    json.dump(references_dict, f, indent=2)

print("Saved refs_dict.json and preds_dict.json")

Saved refs_dict.json and preds_dict.json


In [44]:
references_dict

{'0': [['the tarmac and airport runways divide the field into several orderly arranged rounded rectangles next to which is buildings and a road.',
   'the tarmac and airport runways divide the field into several orderly arranged rounded rectangles next to which is buildings and a road.',
   'a brown ground divided by the grey runway .',
   'we can see a simple terminal building and an apron connected with runways',
   'some building with a parking lot are near an airport with several runways .']],
 '1': [['many white planes are parked at the airport .',
   'a highway is built next to the airport .',
   'a highway is built next to the airport .',
   'many white planes are parked at the airport .',
   'many white planes are parked at the airport .']],
 '2': [['a parking apron with a plane parked on and connected to a runway is lying on the bare land near which there are some square buildings.',
   'a parking apron with a plane parked on and connected to a runway is lying on the bare land

# Inference Time

In [26]:
import time

num_runs = 100
total_time = 0

print("Measuring average inference time...")
for i in range(num_runs):
    img_path = random.choice(img_files)
    start = time.time()
    _ = tester.inference_img(img_path)
    total_time += (time.time() - start)

avg_time = total_time / num_runs
print(f"Avg inference time: {avg_time * 1000:.2f} ms")

Measuring average inference time...
Avg inference time: 416.52 ms


# GFlops

In [58]:
from fvcore.nn import FlopCountAnalysis

puret_model = tester.model.module

# Calculate backbone FLOPs
dummy = torch.randn(1, 3, 384, 384).cuda()
backbone_flops = FlopCountAnalysis(puret_model.backbone, dummy)
print(f"Backbone FLOPs: {backbone_flops.total() / 1e9:.2f} GFLOPs")

# Get features from backbone
att_feats = puret_model.backbone(dummy)
# print(f"att_feats shape: {att_feats.shape}")  # Debug shape

# Process features through att_embed
att_feats = puret_model.att_embed(att_feats)
# print(f"embedded att_feats shape: {att_feats.shape}")  # Debug shape

# Calculate encoder FLOPs
att_mask = torch.ones(1, att_feats.shape[1]).cuda()
encoder_flops = FlopCountAnalysis(puret_model.encoder, (att_feats, att_mask))
print(f"Encoder FLOPs:  {encoder_flops.total() / 1e9:.2f} GFLOPs")

total = backbone_flops.total() + encoder_flops.total()
print(f"Total (backbone + encoder): {total / 1e9:.2f} GFLOPs")

Unsupported operator aten::mul encountered 73 time(s)
Unsupported operator aten::add encountered 83 time(s)
Unsupported operator aten::softmax encountered 24 time(s)
Unsupported operator aten::gelu encountered 24 time(s)
Unsupported operator aten::bernoulli_ encountered 46 time(s)
Unsupported operator aten::div_ encountered 46 time(s)


Backbone FLOPs: 104.08 GFLOPs


Unsupported operator aten::mul encountered 4 time(s)
Unsupported operator aten::sum encountered 5 time(s)
Unsupported operator aten::div encountered 1 time(s)
Unsupported operator aten::repeat encountered 9 time(s)
Unsupported operator aten::add encountered 10 time(s)
Unsupported operator aten::mean encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::pad encountered 1 time(s)


Encoder FLOPs:  1.40 GFLOPs
Total (backbone + encoder): 105.48 GFLOPs


# Parameters

In [53]:
total_params = sum(p.numel() for p in puret_model.parameters())
print(f"Total Parameters: {total_params / 1e6:.2f}M")   
trainable_params = sum(p.numel() for p in puret_model.parameters() if p.requires_grad)
print(f"Trainable Parameters: {trainable_params / 1e6:.2f}M")

Total Parameters: 229.41M
Trainable Parameters: 34.16M
