In [1]:
import pandas as pd
import numpy as np
import pickle

In [2]:
%%capture
!pip install transformers
!pip install datasets
!pip install torch

In [3]:
from transformers import BertTokenizer, BertModel
import torch

In [4]:
model_name = "bert-base-cased"
bert_model = BertModel.from_pretrained(model_name)
bert_tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False)
e = bert_model.eval()
z = bert_model.zero_grad()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
file_name = "preprocess_docs.xlsx"
df = pd.read_excel(file_name, index_col=0)

In [52]:
%%time
token_embd = {}
for ix in df.index:
  preprocess = eval(df.loc[ix]["preprocess"])
  doc_clean = preprocess["doc_clean"]
  matching = preprocess["matching"]
  uni_stop = max([i for i, v in enumerate(matching) if type(v[0]) == int])+1
  doc = " ".join(doc_clean[:uni_stop])
  label = df.loc[ix]["label"]

  tokens = bert_tokenizer.tokenize(doc)
  if len(tokens) > 512:
    tokens1 = tokens[:512]
    tokens2 = tokens[512:512*2]

    tokens_ids1 = bert_tokenizer.convert_tokens_to_ids(tokens1)
    tokens_ids1_tensor = torch.tensor(tokens_ids1)
    attn_mask1 = (tokens_ids1_tensor != 1).long() # [PAD] => 1

    print(ix, len(tokens_ids1))

    cont1 = bert_model(tokens_ids1_tensor.unsqueeze(0), attention_mask=attn_mask1.unsqueeze(0))

    token_embd_per_doc = []
    for i, token in enumerate(tokens1):
      embd = cont1.last_hidden_state[0][i].detach().numpy()
      token_embd_per_doc.append(embd)

    tokens_ids2 = bert_tokenizer.convert_tokens_to_ids(tokens2)
    tokens_ids2_tensor = torch.tensor(tokens_ids2)
    attn_mask2 = (tokens_ids2_tensor != 1).long() # [PAD] => 1

    print(ix, len(tokens_ids2))

    cont2 = bert_model(tokens_ids2_tensor.unsqueeze(0), attention_mask=attn_mask2.unsqueeze(0))

    for i, token in enumerate(tokens2):
      embd = cont2.last_hidden_state[0][i].detach().numpy()
      token_embd_per_doc.append(embd)

    token_embd[label] = (tokens, token_embd_per_doc)
    
  else:
    tokens_ids = bert_tokenizer.convert_tokens_to_ids(tokens)
    tokens_ids_tensor = torch.tensor(tokens_ids)
    attn_mask = (tokens_ids_tensor != 1).long() # [PAD] => 1

    print(ix, len(tokens_ids))

    cont = bert_model(tokens_ids_tensor.unsqueeze(0), attention_mask=attn_mask.unsqueeze(0))


    token_embd_per_doc = []
    for i, token in enumerate(tokens):
      embd = cont.last_hidden_state[0][i].detach().numpy()
      token_embd_per_doc.append(embd)

    token_embd[label] = (tokens, token_embd_per_doc)

0 238
1 137
2 156
3 250
4 163
5 212
6 125
7 124
8 195
9 180
10 165
11 116
12 182
13 165
14 234
15 90
16 132
17 381
18 152
19 143
20 204
21 237
22 184
23 160
24 202
25 216
26 207
27 140
28 126
29 328
30 161
31 200
32 148
33 196
34 162
35 90
36 293
37 235
38 135
39 150
40 211
41 140
42 197
43 114
44 161
45 173
46 172
47 166
48 209
49 143
50 175
51 160
52 128
53 160
54 185
55 123
56 193
57 222
58 230
59 175
60 188
61 166
62 201
63 209
64 170
65 176
66 198
67 244
68 178
69 181
70 139
71 170
72 180
73 130
74 194
75 202
76 130
77 189
78 160
79 172
80 206
81 112
82 248
83 310
84 120
85 141
86 205
87 166
88 201
89 129
90 265
91 116
92 113
93 194
94 161
95 140
96 230
97 202
98 109
99 136
100 223
101 116
102 150
103 195
104 125
105 102
106 154
107 321
108 101
109 225
110 111
111 146
112 208
113 79
114 82
115 149
116 181
117 145
118 182
119 140
120 111
121 128
122 111
123 159
124 127
125 156
126 165
127 155
128 204
129 200
130 136
131 145
132 132
133 126
134 33
135 236
136 140
137 117
138 106
139

