/
logistic_regression.py
117 lines (97 loc) · 3.43 KB
/
logistic_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Logistic regression based probes.
Methodology for ungrouped logistic regression:
1. Given a set of activations, and whether each activation is from a true or false
statement.
2. Fit a linear probe using scikit's LogisticRegression implementation. The probe takes
in the activations and predicts the label.
Methodology for grouped logistic regression:
1. Given a set of activations, and whether each activation is from a true or false
statement.
2. Subtract the average activation of each group from each group member.
2. Fit a linear probe using scikit's LogisticRegression implementation. The probe takes
in the group-normalized activations and predicts the label.
Regularization: C=1.
"""
from dataclasses import dataclass
import numpy as np
import pandas as pd
from jaxtyping import Bool, Float, Int64
from sklearn.linear_model import LogisticRegression
from typing_extensions import override
from repeng.probes.base import BaseGroupedProbe, BaseProbe, PredictResult
@dataclass
class LrConfig:
c: float = 1.0
# We go for newton-cg as we've found it to be the fastest, see
# experiments/scratch/lr_speed.py.
solver: str = "newton-cg"
max_iter: int = 10_000
@dataclass
class LogisticRegressionProbe(BaseProbe):
model: LogisticRegression
@override
def predict(
self,
activations: Float[np.ndarray, "n d"], # noqa: F722
) -> PredictResult:
logits = self.model.decision_function(activations)
return PredictResult(logits=logits)
@dataclass
class LogisticRegressionGroupedProbe(BaseGroupedProbe, LogisticRegressionProbe):
model: LogisticRegression
@override
def predict_grouped(
self,
activations: Float[np.ndarray, "n d"], # noqa: F722
pairs: Int64[np.ndarray, "n"], # noqa: F821
) -> PredictResult:
activations_centered = _center_pairs(activations, pairs)
logits = self.model.decision_function(activations_centered)
return PredictResult(logits=logits)
def train_lr_probe(
config: LrConfig,
*,
activations: Float[np.ndarray, "n d"], # noqa: F722
labels: Bool[np.ndarray, "n"], # noqa: F821
) -> LogisticRegressionProbe:
model = LogisticRegression(
fit_intercept=True,
solver=config.solver,
C=config.c,
max_iter=config.max_iter,
)
model.fit(activations, labels)
return LogisticRegressionProbe(model)
def train_grouped_lr_probe(
config: LrConfig,
*,
activations: Float[np.ndarray, "n d"], # noqa: F722
groups: Int64[np.ndarray, "n d"], # noqa: F722
labels: Bool[np.ndarray, "n"], # noqa: F821
) -> LogisticRegressionGroupedProbe:
probe = train_lr_probe(
config,
activations=_center_pairs(activations, groups),
labels=labels,
)
return LogisticRegressionGroupedProbe(model=probe.model)
# TODO: Double check this preserves order.
def _center_pairs(
activations: Float[np.ndarray, "n d"], # noqa: F722
pairs: Int64[np.ndarray, "n"], # noqa: F821
) -> Float[np.ndarray, "n d"]: # noqa: F722
df = pd.DataFrame(
{
"activations": list(activations),
"pairs": pairs,
}
)
pair_means = (
df.groupby(["pairs"])["activations"]
.apply(lambda a: np.mean(a, axis=0))
.rename("pair_mean") # type: ignore
)
df = df.join(pair_means, on="pairs")
df["activations"] = df["activations"] - df["pair_mean"]
return np.stack(df["activations"].to_list())