In [2]:
import sys
import os

root_path = ""
for path in os.getcwd().split("\\")[:-2]:
    root_path += f"{path}/"
sys.path.insert(1, root_path)

In [3]:
import config

In [4]:
wandb_params = {
    "project_name": "project_name",
    "job_name": "job_name",
    "wandb_flag": False,
}

In [5]:
import torch

# set_up_env
env_params = {"device": torch.device("cpu")}

In [6]:
trainer_params = {
    "resume": False,
    "batch_size": 96,
    "checkpoint_path": "checkpoint/directory/smoe.pt",
    "full_eval_mode": False,
    "nb_batches_per_iter": 1000,
    "batch_split": 2,
}

In [19]:
data_params = {"data_path": "data/text8", "data_name": "text8"}

In [8]:
def _tokenize(ds, type_ds, dictionary_to_update):
    nb_tokens_in_dictionary = len(dictionary_to_update)

    # Count nb of tokens in text and update the dictionary
    for (
        i,
        line,
    ) in enumerate(ds[type_ds]["text"]):
        if i == 10:
            break
        tokens = line.split() + ["<eos>"]
        for token in tokens:
            if token not in dictionary_to_update:
                dictionary_to_update[token] = nb_tokens_in_dictionary
                nb_tokens_in_dictionary += 1

    # Assign to each token its identifier
    ids = []
    for (
        i,
        line,
    ) in enumerate(ds[type_ds]["text"]):
        if i == 10:
            break
        tokens = line.split() + ["<eos>"]
        for token in tokens:
            ids.append(dictionary_to_update[token])
    ids = torch.LongTensor(ids)
    return ids

In [9]:
from datasets import load_dataset

ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
class Corpus:
    def __init__(self, ds):
        self._dictionary = {}
        self.train = _tokenize(
            ds=ds, type_ds="train", dictionary_to_update=self._dictionary
        )
        self.valid = _tokenize(
            ds=ds, type_ds="validation", dictionary_to_update=self._dictionary
        )
        self.test = _tokenize(
            ds=ds, type_ds="test", dictionary_to_update=self._dictionary
        )

    @property
    def vocab_size(self):
        return len(self._dictionary)

In [11]:
corpus_path = os.path.join(config.ROOT_PATH, "data/raw/wikitext-103.pt")
corpus_path

'c:/Users/Admin/OneDrive - Hanoi University of Science and Technology/DANC/source_code/my_source/data/raw/wikitext-103.pt'

In [12]:
corpus = Corpus(ds=ds)

In [20]:
data_params["vocab_size"] = corpus.vocab_size

In [21]:
torch.save(corpus, corpus_path)

In [22]:
# Tạo ra khối token với hai bước
# Bước 1: là cắt bỏ các phần tử cuối của array sao cho độ dài array chia hết cho batch_size
# Bước 2: chuyển từng khối batch thành vector cột với độ dài mỗi cột là batch size


def _batchify(data_tensor, batch_size):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nb_batches = data_tensor.size(0) // batch_size
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data_tensor = data_tensor.narrow(0, 0, nb_batches * batch_size)
    # Evenly divide the data across the bsz batches.
    data_tensor = data_tensor.view(batch_size, -1).contiguous()
    return data_tensor


def _get_train_val_test_data(corpus: Corpus, batch_size: int) -> torch.Tensor:
    return [
        _batchify(corpus.train, batch_size),
        _batchify(corpus.valid, batch_size),
        _batchify(corpus.test, batch_size),
    ]

In [23]:
batch_size = trainer_params["batch_size"]
batch_size

96

In [24]:
train_data, val_data, test_data = _get_train_val_test_data(
    corpus=corpus, batch_size=batch_size
)
print(train_data)
print(val_data)
print(test_data)

tensor([[  0,   1,   2,   3,   4],
        [  1,   0,   0,   5,   6],
        [  2,   7,   8,   9,   3],
        [ 10,  11,   8,  12,  13],
        [ 14,  15,   2,  16,  17],
        [ 18,   7,  19,  13,  20],
        [ 21,  22,  23,   2,   3],
        [  4,  24,  25,  13,  26],
        [ 27,  28,  29,  30,  31],
        [ 32,  33,  34,  35,  36],
        [ 37,  38,  39,  17,  40],
        [ 41,  15,  42,  43,  44],
        [ 45,  43,  25,  13,  46],
        [ 26,  17,  47,  33,  43],
        [ 17,   2,  48,  15,  49],
        [ 17,  50,  51,  16,  28],
        [ 37,  52,  30,  53,  54],
        [ 23,  55,  56,  13,  17],
        [ 57,  58,  59,  22,  17],
        [ 60,  33,  37,  61,  17],
        [ 62,  63,  62,  13,  27],
        [ 64,  65,  66,  67,  17],
        [ 68,  16,  69,  70,  17],
        [ 71,  72,  73,  74,  75],
        [ 76,  77,  78,  37,  79],
        [ 80,  81,  17,  82,  66],
        [ 62,  83,  84,  62,  15],
        [  0,  85,  33,  86,  87],
        [ 43,  88,  