In [1]:
import sys
import os

# Get the absolute path of the parent directory.
parent_dir = os.path.abspath(os.path.join(os.path.dirname("__file__"), ".."))

# Add the parent directory to the system path to be able to import modules from 'lib.'
sys.path.append(parent_dir)

In [2]:
C

from IPython.display import HTML, Markdown as md
import itertools

from lib.memory import DSDM
from lib.utils import cleanup, configs, inference, learning, preprocess, utils 

import math
import matplotlib
import matplotlib.pyplot as plt
import numpy
import numpy as np
import random

import pandas as pd
import pathlib

import torch
import torchhd as thd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F 

from tqdm import tqdm
# Type checking
import typing

[nltk_data] Downloading package punkt to
[nltk_data]     /nfs/home/dfichiu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /nfs/home/dfichiu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
# Load Wikipedia dataset.
# TODO: Split between server and local.
#wiki_dataset = datasets.load_dataset("wikipedia", "20220301.en")['train']
wiki_dataset = datasets.load_dataset(
    "wikipedia",
    "20220301.en",
    cache_dir="/nfs/data/projects/daniela")['train']

Found cached dataset wikipedia (/nfs/data/projects/daniela/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
# Set device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set seed.
utils.fix_seed(41)

Using seed: 41

In [5]:
# Set DSDM hyperparameters.
address_size = 1000
ema_time_period = 5000
learning_rate_update = 0.5

temperature = 0.05

normalize = False

chunk_sizes = [5]

prune_mode = "fixed-size"
max_size_address_space = 4000

In [6]:
cleanup = cleanup.Cleanup(address_size)

In [7]:
# Initialize memory.
memory = DSDM.DSDM(
    address_size=address_size,
    ema_time_period=ema_time_period,
    learning_rate_update=learning_rate_update,
    temperature=temperature,
    normalize=normalize,
    prune_mode=prune_mode,
    max_size_address_space=max_size_address_space
) 

In [8]:
# Construct train set (texts) and inference set (sentences; in and out of train set text).
train_size = 250
test_size = 10

# Text indeces.
train_idx = np.random.randint(0, len(wiki_dataset), size=train_size)

# Caclulate chosen text statistics.
# TODO

# Text indeces from which we extract sentences.
intest_idx = np.random.choice(train_idx, test_size)
outtest_idx = np.random.choice(np.setdiff1d(np.arange(len(wiki_dataset)), train_idx), test_size)

In [9]:
inference_sentences_in = []
inference_sentences_out = []

for idx_in, idx_out in zip(intest_idx, outtest_idx):
    # Get sentences.
    sentences_in = utils.preprocess.split_text_into_sentences(wiki_dataset[int(idx_in)]['text'])
    sentences_out = utils.preprocess.split_text_into_sentences(wiki_dataset[int(idx_out)]['text'])
    
    # Get sentence index.
    sentence_idx_in = int(
        np.random.randint(
            0,
            len(sentences_in),
            size=1
        )
    )
    sentence_idx_out = int(
        np.random.randint(
            0,
            len(sentences_out),
            size=1
        )
    )

    # Append sentence to list.
    inference_sentences_in.append(sentences_in[sentence_idx_in])
    inference_sentences_out.append(sentences_out[sentence_idx_out])

In [10]:
# Training
for i in tqdm(train_idx):
    text = wiki_dataset[int(i)]['text']
    
    # Preprocess data. 
    sentences_tokens = preprocess.preprocess_text(text)
    
    for sentence_tokens in sentences_tokens:
        # Generate atomic HVs for unknown tokens.
        learning.generate_atomic_HVs_from_tokens_and_add_them_to_cleanup(
            memory.address_size,
            cleanup,
            sentence_tokens
        )
        
        # Learning: Construct the chunks of each sentence and save them to memory.
        learning.generate_chunk_representations_and_save_them_to_memory(
            memory.address_size,
            cleanup,
            memory,
            sentence_tokens,
            chunk_sizes=chunk_sizes
        )

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [01:32<00:00,  2.69it/s]


In [11]:
# inference_sentences_in = ['Dagored', 'is an Italian', 'record labels', 'based in Firenze', 'formed', 'in 1998.'] 250, 0.05 temperature
# 'record labels' also caught by transformer attention.

In [12]:
# def score_partition(input_partition, output_partition):
#     # Note: What if a sentence contains the same word multiple times? This is why using 'set' is  bad idea!
#     set_query = set(preprocess.remove_stopwords(tokens)[0]) 
#     set_content = inference.get_most_similar_HVs(sentence_sims_df, delta_threshold=0.1)

#     set_input = set(input_partition)
#     set_output = set(output_partition)
    
#     score = len(set_input.intersection(set_output)) / len(set_input)

#     return score




# def divide_and_conquer(token_partitions: typing.List[typing.List[str]]):
#     retrieve_mode = "pooling"
    
#     for tp in token_partitions:
#         retrieved_content = inference.infer(
#             memory.address_size,
#             cleanup,
#             memory,
#             [tp],
#             retrieve_mode=retrieve_mode
#         )
#         output_tokens = inference.get_most_similar_HVs(
#             inference.get_similarities_to_atomic_set(
#                 retrieved_contents[0],
#                 cleanup,
#             ),
#             delta_threshold=0.1
#         )
#         score = score_partition(tp, output_tokens)
    

    

#     display(score)
#     if score == 1:
#         return tokens
#     else:
#         return max(score, divide_and_conquer())
    

In [13]:
# divide_and_conquer("Record labels from all over the world.")

In [14]:
inference_sentences_in = ["Dagored is an Italian record label based in Firenze, formed in 1998."]

In [24]:
retrieve_mode = "top_k"

# Get table with token similarities for each "out-of-train" sentence.
retrieved_contents = inference.infer(
    memory.address_size,
    cleanup,
    memory,
    inference_sentences_in,
    retrieve_mode=retrieve_mode,
    k=3, #TODO: What if index is out of range?
)

if retrieve_mode == "top_k":
    sims_df = pd.DataFrame(columns=['sentence', 'token', 'similarity']) 
    
    for s, addresses in zip(inference_sentences_in, retrieved_contents):
        display(s)
        for a in addresses:
            address_sims_df = inference.get_similarities_to_atomic_set(
                a, cleanup)
            display(address_sims_df)
elif retrieve_mode == "pooling":  
    sims_df = pd.DataFrame(columns=['sentence', 'token', 'similarity']) 
      
    for s, c in zip(inference_sentences_in, retrieved_contents):
        sentence_sims_df = inference.get_similarities_to_atomic_set(
            c, cleanup)
        sentence_sims_df['sentence'] = [s] * len(sentence_sims_df)
        sims_df = pd.concat([sims_df, sentence_sims_df])

    sims_df = sims_df.sort_values(['sentence', 'similarity'], ascending=False) \
                     .set_index(['sentence', 'token'])
    
    display(sims_df)
else:  # unrecognized
    pass

'Dagored is an Italian record label based in Firenze, formed in 1998.'

Unnamed: 0,token,similarity
0,1998,0.635719
1,sindurer,0.499019
2,adhikar,0.423214
3,pabitra,0.347793
4,papi,0.223909
5,badhu,0.13353
6,meeting,0.120554
7,pacified,0.120221
8,mierow,0.116307
9,ridership,0.115374


Unnamed: 0,token,similarity
0,italian,0.670658
1,psychiatrists,0.435922
2,psychotherapists,0.408733
3,family,0.394052
4,therapists,0.234842
5,canada,0.149773
6,catherines,0.119863
7,deduct,0.118975
8,1613,0.112249
9,disperse,0.111707


Unnamed: 0,token,similarity
0,organizations,0.50447
1,1998,0.496752
2,based,0.474001
3,providence,0.394798
4,completed,0.267846
5,rhode,0.256236
6,buildings,0.123809
7,labriolle,0.120112
8,hemlock,0.119111
9,video,0.11553


In [16]:
addresses = np.random.randint(0, len(memory.addresses), size=30)

for address in addresses:
    display(md(f"### Address {address}"))
    address_sims_df = inference.get_similarities_to_atomic_set(
            memory.addresses[address],
            cleanup,
    )
    display(address_sims_df)

### Address 2466

Unnamed: 0,token,similarity
0,university,0.590068
1,fasa,0.529645
2,sciences,0.39771
3,medical,0.351519
4,islamic,0.271129
5,azad,0.136349
6,improved,0.128995
7,bsf,0.111575
8,moonocean,0.109594
9,higher,0.106022


### Address 255

Unnamed: 0,token,similarity
0,process,0.788678
1,casting,0.490115
2,difficult,0.282021
3,auditioning,0.226725
4,producers,0.183491
5,shizao,0.134281
6,waverley,0.126205
7,barring,0.125122
8,jacobsen,0.12473
9,akb48s,0.108952


### Address 2935

Unnamed: 0,token,similarity
0,h,0.529361
1,fame,0.51185
2,inductee,0.476655
3,allen,0.412598
4,hall,0.264647
5,jerkens,0.236938
6,hand,0.124224
7,0304396,0.121562
8,hk,0.121125
9,soninlaw,0.114231


### Address 30

Unnamed: 0,token,similarity
0,people,0.778836
1,living,0.518935
2,births,0.283155
3,styriato,0.143169
4,corpse,0.10887
5,copenhagen,0.107863
6,walker,0.106828
7,nunci,0.104995
8,tottenham,0.10383
9,楚武穆王,0.102841


### Address 949

Unnamed: 0,token,similarity
0,3,0.509248
1,march,0.497078
2,2006,0.463739
3,westminster,0.36979
4,held,0.266034
5,abbey,0.239924
6,marvel,0.115737
7,jumpstyle,0.11031
8,scheduled,0.109382
9,flooding,0.106326


### Address 1991

Unnamed: 0,token,similarity
0,housing,0.51553
1,developments,0.505885
2,received,0.455971
3,considerable,0.362386
4,scale,0.277966
5,attention,0.243432
6,larger,0.163831
7,presidential,0.110435
8,mick,0.110415
9,amitabh,0.106246


### Address 2625

Unnamed: 0,token,similarity
0,saltzmans,0.505584
1,woodfall,0.501441
2,film,0.482124
3,productions,0.374324
4,city,0.278914
5,harry,0.224766
6,shubha,0.129999
7,feel,0.125116
8,22nd,0.11883
9,proportion,0.1134


### Address 517

Unnamed: 0,token,similarity
0,glb,0.494789
1,gltf,0.492303
2,ply,0.486397
3,3mf,0.391226
4,highquality,0.259583
5,obj,0.211145
6,hear,0.133194
7,xuanzhou,0.116486
8,immunotherapy,0.115065
9,shiyi,0.106186


### Address 1737

Unnamed: 0,token,similarity
0,–,0.62758
1,1941,0.472513
2,fleet,0.372856
3,alsab,0.365468
4,1940,0.298644
5,count,0.167367
6,hufford,0.119258
7,ranking,0.113122
8,immediate,0.112081
9,motor,0.109926


### Address 1775

Unnamed: 0,token,similarity
0,groups,0.520482
1,utilized,0.49494
2,parts,0.460224
3,plant,0.418673
4,american,0.273088
5,number,0.229436
6,recognition,0.137801
7,translating,0.115623
8,saudi,0.115119
9,constituency,0.114001


### Address 3420

Unnamed: 0,token,similarity
0,systems,0.594934
1,media,0.481561
2,firerescue,0.430994
3,administration,0.320938
4,digital,0.256527
5,geographic,0.219909
6,ambulatory,0.129032
7,starts,0.112467
8,recognising,0.111133
9,lanthanidebinol,0.111009


### Address 1108

Unnamed: 0,token,similarity
0,bokn,0.576826
1,mayor,0.461129
2,municipality,0.43162
3,council,0.334123
4,served,0.298377
5,1925,0.277577
6,born,0.138724
7,carthaginian,0.135001
8,certification,0.131879
9,rome,0.117021


### Address 2148

Unnamed: 0,token,similarity
0,sister,0.545884
1,chinese,0.489906
2,group,0.481973
3,akb48,0.335481
4,former,0.23889
5,since,0.201571
6,snh48,0.117119
7,yielding,0.113668
8,min,0.112294
9,nashist,0.11161


### Address 216

Unnamed: 0,token,similarity
0,uranium,0.64955
1,exploration,0.428646
2,mining,0.407744
3,refining,0.347312
4,enrichment,0.230238
5,proceedings,0.11621
6,band—the,0.115257
7,want,0.113382
8,occurs,0.112457
9,ridge,0.111716


### Address 1634

Unnamed: 0,token,similarity
0,dutch,0.518396
1,nance,0.479148
2,coolen,0.47526
3,rapper,0.385933
4,ricardo,0.237821
5,nancy,0.233103
6,singer,0.146661
7,train,0.123981
8,drowned,0.122266
9,planar,0.121645


### Address 3669

Unnamed: 0,token,similarity
0,final,0.650317
1,allireland,0.452489
2,45th,0.433817
3,deciding,0.347494
4,match,0.178877
5,intermittent,0.123098
6,481,0.122309
7,championship,0.120996
8,criminology,0.117965
9,belle,0.112287


### Address 3738

Unnamed: 0,token,similarity
0,des,0.556985
1,fortifications,0.492129
2,du,0.43159
3,moyen,0.384815
4,âge,0.264692
5,et,0.220117
6,châteaux,0.130039
7,heywood,0.112256
8,mayer,0.106084
9,wickettaker,0.105485


### Address 853

Unnamed: 0,token,similarity
0,actors,0.526622
1,initially,0.515272
2,first,0.480382
3,considered,0.359521
4,role,0.256289
5,one,0.212021
6,barker,0.127574
7,yearbook,0.124601
8,indecent,0.116477
9,jouissance,0.116385


### Address 699

Unnamed: 0,token,similarity
0,wileycredited,0.499985
1,submitted,0.488424
2,pseudonym,0.474553
3,sketch,0.343978
4,ventriloquist,0.265351
5,material,0.260286
6,según,0.138985
7,tall,0.127038
8,bb,0.12256
9,chapman,0.12157


### Address 2460

Unnamed: 0,token,similarity
0,university,0.614173
1,shiraz,0.446631
2,sciences,0.436749
3,technology,0.320421
4,islamic,0.210505
5,medical,0.180481
6,boccato,0.112069
7,airports,0.110112
8,honoured,0.108134
9,interros,0.106424


### Address 2535

Unnamed: 0,token,similarity
0,report,0.522177
1,12,0.521694
2,herald,0.50243
3,august,0.352175
4,1880,0.254541
5,morning,0.232108
6,ricans,0.14004
7,sydney,0.121378
8,entire,0.121323
9,laino,0.119833


### Address 3816

Unnamed: 0,token,similarity
0,model,0.492682
1,fits,0.479755
2,regression,0.461284
3,data,0.402848
4,well,0.254615
5,proposed,0.232073
6,auger,0.114998
7,competing,0.112911
8,green,0.112593
9,byhippocrates,0.107066


### Address 1031

Unnamed: 0,token,similarity
0,wap1,0.531632
1,asansol,0.482486
2,wag2,0.480659
3,locosin,0.33806
4,receiving,0.282697
5,nov,0.220769
6,wilma,0.129125
7,started,0.1252
8,transition,0.117666
9,macedonians,0.113764


### Address 1923

Unnamed: 0,token,similarity
0,performance,0.490947
1,cressida,0.475813
2,biggest,0.473417
3,theatrical,0.429695
4,troilus,0.305376
5,recorded,0.217185
6,subsequently,0.116292
7,chas,0.109316
8,madeleine,0.106692
9,aggression,0.106078


### Address 2370

Unnamed: 0,token,similarity
0,formation,0.533432
1,associated,0.517414
2,large,0.477483
3,amounts,0.339458
4,highquality,0.280287
5,chunk,0.275223
6,cinéma,0.128173
7,negate,0.121006
8,doom,0.113377
9,fingerstyle,0.111915


### Address 3618

Unnamed: 0,token,similarity
0,members,0.672278
1,parliament,0.438658
2,victoria,0.431457
3,victorian,0.320891
4,legislative,0.223578
5,party,0.13069
6,shizhou,0.127247
7,keyboardist,0.123186
8,verdes,0.119969
9,villa,0.119946


### Address 3745

Unnamed: 0,token,similarity
0,beth,0.530983
1,strong,0.51794
2,darwin,0.429782
3,sis,0.358427
4,matthews,0.302715
5,margaret,0.243224
6,burns,0.145337
7,okay,0.129893
8,fencer,0.113251
9,verbal,0.108698


### Address 2389

Unnamed: 0,token,similarity
0,8,0.528323
1,slope,0.493277
2,patches,0.466685
3,15,0.329497
4,hills,0.264017
5,percent,0.229347
6,typical,0.131368
7,favorites,0.128667
8,proceeding,0.126293
9,monuments,0.125777


### Address 1576

Unnamed: 0,token,similarity
0,later,0.521764
1,purchased,0.488822
2,1994,0.476134
3,numerous,0.398644
4,collection,0.280024
5,donations,0.271636
6,tyler,0.14787
7,viscous,0.12363
8,francesc,0.114293
9,fantasy,0.112593


### Address 2637

Unnamed: 0,token,similarity
0,project,0.524455
1,elmo,0.513646
2,tio,0.471745
3,roper,0.36254
4,survey,0.273737
5,first,0.264457
6,approaches,0.136051
7,ambient,0.12928
8,4821,0.120189
9,widow,0.117719


In [17]:
memory.n_updates / (memory.n_updates + memory.n_expansions)

0.5672733084283916

In [29]:
memory.n_updates

28396

In [28]:
memory.n_expansions

21661

In [27]:
len(memory.addresses)

4000

In [26]:
memory.n_deletions

8000

In [22]:
inference_sentences_in

['Dagored is an Italian record label based in Firenze, formed in 1998.']