Computing Projection - sanity check

In [157]:
import numpy as np
from scipy.linalg import lstsq

images = np.array([[1,1,1], [0,0,0]])

text = np.array([[2,2,2], [0,0,0]])

proj, _, _, _ = lstsq(text, images)

proj
# np.dot(text,proj)

array([[0.16666667, 0.16666667, 0.16666667],
       [0.16666667, 0.16666667, 0.16666667],
       [0.16666667, 0.16666667, 0.16666667]])

In [158]:
from data_loader import dataloaders as dataloader

In [165]:
# Load Image Data

d = dataloader.load_data(
    data_root="./datasets/ImageNet/",
    dataset="ImageNet_LT",
    phase="train",
    batch_size=128,
#     batch_size=1,
    sampler_dic=None,
    num_workers=12,
    type="random_prompts",
    prompt_set="ImageNet",
)
data = d[0]

Loading data from /nethome/bdevnani3/flash1/long_tail_lang/data/ImageNet_LT/ImageNet_LT_train.txt
Use data transformation: Compose(
    RandomResizedCrop(size=(224, 224), scale=(0.5, 1), ratio=(0.75, 1.3333), interpolation=bicubic)
    RandomHorizontalFlip(p=0.5)
    ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.6, 1.4], hue=None)
    ToTensor()
    Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
)
***********************DATASET: train random_prompts
************************* dict_keys(['default', 'ImageNet', 'bestImageNet'])
train 115846
No sampler.
Shuffle is True.


In [166]:
# import json
# per_class_frequencies = {}

# for im, label, _, path in tqdm(data):
#     if label.item() not in per_class_frequencies:
#         per_class_frequencies[label.item()] = 0
#     per_class_frequencies[label.item()] +=1
    
# json.dump("per_class_frequencies.json", per_class_frequencies)

In [167]:
import torch.nn as nn
from clip import clip
import os
import torch


# Initialize CLIP models 
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype
        self.token_embedding = clip_model.token_embedding

    def forward(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

def load_clip_to_cpu(visual_backbone):
    backbone_name = visual_backbone
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url, os.path.expanduser("~/.cache/clip"))

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    model = clip.build_model(state_dict or model.state_dict())

    return model

clip_model = load_clip_to_cpu("RN50")

visual_model = torch.nn.DataParallel(clip_model.visual).cuda()

text_model = TextEncoder(clip_model)
text_model = torch.nn.DataParallel(text_model).cuda()

In [168]:
from classes import CLASSES, CUSTOM_TEMPLATES, GENERIC_PROMPT_COLLECTIONS
from tqdm.notebook import trange, tqdm

final_images, final_texts = [], []
final_labels = []

# count = {}

with torch.no_grad():
    for im, label, _, path in tqdm(data):
        x = visual_model(im.half()).float()
        x = x / x.norm(dim=-1, keepdim=True)
        final_images.append(x)

        templates = np.array(GENERIC_PROMPT_COLLECTIONS["ImageNet"])[path.cpu()]
        classnames_for_labels = np.array(CLASSES)[label.cpu()]

        texts = clip.tokenize(t.format(c) for t,c in zip(templates, classnames_for_labels))
        texts = texts.cuda()
        zeroshot_weights = text_model(texts).float()
        zeroshot_weights = zeroshot_weights / zeroshot_weights.norm(
            dim=-1, keepdim=True
        )
        final_texts.append(zeroshot_weights)
        final_labels.append(label)
#         count[label.item] +=1

  0%|          | 0/906 [00:00<?, ?it/s]

In [169]:
final_images = torch.cat(final_images, dim=0)
final_texts = torch.cat(final_texts, dim=0)
final_labels = torch.cat(final_labels, dim=0)

In [170]:
final_images.shape

torch.Size([115846, 1024])

In [171]:
proj, _, _, _ = lstsq(final_texts.cpu(), final_images.cpu())

In [172]:
proj.shape

(1024, 1024)

In [174]:
np.save("imagenet_text2img_proj.npy", proj)

Balanced projection matrix - 1 instance per class

In [175]:
balanced_images = np.zeros((1000, 1024))
balanced_texts = np.zeros((1000, 1024))

for i, label in tqdm(enumerate(final_labels)):
    balanced_images[label,:] = final_images[i,:].cpu()
    balanced_texts[label,:] = final_texts[i,:].cpu()

0it [00:00, ?it/s]

In [176]:
balanced_proj, _, _, _ = lstsq(balanced_texts, balanced_images)
balanced_proj.shape

(1024, 1024)

In [177]:
np.save("imagenet_text2img_balanced_proj.npy", balanced_proj.astype(float))

Balanced projection matrix - 100 instances per class

