In [17]:
import torch
import pickle

# Load sharding.pkl
with open('sharding.pkl', 'rb') as f:
    interval_dict = pickle.load(f)

In [33]:
import pickle
import io

class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            return super().find_class(module, name)

In [34]:
models = {}
interval = 25

for k in interval_dict.keys():
    start = k
    end = start + interval
    
    if torch.cuda.is_available():
        shard = CPU_Unpickler(open(f'models_{start}_{end}.pkl', 'rb')).load()
    else:
        shard = pickle.load(open(f'models_{start}_{end}.pkl', 'rb'))
    
    for key in shard.keys():
        if shard[key] is not None:
            models[key] = shard[key].to('cpu')
            
# Save the models dictionary to a file
with open('all_models.pkl', 'wb') as f:
    pickle.dump(models, f)

In [22]:
# Set the seed
seed = 2
torch.manual_seed(seed)
batch_size = 64

# First thing to do is load the data
_, data = pickle.load(open("./glove/kmeans_clusters_500.pkl", "rb"))

# Checks if GPU(s) are available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

X, labels = None, None
for k in data.keys():
    if X is None:
        X = torch.Tensor(data[k])
        labels = torch.reshape(torch.Tensor([k] * len(data[k])), (-1, 1))
    else:
        X = torch.cat([X, torch.Tensor(data[k])])
        temp = torch.reshape(torch.Tensor([k] * len(data[k])), (-1, 1))
        labels = torch.cat([labels, temp])

# Construct the set of unique labels
targets = [x.item() for x in labels]
unique_labels = list(set(targets))



In [13]:
X.requires_grad = True
explanations = {}

# k = unique_labels[0]
for k in unique_labels:
    print("Explaining model for class", k)

    # Move both the model and data onto `device`
    model = models[k].to(device)
    X = X.to(device)

    predictions = model.forward(X, explain=True, rule="alpha2beta1")
    predictions = predictions.sum()
    predictions.backward()
    
    print(predictions)
    print(X.grad)

    explanation = X.grad
    explanations[k] = explanation

# Save the explanations
pickle.dump(explanations, open("explanations.pkl", "wb"))

Explaining model for class 0.0
tensor(338.2522, grad_fn=<SumBackward0>)
tensor([[-3.9901e-05,  4.2623e-04,  1.0068e-03,  ...,  7.7195e-06,
          1.7169e-05,  6.6941e-04],
        [ 1.3709e-03, -2.5898e-04, -6.3287e-04,  ...,  1.5772e-03,
         -2.7916e-03,  1.4479e-04],
        [-9.3946e-04, -1.0292e-04,  2.2466e-03,  ..., -2.0440e-03,
          1.3429e-04, -2.2135e-04],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]])
