In [1]:
import sys
import os

# Get the absolute path of the parent directory.
parent_dir = os.path.abspath(os.path.join(os.path.dirname("__file__"), ".."))

# Add the parent directory to the system path to be able to import modules from 'lib.'
sys.path.append(parent_dir)

In [2]:
import datasets

from IPython.display import HTML, Markdown as md
import itertools

from lib.memory import DSDM
from lib.utils import cleanup, configs, inference, learning, preprocess, utils 

import math
import matplotlib
import matplotlib.pyplot as plt
import numpy
import numpy as np
import random

import pandas as pd
import pathlib

import torch
import torchhd as thd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F 

from tqdm import tqdm
# Type checking
import typing

[nltk_data] Downloading package punkt to
[nltk_data]     /nfs/home/dfichiu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /nfs/home/dfichiu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
# Load Wikipedia dataset.
# TODO: Split between server and local.
#wiki_dataset = datasets.load_dataset("wikipedia", "20220301.en")['train']
wiki_dataset = datasets.load_dataset(
    "wikipedia",
    "20220301.en",
    cache_dir="/nfs/data/projects/daniela")['train']

Found cached dataset wikipedia (/nfs/data/projects/daniela/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
# Set device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set seed.
utils.fix_seed(41)

Using seed: 41

In [5]:
# Set DSDM hyperparameters.
address_size = 1000
ema_time_period = 5000
learning_rate_update = 0.5

temperature = 0.05

normalize = False

chunk_sizes = [5]

prune_mode = "fixed-size"
max_size_address_space = 4000

In [6]:
cleanup = cleanup.Cleanup(address_size)

In [7]:
# Initialize memory.
memory = DSDM.DSDM(
    address_size=address_size,
    ema_time_period=ema_time_period,
    learning_rate_update=learning_rate_update,
    temperature=temperature,
    normalize=normalize,
    prune_mode=prune_mode,
    max_size_address_space=max_size_address_space
) 

In [8]:
# Construct train set (texts) and inference set (sentences; in and out of train set text).
train_size = 1250
test_size = 10

# Text indeces.
train_idx = np.random.randint(0, len(wiki_dataset), size=train_size)

# Caclulate chosen text statistics.
# TODO

# Text indeces from which we extract sentences.
intest_idx = np.random.choice(train_idx, test_size)
outtest_idx = np.random.choice(np.setdiff1d(np.arange(len(wiki_dataset)), train_idx), test_size)

In [9]:
inference_sentences_in = []
inference_sentences_out = []

for idx_in, idx_out in zip(intest_idx, outtest_idx):
    # Get sentences.
    sentences_in = utils.preprocess.split_text_into_sentences(wiki_dataset[int(idx_in)]['text'])
    sentences_out = utils.preprocess.split_text_into_sentences(wiki_dataset[int(idx_out)]['text'])
    
    # Get sentence index.
    sentence_idx_in = int(
        np.random.randint(
            0,
            len(sentences_in),
            size=1
        )
    )
    sentence_idx_out = int(
        np.random.randint(
            0,
            len(sentences_out),
            size=1
        )
    )

    # Append sentence to list.
    inference_sentences_in.append(sentences_in[sentence_idx_in])
    inference_sentences_out.append(sentences_out[sentence_idx_out])

In [10]:
# Training
for i in tqdm(train_idx):
    text = wiki_dataset[int(i)]['text']
    
    # Preprocess data. 
    sentences_tokens = preprocess.preprocess_text(text)
    
    for sentence_tokens in sentences_tokens:
        # Generate atomic HVs for unknown tokens.
        learning.generate_atomic_HVs_from_tokens_and_add_them_to_cleanup(
            memory.address_size,
            cleanup,
            sentence_tokens
        )
        
        # Learning: Construct the chunks of each sentence and save them to memory.
        learning.generate_chunk_representations_and_save_them_to_memory(
            memory.address_size,
            cleanup,
            memory,
            sentence_tokens,
            chunk_sizes=chunk_sizes
        )

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [22:39<00:00,  1.09s/it]


In [11]:
# inference_sentences_in = ['Dagored', 'is an Italian', 'record labels', 'based in Firenze', 'formed', 'in 1998.'] 250, 0.05 temperature
# 'record labels' also caught by transformer attention.

In [12]:
# def score_partition(input_partition, output_partition):
#     # Note: What if a sentence contains the same word multiple times? This is why using 'set' is  bad idea!
#     set_query = set(preprocess.remove_stopwords(tokens)[0]) 
#     set_content = inference.get_most_similar_HVs(sentence_sims_df, delta_threshold=0.1)

#     set_input = set(input_partition)
#     set_output = set(output_partition)
    
#     score = len(set_input.intersection(set_output)) / len(set_input)

#     return score




# def divide_and_conquer(token_partitions: typing.List[typing.List[str]]):
#     retrieve_mode = "pooling"
    
#     for tp in token_partitions:
#         retrieved_content = inference.infer(
#             memory.address_size,
#             cleanup,
#             memory,
#             [tp],
#             retrieve_mode=retrieve_mode
#         )
#         output_tokens = inference.get_most_similar_HVs(
#             inference.get_similarities_to_atomic_set(
#                 retrieved_contents[0],
#                 cleanup,
#             ),
#             delta_threshold=0.1
#         )
#         score = score_partition(tp, output_tokens)
    

    

#     display(score)
#     if score == 1:
#         return tokens
#     else:
#         return max(score, divide_and_conquer())
    

In [13]:
# divide_and_conquer("Record labels from all over the world.")

In [14]:
inference_sentences_in = ["Dagored is an Italian record label based in Firenze, formed in 1998."]

In [15]:
retrieve_mode = "top_k"

# Get table with token similarities for each "out-of-train" sentence.
retrieved_contents = inference.infer(
    memory.address_size,
    cleanup,
    memory,
    inference_sentences_in,
    retrieve_mode=retrieve_mode,
    k=3, #TODO: What if index is out of range?
)

if retrieve_mode == "top_k":
    sims_df = pd.DataFrame(columns=['sentence', 'token', 'similarity']) 
    
    for s, addresses in zip(inference_sentences_in, retrieved_contents):
        display(s)
        for a in addresses:
            address_sims_df = inference.get_similarities_to_atomic_set(
                a, cleanup)
            display(address_sims_df)
elif retrieve_mode == "pooling":  
    sims_df = pd.DataFrame(columns=['sentence', 'token', 'similarity']) 
      
    for s, c in zip(inference_sentences_in, retrieved_contents):
        sentence_sims_df = inference.get_similarities_to_atomic_set(
            c, cleanup)
        sentence_sims_df['sentence'] = [s] * len(sentence_sims_df)
        sims_df = pd.concat([sims_df, sentence_sims_df])

    sims_df = sims_df.sort_values(['sentence', 'similarity'], ascending=False) \
                     .set_index(['sentence', 'token'])
    
    display(sims_df)
else:  # unrecognized
    pass

'Dagored is an Italian record label based in Firenze, formed in 1998.'

Unnamed: 0,token,similarity
0,italian,0.674341
1,psychiatrists,0.433697
2,psychotherapists,0.406644
3,family,0.393137
4,therapists,0.233589
5,canada,0.149025
6,counterfeit,0.143865
7,selector,0.139089
8,catherines,0.119969
9,coeducational,0.119558


Unnamed: 0,token,similarity
0,quebec,0.751293
1,record,0.508076
2,labels,0.30071
3,royal,0.271477
4,sixpence,0.121176
5,contraband,0.11625
6,dantes,0.114912
7,sul,0.114253
8,mount,0.112469
9,aflame,0.111373


Unnamed: 0,token,similarity
0,sculptors,0.616233
1,italian,0.513351
2,male,0.371664
3,people,0.305736
4,siena,0.211192
5,potency,0.136996
6,felt,0.126381
7,pentavarit,0.12098
8,ellebækken,0.1206
9,converted,0.119918


In [16]:
addresses = np.random.randint(0, len(memory.addresses), size=30)

for address in addresses:
    display(md(f"### Address {address}"))
    address_sims_df = inference.get_similarities_to_atomic_set(
            memory.addresses[address],
            cleanup,
    )
    display(address_sims_df)

### Address 2470

Unnamed: 0,token,similarity
0,faculty,0.695822
1,students,0.434468
2,studying,0.421347
3,political,0.318213
4,economics,0.185277
5,syllable,0.138559
6,simulated,0.123408
7,provinciale,0.121061
8,spicules,0.119595
9,hobbies,0.114862


### Address 2795

Unnamed: 0,token,similarity
0,nominee,0.508709
1,spirit,0.475365
2,independent,0.460352
3,award,0.41881
4,best,0.351459
5,screenplay,0.204918
6,faulkner,0.135284
7,lyrics,0.129944
8,alys,0.129615
9,1884–1961,0.128215


### Address 685

Unnamed: 0,token,similarity
0,child,0.678872
1,psychiatrists,0.487436
2,psychologists,0.399858
3,anglophone,0.347805
4,quebec,0.234401
5,spalletti,0.127006
6,bulging,0.126429
7,hospitaller,0.126358
8,biodiversity,0.124061
9,demilitarized,0.121547


### Address 3346

Unnamed: 0,token,similarity
0,fiction,0.5498
1,fantasy,0.526933
2,science,0.489247
3,magazine,0.389649
4,may–june,0.192439
5,scalar,0.129456
6,syncope,0.119824
7,viggen,0.118093
8,kamacite,0.116904
9,slacker,0.113402


### Address 3831

Unnamed: 0,token,similarity
0,hungarian,0.680034
1,political,0.440582
2,writers,0.376967
3,stock,0.300838
4,traders,0.237553
5,javier,0.127994
6,philanthropists,0.125888
7,mayan,0.125511
8,mtdna,0.125465
9,go,0.121912


### Address 3270

Unnamed: 0,token,similarity
0,1,0.586696
1,2005,0.513745
2,apr,0.387139
3,2008–15,0.365662
4,may,0.325904
5,jul,0.219358
6,polymath,0.141628
7,upmynster,0.124347
8,framing,0.121313
9,wipe,0.11569


### Address 3234

Unnamed: 0,token,similarity
0,film,0.487621
1,male,0.472559
2,composers,0.446338
3,score,0.4402
4,1973,0.244208
5,british,0.238656
6,poccioni,0.127896
7,malabar,0.124445
8,tactically,0.124346
9,torrential,0.118219


### Address 0

Unnamed: 0,token,similarity
0,species,0.466706
1,marine,0.46018
2,snail,0.452732
3,sea,0.438723
4,gastropod,0.393568
5,primal,0.132395
6,boating,0.128613
7,perdido,0.118755
8,tronada,0.118633
9,citizenship,0.117283


### Address 1998

Unnamed: 0,token,similarity
0,living,0.571209
1,married,0.491132
2,18,0.485189
3,670,0.271035
4,age,0.258695
5,couples,0.244922
6,427,0.13375
7,killmeckesvillecalvin,0.130664
8,manors,0.126362
9,sandburg,0.11895


### Address 2369

Unnamed: 0,token,similarity
0,hozier,0.54007
1,written,0.504107
2,musician,0.470016
3,island,0.3267
4,records,0.277042
5,songs,0.2297
6,eps,0.141069
7,italics,0.135594
8,sevens,0.129362
9,shweder,0.120917


### Address 3325

Unnamed: 0,token,similarity
0,male,0.589315
1,spanish,0.435625
2,essayists,0.423208
3,21stcentury,0.392693
4,poets,0.373344
5,crimes,0.13571
6,recognizes,0.129155
7,cease,0.121686
8,arithmetic,0.120719
9,common,0.118553


### Address 786

Unnamed: 0,token,similarity
0,mechanism,0.705631
1,nucleophilic,0.451603
2,begins,0.388114
3,epoxidation,0.336512
4,conjugate,0.197474
5,prevailing,0.14335
6,baden,0.1259
7,backing,0.125373
8,pomace,0.122918
9,fluctuate,0.115173


### Address 2713

Unnamed: 0,token,similarity
0,players,0.508426
1,york,0.47506
2,giants,0.448794
3,people,0.417408
4,new,0.292753
5,monicelli,0.14324
6,branching,0.131454
7,barbour,0.131371
8,reptiloid,0.123966
9,squad,0.111476


### Address 3266

Unnamed: 0,token,similarity
0,host,0.59301
1,television,0.435613
2,syndicated,0.428382
3,shortlived,0.422168
4,show,0.229103
5,named,0.222923
6,choose,0.117794
7,27445,0.113538
8,toya,0.112873
9,middleschool,0.111776


### Address 675

Unnamed: 0,token,similarity
0,states,0.703223
1,exception,0.471688
2,dissociation,0.366786
3,cyranoids,0.334663
4,zombies,0.249374
5,karens,0.123948
6,1851,0.120032
7,sammi,0.11917
8,anto,0.115053
9,deliberately,0.114292


### Address 3267

Unnamed: 0,token,similarity
0,bailey,0.503827
1,maine,0.49209
2,yarmouth,0.457993
3,partner,0.425687
4,moved,0.279763
5,elliott,0.242496
6,lune,0.142661
7,giambrone,0.135325
8,alphabet,0.134668
9,shukrani,0.119383


### Address 2830

Unnamed: 0,token,similarity
0,slovenia,0.824636
1,opera,0.390824
2,theatres,0.296809
3,ballet,0.207645
4,ljubljana,0.20111
5,branched,0.121958
6,harmless,0.119285
7,bridgets,0.119073
8,galloped,0.117147
9,10101999,0.115557


### Address 1454

Unnamed: 0,token,similarity
0,mps,0.58457
1,english,0.537195
2,1698–1700,0.415458
3,fellows,0.352062
4,royal,0.239288
5,1695–1698,0.235166
6,vibhushan,0.128531
7,invasions,0.120209
8,lysvet,0.119838
9,mochizuki,0.117342


### Address 3760

Unnamed: 0,token,similarity
0,orthogonal,0.511305
1,orbital,0.495366
2,plane,0.454587
3,must,0.408824
4,burns,0.283282
5,executed,0.200976
6,thruster,0.158623
7,moore,0.127508
8,4655,0.124494
9,prospected,0.122046


### Address 2768

Unnamed: 0,token,similarity
0,united,0.570095
1,states,0.519107
2,also,0.465178
3,see,0.312627
4,1911,0.196849
5,external,0.153279
6,list,0.142257
7,matteo,0.126998
8,doorman,0.126832
9,haigler,0.122506


### Address 2304

Unnamed: 0,token,similarity
0,life,0.480402
1,gudina,0.475206
2,tumsa,0.446538
3,foundation,0.442565
4,tumsas,0.255838
5,founded,0.228591
6,inspired,0.138882
7,synopses,0.128208
8,olavi,0.12477
9,boatbuilders,0.119905


### Address 301

Unnamed: 0,token,similarity
0,also,0.508023
1,seen,0.506958
2,frequently,0.431886
3,part,0.386819
4,could,0.307394
5,dutch,0.279236
6,tasks,0.155105
7,rajasthan,0.133391
8,sour,0.120642
9,brotherhoods,0.119186


### Address 3818

Unnamed: 0,token,similarity
0,open,0.533057
1,society,0.524926
2,institute,0.462837
3,foundations,0.356927
4,osf,0.179276
5,osi,0.165107
6,seldes,0.135247
7,duckling,0.12631
8,leakage,0.120365
9,space,0.118448


### Address 2354

Unnamed: 0,token,similarity
0,state,0.516879
1,football,0.473139
2,nc,0.446949
3,wolfpack,0.405274
4,seasons,0.383905
5,orchard,0.133194
6,nonleague,0.132386
7,collura,0.129934
8,2439m,0.122878
9,bou,0.12028


### Address 2384

Unnamed: 0,token,similarity
0,kancabchén,0.65096
1,hacienda,0.549699
2,ucí,0.43431
3,de,0.31549
4,valencia,0.281239
5,corbetts,0.137645
6,compact,0.13677
7,télégraphiques,0.135048
8,1735,0.128999
9,saturniinae,0.120677


### Address 1657

Unnamed: 0,token,similarity
0,football,0.535237
1,managers,0.531344
2,bulgaria,0.418016
3,fc,0.368267
4,ska,0.275398
5,expatriate,0.248053
6,ibdu,0.131326
7,praetorship,0.124149
8,inability,0.121465
9,extracted,0.119774


### Address 677

Unnamed: 0,token,similarity
0,thought,0.561017
1,see,0.526374
2,philosophy,0.432041
3,section,0.36522
4,slow,0.268425
5,manifesto,0.258478
6,lifted,0.143235
7,barlows,0.130673
8,cymru,0.126098
9,ned,0.124527


### Address 683

Unnamed: 0,token,similarity
0,canadian,0.632689
1,descent,0.507022
2,people,0.486878
3,croatian,0.331495
4,players,0.200638
5,jesse,0.146846
6,198687,0.139483
7,1902–1986,0.120963
8,624626,0.11628
9,mixing,0.115327


### Address 1139

Unnamed: 0,token,similarity
0,caracas,0.68683
1,la,0.44721
2,castellana,0.406218
3,references,0.290513
4,external,0.253783
5,raull,0.12824
6,360°,0.123763
7,yalta,0.123537
8,racing,0.123232
9,wilmington,0.122914


### Address 1857

Unnamed: 0,token,similarity
0,cyclists,0.780678
1,france,0.429121
2,1968,0.339541
3,summer,0.302898
4,olympics,0.20178
5,olympic,0.135561
6,instead,0.135305
7,sopranos,0.13192
8,idaho,0.124753
9,overweight,0.12124


In [17]:
memory.n_updates / (memory.n_updates + memory.n_expansions)

0.5868752165286133

In [18]:
memory.n_updates

149071

In [19]:
memory.n_expansions

104937

In [20]:
len(memory.addresses)

4000

In [23]:
memory.n_deletions

1

In [22]:
inference_sentences_in

['Dagored is an Italian record label based in Firenze, formed in 1998.']