In [178]:
label_frequencies = {}
for label in final_labels:
    if label.item() not in label_frequencies:
        label_frequencies[label.item()] = 0
    label_frequencies[label.item()] += 1

indices_by_label = {}

for i, label in enumerate(final_labels):
    if label.item() not in indices_by_label:
        indices_by_label[label.item()] = []
    indices_by_label[label.item()].append(i)

In [179]:
import random 

upsampled_images = np.zeros((100000, 1024))
upsampled_texts = np.zeros((100000, 1024))

for i in tqdm(range(100000)):
    label = random.choice(range(1000))
    
    idx = random.choice(indices_by_label[label])
    upsampled_images[i] = final_images[idx].cpu()
    upsampled_texts[i] = final_texts[idx].cpu()


  0%|          | 0/100000 [00:00<?, ?it/s]

In [180]:
upsampled_balanced_proj, _, _, _ = lstsq(upsampled_texts, upsampled_images)
upsampled_balanced_proj.shape

(1024, 1024)

In [181]:
np.save("imagenet_text2img_upsampled_balanced_proj.npy", upsampled_balanced_proj.astype(float))

In [222]:
import random 

upsampled_images = np.zeros((400000, 1024))
upsampled_texts = np.zeros((400000, 1024))

for i in tqdm(range(400000)):
    label = random.choice(range(1000))
    
    idx = random.choice(indices_by_label[label])
    upsampled_images[i] = final_images[idx].cpu()
    upsampled_texts[i] = final_texts[idx].cpu()

upsampled_balanced_proj, _, _, _ = lstsq(upsampled_texts, upsampled_images)
upsampled_balanced_proj.shape

np.save("imagenet_text2img_upsampled_balanced_proj400.npy", upsampled_balanced_proj.astype(float))

  0%|          | 0/400000 [00:00<?, ?it/s]

Balanced projection matrix - 100 instances/pooling of labels 

In [182]:
all_labels_text = {}


with torch.no_grad():
    for label in tqdm(range(1000)):
        all_labels_text[label] = []

        templates = np.array(GENERIC_PROMPT_COLLECTIONS["ImageNet"])
        c = np.array(CLASSES)[label]
            
        texts = clip.tokenize([template.format(c) for template in templates]) 
        texts = texts.cuda()
        zeroshot_weights = text_model(texts).float()
        zeroshot_weights = zeroshot_weights / zeroshot_weights.norm(
            dim=-1, keepdim=True
        )
        all_labels_text[label].append(zeroshot_weights)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [183]:
import random 

upsampled2_images = np.zeros((100000, 1024))
upsampled2_texts = np.zeros((100000, 1024))

for i in tqdm(range(100000)):
    label = random.choice(range(1000))
    
    idx = random.choice(indices_by_label[label])
    upsampled2_images[i] = final_images[idx].cpu()
    idx_2 = random.choice(range(82))
    upsampled2_texts[i] = all_labels_text[label][0][idx_2].cpu()

  0%|          | 0/100000 [00:00<?, ?it/s]

In [184]:
upsampled2_texts.shape

(100000, 1024)

In [185]:
upsampled2_balanced_proj, _, _, _ = lstsq(upsampled2_texts, upsampled2_images)
upsampled2_balanced_proj.shape

(1024, 1024)

In [186]:
np.save("imagenet_text2img_upsampled2_balanced_proj.npy", upsampled2_balanced_proj.astype(float))

In [212]:
import random 

upsampled2_images = np.zeros((200000, 1024))
upsampled2_texts = np.zeros((200000, 1024))

for i in tqdm(range(200000)):
    label = random.choice(range(1000))
    
    idx = random.choice(indices_by_label[label])
    upsampled2_images[i] = final_images[idx].cpu()
    idx_2 = random.choice(range(82))
    upsampled2_texts[i] = all_labels_text[label][0][idx_2].cpu()
    
upsampled2_balanced_proj, _, _, _ = lstsq(upsampled2_texts, upsampled2_images)

np.save("imagenet_text2img_upsampled2_balanced_proj200.npy", upsampled2_balanced_proj.astype(float))

  0%|          | 0/200000 [00:00<?, ?it/s]

In [223]:
import random 

upsampled2_images = np.zeros((400000, 1024))
upsampled2_texts = np.zeros((400000, 1024))

for i in tqdm(range(400000)):
    label = random.choice(range(1000))
    
    idx = random.choice(indices_by_label[label])
    upsampled2_images[i] = final_images[idx].cpu()
    idx_2 = random.choice(range(82))
    upsampled2_texts[i] = all_labels_text[label][0][idx_2].cpu()
    
upsampled2_balanced_proj, _, _, _ = lstsq(upsampled2_texts, upsampled2_images)

