In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from sklearn import utils
from torch.utils.data import Dataset, DataLoader
import os
import glob
import matplotlib as plt

from copy import deepcopy, copy


In [3]:
# TODO
# define hyperparameters
input_dim = None
output_dim = None
emb_dim = None
hid_dim = None
n_layers = None
dropout = 0.2

Basic LSTM Outline

In [4]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.rnn(embedded)
        return hidden, cell


In [5]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()

        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden, cell


In [6]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        trg_len = trg.shape[0]
        batch_size = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim

        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)

        hidden, cell = self.encoder(src)

        input = trg[0, :]

        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[t] = output
            teacher_force = np.random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[t] if teacher_force else top1

        return outputs


In [8]:
def train(model, iterator, optimizer, criterion, clip):
    model.train()

    epoch_loss = 0

    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg

        optimizer.zero_grad()

        output = model(src, trg)

        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)

        loss = criterion(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


In [9]:
def evaluate(model, iterator, criterion):
    model.eval()

    epoch_loss = 0

    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            trg = batch.trg

            output = model(src, trg, 0)  # turn off teacher forcing

            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)


Creating DataLoader

In [10]:
def readfile(fn_in):
    # Read episode from text file
    #
    # Input
    #  fn_in : filename to read
    #
    # Output
    #   Parsed version of the file, with struct
    #   
    fid = open(os.path.join(fn_in),'r')
    lines = fid.readlines()
    lines = [line.rstrip('\n') for line in lines]
    lines = [line for line in lines if line != '']
    idx_support = lines.index('*SUPPORT*')
    idx_query = lines.index('*QUERY*')
    idx_grammar = lines.index('*GRAMMAR*')
    x_support, y_support = parse_commands(lines[idx_support+1:idx_query])
    x_query, y_query = parse_commands(lines[idx_query+1:idx_grammar])
    grammar_str = lines[idx_grammar+1:]
    grammar_str = '\n'.join(grammar_str)
    fid.close()
    return {'xs':x_support, 'ys':y_support, 'xq':x_query, 'yq':y_query, 'grammar_str':grammar_str}

In [11]:
def parse_commands(lines):
    # Parse lines from input files into command sequence and output sequence
    #
    # Input
    #  lines: [list of strings], each of format "IN: a b c OUT: d e f""
    lines = [l.strip() for l in lines]
    lines = [l.lstrip('IN: ') for l in lines]
    D = [l.split(' OUT: ') for l in lines]
    x = [d[0].split(' ') for d in D]
    y = [d[1].split(' ') for d in D]
    return x, y

In [12]:
SOS_token = "SOS" # start of sentence
EOS_token = "EOS" # end of sentence
PAD_token = "PAD" # padding symbol
IO_SEP = 'IO' # separator '->' between input/outputs in support examples
ITEM_SEP  = SOS_token # separator '|' between support examples in input sequence


input_symbols_list_default = ['dax', 'lug', 'wif', 'zup', 'fep', 'blicket', 'kiki', 'tufa', 'gazzer']
output_symbols_list_default = ['RED', 'YELLOW', 'GREEN', 'BLUE', 'PURPLE', 'PINK']

def combine_input_output_symb(list_input_symb,list_output_symb):
    # Make new source vocabulary that combines list of input and output symbols.
    #  Include input/output and item separators, IO_SEP,ITEM_SEP.
    #  Exclude EOS_token,SOS_token,PAD_token, which will be added automatically by Lang constructor/
    #
    # Input
    #   list_input_symb : list of token symbols (strings)
    #   list_output_symb : list of token symbols (strings)
    # Output
    #   comb : combined list of tokens as strings
    additional_symb = sorted(set([IO_SEP,ITEM_SEP])-set([EOS_token,SOS_token,PAD_token]))
    comb = sorted(set(list_input_symb + list_output_symb + additional_symb))
    return comb

