Skip to content

Commit

Permalink
Add some docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
iamgroot42 committed Feb 24, 2024
1 parent 8ed6886 commit c36e8bc
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 17 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

![MIMIR logo](assets/logo.png)

MIMIR - Python package for measuring memorization in LLMs.
MIMIR - Python package for measuring memorization in LLMs.

Documentation is available [here](https://iamgroot42.github.io/mimir.github.io).

[![Tests](https://github.com/iamgroot42/mimir/actions/workflows/test.yml/badge.svg)](https://github.com/iamgroot42/mimir/actions/workflows/test.yml)
[![Documentation](https://github.com/iamgroot42/mimir/actions/workflows/documentation.yml/badge.svg)](https://github.com/iamgroot42/mimir/actions/workflows/documentation.yml)
Expand Down Expand Up @@ -35,7 +37,7 @@ MIMIR_DATA_SOURCE: Path to data directory

## Using cached data

The data we used for our experiments is available on [Huggingface Datasets](https://huggingface.co/datasets/iamgroot42/mimir). You can either choose to either load the data directly from Huggingface with the `load_from_hf` flag in the config (preferred), or download the `cache_100_200_....` folders into your `MIMIR_CACHE_PATH` directory.
The data we used for our experiments is available on [Hugging Face Datasets](https://huggingface.co/datasets/iamgroot42/mimir). You can either choose to either load the data directly from Hugging Face with the `load_from_hf` flag in the config (preferred), or download the `cache_100_200_....` folders into your `MIMIR_CACHE_PATH` directory.

## MIA experiments how to run

Expand Down
36 changes: 28 additions & 8 deletions mimir/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
import transformers
import time
from collections import defaultdict
from tqdm import tqdm
from multiprocessing.pool import ThreadPool
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import zlib
from hf_olmo import *

from mimir.config import ExperimentConfig
Expand Down Expand Up @@ -70,9 +68,21 @@ def unload(self):
print(f'DONE ({time.time() - start:.2f}s)')

@torch.no_grad()
def get_probabilities(self, text: str, tokens=None):
def get_probabilities(self,
text: str,
tokens: np.ndarray = None):
"""
Get the probabilities or log-softmaxed logits for a text under the current model
Get the probabilities or log-softmaxed logits for a text under the current model.
Args:
text (str): The input text for which to calculate probabilities.
tokens (numpy.ndarray, optional): An optional array of token ids. If provided, these tokens
are used instead of tokenizing the input text. Defaults to None.
Raises:
ValueError: If the device or name attributes of the instance are not set.
Returns:
list: A list of probabilities.
"""
if self.device is None or self.name is None:
raise ValueError("Please set self.device and self.name in child class")
Expand Down Expand Up @@ -105,7 +115,6 @@ def get_probabilities(self, text: str, tokens=None):
del input_ids
del target_ids


for i, token_id in enumerate(labels_processed):
if token_id != -100:
probability = probabilities[0, i, token_id].item()
Expand All @@ -116,9 +125,19 @@ def get_probabilities(self, text: str, tokens=None):
return all_prob

@torch.no_grad()
def get_ll(self, text: str, tokens=None, probs=None):
def get_ll(self,
text: str,
tokens: np.ndarray=None,
probs = None):
"""
Get the log likelihood of each text under the base_model
Get the log likelihood of each text under the base_model.
Args:
text (str): The input text for which to calculate the log likelihood.
tokens (numpy.ndarray, optional): An optional array of token ids. If provided, these tokens
are used instead of tokenizing the input text. Defaults to None.
probs (list, optional): An optional list of probabilities. If provided, these probabilities
are used instead of calling the `get_probabilities` method. Defaults to None.
"""
all_prob = probs if probs is not None else self.get_probabilities(text, tokens=tokens)
return -np.mean(all_prob)
Expand Down Expand Up @@ -176,7 +195,8 @@ def load_base_model_and_tokenizer(self, model_kwargs):
"stanford-crfm/BioMedLM", **optional_tok_kwargs, cache_dir=self.cache_dir)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(
self.name, **optional_tok_kwargs, cache_dir=self.cache_dir)
self.name, **optional_tok_kwargs, cache_dir=self.cache_dir,
trust_remote_code=True if "olmo" in self.name.lower() else False)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

return model, tokenizer
Expand Down
29 changes: 22 additions & 7 deletions mimir/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""
Misc utils
utils.py
This module provides utility functions.
Environment Variables:
MIMIR_CACHE_PATH: The path to the cache directory. This should be set in the environment.
MIMIR_DATA_SOURCE: The data source for the MIMIR project. This should be set in the environment.
"""

import os
import random
import torch as ch
Expand All @@ -14,6 +20,9 @@
def fix_seed(seed: int = 0):
"""
Fix seed for reproducibility.
Parameters:
seed (int): The seed to set. Default is 0.
"""
ch.manual_seed(seed)
np.random.seed(seed)
Expand All @@ -22,9 +31,12 @@ def fix_seed(seed: int = 0):

def get_cache_path():
"""
Get path to cache directory.
Returns:
str: path to cache directory
Get path to cache directory.
Returns:
str: path to cache directory
Raises:
ValueError: If the MIMIR_CACHE_PATH environment variable is not set.
"""
if CACHE_PATH is None:
raise ValueError('MIMIR_CACHE_PATH environment variable not set')
Expand All @@ -33,9 +45,12 @@ def get_cache_path():

def get_data_source():
"""
Get path to data source directory.
Returns:
str: path to data source directory
Get path to data source directory.
Returns:
str: path to data source directory
Raises:
ValueError: If the MIMIR_DATA_SOURCE environment variable is not set.
"""
if DATA_SOURCE is None:
raise ValueError('MIMIR_DATA_SOURCE environment variable not set')
Expand Down
5 changes: 5 additions & 0 deletions templates/logo.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<header>
<a class="homelink" rel="home" title="MIMIR Home" href="https://iamgroot42.github.io/mimir/">
<img src="https://raw.githubusercontent.com/iamgroot42/mimir/8ed6886fb6df7a72f2f0f398688f48b68c5f48b0/assets/logo.png" alt="MIMIR">
</a>
</header>

0 comments on commit c36e8bc

Please sign in to comment.