In [None]:
!pip install pytorch_lightning -q

In [None]:
#=============================
#Import
#=============================
import numpy as np
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import random_split, DataLoader, Dataset
import torch
from pytorch_lightning import LightningDataModule
import os
#-----------------------------MNIST Dataset-----------------------------#
#=============================
#Lightning Data Module Wrapper
#=============================
class MNISTDataModule(LightningDataModule):
  def __init__(self, sequence_length, input_size, batch_size, val_size = 0, permute_seed = None):
    super().__init__()
    self.sequence_length = sequence_length
    self.input_size = input_size
    self.batch_size = batch_size
    self.val_size = val_size
    self.train_size = 60000 - self.val_size
    self.root = '/content/drive/My Drive/Colab Notebooks/EXP/Dataset/'
    self.idx_permute = None

    if permute_seed:
      rng_permute = np.random.RandomState(permute_seed)
      self.idx_permute = torch.from_numpy(rng_permute.permutation(784))
  def setup(self, stage):
    if self.idx_permute:
      transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1)[self.idx_permute].view(self.sequence_length,self.input_size))])
    else:
      transform = transforms.ToTensor()

    self.mnist_train = MNIST(root=self.root,transform=transform, download=True)
    if (self.val_size > 0):
      self.mnist_train, self.mnist_valid = random_split(self.mnist_train, [self.train_size, self.val_size])
    else:
      self.mnist_valid = MNIST(root=self.root,train=False, transform=transform, download=True)
    self.mnist_test = MNIST(root=self.root,train=False, transform=transform, download=True)
    
  def train_dataloader(self):
    return DataLoader(self.mnist_train,batch_size=self.batch_size, shuffle = True)
  def val_dataloader(self):
    return DataLoader(self.mnist_valid,batch_size=self.mnist_valid.__len__())
  def test_dataloader(self):
    return DataLoader(self.mnist_test,batch_size=self.mnist_test.__len__())

#-----------------------------Copy Memory Task Dataset-----------------------------#
#=============================
#Sample Generator
#=============================
def generate_copying_sequence(recall_length, delay_length):
  marker = [0]*8 + [1]
  blank = [0]*9
  alphabet = np.array([[0]*i + [1] + [0]*(9-(i+1)) for i in range(8)])
  choices_idx = np.random.choice(len(alphabet), size=recall_length, replace=True)
  seq_to_be_copied = alphabet[choices_idx,:]
  if delay_length > 0:
    x = np.vstack((seq_to_be_copied,[blank]*delay_length, [marker], [blank]*(recall_length-1)))
  else:
    x = np.vstack((seq_to_be_copied, [marker], [blank]*(recall_length-1)))
  y = np.vstack(([blank]*(recall_length+delay_length),seq_to_be_copied))
  return torch.FloatTensor(x), torch.FloatTensor(y)[:,:-1]
#=============================
#Batch Generator
#=============================
def create_copying_batch(recall_length, delay_length, batch_size):
    x = []
    y = []
    for i in range(batch_size):
        sample_x, sample_y = generate_copying_sequence(recall_length, delay_length)
        x.append(sample_x)
        y.append(sample_y)

    x = torch.stack(x, axis=0)
    y = torch.stack(y, axis=0)
    return x, y
#-----------------------------PTB Dataset-----------------------------#
#=============================
#Dictionary
#=============================
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)
#=============================
#Tokenization, Word2VecIdx
#=============================
class PTBCorpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))
    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r') as f:
            ids = torch.LongTensor(tokens)
            token = 0
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]
                    token += 1
        return ids
    def trim_batch(self, data, bsz):  #trimming and reshapeing to batches
        data = data.view(-1)
        nbatch = data.size(0) // bsz
        data = data[:nbatch * bsz]
        data = data.view(bsz, -1)
        return data
    def get_ith_sequence(self, source, ith, sequence_length): #get ith-sequence for BPTT
        return source[:,ith*sequence_length:(ith+1)*sequence_length] #batch_size X sequence_length

In [None]:
x, y = create_copying_batch(1000,0,1)
print(x.shape, y.shape)

torch.Size([1, 2000, 9]) torch.Size([1, 2000, 8])
