In [3]:
import pandas as pd
import torch
import os

In [4]:
import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score

from tqdm import tqdm

In [33]:
def get_res(name, split):

    res = {'acc': [], 'roc': [], 'train_acc': []}
    int_id = 0 if name == 'bert' else -1
    n = 25 if name == 'bert' else 33

    for i in tqdm(range(1, n)):
        pd_test = torch.load(os.path.join('../LID-HallucinationDetection/output_tensors', f"{name}/all_layer_{i}_{split}.pt")).detach().type(torch.float32)
        pd_train = torch.load(os.path.join('../LID-HallucinationDetection/output_tensors', f"{name}_ru/all_layer_{i}_{split}.pt")).detach().type(torch.float32)

        pca = PCA()
        # Define a Standard Scaler to normalize inputs
        scaler = StandardScaler()

        # set the tolerance to a large value to make the example faster
        logistic = LogisticRegression(max_iter=10000, tol=0.1)
        # pipe = Pipeline(steps=[("scaler", scaler), ("pca", pca), ("logistic", logistic)])
        pipe = Pipeline(steps=[("scaler", scaler), ("logistic", logistic)])

        n_train = 25
        n_test = 102

        X_train = np.stack([pd_train[i, int_id] for i in range(len(pd_train))])
        X_test = np.stack([pd_test[i, int_id] for i in range(len(pd_test))])
        y_train = np.array([1] * n_train + [0] * n_train)
        y_test = np.array([1] * n_test + [0] * n_test)

        # Parameters of pipelines can be set using '__' separated parameter names:
        param_grid = {
            # "pca__n_components": [5, 10, 15, 20, 30],
            "logistic__C": np.logspace(-4, 4, 4),
        }
        search = GridSearchCV(pipe, param_grid, n_jobs=4, cv=3)
        search.fit(X_train, y_train)
        # print("Best parameter (CV score=%0.3f):" % search.best_score_)
        # print(search.best_params_)

        preds = search.predict(X_test)
        logits = search.predict_proba(X_test)[:, 1]
        acc = accuracy_score(y_test, preds)
        roc_auc = roc_auc_score(y_test, logits)

        res['train_acc'].append(accuracy_score(y_train, search.predict(X_train)))
        res['acc'].append(acc)
        res['roc'].append(roc_auc)
        # print(search.best_estimator_['logistic'].coef_.shape)

    df = pd.DataFrame(res)
    # print(df)
    print(df['acc'].max(), df['roc'].max())
    return df

## GT

In [34]:
get_res('bert', 'gt')

100%|██████████| 24/24 [00:01<00:00, 17.86it/s]

0.5833333333333334 0.6221645520953479





Unnamed: 0,acc,roc,train_acc
0,0.553922,0.563149,0.8
1,0.5,0.539312,0.8
2,0.509804,0.543156,0.88
3,0.558824,0.560746,0.92
4,0.54902,0.570069,0.68
5,0.558824,0.568243,0.74
6,0.553922,0.584391,0.86
7,0.553922,0.58862,0.78
8,0.544118,0.572857,0.94
9,0.583333,0.595636,0.88


In [35]:
get_res('llama', 'gt')

100%|██████████| 32/32 [00:05<00:00,  6.04it/s]

0.6470588235294118 0.7135717031910802





Unnamed: 0,acc,roc,train_acc
0,0.544118,0.569973,0.82
1,0.588235,0.60371,0.8
2,0.578431,0.598039,0.86
3,0.558824,0.584679,0.8
4,0.573529,0.595348,0.86
5,0.593137,0.593522,0.82
6,0.573529,0.608612,0.96
7,0.578431,0.617647,0.82
8,0.622549,0.636005,0.84
9,0.602941,0.634468,0.86


In [36]:
get_res('mistral', 'gt')

100%|██████████| 32/32 [00:03<00:00,  8.15it/s]

0.6617647058823529 0.7181853133410226





