In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

import unireps
import torch
import datasets

# project_dir = '/Users/christopher/Documents/unireps'
project_dir = '/net/scratch2/chriswolfram/unireps'
unireps.set_hf_cache_directory(os.path.join(project_dir, 'hf_cache'))
unireps.set_datasets_directory(os.path.join(project_dir, 'datasets'))
unireps.set_outputs_directory(os.path.join(project_dir, 'outputs'))

datasets.disable_caching()

### Models

In [2]:
model_names = [
    "openai-community/gpt2",
    "google/gemma-2b",
    "google/gemma-7b",
    "google/gemma-2-2b",
    "google/gemma-2-9b",
    "google/gemma-2-9b-it",
    "google/gemma-2-27b",
    "meta-llama/Meta-Llama-3.1-8B",
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "meta-llama/Llama-3.2-1B",
    "meta-llama/Llama-3.2-3B",
    "meta-llama/Llama-3.2-3B-Instruct",
    "meta-llama/Llama-3.2-11B-Vision",
    "mistralai/Mistral-7B-v0.3",
    "mistralai/Mistral-Nemo-Base-2407",
    "mistralai/Mixtral-8x7B-v0.1",
    "microsoft/Phi-3-mini-4k-instruct",
    "microsoft/Phi-3-medium-4k-instruct",
    "microsoft/Phi-3.5-mini-instruct",
    "tiiuae/falcon-40b",
    "tiiuae/falcon-11B",
    "meta-llama/Llama-3.1-70B",
    "meta-llama/Llama-3.1-70B-Instruct",
    "meta-llama/Llama-3.3-70B-Instruct",
    "tiiuae/falcon-mamba-7b"
]

chat_models = [
    "google/gemma-2-9b-it",
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "meta-llama/Llama-3.2-3B-Instruct",
    "microsoft/Phi-3-mini-4k-instruct",
    "microsoft/Phi-3-medium-4k-instruct",
    "microsoft/Phi-3.5-mini-instruct",
    "meta-llama/Llama-3.1-70B-Instruct",
    "meta-llama/Llama-3.3-70B-Instruct"
]

dataset_names = [
    "web_text",
    "web_text_caesar",
    "imdb",
    "random_strings",
    "book_translations_en",
    "book_translations_de",
    "common_words"
]

## Scratch space

In [3]:
model_names

['openai-community/gpt2',
 'google/gemma-2b',
 'google/gemma-7b',
 'google/gemma-2-2b',
 'google/gemma-2-9b',
 'google/gemma-2-9b-it',
 'google/gemma-2-27b',
 'meta-llama/Meta-Llama-3.1-8B',
 'meta-llama/Meta-Llama-3.1-8B-Instruct',
 'meta-llama/Llama-3.2-1B',
 'meta-llama/Llama-3.2-3B',
 'meta-llama/Llama-3.2-3B-Instruct',
 'meta-llama/Llama-3.2-11B-Vision',
 'mistralai/Mistral-7B-v0.3',
 'mistralai/Mistral-Nemo-Base-2407',
 'mistralai/Mixtral-8x7B-v0.1',
 'microsoft/Phi-3-mini-4k-instruct',
 'microsoft/Phi-3-medium-4k-instruct',
 'microsoft/Phi-3.5-mini-instruct',
 'tiiuae/falcon-40b',
 'tiiuae/falcon-11B',
 'meta-llama/Llama-3.1-70B',
 'meta-llama/Llama-3.1-70B-Instruct',
 'meta-llama/Llama-3.3-70B-Instruct',
 'tiiuae/falcon-mamba-7b']

In [11]:
unireps.dataset_embs(unireps.get_dataset('google/gemma-2-9b', 'web_text'), layer=None).flatten().shape[0]

315621376

In [10]:
unireps.dataset_embs(unireps.get_dataset('google/gemma-2-9b', 'web_text'), layer=None).dtype

torch.float32

In [6]:
unireps.get_dataset('google/gemma-2-9b', 'web_text')['at_max_length'].sum()

tensor(0)

In [3]:
ds1 = unireps.get_dataset('google/gemma-2-9b', 'web_text')
knn_1 = unireps.embs_knn(unireps.dataset_embs(ds1), k=10)

In [4]:
ds2 = unireps.get_dataset('meta-llama/Meta-Llama-3.1-8B', 'web_text')
knn_2 = unireps.embs_knn(unireps.dataset_embs(ds2), k=10)

In [5]:
unireps.mutual_knn(knn_1[10], knn_2[10])

0.45341795682907104

In [7]:
unireps.mutual_knn(knn_1, knn_2[0])

TypeError: expected Tensor as element 0 in argument 0, but got float

In [None]:
unireps.mutual_knn(knn_1, knn_2[0])

In [2]:
embs_1 = unireps.dataset_embs(unireps.get_dataset('google/gemma-2-9b', 'web_text'), layer=None)

In [3]:
knn = unireps.embs_knn(embs_1, k=10)

In [5]:
knn_1 = unireps.embs_knn(embs_1[0], k=10)

In [10]:
knn[0] == knn_1

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [4]:
knn.shape

torch.Size([43, 2048, 10])

In [6]:
knn_1.shape

torch.Size([2048, 10])

In [7]:
embs_1.shape

torch.Size([43, 2048, 3584])

In [14]:
knn.dim()

3