In [1]:
import torch
import os
os.chdir('C:/Users/weckbecker/DualView/')

from sklearn.linear_model import LogisticRegression as LR
from random import sample as samp
from random import seed

In [2]:
# standard dataset
sample = torch.load('explanations/MNIST/std/basic_conv_std/dualview_0.001/samples_tensor')
label = torch.load('explanations/MNIST/std/basic_conv_std/dualview_0.001/labels_tensor')

# corrupted dataset
#sample = torch.load('explanations/MNIST/corrupt/basic_conv_corrupt/dualview_0.001/samples_tensor')
#label = torch.load('explanations/MNIST/corrupt/basic_conv_corrupt/dualview_0.001/labels_tensor')

In [3]:
sample.shape, label.shape

(torch.Size([60000, 100]), torch.Size([60000]))

In [4]:
train_size = int(0.8 * len(sample))
test_size = len(sample) - train_size

sample_train = sample[:train_size]
sample_test = sample[train_size:]
label_train = label[:train_size]
label_test = label[train_size:]

In [5]:
lr = LR(penalty='l2', tol=0.0001, C=1.0)

In [6]:
lr.fit(sample_train, label_train)

In [7]:
pred = lr.predict(sample_test)
print('Correct classification rate is {:.2%}.'.format(sum(pred == label_test.numpy()) / len(sample_test)))

Correct classification rate is 99.28%.


In [8]:
prob_train = lr.predict_proba(sample_train)

lev_tensor = torch.empty((10,len(sample_train)))

for c in range(10):
    V_vector = torch.Tensor(prob_train[:,c])
    inv = torch.inverse((V_vector[:,None] * sample_train).T @ sample_train)
    for i in range(len(sample_train)):
        t = sample[i]
        v = V_vector[i]
        lev = v * t.T @ inv @ t
        lev_tensor[c,i] = lev

  lev = v * t.T @ inv @ t


In [9]:
class_tensor = torch.empty((10,len(sample_train)), dtype=torch.bool)
anti_class_tensor = torch.empty((10,len(sample_train)), dtype=torch.bool)
for c in range(10):
    class_tensor[c,:] = (label_train == c)
    anti_class_tensor[c,:] = (label_train != c)

In [10]:
class_tensor.shape, lev_tensor.shape

(torch.Size([10, 48000]), torch.Size([10, 48000]))

In [11]:
for c in range(10):
    anti_class_lev = lev_tensor[c,:] * anti_class_tensor[c,:].to(torch.float)
    print(f'Class {c}')
    idx_max = torch.argmax(anti_class_lev)
    print(idx_max)
    print(torch.max(anti_class_lev))
    print(prob_train[idx_max,c])
    print()

Class 0
tensor(16376)
tensor(0.3254)
0.27935594572205313

Class 1
tensor(33612)
tensor(0.2202)
0.41310704875540394

Class 2
tensor(37147)
tensor(0.0468)
0.4197987605857325

Class 3
tensor(34439)
tensor(0.0490)
0.385537054683199

Class 4
tensor(41949)
tensor(0.0627)
0.22163206747759678

Class 5
tensor(11292)
tensor(0.0401)
0.18903373744500074

Class 6
tensor(20709)
tensor(0.1263)
0.41317759438540974

Class 7
tensor(18487)
tensor(0.0974)
0.16971788956135797

Class 8
tensor(30692)
tensor(0.0821)
0.34036858457328845

Class 9
tensor(26544)
tensor(0.0497)
0.06642422022012999



In [12]:
for c in range(10):
    class_lev = lev_tensor[c,:] * class_tensor[c,:].to(torch.float)
    print(f'Class {c}')
    idx_max = torch.argmax(class_lev)
    print(idx_max)
    print(torch.max(class_lev))
    print(prob_train[idx_max,c])
    print()

Class 0
tensor(40240)
tensor(0.4883)
0.7700803891515025

Class 1
tensor(20672)
tensor(0.4858)
0.8209428098742453

Class 2
tensor(42309)
tensor(0.2214)
0.9999763969279161

Class 3
tensor(44101)
tensor(0.2812)
0.9991536097562138

Class 4
tensor(45069)
tensor(0.9737)
0.9827588066103807

Class 5
tensor(24614)
tensor(0.2287)
0.9684137399268452

Class 6
tensor(38050)
tensor(0.3032)
0.9983060085343369