In [13]:
class Lang:
    #  Class for converting tokens strings to token index, and vice versa.
    #   Use separate class for input and output languages
    #
    def __init__(self, symbols):
        # symbols : list of all possible symbols besides special tokens SOS, EOS, and PAD
        n = len(symbols)
        assert(SOS_token not in symbols)
        assert(EOS_token not in symbols)
        assert(PAD_token not in symbols)
        assert(PAD_token != SOS_token)
        self.symbols = symbols # list of non-special symbols
        self.index2symbol = {n: SOS_token, n+1: EOS_token, n+2: PAD_token}
        self.symbol2index = {SOS_token : n, EOS_token : n+1, PAD_token : n+2}
        for idx,s in enumerate(symbols):
            self.index2symbol[idx] = s
            self.symbol2index[s] = idx
        self.n_symbols = len(self.index2symbol)
        self.PAD_idx = self.symbol2index[PAD_token]
        self.PAD_token = PAD_token
        assert(len(self.index2symbol)==len(self.symbol2index))

    def symbols_to_tensor(self, mylist, add_eos=True):
        # Convert a list of token strings to token index (adding a EOS token at end)
        # 
        # Input
        #  mylist  : list of m symbols as strings
        #  add_eos : true/false, if true add the EOS symbol at end
        #
        # Output
        #  output : [m or m+1 LongTensor] token index for each symbol (plus EOS if appropriate)
        mylist = copy(mylist)
        if add_eos: mylist.append(EOS_token)
        indices = [self.symbol2index[s] for s in mylist]
        output = torch.LongTensor(indices) # keep on CPU since this occurs inside Dataset getitem..
        return output

    def tensor_to_symbols(self, v):
        # Convert tensor of token index to token strings, breaking where we get a EOS token.
        #   The EOS token is not included at the end in the result string list.
        # 
        # Input
        #  v : python list of m indices, or 1D tensor
        #   
        # Output
        #  mylist : list of symbols (excluding EOS)
        if torch.is_tensor(v):
            assert v.dim()==1
            v = v.tolist()
        assert isinstance(v, list)
        mylist = []
        for x in v:
            s = self.index2symbol[x]
            if s == EOS_token:
                break
            mylist.append(s)
        return mylist

In [14]:
def make_hashable(G):
    # Create unique identifier for episodes defined by a grammar.
    #  Separate and sort rules in string format.
    #
    # Input
    #   G : string of elements separated by \n specifying the structure of an episode 
    G_str = str(G).split('\n')
    G_str = [s.strip() for s in G_str]
    G_str = [s for s in G_str if s != ''] # remove empty strings
    G_str.sort()
    out = '\n'.join(G_str)
    return out.strip()

def bundle_biml_episode(x_support,y_support,x_query,y_query,myhash,aux={}):
    # Bundle components for an episode suitable for optimizing BIML
    # 
    # Input
    #  x_support [length ns list of lists] : input sequences (each a python list of words/symbols)
    #  y_support [length ns list of lists] : output sequences (each a python list of words/symbols)
    #  x_query [length nq list of lists] : input sequences (each a python list of words/symbols)
    #  x_query [length nq list of lists] : output sequences (each a python list of words/symbols)
    #  myhash : unique string identifier for this episode (should be order invariant for examples)
    #  aux [dict] : any misc information that we want to pass along with the episode
    #
    # Output
    #  sample : dict that stores episode information
    ns = len(x_support)
    xy_support = [ITEM_SEP]
    for j in range(ns):
        xy_support += x_support[j] + [IO_SEP] + y_support[j] + [ITEM_SEP]
    x_query_context = [item + xy_support for item in x_query] # Create the combined source sequence for every query
    sample = {}
    sample['identifier'] = myhash # unique identifying string for this episode (order invariant)
    sample['xs'] = x_support # support 
    sample['ys'] = y_support
    sample['xq'] = x_query # query
    sample['yq'] = y_query
    sample['xq_context'] = x_query_context
    if aux: sample['aux'] = aux
    return sample

