In [1]:
import sys
import os

# Get the absolute path of the parent directory.
parent_dir = os.path.abspath(os.path.join(os.path.dirname("__file__"), ".."))

# Add the parent directory to the system path to be able to import modules from 'lib.'
sys.path.append(parent_dir)

In [2]:
import datasets 

from IPython.display import HTML, Markdown as md
import itertools

from lib.memory import DSDM
from lib.utils import cleanup, configs, inference, learning, preprocess, utils 

import math
import matplotlib
import matplotlib.pyplot as plt
import numpy
import numpy as np
import random

import pandas as pd
import pathlib

import torch
import torchhd as thd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F 

from tqdm import tqdm
# Type checking
from typing import List 

[nltk_data] Downloading package punkt to
[nltk_data]     /nfs/home/dfichiu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /nfs/home/dfichiu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
# Load Wikipedia dataset.
# TODO: Split between server and local.
#wiki_dataset = datasets.load_dataset("wikipedia", "20220301.en")['train']
wiki_dataset = datasets.load_dataset(
    "wikipedia",
    "20220301.en",
    cache_dir="/nfs/data/projects/daniela")['train']

Found cached dataset wikipedia (/nfs/data/projects/daniela/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
# Set device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set seed.
utils.fix_seed(41)

Using seed: 41

In [5]:
# Set DSDM hyperparameters.
address_size = 1000
ema_time_period = 5000
learning_rate_update = 0.5

temperature = 0.05

normalize = False

chunk_sizes = [5]

prune_mode = "fixed-size"
max_size_address_space = 4000

In [6]:
cleanup = cleanup.Cleanup(address_size)

In [7]:
# Initialize memory.
memory = DSDM.DSDM(
    address_size=address_size,
    ema_time_period=ema_time_period,
    learning_rate_update=learning_rate_update,
    temperature=temperature,
    normalize=normalize,
    prune_mode=prune_mode,
    max_size_address_space=max_size_address_space
) 

In [8]:
# Construct train set (texts) and inference set (sentences; in and out of train set text).
train_size = 500
test_size = 10

# Text indeces.
train_idx = np.random.randint(0, len(wiki_dataset), size=train_size)

# Caclulate chosen text statistics.
# TODO

# Text indeces from which we extract sentences.
intest_idx = np.random.choice(train_idx, test_size)
outtest_idx = np.random.choice(np.setdiff1d(np.arange(len(wiki_dataset)), train_idx), test_size)

In [9]:
inference_sentences_in = []
inference_sentences_out = []

for idx_in, idx_out in zip(intest_idx, outtest_idx):
    # Get sentences.
    sentences_in = utils.preprocess.split_text_into_sentences(wiki_dataset[int(idx_in)]['text'])
    sentences_out = utils.preprocess.split_text_into_sentences(wiki_dataset[int(idx_out)]['text'])
    
    # Get sentence index.
    sentence_idx_in = int(
        np.random.randint(
            0,
            len(sentences_in),
            size=1
        )
    )
    sentence_idx_out = int(
        np.random.randint(
            0,
            len(sentences_out),
            size=1
        )
    )

    # Append sentence to list.
    inference_sentences_in.append(sentences_in[sentence_idx_in])
    inference_sentences_out.append(sentences_out[sentence_idx_out])

In [10]:
# Training
for i in tqdm(train_idx):
    text = wiki_dataset[int(i)]['text']
    
    # Preprocess data. 
    sentences_tokens = preprocess.preprocess_text(text)
    
    for sentence_tokens in sentences_tokens:
        # Generate atomic HVs for unknown tokens.
        learning.generate_atomic_HVs_from_tokens_and_add_them_to_cleanup(
            memory.address_size,
            cleanup,
            sentence_tokens
        )
        
        # Learning: Construct the chunks of each sentence and save them to memory.
        learning.generate_chunk_representations_and_save_them_to_memory(
            memory.address_size,
            cleanup,
            memory,
            sentence_tokens,
            chunk_sizes=chunk_sizes
        )

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [04:55<00:00,  1.69it/s]


In [12]:
memory.temperature

0.05

In [13]:
# inference_sentences_in = ['Dagored', 'is an Italian', 'record labels', 'based in Firenze', 'formed', 'in 1998.'] 250, 0.05 temperature
# 'record labels' also caught by transformer attention.

In [27]:
def divide_and_conquer(sentence: str):
    sentence_tokens = preprocess.preprocess_text(sentence)

    
    

In [29]:
retrieve_mode = "pooling"

# Get table with token similarities for each "out-of-train" sentence.
retrieved_contents = inference.infer(
    memory.address_size,
    cleanup,
    memory,
    inference_sentences_in,
    retrieve_mode=retrieve_mode,
    k=3, #TODO: What if index is out of range?
)

if retrieve_mode == "top_k":
    sims_df = pd.DataFrame(columns=['sentence', 'token', 'similarity']) 
    
    for s, addresses in zip(inference_sentences_in, retrieved_contents):
        display(s)
        for a in addresses:
            address_sims_df = inference.get_similarities_to_atomic_set(
                a,
                cleanup,
            )
            display(address_sims_df)
elif retrieve_mode == "pooling":  
    sims_df = pd.DataFrame(columns=['sentence', 'token', 'similarity']) 
      
    for s, c in zip(inference_sentences_in, retrieved_contents):
        sentence_sims_df = inference.get_similarities_to_atomic_set(
            c,
            cleanup,
        )
        sentence_sims_df['sentence'] = [s] * len(sentence_sims_df)
        sims_df = pd.concat([sims_df, sentence_sims_df])

    sims_df = sims_df.sort_values(['sentence', 'similarity'], ascending=False) \
                     .set_index(['sentence', 'token'])
    
    display(sims_df)
else:  # unrecognized
    pass

Unnamed: 0_level_0,Unnamed: 1_level_0,similarity
sentence,token,Unnamed: 2_level_1
products are,confectionery,0.519852
products are,australian,0.512149
products are,products,0.438105
products are,introduced,0.344476
products are,bars,0.316658
products are,1985,0.256861
products are,putting,0.151609
products are,chocolate,0.115057
products are,morality,0.113854
products are,countdown,0.11271


In [21]:
addresses = np.random.randint(0, len(memory.addresses), size=20)

for address in addresses:
    display(md(f"### Address {address}"))
    address_sims_df = inference.get_similarities_to_atomic_set(
            memory.addresses[address],
            cleanup,
    )
    display(address_sims_df)

### Address 3940

Unnamed: 0,token,similarity
0,standing,0.680146
1,attention,0.490856
2,ease,0.417652
3,whenever,0.361179
4,mainwaring,0.240785
5,tokyo,0.12956
6,clubman,0.124482
7,lefthanded,0.122425
8,dibb,0.117911
9,customs,0.115278


### Address 3242

Unnamed: 0,token,similarity
0,aa,0.691961
1,bowling,0.430642
2,cross,0.361983
3,country,0.349913
4,aaa,0.182991
5,landmark,0.133749
6,headliners,0.122129
7,different,0.121807
8,exist,0.118336
9,vladislaus,0.114504


### Address 1856

Unnamed: 0,token,similarity
0,version,0.537682
1,album,0.519935
2,carphology,0.468902
3,527,0.344818
4,307,0.254989
5,like,0.248635
6,living,0.133116
7,stones,0.132481
8,michelle,0.128916
9,coast,0.121306


### Address 2518

Unnamed: 0,token,similarity
0,law,0.681556
1,graduating,0.434303
2,career,0.400845
3,school,0.367369
4,1975,0.217852
5,began,0.158093
6,safa,0.122945
7,smoother,0.121148
8,berline,0.113197
9,ricans,0.11291


### Address 2983

Unnamed: 0,token,similarity
0,1884–1886,0.557167
1,reutlingen,0.525487
2,ohmenhausen,0.523903
3,evangelical,0.322548
4,martinskirche,0.274826
5,kirchestgallus,0.175586
6,chores,0.132255
7,jamiah,0.124485
8,logic,0.124188
9,mueang,0.116524


### Address 1434

Unnamed: 0,token,similarity
0,directed,0.654266
1,agnès,0.453076
2,vardasubway,0.432646
3,luc,0.370039
4,besson,0.207205
5,loi,0.135378
6,argentinas,0.128809
7,heather,0.124346
8,high,0.117921
9,smoking,0.116866


### Address 1453

Unnamed: 0,token,similarity
0,avant,0.506231
1,directed,0.468676
2,le,0.436923
3,mariage,0.436066
4,nauerjuste,0.295066
5,jacques,0.250657
6,bernard,0.182741
7,vehicles,0.133522
8,silken,0.131294
9,ragged,0.129945


### Address 2880

Unnamed: 0,token,similarity
0,bobby,0.575261
1,organ,0.498433
2,wood,0.424626
3,electric,0.391925
4,emmons,0.240832
5,piano,0.22981
6,suite,0.12655
7,dit,0.123887
8,fun,0.121332
9,edberg,0.120899


### Address 2735

Unnamed: 0,token,similarity
0,official,0.507065
1,agency,0.502259
2,confirmed,0.476796
3,management,0.38934
4,report,0.355735
5,disaster,0.158402
6,pledges,0.138678
7,bakelite,0.129547
8,freeman,0.116565
9,cain,0.11605


### Address 1280

Unnamed: 0,token,similarity
0,local,0.491613
1,tile,0.478777
2,bands,0.440861
3,stations,0.436929
4,except,0.273875
5,wider,0.234173
6,talked,0.118266
7,try,0.115417
8,exclusively,0.113406
9,recapturing,0.110546


### Address 617

Unnamed: 0,token,similarity
0,england,0.532432
1,sent,0.490095
2,devoted,0.451208
3,wife,0.398617
4,audubon,0.307169
5,lucy,0.214404
6,daniel,0.130535
7,ryōhachi,0.120887
8,freehill,0.117137
9,geary,0.113784


### Address 1869

Unnamed: 0,token,similarity
0,engine,0.541552
1,powerpluss,0.537493
2,first,0.438895
3,indians,0.420394
4,gustafson,0.290643
5,flathead,0.256272
6,charles,0.154135
7,suited,0.136658
8,998,0.123145
9,zâtî,0.122231


### Address 813

Unnamed: 0,token,similarity
0,time,0.520701
1,study,0.520128
2,night,0.397367
3,allowed,0.370988
4,california,0.304774
5,teach,0.225397
6,park,0.142362
7,illicit,0.135775
8,jin,0.135116
9,menlo,0.126321


### Address 3605

Unnamed: 0,token,similarity
0,interchange,0.516055
1,cloverleaf,0.499563
2,road,0.48565
3,lynnhaven,0.350772
4,rosemont,0.249469
5,parkway,0.194421
6,sacrificed,0.110791
7,dewellers,0.109503
8,aureus,0.108956
9,francesco,0.108809


### Address 245

Unnamed: 0,token,similarity
0,natural,0.539011
1,sciences,0.490791
2,go,0.441658
3,cambridge,0.378246
4,read,0.283168
5,1919,0.246764
6,bartolomeo,0.120155
7,zachary,0.117181
8,n−p2,0.116961
9,lack,0.114238


### Address 3281

Unnamed: 0,token,similarity
0,cup,0.509863
1,schools,0.508682
2,twice,0.486204
3,beating,0.359708
4,medical,0.288773
5,university,0.224935
6,plowed,0.138405
7,national,0.12791
8,mediate,0.125041
9,biometric,0.114073


### Address 1625

Unnamed: 0,token,similarity
0,sv,0.535893
1,darmstadt,0.516734
2,98,0.47001
3,commerzbank,0.378257
4,team,0.259631
5,arena,0.237593
6,hesse,0.192324
7,2517,0.145439
8,pirates,0.136953
9,validity,0.126133


### Address 3745

Unnamed: 0,token,similarity
0,plane,0.740757
1,image,0.448466
2,imaged,0.40111
3,via,0.231141
4,proper,0.186185
5,supplemented,0.124011
6,fallen,0.11951
7,saare,0.117218
8,dreamqueen,0.116443
9,mydaughter,0.11508


### Address 2017

Unnamed: 0,token,similarity
0,avenue,0.58406
1,east,0.479045
2,franklin,0.451741
3,station,0.348111
4,branching,0.247168
5,constructed,0.235902
6,threaten,0.125036
7,popes,0.117704
8,antioxidant,0.114744
9,released,0.114244


### Address 1177

Unnamed: 0,token,similarity
0,buildings,0.527661
1,completed,0.517606
2,government,0.491631
3,1998,0.355587
4,organizations,0.251592
5,1977,0.234012
6,australian,0.120377
7,residents,0.115398
8,removing,0.109611
9,nearest,0.108853


In [16]:
memory.n_updates / (memory.n_updates + memory.n_expansions)

0.5916405413852871

In [17]:
memory.n_updates

62466

In [18]:
memory.n_expansions

43115

In [19]:
len(memory.addresses)

4000

In [20]:
memory.n_deletions

40000