Unnamed: 0,acc,roc,train_acc
0,0.519608,0.530181,0.82
1,0.563725,0.586697,0.84
2,0.563725,0.564302,0.8
3,0.617647,0.614091,0.84
4,0.558824,0.593041,0.86
5,0.602941,0.604479,0.88
6,0.583333,0.623799,0.86
7,0.544118,0.59554,0.86
8,0.553922,0.59679,0.9
9,0.612745,0.640715,0.92


## PRED

In [37]:
get_res('bert', 'pred')

100%|██████████| 24/24 [00:01<00:00, 16.80it/s]

0.6127450980392157 0.6141868512110727





Unnamed: 0,acc,roc,train_acc
0,0.553922,0.555652,0.72
1,0.544118,0.54902,0.9
2,0.509804,0.520088,0.96
3,0.5,0.541811,0.96
4,0.529412,0.569781,1.0
5,0.52451,0.562476,0.92
6,0.612745,0.604479,0.96
7,0.568627,0.591503,0.86
8,0.578431,0.586505,0.84
9,0.573529,0.592272,0.86


In [38]:
get_res('llama', 'pred')

100%|██████████| 32/32 [00:05<00:00,  5.65it/s]

0.5490196078431373 0.5450788158400616





Unnamed: 0,acc,roc,train_acc
0,0.509804,0.544118,0.9
1,0.529412,0.545079,0.84
2,0.54902,0.525375,0.9
3,0.534314,0.541234,0.94
4,0.480392,0.531719,0.92
5,0.5,0.519608,0.92
6,0.514706,0.528162,0.54
7,0.495098,0.535179,0.94
8,0.514706,0.535659,0.54
9,0.5,0.529892,0.98


In [39]:
get_res('mistral', 'pred')

100%|██████████| 32/32 [00:03<00:00,  8.29it/s]

0.5735294117647058 0.5509419454056133





Unnamed: 0,acc,roc,train_acc
0,0.514706,0.503845,0.92
1,0.495098,0.502307,0.96
2,0.5,0.506824,0.96
3,0.480392,0.478181,0.96
4,0.514706,0.478374,0.92
5,0.470588,0.468281,0.96
6,0.480392,0.473952,0.96
7,0.504902,0.511534,0.96
8,0.529412,0.529796,0.96
9,0.534314,0.535179,0.96


## Test

In [88]:
get_res('bert', 'gt')

100%|██████████| 24/24 [00:13<00:00,  1.83it/s]

         acc       roc
0   0.676471  0.719627
1   0.558824  0.613706
2   0.602941  0.621011
3   0.602941  0.641676
4   0.578431  0.617551
5   0.588235  0.643887
6   0.602941  0.653691
7   0.607843  0.669070
8   0.573529  0.609765
9   0.588235  0.617935
10  0.573529  0.633218
11  0.539216  0.565215
12  0.539216  0.559592
13  0.617647  0.681276
14  0.632353  0.691657
15  0.544118  0.551182
16  0.617647  0.684352
17  0.598039  0.654460
18  0.627451  0.677143
19  0.632353  0.691176
20  0.617647  0.660996
21  0.647059  0.694637
22  0.642157  0.698385
23  0.656863  0.678489
0.6764705882352942 0.7196270665128797





In [89]:
get_res('llama', 'gt')

100%|██████████| 32/32 [00:39<00:00,  1.23s/it]

         acc       roc
0   0.617647  0.675221
1   0.627451  0.692714
2   0.593137  0.618849
3   0.627451  0.689735
4   0.602941  0.639946
5   0.583333  0.627067
6   0.612745  0.650807
7   0.622549  0.658304
8   0.642157  0.705594
9   0.602941  0.663879
10  0.666667  0.749904
11  0.607843  0.698770
12  0.666667  0.716551
13  0.696078  0.778066
14  0.642157  0.697232
15  0.705882  0.773645
16  0.661765  0.721646
17  0.583333  0.657728
18  0.578431  0.658401
19  0.578431  0.650231
20  0.568627  0.653210
21  0.642157  0.703864
22  0.568627  0.577518
23  0.602941  0.657151
24  0.622549  0.643118
25  0.617647  0.659458
26  0.578431  0.632834
27  0.583333  0.629758
28  0.607843  0.647732
29  0.607843  0.645521
30  0.656863  0.703768
31  0.666667  0.696078
0.7058823529411765 0.7780661284121493





