In [None]:
import torch
import open_clip
import pandas as pd
from pathlib import Path

In [None]:
output_dir = Path('./data/')
output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# Force to the last GPU (Adam - for now)

DEVICE = 'cuda:3' if torch.cuda.is_available() else 'cpu'
print(f'Using {DEVICE}')

In [None]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k', device=DEVICE)
tokenizer = open_clip.get_tokenizer('ViT-H-14')

In [None]:
text = pd.read_parquet('mscoco.parquet').TEXT

In [None]:
print(text[0])

# Test for one token to get dims and everything correct
test_toks = tokenizer(text[0]).to(DEVICE)
print(test_toks.shape)
print(test_toks)

with torch.no_grad(), torch.cuda.amp.autocast():
    features = model.encode_text(test_toks)
    features /= features.norm(dim=-1, keepdim=True)

print(features.shape)
print(features)

In [None]:
example = '''
The White House announced on Saturday that Joe Biden was returning to Washington from out of town “to consult with his national security team about events in the Middle East” amid heightened tension between Israel and Iran.

A military helicopter hovering over a cargo vessel
World waits anxiously for Iranian response to Israel’s killing of top general
Read more
The US president had been due to spend the weekend in Delaware at his residence in Rehoboth Beach but early on Saturday afternoon set off at short notice to return to the White House.

This followed Biden saying on Friday that he expects an Iranian attack on Israel “sooner rather than later” and issued a last-ditch message to Tehran, saying: “Don’t.”

Earlier, John Kirby, the White House national security spokesperson, had warned that the threat of a significant Iranian attack on Israel remained “viable” despite Washington-led efforts, including calls to Tehran from the UK and Germany, to deter a serious escalation in the conflict in the Middle East.

On Saturday, Iran’s paramilitary Revolutionary Guard Corps in the strait of Hormuz, 50 nautical miles off the coast of the United Arab Emirates, seized an Israeli-affiliated container ship.
'''
print(example)

# Test for one token to get dims and everything correct
test_toks = tokenizer(example).to(DEVICE)
print(test_toks.shape)
print(test_toks)

with torch.no_grad(), torch.cuda.amp.autocast():
    features = model.encode_text(test_toks)
    features /= features.norm(dim=-1, keepdim=True)

print(features.shape)
print(features)

In [None]:
# Therefore, our output features are n x 1024
# Necessary for batching this output. We will allocate a tensor of that 
#   length and then continually place the features in the correct indxs
#   according to batch_size
import math
from tqdm import tqdm

BATCH_SIZE = 1024
num_batches = math.ceil(text.shape[0] / BATCH_SIZE)

# Preallocate output tensor
out_toks = torch.zeros(text.shape[0], 77)
out_feats = torch.zeros(text.shape[0], 1024)

for bn in tqdm(range(num_batches)):
    # Get batch toks
    tokens = tokenizer(text[BATCH_SIZE*bn:min(BATCH_SIZE*(bn+1), text.shape[0])]).to(DEVICE)

    # Place tokens in output
    out_toks[BATCH_SIZE*bn:min(BATCH_SIZE*(bn+1), text.shape[0]), :] = tokens

    # Encode text
    with torch.no_grad(), torch.cuda.amp.autocast():
        features = model.encode_text(tokens)
        features /= features.norm(dim=-1, keepdim=True)
    
    # Place them in output
    out_feats[BATCH_SIZE*bn:min(BATCH_SIZE*(bn+1), text.shape[0]), :] = features

torch.save(out_toks, output_dir / 'tokens.pt')
torch.save(out_feats, output_dir / 'features.pt')