Class 7
tensor(26882)
tensor(0.6576)
0.9225826283293516

Class 8
tensor(14690)
tensor(0.2119)
0.9998806287034755

Class 9
tensor(20735)
tensor(0.5214)
0.989560144012716



## TODO Tomorrow

### Randomly select a few labels to change and see whether these changed labels will have a high leverage 

In [13]:
poison_idx = range(int(0.5*train_size))
label_train_poison = label_train.clone()
print(label_train_poison[poison_idx])
label_train_poison[poison_idx] = label_train_poison[poison_idx] + 1
print(label_train_poison[poison_idx])
print(label_train[poison_idx])

tensor([5, 0, 4,  ..., 3, 9, 4], dtype=torch.int32)
tensor([ 6,  1,  5,  ...,  4, 10,  5], dtype=torch.int32)
tensor([5, 0, 4,  ..., 3, 9, 4], dtype=torch.int32)


In [14]:
lr_poison = LR(penalty='l2', tol=0.0001, C=1.0, max_iter=1000)
lr_poison.fit(sample_train, label_train_poison)

pred = lr_poison.predict(sample_test)
print('Correct classificaiton rate is {:.2%}.'.format(sum(pred == label_test.numpy()) / len(sample_test)))

Correct classificaiton rate is 49.35%.


In [15]:
prob_train_poison = lr_poison.predict_proba(sample_train)

lev_tensor = torch.empty((10,len(sample_train)))

for c in range(10):
    V_vector = torch.Tensor(prob_train[:,c])
    inv = torch.inverse((V_vector[:,None] * sample_train).T @ sample_train)
    for i in range(len(sample_train)):
        t = sample[i]
        v = V_vector[i]
        lev = v * t.T @ inv @ t
        lev_tensor[c,i] = lev

lev_summed = torch.sum(lev_tensor, dim=0)

In [16]:
lev_summed
lev_summed_poison = lev_summed[poison_idx]

sorted, indices = torch.sort(lev_summed_poison, descending=True)
sorted, indices

(tensor([0.5354, 0.5221, 0.5030,  ..., 0.0032, 0.0031, 0.0030]),
 tensor([ 4205, 20735,  2676,  ..., 21127, 20848,  3380]))

In [20]:
lr_repair = LR(penalty='l2', tol=0.0001, C=1.0, max_iter=1000)

for i in range(25):
    repair_idx = indices[:(i+1)*2000]
    label_train_repaired = label_train_poison.clone()
    label_train_repaired[repair_idx] = label_train_repaired[repair_idx] - 1

    lr_repair.fit(sample_train, label_train_repaired)
    pred = lr_repair.predict(sample_test)
    print('Correct classificaiton rate is {:.2%}.'.format(sum(pred == label_test.numpy()) / len(sample_test)))

    repair_idx = indices[-(i+1)*2000:]
    label_train_repaired = label_train_poison.clone()
    label_train_repaired[repair_idx] = label_train_repaired[repair_idx] - 1

    lr_repair.fit(sample_train, label_train_repaired)
    pred = lr_repair.predict(sample_test)
    print('Correct classificaiton rate is {:.2%}.'.format(sum(pred == label_test.numpy()) / len(sample_test)))
    print()

Correct classificaiton rate is 54.62%.
Correct classificaiton rate is 59.89%.

Correct classificaiton rate is 61.45%.
Correct classificaiton rate is 67.50%.

Correct classificaiton rate is 67.58%.
Correct classificaiton rate is 73.58%.

Correct classificaiton rate is 73.09%.
Correct classificaiton rate is 78.57%.

Correct classificaiton rate is 78.32%.
Correct classificaiton rate is 82.31%.

Correct classificaiton rate is 82.87%.
Correct classificaiton rate is 85.62%.

Correct classificaiton rate is 86.48%.
Correct classificaiton rate is 88.29%.

Correct classificaiton rate is 89.93%.
Correct classificaiton rate is 90.90%.

Correct classificaiton rate is 93.00%.
Correct classificaiton rate is 93.11%.

Correct classificaiton rate is 95.83%.
Correct classificaiton rate is 94.85%.

Correct classificaiton rate is 97.62%.
Correct classificaiton rate is 96.64%.

Correct classificaiton rate is 99.28%.
Correct classificaiton rate is 99.28%.

Correct classificaiton rate is 99.28%.
Correct class