In [None]:
import torch
import numpy as np
import transformers
from tqdm import tqdm
import linear_rep_geometry as lrg

device = torch.device("cuda:0")

### load unembdding vectors ###
gamma = lrg.model.lm_head.weight.detach().to(device)
#gamma = lrg.model.get_output_embeddings().weight.detach()
gamma = gamma.to(torch.float32)

W, d = gamma.shape
gamma_bar = torch.mean(gamma, dim = 0)
centered_gamma = gamma - gamma_bar

### compute Cov(gamma) and tranform gamma to g ###
Cov_gamma = centered_gamma.T @ centered_gamma / W

# Cov_gamma = Cov_gamma.to(torch.float32)

eigenvalues, eigenvectors = torch.linalg.eigh(Cov_gamma)
inv_sqrt_Cov_gamma = eigenvectors @ torch.diag(1/torch.sqrt(eigenvalues)) @ eigenvectors.T
sqrt_Cov_gamma = eigenvectors @ torch.diag(torch.sqrt(eigenvalues)) @ eigenvectors.T
g = gamma @ inv_sqrt_Cov_gamma

### compute concept directions ###

filenames = [
            ## Core concepts
             'word_pairs/[male - female].txt',
             'word_pairs/[lower - upper].txt',
             'word_pairs/[country - capital].txt',
            ## Bilingual pairs
             'word_pairs/[German - English].txt',
             'word_pairs/[English - Dutch].txt',
             'word_pairs/[Swedish - English].txt',
             'word_pairs/[English - Italian].txt',
             'word_pairs/[Portuguese - English].txt',
             'word_pairs/[Spanish - English].txt',
             'word_pairs/[English - French].txt',
             'word_pairs/[French - German].txt',
             'word_pairs/[Spanish - French].txt',
             'word_pairs/[French - Portuguese].txt',
             'word_pairs/[Italian - French].txt',
             'word_pairs/[French - Swedish].txt',
             'word_pairs/[French - Dutch].txt',
             'word_pairs/[German - Dutch].txt',
             'word_pairs/[German - Spanish].txt',
             'word_pairs/[Swedish - German].txt',
             'word_pairs/[German - Italian].txt',
             'word_pairs/[Portuguese - German].txt',            
             'word_pairs/[Dutch - Spanish].txt',
             'word_pairs/[Spanish - Portuguese].txt',
             'word_pairs/[Italian - Spanish].txt',
             'word_pairs/[Spanish - Swedish].txt',           
             'word_pairs/[Dutch - Swedish].txt',
             'word_pairs/[Dutch - Portuguese].txt',
             'word_pairs/[Italian - Dutch].txt',
             'word_pairs/[Swedish - Italian].txt',
             'word_pairs/[Italian - Portuguese].txt',
             'word_pairs/[Portuguese - Swedish].txt'
            # 'word_pairs/[Portuguese - Italian].txt',



             



             ]

concept_names = []

for name in filenames:
    content = name.split("/")[1].split(".")[0][1:-1]
    parts = content.split(" - ")
    concept_names.append(r'${} \Rightarrow {}$'.format(parts[0], parts[1]))

concept_gamma = torch.zeros(len(filenames), d)
concept_g = torch.zeros(len(filenames), d)

count = 0
for filename in filenames:
    base_ind, target_ind, base_name, target_name = lrg.get_counterfactual_pairs(filename)

    mean_diff_gamma, diff_gamma = lrg.concept_direction(base_ind, target_ind, gamma)
    concept_gamma[count] = mean_diff_gamma

    mean_diff_g, diff_g = lrg.concept_direction(base_ind, target_ind, g)
    concept_g[count] = mean_diff_g

    count += 1

### save everything ###
torch.save(gamma, "matrices/gamma.pt")
torch.save(g, "matrices/g.pt")
torch.save(sqrt_Cov_gamma, "matrices/sqrt_Cov_gamma.pt")
torch.save(concept_gamma, "matrices/concept_gamma.pt")
torch.save(concept_g, "matrices/concept_g.pt")

with open('matrices/concept_names.txt', 'w') as f:
    for item in concept_names:
        f.write(f"{item}\n")
        
with open('matrices/filenames.txt', 'w') as f:
    for item in filenames:
        f.write(f"{item}\n")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]