In [1]:
from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
from llm_transparency_tool.models.transparent_llm import TransparentLlm
import llm_transparency_tool.routes.graph as lmttg

B0 = 0



In [2]:
import torch
from torch.amp import autocast
from typing import Dict, List, Optional, Tuple
import networkx as nx

def cached_run_inference_and_populate_state(
    stateless_model,
    sentences,
):
    stateful_model = stateless_model.copy()
    
    stateful_model.run(sentences)
    return stateful_model


def get_contribution_graph(
    model: TransparentLlm,
    threshold: float,
) -> nx.Graph:
    """
    The `model_key` and `tokens` are used only for caching. The model itself is not
    hashed, hence the `_` in the beginning.
    """    
    return lmttg.build_full_graph(
        model,
        B0,
        threshold,
    )

def get_contribution_graph_contrast(
    base_model: TransparentLlm,
    contrast_model: TransparentLlm,
    threshold: float,    
) -> nx.Graph:
    """Get the graph by using the contrast of the two models.
    
    Use object id for models, and added model_key and tokens for hashing purposes

    Args:
        base_model (TransparentLlm): Model 1, the one to be contrast
        contrast_model (TransparentLlm): Model 2, the one to compare
        threshold (float): Threshold to keep the edge.

    Returns:
        nx.Graph: Resulting graph.
    """    
    return lmttg.build_full_graph_with_contrast(
        base_model,
        contrast_model,
        B0,
        threshold,
    ) 


class GraphGen():
    _stateful_model: TransparentLlm = None
    _graph: Optional[nx.Graph] = None
    _contribution_threshold: float = 0.0
    _renormalize_after_threshold: bool = True
    _normalize_before_unembedding: bool = True

    def __init__(self):
        self.dtype = torch.float16
        self.amp_enabled = self.dtype != torch.float32
        
    def set_sentence(self, sentence):
        self.sentence = sentence
        
    def set_contribution_threshold(self, threshold: float):
        self._contribution_threshold = threshold
        
    def load_model(self, model_name, revision = None):
        self._stateful_model = TransformerLensTransparentLlm(
            model_name=model_name,
            device="gpu",
            dtype=torch.float16,
            model_revision=revision,
        )
        
        self.model_key = model_name

    @property
    def stateful_model(self) -> TransparentLlm:
        return self._stateful_model


    def build_graph(self):
        threshold = self._contribution_threshold if not self._renormalize_after_threshold else 0.0
                
        tokens = self.stateful_model.tokens()[B0]
        n_tokens = tokens.shape[0]
        model_info = self.stateful_model.model_info()
        
        graphs = lmttg.build_paths_to_predictions(
            self._graph,
            model_info.n_layers,
            n_tokens,
            range(n_tokens),
            threshold,
        )
        
        token_strs = self.stateful_model.tokens_to_strings(tokens)
        
        edge_weights = {}
        for u, v, weight in self._graph.edges(data="weight"):
            edge_weights[(u,v)] = weight
        
        return graphs, edge_weights

        # return llm_transparency_tool.components.contribution_graph(
        #     model_info,
        #     self.stateful_model.tokens_to_strings(tokens),
        #     graphs,
        #     key=f"graph_{hash(self.sentence)}",
        # )


    def run_inference(self):    
        # We added pair mode to contrast results of two sentences.
        is_pair_mode = False
        
        with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
            if " ||| " in self.sentence:
                # in pair mode
                base_sent, contrast_sent = self.sentence.split(" ||| ")

                # set self._stateful_model to be the base model
                self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, base_sent)
                contrast_state = cached_run_inference_and_populate_state(self.stateful_model, contrast_sent)                
                is_pair_mode = True
                
                n_tokens_base = self._stateful_model.tokens()[B0].shape[0]                
                n_tokens_contrast = contrast_state.tokens()[B0].shape[0]
                
                assert n_tokens_base == n_tokens_contrast                
            else:
                self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, [self.sentence])


        if is_pair_mode:
            with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
                # set the app graph to use base state first
                self._graph = get_contribution_graph_contrast(
                    self._stateful_model,
                    contrast_state,
                    (self._contribution_threshold if self._renormalize_after_threshold else 0.0),
                )
        else:
            with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
                self._graph = get_contribution_graph(
                    self.stateful_model,
                    (self._contribution_threshold if self._renormalize_after_threshold else 0.0),
                )