In [53]:
pickle.dump(token_embd, open("token_embd.pickle", "wb"))

In [54]:
cluster_graph = {df.loc[ix]["label"]: df.loc[ix]["cluster"] for ix in df.index}

In [64]:
%%time
new_token_embd = {}
for label in token_embd.keys():
    tokens, token_embd_per_doc = token_embd[label]

    new_tokens = []
    new_token_embd_per_doc = []

    token_embd_i = token_embd_per_doc[0]
    token = tokens[0]
    lenght = 1
    for i in range(1, len(tokens)):
        token_i = tokens[i]
        if token_i[:2] == "##":
            lenght += 1
            token_embd_i += token_embd_per_doc[i]
            token += token_i[2:]
        else:
            token_embd_i = token_embd_i / lenght
            new_token_embd_per_doc.append(token_embd_i)
            new_tokens.append(token)

            lenght = 1
            token_embd_i = token_embd_per_doc[i]
            token = token_i
    new_token_embd[label] = (new_tokens, new_token_embd_per_doc)

CPU times: total: 78.1 ms
Wall time: 87.1 ms


In [65]:
term_matrix = {}
count_term = {}
cluster_id = 1
for label, (tokens, token_embd_per_doc) in new_token_embd.items():
    if cluster_graph[label] == cluster_id:
        for i, token in enumerate(tokens):
            if token not in term_matrix.keys():
                term_matrix[token] = np.zeros(768)
                count_term[token] = 0
            term_matrix[token] += token_embd_per_doc[i]
            count_term[token] += 1
term_matrix = {k: v / count_term[k] for k, v in term_matrix.items()}

In [68]:
terms = pd.DataFrame(term_matrix).T

In [33]:
from sklearn.cluster import KMeans

In [69]:
kmeans =  KMeans(n_clusters=6, random_state=2022)

In [70]:
kmeans.fit(terms)

KMeans(n_clusters=6, random_state=2022)

In [71]:
terms

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
student,0.167816,-0.520817,0.692566,-0.205275,0.369287,0.235023,-0.521668,0.535237,0.769653,0.163537,...,-0.207278,0.346287,-0.331832,0.261883,-0.232969,-0.492360,-0.213416,-1.025132,0.204038,0.253584
learn,0.385292,-0.763111,0.686116,-0.286982,0.180673,0.072173,-0.328540,0.614519,0.800157,0.168622,...,-0.170324,0.460309,-0.363158,0.397351,0.037847,-0.599525,-0.288070,-0.657134,0.269817,0.321025
implicit,0.197968,-0.484868,0.369245,-0.003862,0.354030,-0.085893,-0.587141,0.796825,1.197275,0.042180,...,-0.069545,0.540698,-0.429407,-0.211454,0.153676,-0.408632,0.034845,-0.626570,0.292982,0.455436
explicit,0.197058,-0.593117,0.746997,0.024799,0.151934,0.039749,-0.458751,0.434262,1.052656,-0.032281,...,-0.555597,0.553590,-0.638588,0.104476,-0.023907,-0.385238,0.104398,-0.845130,0.315902,0.062396
non,0.327843,-0.650699,0.397550,0.037959,0.082697,-0.187926,-0.602387,0.733930,0.941444,0.092892,...,-0.198557,0.267545,-0.785607,0.144077,-0.299918,-0.359611,0.088504,-0.612101,0.130824,0.399633
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
capture,0.392429,-0.548473,0.499759,-0.135191,-0.199769,-0.343124,-0.297043,0.192482,0.584212,-0.064458,...,0.046630,0.106532,-0.376058,-0.199232,-0.358138,-0.607307,-0.661267,-0.566535,0.047974,0.313740
inappropriateness,0.271552,-0.096745,0.069562,0.012946,-0.161574,-0.614043,-0.125938,0.510729,0.364671,0.064200,...,-0.211386,0.241241,-0.543218,-0.157115,-0.411815,-0.646144,-0.094903,-0.640485,-0.008809,0.149886
seriousness,0.616778,-0.003403,0.231534,0.238645,-0.123085,-0.988227,-0.315537,0.838340,0.607151,-0.266208,...,-0.309763,-0.163556,-0.284659,0.055678,-0.722849,-0.618821,-0.155228,-0.596301,0.440540,-0.062036
genuine,0.469318,-0.205466,0.460433,-0.260546,-0.135559,-0.748442,-0.265260,0.612635,0.489478,-0.348593,...,0.081604,0.364843,-0.322390,0.573405,-0.560411,-0.559085,-0.029280,-0.407410,0.661643,-0.049528


