In [1]:
from train_sae import SAE
import torch
import transformers
import pandas as pd
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [52]:
df1 = pd.read_csv('/home/ubuntu/clip-text-directions/cc12m.tsv', sep='\t', header=None)
df2 = pd.read_parquet('/home/ubuntu/clip-text-directions/ori_prompts_df.parquet')
df2 = df2[['clean_prompts']]
df2.columns = ["text"]
df1=df1[[1]]
df1.columns = ["text"]
df = pd.concat([df1, df2], axis=0)
df=df.drop_duplicates()

In [66]:
df.to_parquet("full.parquet", index=False)

In [2]:
sae = SAE(768, 768*32)
sae.load_state_dict(torch.load('sae.pt'))
sae = sae.to('cuda').to(torch.float16).requires_grad_(False)

In [3]:
df = pd.read_parquet("full.parquet")

In [12]:
clip = transformers.CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda").to(torch.float16).requires_grad_(False)
tokenizer = transformers.AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")

def tokenize(x):
    return tokenizer(x, padding="max_length", max_length=77, truncation=True, return_tensors="pt").input_ids




In [5]:
texts = df.text.to_list()

In [13]:
most_active_features = []
batch_size = 4096

def encode(embs):
    return sae.relu(sae.encoder(embs))

clip_forward = torch.compile(clip.forward)
encode = torch.compile(encode)

with torch.no_grad():
    for i in tqdm(range(0, len(texts), batch_size)):
        input_ids = tokenize(texts[i:i+batch_size]).to(clip.device)
        embeds = clip_forward(input_ids).pooler_output
        feats = encode(embeds).argmax(dim=-1).tolist()
        most_active_features.extend(feats)

df['most_active_feature'] = most_active_features

100%|██████████| 5494/5494 [1:42:53<00:00,  1.12s/it]  


In [18]:
# get counts for each
counts = df.most_active_feature.value_counts()

In [27]:
counts.to_dict()

{22335: 2002504,
 11177: 1096528,
 21491: 584349,
 12174: 417642,
 15360: 412136,
 8411: 357043,
 19158: 325261,
 2456: 312986,
 14634: 309433,
 1847: 309061,
 10210: 258361,
 20763: 245717,
 21862: 225187,
 22169: 213434,
 13571: 209434,
 23691: 188072,
 10767: 169645,
 7685: 166775,
 7755: 160121,
 834: 155217,
 18750: 150476,
 21038: 150270,
 19258: 147915,
 19248: 144743,
 18520: 142748,
 1974: 139383,
 1193: 139205,
 389: 139114,
 7963: 136267,
 20840: 132151,
 15780: 131432,
 1661: 129742,
 17625: 126219,
 16044: 123767,
 1259: 122487,
 2509: 118169,
 10156: 117867,
 15194: 116462,
 1799: 115760,
 9672: 110813,
 17676: 107879,
 1730: 102000,
 19894: 101876,
 21207: 101714,
 10224: 100439,
 19585: 99679,
 22893: 99561,
 14023: 98926,
 19532: 98786,
 9351: 97630,
 7004: 96203,
 828: 94882,
 20403: 91380,
 8998: 89805,
 6892: 88234,
 7819: 88170,
 6741: 87268,
 22238: 87222,
 14193: 85603,
 171: 85035,
 22715: 83499,
 4826: 83105,
 17953: 82674,
 14539: 80432,
 20791: 79380,
 15102:

In [34]:
df[df.most_active_feature == 24301].sample(5).text.to_list()

['Rework Coverup Thiscanbedone Tattoo Colorblasted Color Tattoo Somuchbetter Studio13tattoomg Tattoos Skull Tattoo Cover Up',
 'Tattoos - Puerto Rico (cover up and fixer upper) - 140385',
 'Start of a cover-up tattoo. Lilly tattoo, botanical tattoo',
 'Scottish Tattoo Cover-Up by Chameleon Tattoo',
 'Tattoos - Cover up - 134353']