In [15]:
class DataAlg(Dataset):
    # Meta-training for few-shot grammar learning (fully algebraic)

    def __init__(self, mode, mydir, p_noise=0., inc_support_in_query=True, min_ns=0):
        # Each episode has different latent (algebraic) grammar. 
        #  The number of support items picked uniformly from min_ns...max_ns
        #
        # Input
        #  mode: 'train' or 'val'
        #  mydir : directory where data is stored
        #  p_noise : for a given symbol emission, probability that it will be from uniform distribution
        #  inc_support_in_query : default=True. Boolean. Should support items also be query items?
        #  min_ns : min number of support items in episode
        assert mode in ['train','val']
        self.mode = mode
        self.train = mode == 'train'
        self.p_noise = p_noise
        self.mydir_items = os.path.join(mydir,self.mode)
        self.list_items = glob.glob(self.mydir_items+"/*.txt") # all episode files
        self.input_symbols = input_symbols_list_default
        self.output_symbols = output_symbols_list_default
        comb = combine_input_output_symb(self.input_symbols,self.output_symbols)
        self.langs = {'input' : Lang(comb), 'output': Lang(self.output_symbols)}
        self.min_ns = min_ns # min number of support items in episode
        self.inc_support_in_query = inc_support_in_query

    def __len__(self):
        return len(self.list_items)

    def __getitem__(self, idx):
        S = readfile(self.list_items[idx])
        max_ns = len(S['xs'])
        if self.train:
            S['xs'], S['ys'] = utils.shuffle(S['xs'], S['ys'])
            ns = random.randint(self.min_ns,max_ns)
            S['xs'] = S['xs'][:ns]
            S['ys'] = S['ys'][:ns]
        # if self.p_noise > 0: # emission noise
            # for i in range(len(S['yq'])):
                # S['yq'][i] = add_response_noise(S['yq'][i],self.p_noise,self.langs)
        if self.inc_support_in_query:
            S['xq'] = S['xs'] + S['xq']
            S['yq'] = S['ys'] + S['yq']
        myhash = make_hashable(S['grammar_str'])
        aux = {'grammar_str':S['grammar_str']}        
        return bundle_biml_episode(S['xs'],S['ys'],S['xq'],S['yq'],myhash,aux=aux)

In [16]:
class DataRetrieve(DataAlg):
    # Copy task of retrieving study examples

    def __init__(self, mode, mydir, min_ns=2, max_ns=2):
        # Each episode has a set of support strings, identical to DataAlg class
        #  Number of support items picked uniformly from min_ns...max_ns
        #
        # Input
        #  mode: 'train' or 'val'
        #  mydir : directory where data is stored
        super().__init__(mode, mydir, inc_support_in_query=True, min_ns=min_ns)
        self.my_max_ns = max_ns

    def __getitem__(self, idx):
        S = readfile(self.list_items[idx])
        assert(len(S['xs'])>=self.my_max_ns)
        if self.train:
            S['xs'], S['ys'] = utils.shuffle(S['xs'], S['ys'])
        ns = random.randint(self.min_ns,self.my_max_ns)
        S['xs'] = S['xs'][:ns]
        S['ys'] = S['ys'][:ns]        
        S['xq'] = S['xs']
        S['yq'] = S['ys']
        myhash = make_hashable(S['grammar_str'])
        aux = {'grammar_str':S['grammar_str']}
        return bundle_biml_episode(S['xs'],S['ys'],S['xq'],S['yq'],myhash,aux=aux)

In [17]:
# def __getitem__(self, idx):
#     S = readfile(self.list_items[idx])
#     assert(len(S['xs'])>=self.my_max_ns)
#     if self.train:
#         S['xs'], S['ys'] = utils.shuffle(S['xs'], S['ys'])
#     ns = random.randint(self.min_ns,self.my_max_ns)
#     S['xs'] = S['xs'][:ns]
#     S['ys'] = S['ys'][:ns]        
#     S['xq'] = S['xs']
#     S['yq'] = S['ys']
#     myhash = make_hashable(S['grammar_str'])
#     aux = {'grammar_str':S['grammar_str']}
#     return bundle_biml_episode(S['xs'],S['ys'],S['xq'],S['yq'],myhash,aux=aux)

In [18]:
D_train = DataRetrieve('train',mydir='data_algebraic', min_ns=14, max_ns=14)
D_val = DataRetrieve('val',mydir='data_algebraic', min_ns=14, max_ns=14)