In [77]:
labels = kmeans.predict(terms)
centroids = kmeans.cluster_centers_

array([[ 0.22558941, -0.56518902,  0.47662052, ..., -0.78617984,
         0.12220512,  0.327748  ],
       [ 0.20553102, -0.45288328,  0.41846354, ..., -0.61437471,
         0.29275487,  0.38887632],
       [ 0.27462403, -0.55606979,  0.19939457, ..., -0.66931545,
        -0.00711895,  0.18437753],
       [ 0.22037362, -0.55001439,  0.45040705, ..., -0.84628363,
         0.23261873,  0.42586109],
       [ 0.25656757, -0.53324803,  0.37242352, ..., -0.7840618 ,
         0.285144  ,  0.27965373],
       [ 0.21208382, -0.52290966,  0.30927924, ..., -0.38559376,
        -0.01507883,  0.35281948]])

In [86]:
centroid = centroids[0]
dist = np.linalg.norm(centroid - embd)
dist

9.532549979170748

In [87]:
%%time
dists = {}
for term in terms.index:
    dists[term] = []
    embd = terms.loc[term].values
    for i, centroid in enumerate(centroids):
        dist = np.linalg.norm(centroid - embd)
        dists[term].append(dist)

CPU times: total: 219 ms
Wall time: 199 ms


In [89]:
df_dist = pd.DataFrame(dists).T
df_dist

Unnamed: 0,0,1,2,3,4,5
student,4.461106,5.245509,6.531579,3.803008,4.376983,6.363828
learn,4.556132,4.809060,6.977041,3.949124,4.278583,6.672489
implicit,6.331229,7.343241,7.000492,6.470834,6.698545,7.979024
explicit,5.783002,6.684466,8.166710,6.198750,6.228590,8.303774
non,4.376479,5.877950,7.075219,5.137199,5.576512,6.939883
...,...,...,...,...,...,...
capture,7.165847,6.914288,9.295683,7.328633,7.827425,8.653174
inappropriateness,7.548584,7.927481,8.994511,8.002313,8.295702,9.165384
seriousness,7.954154,8.521012,9.016651,8.596704,8.011523,9.628037
genuine,8.370756,7.436642,10.202690,8.181618,8.681155,9.368878


In [93]:
topics = {}
for i, l in enumerate(labels):
    term = terms.index[i]
    if l not in topics.keys():
        topics[l] = []
    topics[l].append(term)

In [94]:
{k: len(v) for k, v in topics.items()}

{3: 606, 0: 337, 4: 369, 2: 204, 1: 387, 5: 211}

In [103]:
for topic_id in topics.keys():
    topic_terms = topics[topic_id]
    print(topic_id)
    print(df_dist.loc[topic_terms][topic_id].sort_values()[:10].index)

3
Index(['provide', 'strategy', 'identify', 'develop', 'practice', 'promote',
       'guide', 'aspect', 'student', 'consider'],
      dtype='object')
0
Index(['enable', 'obtain', 'reflect', 'non', 'select', 'acquire', 'meaningful',
       'test', 'employ', 'examination'],
      dtype='object')
4
Index(['society', 'critical', 'knowledge', 'skill', 'implementation', 'change',
       'challenge', 'need', 'system', 'individual'],
      dtype='object')
2
Index(['normative', 'competency', 'implication', 'methodological', 'analyse',
       'qualitative', 'motivate', 'align', 'transdisciplinary', 'curricular'],
      dtype='object')
1
Index(['argue', 'include', 'examine', 'involve', 'propose', 'carry', 'suggest',
       'incorporate', 'conclude', 'build'],
      dtype='object')
5
Index(['germany', 'springer', 'academia', 'poland', 'asu', 'switzerland',
       'ict', 'spain', 'taylor', 'central'],
      dtype='object')
