Skip to content

Commit

Permalink
return full text if not using tokenizer. cleanup type hinting etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joey Legere committed Jul 7, 2022
1 parent 8e2c7e4 commit 046a0d2
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions bittensor/_dataset/dataset_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import json
import os
import random
import time
from typing import Union

from torch.utils.data.dataloader import DataLoader
from torch.utils.data import Subset
import requests
import torch

from loguru import logger
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import requests
from torch.utils.data.dataloader import DataLoader

from loguru import logger
import bittensor

from .thread_queue import ThreadQueue
import time
import json

logger = logger.opt(colors=True)

Expand Down Expand Up @@ -631,7 +631,7 @@ def __len__(self):
return 0
return round( len(self.data) / self.block_size )

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Union[str, torch.tensor]:
""" Returns a block of sentences from text dataset.
Args:
Expand All @@ -642,12 +642,13 @@ def __getitem__(self, idx):
"""
start_idx = (idx * self.block_size) % len(self.data)
end_idx = start_idx + self.block_size
if self.no_tokenizer == False:
tokenized_text = torch.tensor(self.tokenizer(" ".join(self.data[start_idx:end_idx]), padding=True, truncation=True)['input_ids'], dtype=torch.long)
elif self.no_tokenizer == True:
tokenized_text = " ".join(self.data[start_idx:end_idx])
text = " ".join(self.data[start_idx:end_idx])

return tokenized_text[:self.block_size]
if self.no_tokenizer is True:
return text
else:
tokens = self.tokenizer(text, padding=True, truncation=True)["input_ids"]
return torch.tensor(tokens, dtype=torch.long)[:self.block_size]

def build_hash_table(self):
self.IPFS_fails = 0
Expand Down

0 comments on commit 046a0d2

Please sign in to comment.