In [19]:
def make_biml_batch(samples, langs):
    # Batch episodes into a series of padded input and target tensors
    # 
    # Input
    #  samples : list of dicts from bundle_biml_episode
    #  langs : input and output version of Lang class
    assert isinstance(samples,list)
    m = len(samples)
    mybatch = {}
    mybatch['list_samples'] = samples
    mybatch['batch_size'] = m
    mybatch['xq_context'] = [] # list of source sequences (as lists) across all episodes
    mybatch['xq'] = []  # list of queries (as lists) across all episodes
    mybatch['yq'] = [] # list of query outputs (as lists) across all episodes
    mybatch['q_idx'] = [] # index of which episode each query belongs to
    mybatch['in_support'] = [] # bool list indicating whether each query is in its corresponding support set, or not
    for idx in range(m): # each episode
        sample = samples[idx]
        nq = len(sample['xq'])
        assert(nq == len(sample['yq']))
        mybatch['xq_context'] += sample['xq_context']
        mybatch['xq'] += sample['xq']
        mybatch['yq'] += sample['yq']
        mybatch['q_idx'] += [idx*torch.ones(nq, dtype=torch.int)]
        mybatch['in_support'] += [x in sample['xs'] for x in sample['xq']]
    mybatch['q_idx'] = torch.cat(mybatch['q_idx'], dim=0)
    mybatch['xq_context_padded'],mybatch['xq_context_lengths'] = build_padded_tensor(mybatch['xq_context'], langs['input'])
    mybatch['yq_padded'],mybatch['yq_lengths'] = build_padded_tensor(mybatch['yq'], langs['output'])
    mybatch['yq_sos_padded'],mybatch['yq_sos_lengths'] = build_padded_tensor(mybatch['yq'],langs['output'],add_eos=False,add_sos=True)
    return mybatch

In [20]:
D_train

batch_size = 25

In [21]:
train_dataloader = DataLoader(D_train,batch_size=batch_size,collate_fn=lambda x:make_biml_batch(x,D_train.langs),
                                shuffle=True)
val_dataloader = DataLoader(D_val,batch_size=batch_size,collate_fn=lambda x:make_biml_batch(x,D_val.langs),
                                shuffle=False)

In [29]:
D_train.langs

{'input': <__main__.Lang at 0x7efa79afbaa0>,
 'output': <__main__.Lang at 0x7efa7b3830e0>}

In [23]:
def pad_seq(seq, max_length):
    # Pad token string sequence with the PAD_token symbol to achieve max_length
    #
    # Input
    #  seq : list of symbols (as strings)
    #
    # Output
    #  seq : padded list now extended to length max_length
    seq += (max_length - len(seq)) * [PAD_token]
    return seq

def build_padded_tensor(list_seq, lang, add_eos=True, add_sos=False):
    # Transform list of python lists to a padded torch tensors
    # 
    # Input
    #  list_seq : list of n sequences (each sequence is a python list of token srings)
    #  lang : language object for translation of token string to token index
    #  add_eos : add end-of-sequence token at the end?
    #  add_sos : add start-of-sequence token at the beginning?
    #
    # Output
    #  z_padded : LongTensor (n x max_len)
    #  z_lengths : python list of sequence lengths (n-length list of scalars)
    n = len(list_seq)
    if n==0: return [],[]
    z_eos = list_seq
    if add_sos: 
        z_eos = [[SOS_token]+z for z in z_eos]
    if add_eos:
        z_eos = [z+[EOS_token] for z in z_eos]    
    z_lengths = [len(z) for z in z_eos]
    max_len = max(z_lengths) # maximum length in this episode
    z_padded = [pad_seq(z, max_len) for z in z_eos]
    z_padded = [lang.symbols_to_tensor(z, add_eos=False).unsqueeze(0) for z in z_padded]
    z_padded = torch.cat(z_padded, dim=0) # n x max_len
    return z_padded,z_lengths

In [33]:
type(train_dataloader)

torch.utils.data.dataloader.DataLoader

In [38]:
train_features = next(iter(train_dataloader))
train_features