np.save("imagenet_text2img_upsampled2_balanced_proj400.npy", upsampled2_balanced_proj.astype(float))

  0%|          | 0/400000 [00:00<?, ?it/s]

Balanced projection matrix - 100 instances/pooling of labels/ VL-LTR text pool

In [141]:
desc_path = "/nethome/bdevnani3/flash1/long_tail_lang/data_loader/imagenet/wiki/desc_{}.txt"

all_labels_wiki_text = {}
all_labels_wiki_text_embs = {}

with torch.no_grad():
    for label in range(1000):
        all_labels_wiki_text[label] = []
        label_desc_path = desc_path.format(label)
        f = open(label_desc_path)
        for line in f:
            line = line.strip()
            if "==" in line:
                continue
            if len(line) == 0:
                continue
            all_labels_wiki_text[label].append(line[:76])
        texts = clip.tokenize(all_labels_wiki_text[label])
        texts = texts.cuda()
        zeroshot_weights = text_model(texts).float()
        zeroshot_weights = zeroshot_weights / zeroshot_weights.norm(
            dim=-1, keepdim=True
        )
        all_labels_wiki_text_embs[label] = zeroshot_weights
        f.close()
    

In [142]:
import random 

upsampled3_images = np.zeros((100000, 1024))
upsampled3_texts = np.zeros((100000, 1024))

for i in tqdm(range(100000)):
    label = random.choice(range(1000))
    
    idx = random.choice(indices_by_label[label])
    upsampled3_images[i] = final_images[idx].cpu()
    idx_2 = random.choice(range(len(all_labels_wiki_text[label])))
    upsampled3_texts[i] = all_labels_wiki_text_embs[label][idx_2].cpu()

  0%|          | 0/100000 [00:00<?, ?it/s]

In [143]:
upsampled3_balanced_proj, _, _, _ = lstsq(upsampled3_texts, upsampled3_images)
upsampled3_balanced_proj.shape

(1024, 1024)

In [144]:
np.save("imagenet_text2img_upsampled3_balanced_proj.npy", upsampled3_balanced_proj.astype(float))

Balanced projection matrix - 100 instances/pooling of labels/ VL-LTR text pool + templates

In [145]:
all_labels_wiki_and_template = {}

for label in all_labels_wiki_text_embs:
    all_labels_wiki_and_template[label] = torch.cat([all_labels_wiki_text_embs[label], all_labels_text[label][0]])

In [146]:
import random 

upsampled4_images = np.zeros((100000, 1024))
upsampled4_texts = np.zeros((100000, 1024))

for i in tqdm(range(100000)):
    label = random.choice(range(1000))
    
    idx = random.choice(indices_by_label[label])
    upsampled4_images[i] = final_images[idx].cpu()
    idx_2 = random.choice(range(len(all_labels_wiki_and_template[label])))
    upsampled4_texts[i] = all_labels_wiki_and_template[label][idx_2].cpu()

  0%|          | 0/100000 [00:00<?, ?it/s]

In [147]:
upsampled4_balanced_proj, _, _, _ = lstsq(upsampled4_texts, upsampled4_images)
upsampled4_balanced_proj.shape

(1024, 1024)

In [148]:
np.save("imagenet_text2img_upsampled4_balanced_proj.npy", upsampled4_balanced_proj.astype(float))

Project to a combination of both text and image

In [116]:
import random 

upsampled5_images = np.zeros((100000, 1024))
upsampled5_texts = np.zeros((100000, 1024))

for i in tqdm(range(100000)):
    label = random.choice(range(1000))
    
    idx_2 = random.choice(range(len(all_labels_wiki_and_template[label])))
    upsampled5_texts[i] = all_labels_wiki_and_template[label][idx_2].cpu()
    
    cointoss = random.choice(range(2))
    print(cointoss)
    
    if cointoss == 0:
        idx = random.choice(indices_by_label[label])
        upsampled5_images[i] = final_images[idx].cpu()
    else:
        idx = random.choice(range(len(all_labels_wiki_and_template[label])))
        upsampled5_images[i] = all_labels_wiki_and_template[label][idx].cpu()

  0%|          | 0/100000 [00:00<?, ?it/s]

In [117]:
upsampled5_balanced_proj, _, _, _ = lstsq(upsampled5_texts, upsampled5_images)
upsampled5_balanced_proj.shape

(1024, 1024)

In [118]:
np.save("imagenet_text2img_upsampled5_balanced_proj.npy", upsampled4_balanced_proj.astype(float))

Analysis of content of the projection matrices

In [187]:
upsampled2_balanced_proj

