In [1]:
import os
import sys
import json

import numpy as np

from utils import *
from networks import *
from train_probVLM import *
from probvlm_data_loader import get_data as probvlm_get_data

import matplotlib.pyplot as plt

import itertools
sys.path.append('/vol/tensusers4/nhollain/thesis2023-2024/s_clip_scripts')
from data_loader import get_data
from params import parse_args
from open_clip import create_model_and_transforms, get_tokenizer, create_loss

import gc

[nltk_data] Downloading package punkt to /home/nhollain/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/nhollain/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [2]:
def prep_str_args(str_args): # Code to parse the string style arguments, as shown below
    str_args = str_args.split('\n') # Split on newline
    str_args = [s.strip() for s in str_args] # Remove any whitespaces from the start and end of the strings
    # Split on the space between the parameter name and the value, e.g. '--name x' becomes ['--name', 'x']
    str_args = [s.split(' ') for s in str_args] 
    str_args = list(itertools.chain(*str_args)) # Flatten the resulting list of lists
    str_args = [s for s in str_args if len(s) > 0] # Remove arguments that are empty
    return str_args

str_args = ''' 
    --train-data RS-ALL
    --val-data RS-ALL
    --imagenet-val RSICD-CLS 
    --label-ratio 1.0
    '''

args = parse_args(prep_str_args(str_args))

with torch.no_grad():
    model, preprocess_train, preprocess_val = create_model_and_transforms(args.model, args.pretrained, precision=args.precision, 
                                                                          device=args.device, output_dict=True, aug_cfg = args.aug_cfg, )
    del model
    gc.collect()

In [3]:
# Getting the probvlm data (COCO dataset)
probvlm_data = probvlm_get_data((preprocess_train, preprocess_val), tokenizer=get_tokenizer(args.model), batch_size = 128)
train_loader, valid_loader = probvlm_data['train'].dataloader, probvlm_data['val'].dataloader


Coco data (split: val)	Coco data (split: train)	

In [4]:
CLIP_Net = load_model(device='cuda', model_path=None)
ProbVLM_Net = BayesCap_for_CLIP(inp_dim=512, out_dim=512, hid_dim=256, num_layers=3, p_drop=0.05,)

In [5]:
# train_ProbVLM(CLIP_Net, ProbVLM_Net, train_loader, valid_loader, Cri = TempCombLoss(), device='cuda', dtype=torch.float, init_lr=8e-5,
     # num_epochs=100, eval_every=5, ckpt_path='../ckpt/ProbVLM_Net', T1=1e0, T2=1e-4, resume_path = '../ckpt/ProbVLM_Net_last.pth') 
#'../ckpt/ProbVLM_Net_last.pth'

In [6]:
import torch
optimizer = torch.optim.Adam(list(ProbVLM_Net.img_BayesCap.parameters())+list(ProbVLM_Net.txt_BayesCap.parameters()), lr=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 0)
ProbVLM_Net, _, _, epoch = load_checkpoint(ProbVLM_Net, optimizer, scheduler, resume_path = '../ckpt/ProbVLM_Net_last.pth')
ProbVLM_Net.to('cuda')
CLIP_Net.to('cuda')
print('Epochs trained for:', epoch)

=> loading checkpoint '../ckpt/ProbVLM_Net_last.pth'
=> loaded checkpoint '../ckpt/ProbVLM_Net_last.pth' (epoch 99)
Epochs trained for: 100


In [7]:
rs_data = get_data(args, (preprocess_train, preprocess_val), iter=0, tokenizer=get_tokenizer(args.model), model = ProbVLM_Net)
rs_train_loader, rs_valid_loader = rs_data['train'].dataloader, rs_data['val'].dataloader

In [8]:
with torch.no_grad():
    for image, caption in rs_train_loader:
        image, caption = image.to('cuda'), caption.to('cuda')
        z_I, z_T = CLIP_Net(image, caption)
        ProbVLM_z_I, ProbVLM_z_T = ProbVLM_Net(z_I, z_T)
        # print(result)
        break

In [9]:
(img_mu, img_1alpha, img_beta) = ProbVLM_z_I
(txt_mu, txt_1alpha, txt_beta) = ProbVLM_z_T

In [10]:
img_mu, txt_mu

