In [2]:
import math, copy, sys, logging, json, time, random, os, string, pickle, re

import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt

from modules.TransformerComponents import Transformer
from modules.Vocabulary import Vocab
from modules.MetaLearnNeuralMemory import MNMp

%matplotlib inline
%load_ext autoreload
%autoreload 2

np.random.seed(0) 
random.seed(0)
torch.manual_seed(0)

print('torch.version', torch.__version__)
print('torch.cuda.is_available()', torch.cuda.is_available())
print('torch.cuda.device_count()', torch.cuda.device_count())

torch.version 1.7.0
torch.cuda.is_available() True
torch.cuda.device_count() 2


In [3]:
vocab = Vocab(emb_dim=32)

In [4]:
class Teacher(): 
    
    def __init__(self, vocab):
        self.vocab = vocab
        self.mynameis = vocab.string2tensor("my name is")
        self.hi = vocab.string2tensor("hi")
        self.whatmyname = vocab.string2tensor("what is my name?")
        
    def add_vocab(self,):
        self.vocab.string2embedding("my name is, hi. what is my name?")
        self.vocab.string2embedding("a b c d e f g h i j k l m n o p q r s t u v w x y z")
        
    def randomString(self, stringLength):
        """ Generate a random string of fixed length """
        letters = string.ascii_lowercase
        return ''.join(random.choice(letters) for i in range(stringLength))
    
    def repeat(self, batch_size):
        
        if self.mynameis.shape[0] != batch_size:
            self.mynameis = self.vocab.string2tensor("my name is")
            self.hi = self.vocab.string2tensor("hi")
            self.whatmyname = self.vocab.string2tensor("what is my name?")
            self.mynameis = self.mynameis.repeat(batch_size,1)
            self.hi = self.hi.repeat(batch_size,1)
            self.whatmyname = self.whatmyname.repeat(batch_size,1)
    
    def get_batch(self, batch_size, name_size):
        
        self.repeat(batch_size)
        
        newnames = self.randomString(name_size)
        for n in range(batch_size - 1):
            newnames += " " + self.randomString(name_size)
        self.vocab.string2embedding(newnames)
        self.names = self.vocab.string2tensor(newnames).T

        self.intro = torch.cat((self.mynameis, self.names),dim=1)
        self.introtarget = torch.cat((self.hi, self.names),dim=1)
        return self.intro, self.introtarget, self.whatmyname, self.names

In [6]:
teacher = Teacher(vocab)
teacher.add_vocab()

print(vocab.word2index)

{'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3, 'my': 4, 'name': 5, 'is': 6, ',': 7, 'hi': 8, '.': 9, 'what': 10, '?': 11, 'a': 12, 'b': 13, 'c': 14, 'd': 15, 'e': 16, 'f': 17, 'g': 18, 'h': 19, 'i': 20, 'j': 21, 'k': 22, 'l': 23, 'm': 24, 'n': 25, 'o': 26, 'p': 27, 'q': 28, 'r': 29, 's': 30, 't': 31, 'u': 32, 'v': 33, 'w': 34, 'x': 35, 'y': 36, 'z': 37}


In [12]:
batch_size = 4
name_size = 1
intro, introtarget, whatmyname, names = teacher.get_batch(batch_size, name_size)

In [13]:
intro

tensor([[ 4,  5,  6, 37],
        [ 4,  5,  6, 21],
        [ 4,  5,  6, 27],
        [ 4,  5,  6, 23]])

In [14]:
introtarget

tensor([[ 8, 37],
        [ 8, 21],
        [ 8, 27],
        [ 8, 23]])

In [15]:
whatmyname

tensor([[10,  6,  4,  5, 11],
        [10,  6,  4,  5, 11],
        [10,  6,  4,  5, 11],
        [10,  6,  4,  5, 11]])

In [16]:
names

tensor([[37],
        [21],
        [27],
        [23]])