{'list_samples': [{'identifier': 'gazzer -> YELLOW\ntufa -> BLUE\nu1 x1 -> [u1] [x1]\nwif -> GREEN\nx1 blicket -> [x1] [x1] [x1]\nx1 dax u1 -> [x1] [u1]\nx1 lug -> [x1] [x1] [x1]\nzup -> PURPLE',
   'xs': [['zup', 'gazzer', 'zup', 'dax', 'gazzer', 'dax', 'gazzer'],
    ['tufa'],
    ['wif', 'blicket'],
    ['wif', 'dax', 'gazzer'],
    ['tufa', 'blicket'],
    ['zup', 'blicket'],
    ['tufa', 'tufa'],
    ['wif', 'lug'],
    ['tufa', 'tufa', 'tufa'],
    ['gazzer', 'lug'],
    ['tufa', 'zup'],
    ['gazzer'],
    ['wif', 'dax', 'wif', 'dax', 'zup'],
    ['zup', 'zup', 'dax', 'gazzer']],
   'ys': [['PURPLE', 'YELLOW', 'PURPLE', 'YELLOW', 'YELLOW'],
    ['BLUE'],
    ['GREEN', 'GREEN', 'GREEN'],
    ['GREEN', 'YELLOW'],
    ['BLUE', 'BLUE', 'BLUE'],
    ['PURPLE', 'PURPLE', 'PURPLE'],
    ['BLUE', 'BLUE'],
    ['GREEN', 'GREEN', 'GREEN'],
    ['BLUE', 'BLUE', 'BLUE'],
    ['YELLOW', 'YELLOW', 'YELLOW'],
    ['BLUE', 'PURPLE'],
    ['YELLOW'],
    ['GREEN', 'GREEN', 'PURPLE'],
    ['PURPL

In [36]:
train_features = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

ValueError: too many values to unpack (expected 2)

In [39]:
train_features.size()

AttributeError: 'dict' object has no attribute 'size'

In [24]:
# Function to view the contents of the dataloader
def view_dataloader_contents(dataloader, num_batches=1):
    for i, batch in enumerate(dataloader):
        if i >= num_batches:
            break
        print(f"Batch {i+1}:")
        print(batch)
        print("\n")

# View contents of train_dataloader
print("Train DataLoader contents:")
view_dataloader_contents(train_dataloader, num_batches=2)

# View contents of val_dataloader
print("Validation DataLoader contents:")
view_dataloader_contents(val_dataloader, num_batches=2)

Train DataLoader contents:
Batch 1:
{'list_samples': [{'identifier': 'blicket -> YELLOW\ndax -> GREEN\nlug -> PURPLE\ntufa -> PINK\nu1 fep -> [u1] [u1] [u1] [u1]\nu1 wif x1 -> [u1] [x1]\nu1 x1 -> [u1] [x1]\nx1 zup -> [x1] [x1] [x1] [x1]', 'xs': [['blicket'], ['dax', 'wif', 'tufa', 'wif', 'tufa'], ['lug', 'lug', 'fep'], ['lug', 'lug', 'zup'], ['tufa', 'tufa', 'fep'], ['lug'], ['tufa', 'dax', 'fep'], ['dax', 'wif', 'blicket'], ['dax', 'fep'], ['blicket', 'lug', 'wif', 'tufa', 'dax', 'wif', 'lug'], ['tufa', 'lug'], ['dax', 'wif', 'tufa', 'wif', 'blicket'], ['lug', 'wif', 'blicket', 'dax', 'fep'], ['tufa', 'fep']], 'ys': [['YELLOW'], ['GREEN', 'PINK', 'PINK'], ['PURPLE', 'PURPLE', 'PURPLE', 'PURPLE', 'PURPLE'], ['PURPLE', 'PURPLE', 'PURPLE', 'PURPLE', 'PURPLE', 'PURPLE', 'PURPLE', 'PURPLE'], ['PINK', 'PINK', 'PINK', 'PINK', 'PINK'], ['PURPLE'], ['PINK', 'GREEN', 'GREEN', 'GREEN', 'GREEN'], ['GREEN', 'YELLOW'], ['GREEN', 'GREEN', 'GREEN', 'GREEN'], ['YELLOW', 'PURPLE', 'PINK', 'GREEN', 'PUR

In [25]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7efa79383590>