(tensor([[-5.3302e+00,  3.0907e+00,  1.4049e-03,  ...,  7.5997e+00,
          -2.3168e+00, -3.1972e+00],
         [ 1.4710e-01,  3.4592e+00,  2.7596e+00,  ...,  3.7174e+00,
           2.9546e-01, -4.4244e+00],
         [-3.6406e+00,  6.4776e+00, -1.4187e+00,  ...,  8.8580e+00,
          -1.8328e+00, -4.2367e+00],
         ...,
         [-2.0518e+00,  3.5718e+00,  1.6814e+00,  ...,  3.5024e+00,
           1.4921e+00, -3.1210e+00],
         [-1.6561e+00,  3.5915e+00,  3.5659e+00,  ...,  4.4803e+00,
          -3.9206e+00, -3.0230e+00],
         [-2.7899e+00,  1.3528e+00,  4.0665e+00,  ...,  6.1188e+00,
           7.7097e-01, -2.2913e+00]], device='cuda:0'),
 tensor([[-0.0465,  0.0429,  0.0132,  ...,  0.0040, -0.0223,  0.0100],
         [-0.0109,  0.0005,  0.0172,  ...,  0.0257, -0.0363, -0.0195],
         [ 0.0208,  0.0379,  0.0225,  ...,  0.0860, -0.0275, -0.0260],
         ...,
         [-0.0163, -0.0033, -0.0091,  ..., -0.0106,  0.0291,  0.0177],
         [ 0.0086,  0.0129,  0.0282,  .

In [11]:
img_mu.shape

torch.Size([64, 512])

In [12]:
r_dict = get_features_uncer_ProbVLM(CLIP_Net, ProbVLM_Net, rs_train_loader)
sorted_uncertainty = sort_wrt_uncer(r_dict)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [08:05<00:00,  2.85s/it]


In [13]:
len(sorted_uncertainty)

2

In [14]:
sorted_uncertainty[0] # images?

[(2038, 14.934910774230957),
 (2216, 14.728181838989258),
 (6194, 13.362388610839844),
 (6905, 12.610465049743652),
 (9980, 12.368467330932617),
 (2563, 12.17802906036377),
 (9288, 11.760489463806152),
 (2065, 11.241554260253906),
 (3118, 10.805845260620117),
 (252, 10.65372371673584),
 (6535, 10.625394821166992),
 (2767, 10.402542114257812),
 (741, 10.350926399230957),
 (3396, 10.128325462341309),
 (4322, 10.066082000732422),
 (2783, 9.835289001464844),
 (3898, 9.78312873840332),
 (8031, 9.19912338256836),
 (5083, 8.951760292053223),
 (10207, 8.8982515335083),
 (10322, 8.800383567810059),
 (7719, 8.714487075805664),
 (1246, 8.666031837463379),
 (5155, 8.473018646240234),
 (1112, 8.273271560668945),
 (2636, 8.26350212097168),
 (7755, 8.218881607055664),
 (911, 8.21515941619873),
 (5930, 8.205717086791992),
 (4021, 8.124249458312988),
 (1212, 8.025969505310059),
 (71, 7.907519340515137),
 (2251, 7.9062910079956055),
 (7775, 7.82110595703125),
 (6275, 7.81758975982666),
 (3819, 7.7789597

In [15]:
sorted_uncertainty[1] # text?

[(3517, 0.004140013518720902),
 (6927, 0.004139933491705667),
 (9805, 0.004139628057552048),
 (10825, 0.004139575761702532),
 (2310, 0.0041394507799747),
 (2159, 0.0041394011031412985),
 (7388, 0.004139396396976806),
 (4519, 0.0041393665915168505),
 (1958, 0.004139323191351483),
 (7230, 0.004139218616378259),
 (5749, 0.004139209204889845),
 (10325, 0.004139204499161687),
 (1136, 0.0041391514297459955),
 (1597, 0.004139148815524234),
 (3602, 0.004139141757141971),
 (6641, 0.004139110648262563),
 (4396, 0.004139078494197989),
 (2325, 0.004139074311592161),
 (5286, 0.004139060195359903),
 (1987, 0.004139059933949102),
 (1677, 0.0041390565356117),
 (3837, 0.0041390426809092785),
 (2216, 0.004139040851049855),
 (7884, 0.00413895667926534),
 (6232, 0.004138954588084464),
 (6837, 0.004138949882935219),
 (7047, 0.004138900479514002),
 (644, 0.004138885318912241),
 (9891, 0.004138814483779265),
 (10072, 0.004138802983100856),
 (5397, 0.004138799323807493),
 (354, 0.004138797232785619),
 (911, 0