In [90]:
get_res('mistral', 'gt')

100%|██████████| 32/32 [00:40<00:00,  1.28s/it]

         acc       roc
0   0.495098  0.558247
1   0.593137  0.660419
2   0.617647  0.680219
3   0.622549  0.668877
4   0.622549  0.658208
5   0.578431  0.613706
6   0.612745  0.665129
7   0.622549  0.641003
8   0.558824  0.598568
9   0.686275  0.728950
10  0.696078  0.788927
11  0.681373  0.762111
12  0.705882  0.791522
13  0.632353  0.717609
14  0.740196  0.806901
15  0.661765  0.723376
16  0.651961  0.714341
17  0.691176  0.761822
18  0.691176  0.769319
19  0.656863  0.735198
20  0.681373  0.750577
21  0.661765  0.738082
22  0.676471  0.738658
23  0.671569  0.739523
24  0.671569  0.730777
25  0.656863  0.713956
26  0.666667  0.742791
27  0.710784  0.752691
28  0.681373  0.741830
29  0.696078  0.754421
30  0.700980  0.753941
31  0.647059  0.737889
0.7401960784313726 0.8069011918492887





In [85]:
get_res('bert', 'pred')

100%|██████████| 24/24 [00:15<00:00,  1.58it/s]

         acc       roc
0   0.558824  0.605248
1   0.539216  0.562860
2   0.558824  0.558247
3   0.583333  0.636582
4   0.553922  0.563822
5   0.602941  0.663783
6   0.627451  0.683872
7   0.612745  0.677624
8   0.583333  0.638120
9   0.598039  0.660900
10  0.602941  0.696847
11  0.666667  0.732314
12  0.593137  0.641292
13  0.568627  0.566705
14  0.612745  0.614860
15  0.607843  0.625240
16  0.656863  0.726163
17  0.661765  0.741734
18  0.602941  0.612649
19  0.568627  0.598904
20  0.573529  0.595252
21  0.549020  0.562764
22  0.578431  0.609765
23  0.578431  0.604575
0.6666666666666666 0.7417339484813532





In [86]:
get_res('llama', 'pred')

100%|██████████| 32/32 [00:38<00:00,  1.21s/it]

         acc       roc
0   0.568627  0.599097
1   0.602941  0.702518
2   0.598039  0.650519
3   0.666667  0.738947
4   0.656863  0.711457
5   0.549020  0.606594
6   0.612745  0.685986
7   0.563725  0.643695
8   0.573529  0.655037
9   0.627451  0.730008
10  0.539216  0.566609
11  0.544118  0.574491
12  0.549020  0.580738
13  0.549020  0.600442
14  0.583333  0.605440
15  0.588235  0.590830
16  0.588235  0.620242
17  0.598039  0.651480
18  0.573529  0.642349
19  0.588235  0.629566
20  0.637255  0.705210
21  0.593137  0.601403
22  0.622549  0.688581
23  0.637255  0.729431
24  0.651961  0.697040
25  0.588235  0.623606
26  0.671569  0.685506
27  0.583333  0.581699
28  0.583333  0.582949
29  0.573529  0.587274
30  0.656863  0.752980
31  0.681373  0.768358
0.6813725490196079 0.7683583237216455





In [87]:
get_res('mistral', 'pred')

100%|██████████| 32/32 [00:40<00:00,  1.26s/it]

         acc       roc
0   0.617647  0.642637
1   0.549020  0.556805
2   0.553922  0.574683
3   0.553922  0.554114
4   0.544118  0.574971
5   0.549020  0.575644
6   0.563725  0.576605
7   0.563725  0.585736
8   0.558824  0.559400
9   0.656863  0.700692
10  0.553922  0.608612
11  0.568627  0.609093
12  0.583333  0.616398
13  0.612745  0.647155
14  0.710784  0.755383
15  0.666667  0.692907
16  0.602941  0.668685
17  0.568627  0.611496
18  0.607843  0.656863
19  0.563725  0.607266
20  0.607843  0.644848
21  0.578431  0.615629
22  0.622549  0.672722
23  0.588235  0.634179
24  0.627451  0.657247
25  0.612745  0.658593
26  0.593137  0.655998
27  0.612745  0.660611
28  0.607843  0.664648
29  0.612745  0.691945
30  0.578431  0.653595
31  0.607843  0.660900
0.7107843137254902 0.7553825451749328