array([[-8.22019977e+00, -4.05813132e+00,  7.99565807e+00, ...,
         3.68203487e+00,  9.65279449e+00,  9.00989232e+00],
       [-2.79999779e-01,  2.55104054e+00, -9.46708333e+00, ...,
         1.99724974e+00,  1.05673761e+01,  2.01435400e+00],
       [ 1.73314683e+00, -4.13588820e+00,  9.67940859e-01, ...,
         5.86499522e+00,  1.03482521e+00,  7.07844417e+00],
       ...,
       [ 4.50658966e-01,  1.25472663e+01, -2.94527356e+00, ...,
         2.06302891e+00,  8.13378075e+00,  1.48842194e+01],
       [-2.49041376e+00,  6.27473946e+00,  2.70703154e+00, ...,
         4.94021511e+00, -4.20177749e+00, -1.76725849e+01],
       [-1.51124219e+00,  1.24903615e+00,  3.76061880e-03, ...,
         1.99832544e+00,  2.66405994e+00, -9.15591247e+00]])

In [189]:
eigs = np.linalg.eig(upsampled2_balanced_proj)

In [191]:
w, v = eigs

In [209]:
for i in w:
    print(round(i,5))

(-191.86758+227.50055j)
(-191.86758-227.50055j)
(-271.64113+7.71438j)
(-271.64113-7.71438j)
(229.25223+147.61494j)
(229.25223-147.61494j)
(-108.70222+211.01767j)
(-108.70222-211.01767j)
(-17.55595+240.04112j)
(-17.55595-240.04112j)
(-201.09813+87.59422j)
(-201.09813-87.59422j)
(-185.31961+103.46369j)
(-185.31961-103.46369j)
(-195.83311+7.4j)
(-195.83311-7.4j)
(-114.2374+172.51019j)
(-114.2374-172.51019j)
(-83.38978+193.52439j)
(-83.38978-193.52439j)
(-177.63156+46.29117j)
(-177.63156-46.29117j)
(260.30649+0j)
(244.26998+65.76284j)
(244.26998-65.76284j)
(126.80332+192.02814j)
(126.80332-192.02814j)
(-8.97593+205.56686j)
(-8.97593-205.56686j)
(155.41452+138.12063j)
(155.41452-138.12063j)
(12.40778+181.86152j)
(12.40778-181.86152j)
(212.04363+0j)
(185.71186+79.73079j)
(185.71186-79.73079j)
(50.21728+166.94153j)
(50.21728-166.94153j)
(7.20089+163.93132j)
(7.20089-163.93132j)
(93.69756+151.73551j)
(93.69756-151.73551j)
(187.01059+0j)
(134.99519+108.28502j)
(134.99519-108.28502j)
(117.68911+

  print(round(i,5))


In [195]:
print(min(w), max(w), np.std(w))

(-271.6411318751175-7.714377189984287j) (260.30649003419984+0j) 56.589194459364066


In [219]:
np.mean(w)

(0.09238127581293652-3.9619119074872644e-17j)

In [203]:
final_images.shape

torch.Size([115846, 1024])

In [202]:
final_texts.shape

torch.Size([115846, 1024])

In [204]:
final_labels

tensor([715, 246, 247,  ..., 986, 968, 674])

In [216]:
projected_texts = np.matmul(final_texts.cpu(), upsampled2_balanced_proj)

In [217]:
np.linalg.norm(projected_texts)

268.58209378326274

In [218]:
np.linalg.norm(final_images.cpu())

339.276

In [213]:
np.linalg.norm(final_texts.cpu())

339.25485

In [207]:
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(random_state=0,verbose=1, n_jobs=-1).fit(projected_texts, final_labels)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 10 concurrent workers.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =      1025000     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  8.00236D+05    |proj g|=  1.16415D+03


 This problem is unconstrained.



At iterate   50    f=  2.66757D+05    |proj g|=  5.40155D+01

At iterate  100    f=  2.66095D+05    |proj g|=  1.68676D+00

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
*****    100    106      1     0     0   1.687D+00   2.661D+05
  F =   266095.48310057109     

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT                 


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed: 36.5min finished


In [208]:
clf.score(projected_texts, final_labels)

0.894748200196813

In [221]:
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(random_state=0,verbose=1, n_jobs=-1).fit(final_texts.cpu(), final_labels)
clf.score(projected_texts, final_labels)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 10 concurrent workers.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =      1025000     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  8.00236D+05    |proj g|=  1.16415D+03


 This problem is unconstrained.



At iterate   50    f=  1.31944D+05    |proj g|=  4.53975D+00

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
*****     97    100      1     0     0   3.988D-02   1.319D+05
  F =   131928.92824115991     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             


[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed: 49.7min finished


0.7241423959394369