In [None]:
!git clone https://github.com/brendenlake/SCAN.git
%pip install -r requirements.txt --quiet

In [None]:
import wandb
import numpy as np
import torch
from torch.nn import NLLLoss
from torch.optim import Adam
from Dataloader import Dataloader, get_dataset_path
from Encoder import EncoderRNN
from Decoder import DecoderRNN
from Evaluate import get_accuracy, get_accuracy_across_length
from Train import train
import os
device = "cuda:0" if torch.cuda.is_available() else "cpu"
os.environ['WANDB_NOTEBOOK_NAME'] = 'main.ipynb'
wandb.login()

In [None]:
global_config = {
    'learning_rate': 0.001,
    'teacher_forcing_ratio': 0.5,
    'trials': int(1e5),
    'gradient_clip_norm': 5,
}
paper_models = {
    "overall-best": {
        "rnn-type": "LSTM",
        "layers": 2,
        "hidden_units": 200,
        "dropout": 0.5,
        "attention": False,
    },
    "simple-best": {
        "rnn-type": "LSTM",
        "layers": 2,
        "hidden_units": 200,
        "dropout": 0,
        "attention": False,
    },
    "length-best": {
        "rnn-type": "GRU",
        "layers": 1,
        "hidden_units": 50,
        "dropout": 0.5,
        "attention": True,
    },
    "add-prim-jump-best": {
        "rnn-type": "LSTM",
        "layers": 1,
        "hidden_units": 100,
        "dropout": 0.1,
        "attention": True,
    }
}

In [None]:
experiments = (
    [
        {
            '_folder': 'SCAN/simple_split',
            'dataset': 'simple',
            'model': model,
            'test_across_length': False,
            '_eval_during_training': False
        } for model in ['overall-best', 'simple-best']
    ] + [
        {
            '_folder': 'SCAN/simple_split/size_variations',
            'dataset': f'simple_p{i}',
            'model': 'overall-best',
            'test_across_length': False,
            '_eval_during_training': False
        } for i in [1, 2, 4, 8, 16, 32, 64]
    ] + [
        {
            '_folder': 'SCAN/length_split',
            'dataset': 'length',
            'model': 'overall-best',
            'test_across_length': True,
            '_eval_during_training': False
        },
        {
            '_folder': 'SCAN/length_split',
            'dataset': 'length',
            'model': 'length-best',
            'test_across_length': False,
            '_eval_during_training': False
        }
    ] + [
        {
            '_folder': 'SCAN/add_prim_split',
            'dataset': 'addprim_jump',
            'model': model_jump,
            'test_across_length': False,
            '_eval_during_training': False
        } for model_jump in ['overall-best', 'add-prim-jump-best']
    ] + [
        {
            '_folder': 'SCAN/add_prim_split',
            'dataset': 'addprim_turn_left',
            'model': model_left_turn,
            'test_across_length': False,
            '_eval_during_training': False
        } for model_left_turn in ['overall-best', 'add-prim-left-turn-best']
    ] + [
        {
            '_folder': 'SCAN/add_prim_split/with_additional_examples',
            'dataset': f'addprim_complex_jump_num{num}_rep{rep}',
            'model': 'add-prim-jump-best',
            'test_across_length': False,
            '_eval_during_training': False
        }
        for num in [1, 2, 4, 8, 16, 32]
        for rep in [1, 2, 3, 4, 5]
    ]
)


In [None]:
experiment_num = -2
config = {
    **global_config,
    **paper_models[
        experiments[experiment_num]['model']
    ],
    **experiments[experiment_num]
}

In [None]:
dataloader = Dataloader()

train_path = get_dataset_path(config['_folder'], config['dataset'], 'train')
train_X, train_Y = dataloader.fit_transform(train_path)
test_path = get_dataset_path(config['_folder'], config['dataset'], 'test')
test_X, test_Y = dataloader.transform(test_path)

config['input_size'] = dataloader.input_lang.n_words
config['output_size'] = dataloader.output_lang.n_words

