In [None]:
! pip install datasets
! pip install livelossplot
! pip install einops

In [None]:
import torch

print("gpu?", torch.cuda.is_available())

for i in range(torch.cuda.device_count()): print(torch.cuda.get_device_properties(i).name)

# Choosing architectures for the task

We fix the backbone's parameter count to 0.1M, by setting hidden dimension of 64 features and varying depth (=layer count).

In [None]:
from generic_classifier import GenericClassifier
from utils import print_num_params
from models.lstm import LSTMEncoder
from models.transformer import TransformerEncoder
from models.s4 import S4Encoder

lstm_cfg = {'hidden_dim': 64, 'num_layers': 3, 'encoder_module': LSTMEncoder}
# print(lstm_cfg)
# print_num_params(LSTMEncoder(**lstm_cfg))

transformer_cfg = {'hidden_dim': 64, 'num_layers': 2, 'encoder_module': TransformerEncoder}
# print(transformer_cfg)
# print_num_params(TransformerEncoder(**transformer_cfg))


s4_cfg = {'hidden_dim': 64, 'num_layers': 6, 'encoder_module': S4Encoder}
# print(s4_cfg)
# print_num_params(S4Encoder(**s4_cfg))

We fix the training configuration across experiments such that each training is limited to one hour (regardless to epoch count), and such that the batch-size is identical (although naturally memory consumption differs between models).

In [None]:
# Hyperparameters:

BATCH_SIZE = 512  # batch-size that fits all models' training to <80GB VRAM (A100)

# for training:
TRAIN_TIME_LIMIT = 4000  # seconds to limit the train time, due to colab limit
NUM_EPOCHS = 1000  # set maximal amount of epoch (mostly time limit will come first)

# for pretraining->fine-tuning setting:
TIMELIMIT_PT, TIMELIMIT_FT = 2800, 1200
NUM_EPOCHS_PT, NUM_EPOCHS_FT = NUM_EPOCHS, NUM_EPOCHS  # set maximal amount of epoch (mostly time limit will come first, so this irrelevant)

# Experiment 1: Training on ListOps (only)

In [None]:
from experiments import setting_1__directly_on_listops


exp1_model2metric = {}

### LSTM

In [None]:
model, test_acc = setting_1__directly_on_listops(
   model_cls=GenericClassifier,
   model_kwargs=lstm_cfg,
   batch_size=BATCH_SIZE,
   num_epochs=NUM_EPOCHS,
   train_time_limit_secs=TRAIN_TIME_LIMIT
)

exp1_model2metric['lstm'] = test_acc

# [OPTIONAL FOR CLEANING VRAM:]
del model
torch.cuda.empty_cache()

## Transformer

In [None]:
model, test_acc = setting_1__directly_on_listops(
    model_cls=GenericClassifier,
    model_kwargs=transformer_cfg,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    train_time_limit_secs=TRAIN_TIME_LIMIT
)

exp1_model2metric['transformer'] = test_acc

## S4

In [None]:
model, test_acc = setting_1__directly_on_listops(
    model_cls=GenericClassifier,
    model_kwargs=s4_cfg,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    train_time_limit_secs=TRAIN_TIME_LIMIT
)

exp1_model2metric['s4'] = test_acc

In [None]:
exp1_model2metric

# Experiment 2: PreTrain (CLM) on wikitext -> fintune on ListOps

In [None]:
from experiments import setting_2__clm_pretrain_text_then_listops

exp2_model2metric = {}

### LSTM

In [None]:
model, test_acc = setting_2__clm_pretrain_text_then_listops(
    model_cls=GenericClassifier,
    model_kwargs=lstm_cfg,
    batch_size=BATCH_SIZE,
    num_epochs_pt=NUM_EPOCHS_PT,
    num_epochs_ft=NUM_EPOCHS_FT,
    timelimit_pt=TIMELIMIT_PT, timelimit_ft=TIMELIMIT_FT,
)

exp2_model2metric['lstm'] = test_acc

## Transformer

In [None]:
model, test_acc = setting_2__clm_pretrain_text_then_listops(
    model_cls=GenericClassifier,
    model_kwargs=transformer_cfg,
    batch_size=BATCH_SIZE,
    num_epochs_pt=NUM_EPOCHS_PT, num_epochs_ft=NUM_EPOCHS_FT,
    timelimit_pt=TIMELIMIT_PT, timelimit_ft=TIMELIMIT_FT,
)

exp2_model2metric['transformer'] = test_acc

## S4

In [None]:
model, test_acc = setting_2__clm_pretrain_text_then_listops(
    model_cls=GenericClassifier,
    model_kwargs=s4_cfg,
    batch_size=BATCH_SIZE,
    num_epochs_pt=NUM_EPOCHS_PT, num_epochs_ft=NUM_EPOCHS_FT,
    timelimit_pt=TIMELIMIT_PT, timelimit_ft=TIMELIMIT_FT,
)

exp2_model2metric['s4'] = test_acc

In [None]:
exp2_model2metric

# Experiment 3: PreTrain (CLM) on ListOps -> fintune on ListOps

In [None]:
from experiments import setting_3__clm_pretrain_listops_then_listops

exp3_model2metric = {}

### LSTM

In [None]:
model, test_acc = setting_3__clm_pretrain_listops_then_listops(
    model_cls=GenericClassifier,
    model_kwargs=lstm_cfg,
    batch_size=BATCH_SIZE,
    num_epochs_pt=NUM_EPOCHS_PT,
    num_epochs_ft=NUM_EPOCHS_FT,
    timelimit_pt=TIMELIMIT_PT, timelimit_ft=TIMELIMIT_FT,
)

exp3_model2metric['lstm'] = test_acc

## Transformer

In [None]:
model, test_acc = setting_3__clm_pretrain_listops_then_listops(
    model_cls=GenericClassifier,
    model_kwargs=transformer_cfg,
    batch_size=BATCH_SIZE,
    num_epochs_pt=NUM_EPOCHS_PT,
    num_epochs_ft=NUM_EPOCHS_FT,
    timelimit_pt=TIMELIMIT_PT, timelimit_ft=TIMELIMIT_FT,
)

exp3_model2metric['transformer'] = test_acc

## S4

In [None]:
model, test_acc = setting_3__clm_pretrain_listops_then_listops(
    model_cls=GenericClassifier,
    model_kwargs=s4_cfg,
    batch_size=BATCH_SIZE,
    num_epochs_pt=NUM_EPOCHS,
    num_epochs_ft=NUM_EPOCHS,
    timelimit_pt=TIMELIMIT_PT, timelimit_ft=TIMELIMIT_FT,
)

exp3_model2metric['s4'] = test_acc

In [None]:
exp3_model2metric

# Summary

**The following table presents the accuracy on the test-set of ListOps per model, per training setting.**

`ListOps_CLS` for training on ListOPs classification, `ListOps_AUT` for training autoregressively (causal language model) on ListOps, `Wikitext_AUT` for traiing in the same manner as the latter but on WikiText dataset. `{X}->{Y}` indicates a pretraining on the first task (`X`) and finetuning on the second (`Y`).



In [None]:
from tabulate import tabulate

rows = []
for exp_name, result_dict in zip(
        ['ListOps_CLS', 'Wikitext_AUT->ListOps_CLS', 'ListOps_AUT->ListOps_CLS'],
        [exp1_model2metric, exp2_model2metric, exp3_model2metric]
    ):
    for model, metric in result_dict.items():
        rows.append([exp_name, model, metric])

print(tabulate(rows, headers=['training', 'model', 'test_acc']))