In [1]:
import os
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

from transformers import AutoTokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained('codeparrot/codeparrot')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

PAD_IDX = tokenizer.pad_token_id
VOCAB_SIZE = tokenizer.vocab_size

In [58]:
class PreProcessingTransform(nn.Module):
  def __init__(self, tokenizer):
    super(PreProcessingTransform, self).__init__()
    self.tokenizer = tokenizer
  def forward(self, text):
    return self.tokenizer.encode(text, return_tensors='pt')[0]

In [59]:
class DataPy(Dataset):
  def __init__(self, csv_filename, path, transform=None):
    super(DataPy, self).__init__()
    self.df = pd.read_csv(csv_filename).sample(frac=1).reset_index(drop=True)
    self.path = path
    self.transform = transform
  
  def __len__(self):
    return len(self.df)

  def _get_path(self, index):
    filename = self.df.iloc[index].filename
    path = os.path.join(self.path, filename)
    return path
  
  def _read_file(self, path):
    with open(path, 'r') as f:
      data = f.read()
    return data

  def __getitem__(self, index):
    path = self._get_path(index)
    data = self._read_file(path)
    if self.transform is not None:
      data = self.transform(data)
    return data

  def __repr__(self):
    return f"<DataPy len:{len(self)} path:{self.path}>"

dataset = DataPy('./data/dataset.csv', './data/', PreProcessingTransform(tokenizer))

In [60]:
def collate_fn(batch):
  X = nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=PAD_IDX)
  length = torch.LongTensor([ a.shape[0] for a in batch])
  return X, length

dl = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)

X_batch, length = next(iter(dl))

In [62]:
X_batch

tensor([[  646, 21416,   199,   504, 21416,   492, 11542,   199,   504,   570,
          1690,   492, 17581,   379, 12984, 10746, 12444,   199,   199,   533,
         24297,  1098,  3142,  7147,     8,  4411,    14,  2377,   304,   523,
           347,   636,   826,   721,   277,    12, 11181,    63,   890,    12,
           366,    63,  1238,    12,  6567,    63,  2996,    12,  1390,    63,
          5895,    12,  2152,    29,   403,    12,  2243,    29,   403,   304,
           272,  1022,   589,   275,   469,   283,  2271,   356,  2152,    12,
           283,  1782,   356,  2243,   789,   272,  1613,     8, 27907,  3142,
          7147,    12,   291,  2843,   826,  4533,   272,   291,    14,  1027,
            63, 14275,   275, 11542,    14, 27907,     8,  9953,    63,   890,
            12,   366,    63,  1238,    12,  6567,    63,  2996,    12,  1011,
         24459,     9,   272,   291,    14, 11018,    63,  6688,    63,  3028,
         12444,   275, 17581,   379, 12984, 10746, 1