In [3]:
revisions = [f"ckpt_{n:03d}" for n in range(8, 351, 9)]
# revisions = [f"ckpt_{n:03d}" for n in range(8, 36, 9)]
revisions.append("main")

sents = [
    "Sarah was a much better surgeon than Maria, so the harder cases always went to",
    "Sarah was a much better surgeon than Maria, so the easier cases always went to",
    "Keeping the doors closed and the windows opened kept the apartment cool , because the heat was let out by the",
    "Keeping the doors closed and the windows opened kept the apartment cool , because the heat was kept out by the",
    "In the hotel laundry room, Emma burned Mary's shirt while ironing it, so the manager scolded",
    "In the hotel laundry room, Emma burned Mary's shirt while ironing it, so the manager refunded",
    "They had to eat a lot to gain the strength they had lost and be able to work, they had too much",
    "They had to eat a lot to gain the strength they had lost and be able to work, they had too little",
]


graph_timeline = {}

for rev in revisions:
    print("Working on revision ", rev)
    gg = GraphGen()
    gg.load_model("LLM360/Amber", rev)
    graph_timeline[rev] = []
    for s in sents:
        gg.set_sentence(s)
        gg.set_contribution_threshold(0.02)
        gg.run_inference()
        graphs, edge_weights = gg.build_graph()
        
        graph_timeline[rev].append((graphs, edge_weights))



2024-07-26 07:58:34.751 
  command:

    streamlit run /opt/conda/envs/llmtt/lib/python3.12/site-packages/ipykernel_launcher.py [ARGUMENTS]


Working on revision  ckpt_008


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_017


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_026


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_035


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_044


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_053


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_062


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_071


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_080


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_089


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_098


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_107


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_116


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_125


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_134


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_143


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_152


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_161


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_170


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_179


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_188


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_197


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_206


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_215


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_224


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_233


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_242


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_251


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_260


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_269


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_278


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_287


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_296


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_305


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_314


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_323


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_332


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_341


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  ckpt_350


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer
Working on revision  main


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model LLM360/Amber into HookedTransformer


In [4]:
import csv

def jaccard_similarity(edges1, edges2):
    """
    Compute the Jaccard similarity between two sets of edges.

    Args:
    edges1 (list of tuples): First list of edges.
    edges2 (list of tuples): Second list of edges.

    Returns:
    float: Jaccard similarity between edges1 and edges2.
    """
    set1 = set(edges1)
    set2 = set(edges2)
    
    intersection = set1.intersection(set2)
    union = set1.union(set2)
    
    if not union:
        return 0.0
    
    # print(
    #     f"set 1 has {len(set1)} edges, set 2 has {len(set2)} edges, intersection is {len(intersection)}, union is {len(union)}")
    return len(intersection) / len(union)
 
def weighted_jaccard_similarity(edges1, weights1, edges2, weights2):
    """
    Compute the weighted Jaccard similarity between two sets of edges with weights.

    Args:
    edges1 (list of tuples): First list of edges.
    weights1 (dict): Dictionary of weights for edges in the first list.
    edges2 (list of tuples): Second list of edges.
    weights2 (dict): Dictionary of weights for edges in the second list.

    Returns:
    float: Weighted Jaccard similarity between edges1 and edges2.
    """
    set1 = set(edges1)
    set2 = set(edges2)
    
    intersection = set1.intersection(set2)
    union = set1.union(set2)
    
    if not union:
        return 0.0
    
    intersection_sum = sum(min(weights1[e], weights2[e]) for e in intersection)
    union_sum = sum(max(weights1.get(e, 0), weights2.get(e, 0)) for e in union)
    
    return intersection_sum / union_sum
 
def write_jaccard_similarities_to_csv(sents, revisions, graph_timeline, output_file):
    final_graphs = graph_timeline['main']
    
    with open(output_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Sentence', 'Revision', 'Jaccard', 'Weighted Jaccard'])
        
        for i, sent in enumerate(sents):
            for rev in revisions[:-1]:
                g1, g1w = graph_timeline[rev][i]
                g2, g2w = final_graphs[i]
                
                assert len(g1) == len(g2)
                
                g1edges = g1[-1].edges()            
                g2edges = g2[-1].edges()
                
                jaccard = jaccard_similarity(g1edges, g2edges)
                wj = weighted_jaccard_similarity(g1edges, g1w, g2edges, g2w)
                
                writer.writerow([sent, rev, jaccard, wj])

output_file = 'jaccard_similarities.csv'
write_jaccard_similarities_to_csv(sents, revisions, graph_timeline, output_file)