In [None]:
wandb.init(
    project="Paper-Implemenation",
    entity="project-group-1",
    name=f"Model: {config['model']}, Dataset: {config['dataset']}",
    config=config,
    tags=["test"]
)
dataloader.save(wandb.run.dir)

In [None]:
encoder = EncoderRNN(
    RNN_type=config['rnn-type'],
    input_size=config['input_size'],
    hidden_size=config['hidden_units'],
    hidden_layers=config['layers'],
    dropout=config['dropout']
).to(device)
decoder = DecoderRNN(
    RNN_type=config['rnn-type'],
    input_size=config['hidden_units'],
    hidden_size=config['hidden_units'],
    hidden_layers=config['layers'],
    dropout=config['dropout'],
    attention=config['attention'],
    output_size=config['output_size']
).to(device)

In [None]:
encoder_optimizer = Adam(encoder.parameters(), lr=config['learning_rate'])
decoder_optimizer = Adam(decoder.parameters(), lr=config['learning_rate'])
criterion = NLLLoss()
plot_every = 100
evaluate_every = 4000  # Arbitrarily set to 100.000 / 4.000 = 25 times

In [None]:
plot_loss_total = 0
for iter in range(1, config['trials'] + 1):
    index = np.random.randint(0, len(train_X))
    input_tensor = train_X[index]
    target_tensor = train_Y[index]
    loss = train(
        input_tensor=input_tensor,
        target_tensor=target_tensor,
        encoder=encoder,
        decoder=decoder,
        encoder_optimizer=encoder_optimizer,
        decoder_optimizer=decoder_optimizer,
        criterion=criterion,
        input_max_length=dataloader.input_max_length,
        teacher_forcing_ratio=config['teacher_forcing_ratio'],
        gradient_clip_norm=config['gradient_clip_norm'],
    )
    plot_loss_total += loss
    
    if iter % plot_every == 0:
        plot_loss_avg = plot_loss_total / plot_every
        wandb.log({"loss": plot_loss_avg, "progress": iter / config['trials']})
        plot_loss_total = 0

    if config['_eval_during_training'] and iter % evaluate_every == 0:
        wandb.log({
            "iter": iter,
            "eval_during_training": get_accuracy(
                test_X=test_X,
                test_Y=test_Y,
                encoder=encoder,
                decoder=decoder,
                input_max_length=dataloader.input_max_length
            )
        })


In [None]:
wandb.summary['test_accuracy'] = get_accuracy(
    test_X=test_X,
    test_Y=test_Y,
    encoder=encoder,
    decoder=decoder,
    input_max_length=dataloader.input_max_length
)
wandb.summary['training_accuracy'] = get_accuracy(
    test_X=train_X,
    test_Y=train_Y,
    encoder=encoder,
    decoder=decoder,
    input_max_length=dataloader.input_max_length
)

if config['test_across_length']:
    input_table = wandb.Table(
        data=get_accuracy_across_length(
            filter=test_X,  # Across test_X
            test_X=test_X,
            test_Y=test_Y,
            encoder=encoder,
            decoder=decoder,
            input_max_length=dataloader.input_max_length
        ),
        columns=["length", "accuracy"]
    )
    output_table = wandb.Table(
        data=get_accuracy_across_length(
            filter=test_Y,  # Across test_Y
            test_X=test_X,
            test_Y=test_Y,
            encoder=encoder,
            decoder=decoder,
            input_max_length=dataloader.input_max_length
        ),
        columns=["length",
                 "accuracy"]
    )
    wandb.log({
        "input-zero-shot": wandb.plot.bar(
            table=input_table,
            label="length",
            value="accuracy",
            title="Command length vs. accuracy"
        ),
        "output-zero-shot": wandb.plot.bar(
            table=output_table,
            label="length",
            value="accuracy",
            title="Action sequence length vs. accuracy"
        ),
    })


In [None]:
torch.save(encoder.state_dict(), os.path.join(wandb.run.dir, "encoder.pt"))
torch.save(decoder.state_dict(), os.path.join(wandb.run.dir, "decoder.pt"))
wandb.finish()