Explaining model for class 1.0
tensor(140.1828, grad_fn=<SumBackward0>)
tensor([[-3.9901e-05,  4.2623e-04,  1.0068e-03,  ...,  7.7195e-06,
          1.7169e-05,  6.6941e-04],
        [ 1.3709e-03, -2.5898e-04, -6.3287e-04,  ...,  1.5772e-03,
         -2.7916e-03,  1.4479e-04],
        [-9.3946e-

In [49]:
from PCA import extract_pca_features
import pickle

# Load the explanations
explanations = pickle.load(open("explanations.pkl", "rb"))

# Extract the PCA features from each pickle and save as pickles
pca_features = {}
for k in explanations.keys():
    explanations[k] = explanations[k].to('cpu')
    pca_features[k] = extract_pca_features(explanations[k].detach(), threshold=2)
    
pickle.dump(pca_features, open("pca_features.pkl", "wb"))

In [60]:
for i in range(500):
    print((explanations[0] == explanations[i]).all())

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)

In [50]:
import pickle
import numpy as np

# Load the PCA features
pca_features = pickle.load(open("pca_features.pkl", "rb"))

for k in pca_features.keys():
    print(pca_features[k])

[[-0.02954328 -0.00848463 -0.0243983  -0.00630257 -0.01534838 -0.04090177
  -0.01711915 -0.01247065 -0.00639554 -0.02737444 -0.04047353 -0.02226457
  -0.016213   -0.02480644 -0.02328082 -0.02691017 -0.01490191 -0.00853928
  -0.02283185 -0.01998735 -0.04018471 -0.01646969 -0.03040175 -0.03414013
   0.0075054   0.14314357 -0.01951948 -0.02494753 -0.03641728 -0.03200348
   0.97222221 -0.02632682 -0.03727819  0.02295776 -0.02417602 -0.03086798
  -0.01847303 -0.02980192 -0.04092811 -0.01950156 -0.01772566 -0.00361857
  -0.03483774  0.02873888 -0.03285628 -0.01529747  0.02385252 -0.06655666
  -0.02217541  0.0091568 ]
 [-0.06676131  0.31950874 -0.05210661  0.11065716 -0.00713533 -0.08031025
   0.32823278 -0.06932088 -0.08746529  0.14690493 -0.11505246 -0.08798444
   0.24969676  0.06175984  0.00475179 -0.09256548 -0.05987345 -0.06029823
   0.29601588 -0.01541445 -0.07327114 -0.08163641 -0.06691896 -0.08630629
  -0.022349    0.39842285 -0.09253141 -0.03990879  0.10822979 -0.11643464
  -0.094470

In [11]:
# load 50d glove embeddings
with open('glove/glove_50d.pkl', 'rb') as f:
    glove_50d = pickle.load(f)

print(len(glove_50d))

# reverse the mapping
glove_50d_inv = {tuple(v): k for k, v in glove_50d.items()}

62143


In [41]:
from sklearn.metrics import pairwise_distances
# from sklearn.metrics.pairwise import cosine_similarity

glove_50d_data = np.array(list(glove_50d.values()))
glove_50d_labels = np.array(list(glove_50d.keys()))

for k in pca_features.keys():
    # print the words in the cluster
    # print("Words in cluster", k)
    # for i, embedding in enumerate(data[k]):
    #     print(glove_50d_inv[tuple(embedding)])
    #     # if i == 10:
    #     #     break
    # print()

    pca_1 = pca_features[k][0]
    pca_2 = pca_features[k][1]

    print(pca_1)
    print(pca_2)
    # get the 10 closest words to pca_1
    distances = pairwise_distances([pca_1], glove_50d_data, metric='cosine')
    closest_words = np.argsort(distances)[0][:10]
    print("Closest words to pca_1 for class", k)
    # for idx in closest_words:
        # print(glove_50d_labels[idx])

    print()
    # # get the 10 closest words to pca_2
    # distances = pairwise_distances([pca_2], glove_50d_data, metric='cosine')
    # closest_words = np.argsort(distances)[0][:10]
    # print("Closest words to pca_2 for class", k)
    # for idx in closest_words:
    #     print(glove_50d_labels[idx])

    if k == 10:
        break


[-0.02953886 -0.00848377 -0.02439493 -0.00630113 -0.01534669 -0.04089515
 -0.01711577 -0.01246905 -0.00639424 -0.0273684  -0.04046675 -0.02226137
 -0.01621027 -0.02480159 -0.02327697 -0.02690564 -0.01489918 -0.00853733
 -0.02282477 -0.01998447 -0.04017761 -0.01646475 -0.03039579 -0.03413315
  0.007509    0.14311852 -0.01951193 -0.02494017 -0.03640844 -0.03199814
  0.9722324  -0.02631912 -0.03727179  0.02295949 -0.02417167 -0.03086121
 -0.01846827 -0.02979488 -0.04091777 -0.0194977  -0.01772214 -0.0036165
 -0.03483149  0.0287429  -0.032849   -0.01529212  0.02385296 -0.06654574
 -0.02217093  0.00916007]
[-0.06676633  0.31961215 -0.05209441  0.11068875 -0.00713956 -0.08031568
  0.32821947 -0.06933309 -0.08746207  0.14689994 -0.11503708 -0.08799721
  0.24961434  0.06178551  0.00474404 -0.09257188 -0.05988754 -0.06029548
  0.29592294 -0.01542421 -0.07327091 -0.08163735 -0.06692673 -0.08630092
 -0.0223344   0.39817232 -0.09250803 -0.03987113  0.10821833 -0.11644278
 -0.09443071 -0.08359293 -