# Без PCA

In [20]:
get_res('bert', 'gt')
# res.train_acc.argmax

100%|██████████| 24/24 [00:02<00:00, 12.00it/s]

0.7401960784313726 0.8370818915801614





Unnamed: 0,acc,roc,train_acc
0,0.671569,0.731738,0.62
1,0.5,0.5,0.5
2,0.5,0.5,0.5
3,0.632353,0.703672,0.6
4,0.5,0.5,0.5
5,0.5,0.5,0.5
6,0.5,0.5,0.5
7,0.671569,0.73491,0.62
8,0.5,0.5,0.5
9,0.5,0.5,0.5


In [11]:
get_res('llama', 'gt')

100%|██████████| 32/32 [00:10<00:00,  3.06it/s]

0.8725490196078431 0.9351211072664359





Unnamed: 0,acc,roc,train_acc
0,0.789216,0.906478,0.58
1,0.789216,0.85025,0.6
2,0.808824,0.869185,0.6
3,0.852941,0.935121,0.64
4,0.789216,0.860054,0.6
5,0.779412,0.850827,0.64
6,0.784314,0.849193,0.62
7,0.764706,0.851499,0.68
8,0.872549,0.930027,0.58
9,0.852941,0.923106,0.64


In [26]:
get_res('mistral', 'gt')

100%|██████████| 32/32 [00:07<00:00,  4.32it/s]

0.8676470588235294 0.9425221068819685





Unnamed: 0,acc,roc,train_acc
0,0.828431,0.920607,0.58
1,0.794118,0.863802,0.6
2,0.779412,0.846501,0.56
3,0.789216,0.857074,0.6
4,0.803922,0.852653,0.62
5,0.818627,0.867935,0.66
6,0.852941,0.936275,0.6
7,0.784314,0.870531,0.62
8,0.784314,0.868608,0.6
9,0.862745,0.933103,0.68


In [27]:
get_res('bert', 'pred')

100%|██████████| 24/24 [00:01<00:00, 15.00it/s]

0.7745098039215687 0.8686082276047674





Unnamed: 0,acc,roc,train_acc
0,0.588235,0.657247,0.6
1,0.651961,0.694637,0.62
2,0.5,0.5,0.5
3,0.5,0.5,0.5
4,0.5,0.5,0.5
5,0.5,0.5,0.5
6,0.5,0.5,0.5
7,0.70098,0.774414,0.6
8,0.77451,0.868608,0.6
9,0.5,0.5,0.5


In [12]:
get_res('llama', 'pred')

100%|██████████| 32/32 [00:09<00:00,  3.42it/s]

0.9166666666666666 0.9862552864282969





Unnamed: 0,acc,roc,train_acc
0,0.784314,0.905902,0.5
1,0.5,0.5,0.5
2,0.5,0.5,0.5
3,0.5,0.5,0.5
4,0.5,0.5,0.5
5,0.5,0.5,0.5
6,0.514706,0.783929,0.54
7,0.754902,0.8094,0.54
8,0.862745,0.945309,0.48
9,0.872549,0.970396,0.5


In [13]:
get_res('mistral', 'pred')

100%|██████████| 32/32 [00:07<00:00,  4.12it/s]

0.9215686274509803 0.9773164167627836





Unnamed: 0,acc,roc,train_acc
0,0.784314,0.940119,0.46
1,0.862745,0.962226,0.5
2,0.784314,0.856978,0.52
3,0.911765,0.972607,0.58
4,0.921569,0.977316,0.44
5,0.75,0.830834,0.48
6,0.745098,0.819204,0.46
7,0.735294,0.813053,0.54
8,0.715686,0.796521,0.56
9,0.862745,0.94752,0.56
